In [1]:
# GAN in PyTorch - MNIST dataset
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
import os
import matplotlib.pyplot as plt

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Directories
os.makedirs("generated_images", exist_ok=True)

# Hyperparameters
latent_dim = 100
img_shape = (1, 28, 28)
batch_size = 64
epochs = 50
lr = 0.0002
beta1 = 0.5

# Transform
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# Dataset
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST("data", train=True, download=True, transform=transform),
    batch_size=batch_size,
    shuffle=True
)

# Generator
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 784),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        return img.view(z.size(0), *img_shape)

# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        return self.model(img_flat)

# Initialize models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Loss and optimizers
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))

# Training loop
for epoch in range(1, epochs + 1):
    for i, (imgs, _) in enumerate(dataloader):
        real_imgs = imgs.to(device)
        batch_size = real_imgs.size(0)

        # Real and Fake labels
        real = torch.ones(batch_size, 1).to(device)
        fake = torch.zeros(batch_size, 1).to(device)

        # Train Generator
        optimizer_G.zero_grad()
        z = torch.randn(batch_size, latent_dim).to(device)
        gen_imgs = generator(z)
        g_loss = criterion(discriminator(gen_imgs), real)
        g_loss.backward()
        optimizer_G.step()

        # Train Discriminator
        optimizer_D.zero_grad()
        real_loss = criterion(discriminator(real_imgs), real)
        fake_loss = criterion(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        if i % 400 == 0:
            print(f"Epoch [{epoch}/{epochs}] Batch [{i}/{len(dataloader)}]  Loss D: {d_loss.item():.4f}, Loss G: {g_loss.item():.4f}")

    # Save sample images
    if epoch % 10 == 0 or epoch == 1:
        with torch.no_grad():
            z = torch.randn(64, latent_dim).to(device)
            gen_imgs = generator(z)
            grid = make_grid(gen_imgs, nrow=8, normalize=True)
            save_image(grid, f"generated_images/epoch_{epoch}.png")

print("Training complete. Check 'generated_images' folder.")


100%|██████████| 9.91M/9.91M [00:00<00:00, 47.1MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.69MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 14.5MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 4.43MB/s]


Epoch [1/50] Batch [0/938]  Loss D: 0.6959, Loss G: 0.7119
Epoch [1/50] Batch [400/938]  Loss D: 0.4443, Loss G: 1.3666
Epoch [1/50] Batch [800/938]  Loss D: 0.5397, Loss G: 1.3088
Epoch [2/50] Batch [0/938]  Loss D: 0.4881, Loss G: 1.3272
Epoch [2/50] Batch [400/938]  Loss D: 0.5340, Loss G: 0.7166
Epoch [2/50] Batch [800/938]  Loss D: 0.4484, Loss G: 0.8659
Epoch [3/50] Batch [0/938]  Loss D: 0.4789, Loss G: 1.4419
Epoch [3/50] Batch [400/938]  Loss D: 0.5442, Loss G: 1.6351
Epoch [3/50] Batch [800/938]  Loss D: 0.5254, Loss G: 2.0803
Epoch [4/50] Batch [0/938]  Loss D: 0.3616, Loss G: 1.4333
Epoch [4/50] Batch [400/938]  Loss D: 0.5379, Loss G: 0.5648
Epoch [4/50] Batch [800/938]  Loss D: 0.4314, Loss G: 0.9286
Epoch [5/50] Batch [0/938]  Loss D: 0.4761, Loss G: 0.7578
Epoch [5/50] Batch [400/938]  Loss D: 0.4751, Loss G: 1.8608
Epoch [5/50] Batch [800/938]  Loss D: 0.7294, Loss G: 0.3736
Epoch [6/50] Batch [0/938]  Loss D: 0.4324, Loss G: 1.2577
Epoch [6/50] Batch [400/938]  Loss D

In [4]:
import shutil
shutil.make_archive("gan_outputs", 'zip', "generated_images")
from google.colab import files
files.download("gan_outputs.zip")


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [5]:
torch.save(generator.state_dict(), "generator.pth")


In [6]:
files.download("generator.pth")


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>