In [7]:
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split, Dataset
import torchvision.datasets as datasets
from torchvision.datasets import CIFAR10
import torch.optim as optim

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter  

import warnings
warnings.filterwarnings("ignore")

In [3]:
### Hyperparameters

device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 3072
image_dim = 32 * 32 * 3  # 784
batch_size = 32
num_epochs = 50

In [4]:
# dataset module

class CIFARDataset(Dataset):
    '''
    downloads MNIST dataset, performs splitting and transformation, and returns dataloaders
    '''
    def __init__(self, root = './data', download = True, transform = None):
        # download mnist dataset
        self.cifar = CIFAR10(root = root, download = download)

        # default transformation if no specific transformation is provided
        if transform is None:
            self.transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768))
            ])
        else:
            self.transform = transform

        self.indices = list(range(len(self.cifar)))

    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        img, _ = self.cifar[self.indices[idx]]
    
        if self.transform:
            img = self.transform(img)

        return img
    
    def get_dataloader(self, batch_size = batch_size, shuffle = True):
        return DataLoader(self, batch_size = batch_size, shuffle = shuffle)


NameError: name 'Dataset' is not defined

In [10]:
train_dataset = CIFARDataset()
train_dataloader = train_dataset.get_dataloader()


In [14]:
# discriminator class
class Discriminator(nn.Module):
    def __init__(self, in_channels = 3):
        super().__init__()
        # Simple CNN
        self.in_channels = in_channels

        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=4, stride = 2, padding = 1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride = 2, padding = 1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size = 4, stride = 2, padding = 1)

        self.bn1 = nn.BatchNorm2d(128)
        self.bn2 = nn.BatchNorm2d(256)

        self.fc = nn.Linear(256 * 4 * 4, 1)

  
    def forward(self, x):
        x = F.leaky_relu(self.conv1(x), 0.2, inplace = True)
        x = F.leaky_relu(self.bn1(self.conv2(x)), 0.2, inplace = True)
        x = F.leaky_relu(self.bn2(self.conv3(x)), 0.2, inplace = True)
        # Flatten the tensor so it can be fed into the FC layers
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return torch.sigmoid(x)


In [12]:
class Generator(nn.Module):
    '''
    Generates new images from random noise
    in: latent_dim 256*8*8
    out: 32x32
    '''
    def __init__(self, z_dim):
        super().__init__()
        self.gen = nn.Sequential(
        nn.Linear(z_dim, 8*8*64), # [32, 256] -> [32, 4096]
        nn.ReLU(),
        nn.Unflatten(1, (64, 8, 8)), # [32, 4096] -> [32, 64, 8, 8]
        nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding = 1),  # [32, 64, 8, 8] -> [32, 32, 16, 16]
        nn.ReLU(),
        nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding = 1),  # [32, 32, 16, 16] -> [32, 16, 32, 32]
        nn.ReLU(),
        )
        self.conv = nn.Conv2d(16, 3, kernel_size = 3, padding = 1) # [32, 16, 32, 32] -> [32, 3, 32, 32]
    
    def forward(self, x):
        return self.conv(self.gen(x))

In [None]:
disc = Discriminator().to(device)
gen = Generator(z_dim).to(device)

fixed_noise = torch.randn((batch_size, z_dim)).to(device)

opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()

writer_fake = SummaryWriter(f"cifar_logs/fake")
writer_real = SummaryWriter(f"cifar_logs/real")
step = 0

for epoch in range(num_epochs):
    for batch_idx, (real) in enumerate(tqdm(train_dataloader)):
        # real is a batched tensor of the shape (batch_size, *img_dims) *img_dims = 1, 28, 28 for MNIST
        real = real.to(device)
        batch_size = real.shape[0]

        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        ## The Discriminator must always be trained first to provide meaningful gradients for the generator
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        # torch.ones_like function creates a tensor of the same shape as the input tensor filled with ones
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()

        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        ## Here only the fake samples are passed to the discriminator as only those matter in measuring the quality of the generator
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad() # clears old gradients
        lossG.backward() # computes new gradients
        opt_gen.step() # updates weights based on gradients

        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(train_dataloader)} \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise)
                data = real
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "CIFAR Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "CIFAR Real Images", img_grid_real, global_step=step
                )
                step += 1
    # Save the model checkpoints
    torch.save(disc.state_dict(), "cifar_disc.pth")
    torch.save(gen.state_dict(), "cifar_gen.pth")
    

  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [0/50] Batch 0/1563                       Loss D: 0.7483, loss G: 4.2223


