In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# Hyperparameters
latent_dim = 100
image_size = 28 * 28
batch_size = 64
epochs = 10
learning_rate = 0.0002

# Generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, 1024),
            nn.ReLU(True),
            nn.Linear(1024, image_size),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(image_size, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

# Initialize models
generator = Generator()
discriminator = Discriminator()

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate)

# Loss function
criterion = nn.BCELoss()

# Data loader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

dataset = dsets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Training loop
for epoch in range(epochs):
    for i, (imgs, _) in enumerate(dataloader):
        # Real images
        real_imgs = imgs.view(imgs.size(0), -1)
        real_labels = torch.ones(imgs.size(0), 1)

        # Fake images
        z = torch.randn(imgs.size(0), latent_dim)
        fake_imgs = generator(z)
        fake_labels = torch.zeros(imgs.size(0), 1)

        # Train Discriminator
        optimizer_D.zero_grad()
        real_loss = criterion(discriminator(real_imgs), real_labels)
        fake_loss = criterion(discriminator(fake_imgs.detach()), fake_labels)
        d_loss = real_loss + fake_loss
        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()
        g_loss = criterion(discriminator(fake_imgs), real_labels)
        g_loss.backward()
        optimizer_G.step()

    print(f"Epoch [{epoch+1}/{epochs}]  Loss D: {d_loss.item()}, Loss G: {g_loss.item()}")

Epoch [1/10]  Loss D: 0.0463242307305336, Loss G: 9.219164848327637
Epoch [2/10]  Loss D: 0.10884708166122437, Loss G: 3.6527726650238037
Epoch [3/10]  Loss D: 0.2916882038116455, Loss G: 2.755937099456787
Epoch [4/10]  Loss D: 0.13854214549064636, Loss G: 8.743966102600098
Epoch [5/10]  Loss D: 0.15799231827259064, Loss G: 5.337985992431641
Epoch [6/10]  Loss D: 0.381470650434494, Loss G: 2.742551803588867
Epoch [7/10]  Loss D: 0.2382327914237976, Loss G: 3.993457555770874
Epoch [8/10]  Loss D: 0.08960343152284622, Loss G: 4.29990816116333
Epoch [9/10]  Loss D: 0.08877938985824585, Loss G: 4.640545845031738
Epoch [10/10]  Loss D: 0.5340791344642639, Loss G: 3.2380080223083496


In [2]:
# Save the generator model
torch.save(generator.state_dict(), 'generator.pth')