In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader


In [None]:
class GaussianMixtureDataset(Dataset):
    def __init__(self, n_samples=10000, centers=8, radius=2.0, std=0.05):
        super().__init__()
        samples = []
        for _ in range(n_samples):
            angle = np.random.randint(0, centers) * 2 * np.pi / centers
            center = np.array([radius * np.cos(angle), radius * np.sin(angle)])
            point = center + np.random.randn(2) * std
            samples.append(point)
        self.samples = torch.from_numpy(np.array(samples, dtype=np.float32))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

# Create dataset and loader
dataset = GaussianMixtureDataset(n_samples=10000)
loader = DataLoader(dataset, batch_size=512, shuffle=True)