In [None]:
# Discriminator and Generator implementation from DCGAN paper

import torch
import torch.nn as nn
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from matplotlib.colors import Normalize
from spacy.kb.kb_in_memory import Writer
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms.v2 import Transform

class Discriminator(nn.Module):
    def __init__(self, img_channels, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            # input = N x img_channels x 64 x 64
            nn.Conv2d(img_channels, features_d, kernel_size=4, stride=2, padding=1), # 32 x 32
            nn.LeakyReLU(0.2),
            self._block(features_d, features_d*2, 4, 2, 1), # 16 x 16
            self._block(features_d*2, features_d * 4, 4, 2, 1), # 8 x 8
            self._block(features_d*4, features_d * 8, 4, 2, 1), # 4 x 4
            nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0), # 1 x 1
            nn.Sigmoid(),

        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )
    def forward(self, x):
        return self.disc(x)

class Generator(nn.Module):
    def __init__(self, channels_noise, channels_img, features_g):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            # Input: N x channels_noise x 1 x 1
            self._block(channels_noise, features_g * 16, 4, 1, 0),  # img: 4x4
            self._block(features_g * 16, features_g * 8, 4, 2, 1),  # img: 8x8
            self._block(features_g * 8, features_g * 4, 4, 2, 1),  # img: 16x16
            self._block(features_g * 4, features_g * 2, 4, 2, 1),  # img: 32x32
            nn.ConvTranspose2d(
                features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
            ),
            # Output: N x channels_img x 64 x 64
            nn.Tanh(),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.gen(x)


def initialize_weights(model):
    # Initializes weights according to the DCGAN paper
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)


def test():
    N, in_channels, H, W = 8, 3, 64, 64
    noise_dim = 100
    x = torch.randn((N, in_channels, H, W))
    disc = Discriminator(in_channels, 8)
    assert disc(x).shape == (N, 1, 1, 1), "Discriminator test failed"
    gen = Generator(noise_dim, in_channels, 8)
    z = torch.randn((N, noise_dim, 1, 1))
    assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"
    print("Success, tests passed!")


if __name__ == "__main__":
    test()



Success, tests passed!


In [None]:
# Training of DCGAN network on MNIST dataset with discriminator and generator imported from model.py


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

LEARNING_RATE=2e-4
BATCH_SIZE = 128
NUM_EPOCHS = 5
FEATURES_GEN = 64
FEATURES_DISC = 64
Z_DIM = 100
IMAGE_SIZE = 64
IMG_CHANNELS = 1
from torchvision.transforms import Compose, Resize, ToTensor, Normalize # Import necessary transforms
transform_pipeline = Compose(
        [
             Resize(IMAGE_SIZE),
            ToTensor(),
            Normalize((0.5), (0.5))
        ]
    )

dataset = datasets.MNIST(root='dataset/', train=True, transform=transform_pipeline, download=True)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
gen = Generator(Z_DIM, IMG_CHANNELS, FEATURES_GEN).to(device)
disc = Discriminator(IMG_CHANNELS, FEATURES_DISC).to(device)
initialize_weights(gen)
initialize_weights(disc)

opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)
Writer_fake = SummaryWriter(f"logs/fake")
Writer_real = SummaryWriter(f"logs/real")
step = 0
criterion = nn.BCELoss()
gen.train()
disc.train()

for epoch in range(NUM_EPOCHS):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.to(device)
        ### we need to maximize discriminator loss which is log(D(real)) + log(1 - D(G(z)))

        ### where z is the random noise to be given as input to the generator to generate fake image
        noise = torch.randn(BATCH_SIZE, Z_DIM, 1, 1).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)  # D(real)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))  # log(D(real))

        disc_fake = disc(fake.detach()).view(-1)  # D(G(z))
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))  # log(1 - D(G(z)))

        lossD = (lossD_real + lossD_fake) / 2
        disc.zero_grad()
        lossD.backward()
        opt_disc.step()

        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx % 100 == 0:
            print(f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)}  Loss D: {lossD:.4f}, Loss G: {lossG:.4f}")

            with torch.no_grad():
                fake = gen(fixed_noise)
                img_grid_real = torchvision.utils.make_grid(
                    real[:32], normalize=True
                )
                img_grid_fake = torchvision.utils.make_grid(
                    real[:32], normalize=True
                )

                Writer_real.add_image("real", img_grid_real, global_step=step)
                Writer_fake.add_image("fake", img_grid_real, global_step=step)

            step += 1

Epoch [0/5] Batch 0/469  Loss D: 0.6864, Loss G: 0.7920
Epoch [0/5] Batch 100/469  Loss D: 0.0148, Loss G: 4.1165
Epoch [0/5] Batch 200/469  Loss D: 0.7340, Loss G: 0.7049
Epoch [0/5] Batch 300/469  Loss D: 0.5275, Loss G: 1.0397
Epoch [0/5] Batch 400/469  Loss D: 0.5978, Loss G: 1.1456
Epoch [1/5] Batch 0/469  Loss D: 0.5944, Loss G: 0.8268
Epoch [1/5] Batch 100/469  Loss D: 0.6066, Loss G: 0.7971
Epoch [1/5] Batch 200/469  Loss D: 0.6361, Loss G: 0.8992
Epoch [1/5] Batch 300/469  Loss D: 0.6966, Loss G: 0.7966
Epoch [1/5] Batch 400/469  Loss D: 0.6705, Loss G: 0.7006
Epoch [2/5] Batch 0/469  Loss D: 0.5973, Loss G: 0.8740
Epoch [2/5] Batch 100/469  Loss D: 0.5943, Loss G: 0.8484
Epoch [2/5] Batch 200/469  Loss D: 0.6465, Loss G: 0.8421
Epoch [2/5] Batch 300/469  Loss D: 0.5943, Loss G: 1.0008
Epoch [2/5] Batch 400/469  Loss D: 0.8369, Loss G: 1.0922
Epoch [3/5] Batch 0/469  Loss D: 0.5717, Loss G: 1.3017
Epoch [3/5] Batch 100/469  Loss D: 0.5993, Loss G: 1.7477
Epoch [3/5] Batch 200/