100%|██████████| 1563/1563 [01:06<00:00, 23.43it/s]
  0%|          | 1/1563 [00:00<04:34,  5.70it/s]

Epoch [1/50] Batch 0/1563                       Loss D: 0.0091, loss G: 6.1035


100%|██████████| 1563/1563 [01:08<00:00, 22.80it/s]
  0%|          | 2/1563 [00:00<02:12, 11.74it/s]

Epoch [2/50] Batch 0/1563                       Loss D: 0.1460, loss G: 2.3351


100%|██████████| 1563/1563 [01:10<00:00, 22.11it/s]
  0%|          | 1/1563 [00:00<03:28,  7.50it/s]

Epoch [3/50] Batch 0/1563                       Loss D: 0.0462, loss G: 6.0918


100%|██████████| 1563/1563 [01:06<00:00, 23.41it/s]
  0%|          | 1/1563 [00:00<04:05,  6.37it/s]

Epoch [4/50] Batch 0/1563                       Loss D: 0.0795, loss G: 4.1715


100%|██████████| 1563/1563 [01:10<00:00, 22.14it/s]
  0%|          | 2/1563 [00:00<01:51, 14.02it/s]

Epoch [5/50] Batch 0/1563                       Loss D: 0.0522, loss G: 5.4194


100%|██████████| 1563/1563 [00:59<00:00, 26.11it/s]
  0%|          | 1/1563 [00:00<03:35,  7.26it/s]

Epoch [6/50] Batch 0/1563                       Loss D: 0.0406, loss G: 6.2365


100%|██████████| 1563/1563 [00:55<00:00, 28.12it/s]
  0%|          | 2/1563 [00:00<01:46, 14.70it/s]

Epoch [7/50] Batch 0/1563                       Loss D: 0.0823, loss G: 4.3812


100%|██████████| 1563/1563 [00:58<00:00, 26.73it/s]
  0%|          | 1/1563 [00:00<03:53,  6.70it/s]

Epoch [8/50] Batch 0/1563                       Loss D: 0.0407, loss G: 4.9966


100%|██████████| 1563/1563 [00:55<00:00, 27.95it/s]
  0%|          | 2/1563 [00:00<01:56, 13.40it/s]

Epoch [9/50] Batch 0/1563                       Loss D: 0.2573, loss G: 4.3222


100%|██████████| 1563/1563 [00:58<00:00, 26.50it/s]
  0%|          | 1/1563 [00:00<03:25,  7.59it/s]

Epoch [10/50] Batch 0/1563                       Loss D: 0.1098, loss G: 8.2010


100%|██████████| 1563/1563 [00:59<00:00, 26.08it/s]
  0%|          | 1/1563 [00:00<03:53,  6.70it/s]

Epoch [11/50] Batch 0/1563                       Loss D: 0.0572, loss G: 5.9131


100%|██████████| 1563/1563 [01:03<00:00, 24.64it/s]
  0%|          | 1/1563 [00:00<03:42,  7.03it/s]

Epoch [12/50] Batch 0/1563                       Loss D: 0.1050, loss G: 3.2216


100%|██████████| 1563/1563 [00:59<00:00, 26.47it/s]
  0%|          | 1/1563 [00:00<03:43,  7.00it/s]

Epoch [13/50] Batch 0/1563                       Loss D: 0.5196, loss G: 5.7392


100%|██████████| 1563/1563 [00:56<00:00, 27.49it/s]
  0%|          | 3/1563 [00:00<01:51, 13.97it/s]

Epoch [14/50] Batch 0/1563                       Loss D: 0.0471, loss G: 3.5573


100%|██████████| 1563/1563 [00:57<00:00, 27.31it/s]
  0%|          | 1/1563 [00:00<03:46,  6.88it/s]

Epoch [15/50] Batch 0/1563                       Loss D: 0.0315, loss G: 5.4091


100%|██████████| 1563/1563 [00:58<00:00, 26.92it/s]
  0%|          | 3/1563 [00:00<01:44, 14.87it/s]

Epoch [16/50] Batch 0/1563                       Loss D: 0.0597, loss G: 4.8622


100%|██████████| 1563/1563 [01:02<00:00, 24.88it/s]
  0%|          | 1/1563 [00:00<04:01,  6.48it/s]

