<a href="https://colab.research.google.com/github/Lanaanvar/Deep-Learning/blob/main/GAN_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image, make_grid
import os

# Hyperparameters
latent_dim = 100
img_shape = (1, 28, 28)
batch_size = 64
epochs = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create directory to save generated images
os.makedirs("gan_images", exist_ok=True)

# MNIST dataset loader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # scale between [-1, 1]
])

dataloader = DataLoader(
    datasets.MNIST(root='.', train=True, download=True, transform=transform),
    batch_size=batch_size, shuffle=True
)

# Generator Model
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 784),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        return img.view(z.size(0), *img_shape)

# Discriminator Model
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        flat = img.view(img.size(0), -1)
        return self.model(flat)

# Instantiate models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Loss and optimizers
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Training Loop
for epoch in range(epochs):
    for i, (imgs, _) in enumerate(dataloader):
        batch_size = imgs.size(0)
        real_imgs = imgs.to(device)

        # Real and fake labels
        valid = torch.ones(batch_size, 1).to(device)
        fake = torch.zeros(batch_size, 1).to(device)

        # ---------------------
        #  Train Generator
        # ---------------------
        optimizer_G.zero_grad()

        z = torch.randn(batch_size, latent_dim).to(device)
        gen_imgs = generator(z)

        g_loss = criterion(discriminator(gen_imgs), valid)
        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()

        real_loss = criterion(discriminator(real_imgs), valid)
        fake_loss = criterion(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        if i % 200 == 0:
            print(f"Epoch [{epoch}/{epochs}] Batch [{i}/{len(dataloader)}] \
                  Loss D: {d_loss.item():.4f}, Loss G: {g_loss.item():.4f}")

    # Save generated samples each epoch
    with torch.no_grad():
        z = torch.randn(64, latent_dim).to(device)
        generated = generator(z)
        grid = make_grid(generated, nrow=8, normalize=True)
        save_image(grid, f"gan_images/epoch_{epoch+1}.png")

print("🎉 Training complete. Check 'gan_images/' for generated digits.")


100%|██████████| 9.91M/9.91M [00:00<00:00, 16.0MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 487kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.46MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.94MB/s]


Epoch [0/50] Batch [0/938]                   Loss D: 0.6876, Loss G: 0.7098
Epoch [0/50] Batch [200/938]                   Loss D: 0.5805, Loss G: 0.6060
Epoch [0/50] Batch [400/938]                   Loss D: 0.5649, Loss G: 1.3251
Epoch [0/50] Batch [600/938]                   Loss D: 0.4564, Loss G: 1.2234
Epoch [0/50] Batch [800/938]                   Loss D: 0.4717, Loss G: 0.9241
Epoch [1/50] Batch [0/938]                   Loss D: 0.6011, Loss G: 0.4539
Epoch [1/50] Batch [200/938]                   Loss D: 0.5128, Loss G: 0.5760
Epoch [1/50] Batch [400/938]                   Loss D: 0.6061, Loss G: 0.5210
Epoch [1/50] Batch [600/938]                   Loss D: 0.5292, Loss G: 1.5544
Epoch [1/50] Batch [800/938]                   Loss D: 0.4831, Loss G: 0.9838
Epoch [2/50] Batch [0/938]                   Loss D: 0.4915, Loss G: 1.7939
Epoch [2/50] Batch [200/938]                   Loss D: 0.4932, Loss G: 0.8531
Epoch [2/50] Batch [400/938]                   Loss D: 0.5156, Loss G: