In [12]:
import torch
import matplotlib.pyplot as plt

def generate_gmm_pytorch(means, covariances, probabilities, num_samples=1000, seed=None):
    """
    Generates a Gaussian Mixture Model (GMM) distribution in PyTorch using given means, covariances, and probabilities.

    Parameters:
        means (torch.Tensor): A (N, d) tensor of Gaussian means.
        covariances (torch.Tensor): A (N, d, d) tensor of covariance matrices.
        probabilities (torch.Tensor): A (N,) tensor of probabilities for each Gaussian component.
        num_samples (int): Total number of samples to generate.
        seed (int, optional): Random seed for reproducibility.

    Returns:
        samples (torch.Tensor): Generated samples from the GMM.
        labels (torch.Tensor): Component labels for each sample.
        means (torch.Tensor): Means of the Gaussian components.
        covariances (torch.Tensor): Covariance matrices of the Gaussian components.
        probabilities (torch.Tensor): Probability vector for the components.
    """
    if seed is not None:
        torch.manual_seed(seed)

    N, d = means.shape  # Number of components & dimensions

    # Ensure covariance matrices and probabilities are in the right format
    assert covariances.shape == (N, d, d), "Covariances must have shape (N, d, d)"
    assert probabilities.shape == (N,), "Probabilities must have shape (N,)"

    # Normalize probabilities in case they don't sum to 1
    probabilities = probabilities / probabilities.sum()

    # Generate samples
    samples = []
    labels = []
    for _ in range(num_samples):
        # Choose a Gaussian component based on user-defined probabilities
        component = torch.multinomial(probabilities, 1).item()

        # Sample from the chosen Gaussian component
        mean = means[component]
        cov = covariances[component]

        # Sample using multivariate normal distribution
        sample = torch.distributions.MultivariateNormal(mean, covariance_matrix=cov).sample()
        samples.append(sample)
        labels.append(component)

    samples = torch.stack(samples)  # Convert list to tensor
    labels = torch.tensor(labels)

    # If d=2, plot the distribution
    if d == 2:
        plt.figure(figsize=(8, 6))
        for i in range(N):
            plt.scatter(samples[labels == i, 0].numpy(), samples[labels == i, 1].numpy(),
                        label=f'Component {i}', alpha=0.5)
        plt.scatter(means[:, 0].numpy(), means[:, 1].numpy(), color='red', marker='x', s=100, label='Means')
        plt.legend()
        plt.title(f'Gaussian Mixture Model (PyTorch, d={d}, N={N})')
        plt.show()

    return samples, labels, means, covariances, probabilities


def generate_random_covariances(N, d, seed=None):
    """
    Generates N random positive-definite covariance matrices of size (d, d).
    
    Parameters:
        N (int): Number of Gaussian components.
        d (int): Number of dimensions.
        seed (int, optional): Random seed for reproducibility.

    Returns:
        covariances (torch.Tensor): A (N, d, d) tensor of positive-definite covariance matrices.
    """
    if seed is not None:
        torch.manual_seed(seed)

    covariances = []
    for _ in range(N):
        A = torch.randn(d, d)  # Random matrix
        cov_matrix = torch.mm(A, A.T)  # Ensure positive definiteness
        cov_matrix += torch.eye(d) * 0.1  # Add small value to diagonal for stability
        covariances.append(cov_matrix)

    return torch.stack(covariances)  # Shape (N, d, d)

# Example Usage:
d = 3   # Number of dimensions
N = 4   # Number of Gaussian components
num_samples = 2000  # Number of samples

# Define means and covariances manually
#means = torch.tensor([[2.0, 3.0], [-3.0, -2.0], [5.0, -4.0]])  # Shape (N, d)
#covariances = torch.stack([
#    torch.tensor([[1.0, 0.2], [0.2, 1.5]]),
#    torch.tensor([[1.5, -0.3], [-0.3, 1.0]]),
#    torch.tensor([[0.8, 0.1], [0.1, 0.9]])
#])  # Shape (N, d, d)

means = torch.randn(N, d)
covariances = generate_random_covariances(N, d)

# Define component probabilities (user-defined)
probabilities = torch.distributions.Dirichlet(torch.ones(N)).sample()

# Generate GMM samples
samples, labels, means, covariances, probabilities = generate_gmm_pytorch(means, covariances, probabilities, num_samples)


In [13]:
samples.shape

torch.Size([2000, 3])

In [8]:
covariances.shape

torch.Size([3, 2, 2])