In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal

Mixture Density Network (MDN)
----
***

In [4]:
class MDN(nn.Module):
    def __init__(self, input_dim, output_dim, num_gaussians):
        super(MDN, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, num_gaussians * (2 * output_dim + 1))  # 2 for mean, variance and 1 for mixture weight
        )
        self.num_gaussians = num_gaussians
        self.output_dim = output_dim

    def forward(self, x):
        params = self.fc(x)
        # Split the output into means, log variances, and mixture weights
        means = params[:, :self.num_gaussians * self.output_dim].reshape(-1, self.num_gaussians, self.output_dim)
        log_vars = params[:, self.num_gaussians * self.output_dim:2 * self.num_gaussians * self.output_dim]
        log_vars = log_vars.reshape(-1, self.num_gaussians, self.output_dim)
        mixture_weights = params[:, -self.num_gaussians:]
        # Apply softmax for valid weights
        mixture_weights = torch.softmax(mixture_weights, dim=1)
        return means, log_vars, mixture_weights

In [5]:
def mdn_loss(mixture_weights, means, log_vars, target):
    """
    Is this literally just gaussian kde?
    """
    gaussians = Normal(means, torch.exp(log_vars))  # Create Gaussian distributions
    log_probs = gaussians.log_prob(target.unsqueeze(1).expand_as(means))  # Log likelihood of each Gaussian
    log_probs = log_probs.sum(dim=2)  # Sum over the target dimensions
    weighted_log_probs = log_probs + torch.log(mixture_weights)
    log_sum_exp = torch.logsumexp(weighted_log_probs, dim=1)  # Log-sum-exp trick to prevent underflow
    return -log_sum_exp.mean()

# Data: Simulate parameter-data pairs (theta, x)
# For simplicity, we simulate a 1D parameter theta and 1D observed data x
def generate_data(num_samples=1000):
    theta = torch.randn(num_samples, 1)  # Simulate theta ~ N(0, 1)
    x = theta + 0.1 * torch.randn(num_samples, 1)  # Simulate data: x = theta + noise
    return x, theta

# Train the MDN
def train_mdn(mdn, x_train, theta_train, epochs=1000, lr=0.001):
    optimizer = optim.Adam(mdn.parameters(), lr=lr)
    for epoch in range(epochs):
        optimizer.zero_grad()
        means, log_vars, mixture_weights = mdn(x_train)
        loss = mdn_loss(mixture_weights, means, log_vars, theta_train)
        loss.backward()
        optimizer.step()

        if epoch % 100 == 0:
            print(f"Epoch {epoch}, Loss: {loss.item()}")

def posterior_estimation(mdn, x_observed):
    with torch.no_grad():
        means, log_vars, mixture_weights = mdn(x_observed)
        posterior = Normal(means, torch.exp(log_vars))
        return posterior, mixture_weights

In [6]:
# Example usage
input_dim = 1  # Dimension of observed data x
output_dim = 1  # Dimension of parameter theta
num_gaussians = 5  # Number of Gaussian mixtures

mdn = MDN(input_dim, output_dim, num_gaussians)

# Generate training data
x_train, theta_train = generate_data(num_samples=1000)

# Train the MDN to learn the posterior
train_mdn(mdn, x_train, theta_train, epochs=1000)

# Estimate the posterior for a new observed data point
x_observed = torch.tensor([[0.5]])  # Example observed data
posterior, mixture_weights = posterior_estimation(mdn, x_observed)
print(f"Posterior means: {posterior.mean}")
print(f"Posterior std: {posterior.stddev}")
print(f"Mixture weights: {mixture_weights}")

Epoch 0, Loss: 1.4296963214874268
Epoch 100, Loss: -0.9065782427787781
Epoch 200, Loss: -0.9227863550186157
Epoch 300, Loss: -0.931168794631958
Epoch 400, Loss: -0.943693995475769
Epoch 500, Loss: -0.9545442461967468
Epoch 600, Loss: -0.9415557384490967
Epoch 700, Loss: -0.9531027674674988
Epoch 800, Loss: -0.9710179567337036
Epoch 900, Loss: -0.9669299125671387
Posterior means: tensor([[[0.3181],
         [0.5272],
         [0.6455],
         [0.4514],
         [0.5808]]])
Posterior std: tensor([[[0.0330],
         [0.0542],
         [0.0099],
         [0.0511],
         [0.0806]]])
Mixture weights: tensor([[0.0806, 0.3058, 0.0257, 0.3464, 0.2415]])
