In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MDN(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_gaussians):
        super(MDN, self).__init__()
        
        # Define layers
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        
        # Outputs: means, stds, and mixing coefficients for each Gaussian
        self.pi_layer = nn.Linear(hidden_dim, num_gaussians)  # Mixing coefficients
        self.mu_layer = nn.Linear(hidden_dim, num_gaussians)  # Means
        self.sigma_layer = nn.Linear(hidden_dim, num_gaussians)  # Standard deviations
        
        self.num_gaussians = num_gaussians
    
    def forward(self, x):
        # Forward pass through hidden layers
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        
        # Output layers for GMM parameters
        pi = F.softmax(self.pi_layer(x), dim=1)  # Mixing coefficients (softmax for normalization)
        mu = self.mu_layer(x)  # Means of Gaussians
        sigma = torch.exp(self.sigma_layer(x))  # Std devs (exponential to ensure positivity)
        
        return pi, mu, sigma

# Example usage:
# Define input dimensions and model parameters
input_dim = 1  # Example: single feature input
hidden_dim = 20
num_gaussians = 3  # Number of Gaussian components in the mixture

# Instantiate the MDN model
mdn = MDN(input_dim, hidden_dim, num_gaussians)


In [2]:
def mdn_loss(pi, mu, sigma, target):
    # Reshape for broadcasting
    target = target.view(-1, 1)  # Reshape target to (batch_size, 1)

    # Calculate Gaussian probabilities for each component
    gaussian = torch.exp(-0.5 * ((target - mu) / sigma)**2) / (sigma * torch.sqrt(torch.tensor(2 * torch.pi)))
    
    # Weighted sum of Gaussians using pi as the weights
    weighted_gaussian = pi * gaussian
    probability_density = torch.sum(weighted_gaussian, dim=1)  # Sum over all components
    
    # Negative log-likelihood
    nll = -torch.log(probability_density + 1e-10)  # Add epsilon for numerical stability
    return torch.mean(nll)

# Example usage:
# Assume `inputs` is your input data tensor and `targets` are the ground-truth outputs
inputs = torch.randn(5, input_dim)  # Example input batch
targets = torch.randn(5)  # Example target batch

# Forward pass through MDN to get pi, mu, sigma
pi, mu, sigma = mdn(inputs)

# Calculate MDN loss
loss = mdn_loss(pi, mu, sigma, targets)
print(loss)


tensor(1.5812, grad_fn=<MeanBackward0>)


In [3]:
optimizer = torch.optim.Adam(mdn.parameters(), lr=0.001)

# Example training loop
for epoch in range(1000):
    optimizer.zero_grad()
    
    # Forward pass
    pi, mu, sigma = mdn(inputs)
    
    # Compute loss
    loss = mdn_loss(pi, mu, sigma, targets)
    
    # Backward pass and optimization
    loss.backward()
    optimizer.step()
    
    if epoch % 100 == 0:
        print(f'Epoch {epoch}, Loss: {loss.item()}')


Epoch 0, Loss: 1.5812355279922485
Epoch 100, Loss: 1.1394257545471191
Epoch 200, Loss: -2.604430913925171
Epoch 300, Loss: 2.9536476135253906
Epoch 400, Loss: 1.9405672550201416
Epoch 500, Loss: 1.45607590675354
Epoch 600, Loss: 1.2209335565567017
Epoch 700, Loss: 11.907623291015625
Epoch 800, Loss: 11.708788871765137
Epoch 900, Loss: 11.742825508117676