Epoch [17/50] Batch 0/1563                       Loss D: 0.0094, loss G: 5.0590


100%|██████████| 1563/1563 [01:00<00:00, 25.85it/s]
  0%|          | 1/1563 [00:00<03:47,  6.87it/s]

Epoch [18/50] Batch 0/1563                       Loss D: 0.0075, loss G: 5.7026


100%|██████████| 1563/1563 [01:06<00:00, 23.56it/s]
  0%|          | 4/1563 [00:00<01:21, 19.16it/s]

Epoch [19/50] Batch 0/1563                       Loss D: 0.4244, loss G: 9.2599


100%|██████████| 1563/1563 [00:56<00:00, 27.58it/s]
  0%|          | 1/1563 [00:00<03:34,  7.29it/s]

Epoch [20/50] Batch 0/1563                       Loss D: 0.0546, loss G: 5.1197


100%|██████████| 1563/1563 [00:53<00:00, 29.07it/s]
  0%|          | 2/1563 [00:00<01:51, 13.95it/s]

Epoch [21/50] Batch 0/1563                       Loss D: 0.0143, loss G: 5.6359


100%|██████████| 1563/1563 [00:56<00:00, 27.62it/s]
  0%|          | 1/1563 [00:00<03:07,  8.35it/s]

Epoch [22/50] Batch 0/1563                       Loss D: 0.0099, loss G: 8.7102


100%|██████████| 1563/1563 [00:56<00:00, 27.48it/s]
  0%|          | 1/1563 [00:00<03:33,  7.33it/s]

Epoch [23/50] Batch 0/1563                       Loss D: 0.0873, loss G: 4.7444


100%|██████████| 1563/1563 [01:35<00:00, 16.43it/s]
  0%|          | 1/1563 [00:00<06:21,  4.10it/s]

Epoch [24/50] Batch 0/1563                       Loss D: 0.0279, loss G: 3.5108


100%|██████████| 1563/1563 [01:33<00:00, 16.70it/s]
  0%|          | 1/1563 [00:00<06:38,  3.92it/s]

Epoch [25/50] Batch 0/1563                       Loss D: 0.0077, loss G: 5.2359


100%|██████████| 1563/1563 [01:23<00:00, 18.66it/s]
  0%|          | 1/1563 [00:00<04:02,  6.45it/s]

Epoch [26/50] Batch 0/1563                       Loss D: 0.0055, loss G: 9.0361


100%|██████████| 1563/1563 [01:19<00:00, 19.76it/s]
  0%|          | 2/1563 [00:00<03:51,  6.73it/s]

Epoch [27/50] Batch 0/1563                       Loss D: 0.0526, loss G: 3.6349


100%|██████████| 1563/1563 [01:23<00:00, 18.78it/s]
  0%|          | 1/1563 [00:00<05:15,  4.96it/s]

Epoch [28/50] Batch 0/1563                       Loss D: 0.0275, loss G: 9.7978


100%|██████████| 1563/1563 [01:33<00:00, 16.73it/s]
  0%|          | 1/1563 [00:00<04:14,  6.14it/s]

Epoch [29/50] Batch 0/1563                       Loss D: 0.1289, loss G: 2.9012


100%|██████████| 1563/1563 [01:29<00:00, 17.46it/s]
  0%|          | 1/1563 [00:00<05:37,  4.63it/s]

Epoch [30/50] Batch 0/1563                       Loss D: 0.0112, loss G: 6.0238


100%|██████████| 1563/1563 [01:33<00:00, 16.73it/s]
  0%|          | 1/1563 [00:00<07:12,  3.61it/s]

Epoch [31/50] Batch 0/1563                       Loss D: 0.0149, loss G: 5.4981


100%|██████████| 1563/1563 [01:21<00:00, 19.07it/s]
  0%|          | 1/1563 [00:00<04:39,  5.58it/s]

Epoch [32/50] Batch 0/1563                       Loss D: 0.1396, loss G: 5.7951


100%|██████████| 1563/1563 [01:23<00:00, 18.74it/s]
  0%|          | 1/1563 [00:00<05:05,  5.12it/s]

Epoch [33/50] Batch 0/1563                       Loss D: 0.3403, loss G: 1.4942


100%|██████████| 1563/1563 [01:21<00:00, 19.18it/s]
  0%|          | 1/1563 [00:00<04:35,  5.66it/s]

