In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
class ConditionalDiffusionModel(nn.Module):
    def __init__(self, input_dim, condition_dim, hidden_dim=128):
        super().__init__()
        self.fc1 = nn.Linear(input_dim + condition_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, input_dim)

    def forward(self, x, condition):
        x = torch.cat([x, condition], dim=-1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

def add_noise(x, t, num_timesteps):
    alpha = 1 - t / num_timesteps
    noise = torch.randn_like(x)
    noisy_x = torch.sqrt(alpha) * x + torch.sqrt(1 - alpha) * noise
    return noisy_x, noise

In [3]:
def train(model, dataloader, optimizer, device, num_timesteps=100, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for x, condition in dataloader:
            x = x.to(device)
            condition = condition.to(device)
            optimizer.zero_grad()
            t = torch.randint(0, num_timesteps, (x.size(0),), device=device)
            noisy_x, noise = add_noise(x, t, num_timesteps)
            pred_noise = model(noisy_x, condition)
            loss = nn.functional.mse_loss(pred_noise, noise)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss / len(dataloader):.4f}")

In [4]:
def generate_samples(model, device, condition, num_samples=1000, num_timesteps=100):
    model.eval()
    with torch.no_grad():
        x = torch.randn(num_samples, 2, device=device)
        condition = condition.to(device).unsqueeze(0).repeat(num_samples, 1)
        for t in reversed(range(num_timesteps)):
            t_tensor = torch.full((num_samples,), t, device=device, dtype=torch.long)
            pred_noise = model(x, condition)
            alpha = 1 - t / num_timesteps
            x = (x - torch.sqrt(1 - alpha) * pred_noise) / torch.sqrt(alpha)
    return x.cpu().numpy()

In [5]:
means = [
    np.array([4, 4]),
    np.array([-3, 3]),
    np.array([-5, -2]),
    np.array([6, -1])
]

covariances = [
    np.array([[2, 1], [1, 2]]),
    np.array([[1, -0.5], [-0.5, 1]]),
    np.array([[0.5, 0], [0, 0.5]]),
    np.array([[3, 0.8], [0.8, 3]])
]

In [6]:
def generate_gmm_data(n_samples=1000, means=None, covariances=None):
        np.random.seed(42)
        if means is None:
            means = [np.array([0, 0]), np.array([4, 4])]
        if covariances is None:
            covariances = [np.array([[1, 0], [0, 1]]), np.array([[2, 1], [1, 2]])]

        X = []
        labels = []
        for i, (mean, cov) in enumerate(zip(means, covariances)):
            samples = np.random.multivariate_normal(mean, cov, n_samples)
            X.append(samples)
            labels.extend([i] * n_samples)

        X = np.vstack(X)
        labels = np.array(labels)
        return torch.tensor(X, dtype=torch.float32), torch.tensor(labels, dtype=torch.float32)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

X, labels = generate_gmm_data(n_samples=1000)
dataset = torch.utils.data.TensorDataset(X, labels)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True)

model = ConditionalDiffusionModel(input_dim=2, condition_dim=1).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

train(model, dataloader, optimizer, device, num_timesteps=100, num_epochs=10)

condition = torch.tensor([1.0], dtype=torch.float32)
samples = generate_samples(model, device, condition, num_samples=1000)

plt.figure(figsize=(10, 5))
plt.scatter(samples[:, 0], samples[:, 1], s=5, alpha=0.5, label="Generated Samples")
plt.scatter(X[:, 0], X[:, 1], s=5, alpha=0.5, label="Original Data")
plt.legend()
plt.title("Conditional Diffusion Model Generated Samples")
plt.show()

RuntimeError: The size of tensor a (128) must match the size of tensor b (2) at non-singleton dimension 1