In [47]:
# 3D dataset
import numpy as np

rng = np.random.default_rng()

n = 1000
k = 3
d = 3

means = rng.uniform(-5, 5, size=(k, d))
covs = [np.diag(rng.uniform(0.4, 1.2, size=d) ** 2) for _ in range(k)]
weights = rng.dirichlet(np.ones(k))

assignments = rng.choice(k, size=n, p=weights)

X = np.vstack([
    rng.multivariate_normal(means[i], covs[i], size=(assignments == i).sum())
    for i in range(k)
])

np.save("dataset.npy", X)

In [48]:
# SGD

import torch
torch.set_default_dtype(torch.float64)

# Load dataset
X_np = np.load("dataset.npy")
X = torch.from_numpy(X_np).to(torch.float64)

num_samples, num_features = X.shape
num_comps = 3

def log_gauss(x, mean, log_std):
    inv = torch.exp(-2.0 * log_std)
    d = x[:, None, :] - mean[None, :, :]
    q = (d * d * inv[None, :, :]).sum(dim=-1)
    ld = (2.0 * log_std).sum(dim=-1)
    const = x.shape[1] * torch.log(torch.tensor(2.0*np.pi, dtype=x.dtype, device=x.device))
    return -0.5 * (const + ld[None, :] + q)

def responsibilities(x, mean, log_std, logits):
    log_pi = torch.log_softmax(logits, dim=-1)
    log_p = log_gauss(x, mean, log_std) + log_pi[None, :]
    log_Z = torch.logsumexp(log_p, dim=1, keepdim=True)
    return torch.exp(log_p - log_Z)

def nll(x, mean, log_std, logits):
    log_pi = torch.log_softmax(logits, dim=-1)
    log_p = log_gauss(x, mean, log_std) + log_pi[None, :]
    return -(torch.logsumexp(log_p, dim=1)).mean()


def batch_gradients(x, mean, log_std, logits):
    n = x.shape[0]
    r = responsibilities(x, mean, log_std, logits)
    Nk = r.sum(dim=0)
    d = x[:, None, :] - mean[None, :, :]
    inv = torch.exp(-2.0 * log_std)
    var = torch.exp(2.0 * log_std)

    g_mean = -(r[:, :, None] * d * inv[None, :, :]).sum(dim=0)
    term = (d * d) / var[None, :, :]
    g_log_std = -0.5 * (r[:, :, None] * (term - 1.0)).sum(dim=0)
    pi = torch.softmax(logits, dim=-1)
    g_logits = n * pi - Nk
    return g_mean, g_log_std, g_logits

idx = torch.randperm(num_samples)[:num_comps]
mean = X[idx].clone()
log_std = torch.zeros(num_comps, num_features)
logits = torch.zeros(num_comps)

# SGD Loop
epochs = 300
batch_size = 256
lr_mean = 0.02
lr_log_std = 0.01
lr_logits = 0.01

for epoch in range(epochs):
    perm = torch.randperm(num_samples)
    X_s = X[perm]
    for start in range(0, num_samples, batch_size):
        batch = X_s[start:start+batch_size]
        g_mean, g_log_std, g_logits = batch_gradients(batch, mean, log_std, logits)
        with torch.no_grad():
            mean   -= lr_mean   * g_mean
            log_std-= lr_log_std* g_log_std
            logits -= lr_logits * g_logits

print("Final NLL:", nll(X, mean, log_std, logits).item())

Final NLL: 5.706395627774542