Epoch [34/50] Batch 0/1563                       Loss D: 0.0315, loss G: 8.1948


100%|██████████| 1563/1563 [01:28<00:00, 17.76it/s]
  0%|          | 1/1563 [00:00<05:04,  5.14it/s]

Epoch [35/50] Batch 0/1563                       Loss D: 0.0090, loss G: 4.8055


100%|██████████| 1563/1563 [01:28<00:00, 17.57it/s]
  0%|          | 1/1563 [00:00<05:47,  4.49it/s]

Epoch [36/50] Batch 0/1563                       Loss D: 0.0552, loss G: 6.3044


100%|██████████| 1563/1563 [01:28<00:00, 17.60it/s]
  0%|          | 1/1563 [00:00<04:37,  5.64it/s]

Epoch [37/50] Batch 0/1563                       Loss D: 0.0059, loss G: 8.3438


100%|██████████| 1563/1563 [01:30<00:00, 17.29it/s]
  0%|          | 1/1563 [00:00<03:52,  6.72it/s]

Epoch [38/50] Batch 0/1563                       Loss D: 0.0010, loss G: 8.3057


100%|██████████| 1563/1563 [01:16<00:00, 20.34it/s]
  0%|          | 1/1563 [00:00<05:20,  4.88it/s]

Epoch [39/50] Batch 0/1563                       Loss D: 0.0174, loss G: 3.5673


100%|██████████| 1563/1563 [01:17<00:00, 20.07it/s]
  0%|          | 1/1563 [00:00<04:42,  5.52it/s]

Epoch [40/50] Batch 0/1563                       Loss D: 0.0327, loss G: 7.1182


100%|██████████| 1563/1563 [01:15<00:00, 20.66it/s]
  0%|          | 1/1563 [00:00<04:45,  5.47it/s]

Epoch [41/50] Batch 0/1563                       Loss D: 0.2730, loss G: 3.2670


100%|██████████| 1563/1563 [01:16<00:00, 20.37it/s]
  0%|          | 1/1563 [00:00<03:37,  7.17it/s]

Epoch [42/50] Batch 0/1563                       Loss D: 0.2869, loss G: 9.3865


100%|██████████| 1563/1563 [01:16<00:00, 20.46it/s]
  0%|          | 1/1563 [00:00<04:26,  5.87it/s]

Epoch [43/50] Batch 0/1563                       Loss D: 0.0607, loss G: 4.0894


100%|██████████| 1563/1563 [01:16<00:00, 20.44it/s]
  0%|          | 1/1563 [00:00<03:31,  7.38it/s]

Epoch [44/50] Batch 0/1563                       Loss D: 0.0228, loss G: 7.9989


100%|██████████| 1563/1563 [01:12<00:00, 21.71it/s]
  0%|          | 1/1563 [00:00<04:29,  5.80it/s]

Epoch [45/50] Batch 0/1563                       Loss D: 0.0114, loss G: 6.8100


100%|██████████| 1563/1563 [01:04<00:00, 24.41it/s]
  0%|          | 3/1563 [00:00<01:42, 15.19it/s]

Epoch [46/50] Batch 0/1563                       Loss D: 0.2038, loss G: 3.2302


100%|██████████| 1563/1563 [01:07<00:00, 23.16it/s]
  0%|          | 1/1563 [00:00<03:55,  6.63it/s]

Epoch [47/50] Batch 0/1563                       Loss D: 0.0034, loss G: 8.3142


100%|██████████| 1563/1563 [01:04<00:00, 24.34it/s]
  0%|          | 1/1563 [00:00<04:12,  6.20it/s]

Epoch [48/50] Batch 0/1563                       Loss D: 0.0073, loss G: 5.5699


100%|██████████| 1563/1563 [01:04<00:00, 24.18it/s]
  0%|          | 1/1563 [00:00<03:38,  7.14it/s]

Epoch [49/50] Batch 0/1563                       Loss D: 0.0032, loss G: 8.5442


100%|██████████| 1563/1563 [01:11<00:00, 21.89it/s]


In [2]:
# Generate Sample images from random noise
import torch

with torch.no_grad():
    noise = torch.randn(64, z_dim).to(device)
    fake_images = gen(noise)
    img_grid = torchvision.utils.make_grid(fake_images, normalize=True)

NameError: name 'z_dim' is not defined

In [None]:
# training loop

In [None]:
# inference block