<a href="https://colab.research.google.com/github/MehrdadDastouri/MNIST-GAN-Generative-Adversarial-Network-for-Digit-Generation/blob/main/Owner%20avatar%20MNIST-GAN-Generative-Adversarial-Network-for-Digit-Generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os
from tqdm import tqdm

# Configuration
class Config:
    latent_dim = 100
    img_size = (1, 28, 28)
    batch_size = 64
    epochs = 50
    lr = 0.0002
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    sample_dir = "samples"

class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(256),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(512),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(1024),
            nn.Linear(1024, 28*28),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        return img.view(img.size(0), *Config.img_size)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        flattened = img.view(img.size(0), -1)
        validity = self.model(flattened)
        return validity

class GANTrainer:
    def __init__(self):
        self.config = Config()
        self._init_dirs()

        # Initialize models
        self.generator = Generator(self.config.latent_dim).to(self.config.device)
        self.discriminator = Discriminator().to(self.config.device)

        # Optimizers
        self.optimizer_G = optim.Adam(
            self.generator.parameters(), lr=self.config.lr, betas=(0.5, 0.999))
        self.optimizer_D = optim.Adam(
            self.discriminator.parameters(), lr=self.config.latent_dim, betas=(0.5, 0.999))

        # Loss function
        self.adversarial_loss = nn.BCELoss()

        # Data loader
        self.dataloader = self._get_dataloader()

    def _init_dirs(self):
        os.makedirs(self.config.sample_dir, exist_ok=True)

    def _get_dataloader(self):
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])
        dataset = torchvision.datasets.MNIST(
            root='./data', train=True, download=True, transform=transform)
        return DataLoader(
            dataset, batch_size=self.config.batch_size, shuffle=True)

    def _save_samples(self, epoch):
        z = torch.randn(16, self.config.latent_dim).to(self.config.device)
        gen_imgs = self.generator(z)
        save_image(gen_imgs, os.path.join(
            self.config.sample_dir, f"epoch_{epoch}.png"), nrow=4, normalize=True)

    def train(self):
        for epoch in range(self.config.epochs):
            progress_bar = tqdm(self.dataloader, desc=f'Epoch {epoch+1}/{Config.epochs}')

            for i, (imgs, _) in enumerate(progress_bar):
                valid = torch.ones(imgs.size(0), 1).to(self.config.device)
                fake = torch.zeros(imgs.size(0), 1).to(self.config.device)
                real_imgs = imgs.to(self.config.device)

                # Train Generator
                self.optimizer_G.zero_grad()
                z = torch.randn(imgs.size(0), self.config.latent_dim).to(self.config.device)
                gen_imgs = self.generator(z)
                g_loss = self.adversarial_loss(
                    self.discriminator(gen_imgs), valid)
                g_loss.backward()
                self.optimizer_G.step()

                # Train Discriminator
                self.optimizer_D.zero_grad()
                real_loss = self.adversarial_loss(
                    self.discriminator(real_imgs), valid)
                fake_loss = self.adversarial_loss(
                    self.discriminator(gen_imgs.detach()), fake)
                d_loss = (real_loss + fake_loss) / 2
                d_loss.backward()
                self.optimizer_D.step()

                progress_bar.set_postfix(
                    g_loss=g_loss.item(), d_loss=d_loss.item())

            # Save sample images at each epoch
            self._save_samples(epoch)

            # Save model checkpoints
            if (epoch+1) % 10 == 0:
                torch.save(self.generator.state_dict(),
                         f"generator_epoch_{epoch+1}.pth")
                torch.save(self.discriminator.state_dict(),
                         f"discriminator_epoch_{epoch+1}.pth")

if __name__ == "__main__":
    gan_trainer = GANTrainer()
    gan_trainer.train()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [Errno 111] Connection refused>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:01<00:00, 5.24MB/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [Errno 111] Connection refused>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 154kB/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [Errno 111] Connection refused>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:01<00:00, 1.47MB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [Errno 111] Connection refused>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 2.38MB/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



Epoch 1/50: 100%|██████████| 938/938 [01:00<00:00, 15.43it/s, d_loss=59.4, g_loss=81.2]
Epoch 2/50:  54%|█████▍    | 510/938 [00:33<00:24, 17.67it/s, d_loss=57, g_loss=85.9]