In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt
import math

In [2]:
torch.manual_seed(42)
class dataset(Dataset):
    def __init__(self, size):
        self.means = torch.tensor([[0, 0], [2, 2], [-2, 2]], dtype=torch.float32)
        self.sigma = math.sqrt(2)
        self.len = size
        
    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        z1 = torch.normal(self.means[0], self.sigma)
        z2 = torch.normal(self.means[1], self.sigma)
        z3 = torch.normal(self.means[2], self.sigma)
        return z1
#         return 1/math.sqrt(3)*(z1+z2+z3)

In [3]:
samples = dataset(10000)

In [4]:
samples_loader = DataLoader(samples, batch_size = 50, shuffle = True)
total_mean = torch.tensor([0, 0], dtype = torch.float32)
for batch in samples_loader:
    total_mean[0] += batch[:,0].mean()
    total_mean[1] += batch[:,1].mean()
total_mean = total_mean / len(samples_loader)
total_mean

tensor([-0.0283, -0.0077])

In [None]:
# Generator network
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 2)
        )

    def forward(self, z):
        return self.model(z)

# Discriminator network
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(2, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)


In [None]:
latent_dim = 10
# Build the networks
generator = Generator(latent_dim)
discriminator = Discriminator()

# Loss function and optimizers
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.001)
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.001)

# Training the GAN
epochs = 100
batch_size = 50
losses_d = []
losses_g = []

for epoch in range(epochs):
    batch = 1
    running_d = 0.0
    running_g = 0.0
    for sample in samples_loader:
        # Generate real and fake samples
        real_samples = sample
        fake_samples = generator(torch.randn(batch_size, latent_dim))

        # Labels for real and fake samples
        real_labels = torch.ones((batch_size, 1))
        fake_labels = torch.zeros((batch_size, 1))

        # Train discriminator
        optimizer_D.zero_grad()
        d_loss_real = criterion(discriminator(real_samples), real_labels)
        d_loss_fake = criterion(discriminator(fake_samples.detach()), fake_labels)
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_D.step()

        # Train generator
        optimizer_G.zero_grad()
        g_loss = criterion(discriminator(fake_samples), real_labels)
        g_loss.backward()
        optimizer_G.step()

        # Print progress
        print(f"Epoch {epoch}, batch {batch} D Loss: {d_loss.item()}, G Loss: {g_loss.item()}")
        running_d += d_loss.item()
        running_g += g_loss.item()
        batch += 1
    losses_d.append(running_d / len(samples_loader))
    losses_g.append(running_g / len(samples_loader))



In [None]:
plt.plot(range(epochs), losses_d)
plt.title('D-Loss vs Epochs')
plt.show()

In [None]:
plt.plot(range(epochs), losses_g)
plt.title('G-Loss vs Epochs')
plt.show()

In [None]:
# Generate samples using the trained generator
generated_samples = generator(torch.randn(3000, latent_dim)).detach()
gen_mean = torch.tensor([0, 0], dtype = torch.float32)
gen_mean[0] = generated_samples[:,0].mean()
gen_mean[1] = generated_samples[:,1].mean()

In [None]:
# Plot the generated samples
plt.hist2d(generated_samples[:, 0], generated_samples[:, 1])
plt.colorbar()
plt.plot(total_mean[0], total_mean[1],label='Real Mean', marker='x', color='red')
plt.plot(gen_mean[0], gen_mean[1],label='Generated Mean', marker='x', color='black')
plt.legend()
plt.show()

In [None]:
# Create a 2D histogram
hist, x_edges, y_edges = np.histogram2d(generated_samples[:, 0], generated_samples[:, 1], bins=20)

# Get bin centers for x and y
x_centers = (x_edges[:-1] + x_edges[1:]) / 2
y_centers = (y_edges[:-1] + y_edges[1:]) / 2

# Create a meshgrid from bin centers
x_mesh, y_mesh = np.meshgrid(x_centers, y_centers)

# Create a 3D plot
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')

# Plot the 3D surface
ax.plot_surface(x_mesh, y_mesh, hist.T, cmap='viridis')

# Add labels
ax.set_xlabel('X-axis')
ax.set_ylabel('Y-axis')
ax.set_zlabel('Frequency')

# Show the plot
plt.show()