# **Task 2**

In [None]:
!pip install torchmetrics[image] torch-fidelity


import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
import torchvision.utils as vutils
import os

# Spectral Normalization Generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),  # Output: (512, 4, 4)
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),  # Output: (256, 8, 8)
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),  # Output: (128, 16, 16)
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 3, 4, 2, 1, bias=False),  # Output: (3, 32, 32)
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)

# Spectral Normalization Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(3, 64, 4, 2, 1, bias=False)),  # Input: (3, 32, 32) -> Output: (64, 16, 16)
            nn.LeakyReLU(0.2, inplace=True),
            nn.utils.spectral_norm(nn.Conv2d(64, 128, 4, 2, 1, bias=False)),  # Output: (128, 8, 8)
            nn.LeakyReLU(0.2, inplace=True),
            nn.utils.spectral_norm(nn.Conv2d(128, 256, 4, 2, 1, bias=False)),  # Output: (256, 4, 4)
            nn.LeakyReLU(0.2, inplace=True),
            nn.utils.spectral_norm(nn.Conv2d(256, 512, 4, 2, 1, bias=False)),  # Output: (512, 2, 2)
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),  # Flatten to (512*2*2,)
            nn.Linear(512 * 2 * 2, 1)
        )

    def forward(self, input):
        return self.main(input)

def main():
    # Hyperparameters
    batch_size = 128
    nz = 100  # Size of latent vector
    lr = 0.0002
    beta1 = 0.5
    num_epochs = 5

    # Create directory for saving images
    os.makedirs("sngan_generated_images", exist_ok=True)

    # Initialize FID and IS
    fid = FrechetInceptionDistance(normalize=True).to("cpu")
    inception = InceptionScore().to("cpu")

    # DataLoader for CIFAR-10
    transform = transforms.Compose([
        transforms.Resize(32),  # Resize images to 32x32
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))  # Normalize images between -1 and 1
    ])

    dataloader = DataLoader(
        datasets.CIFAR10('./data', download=True, transform=transform),
        batch_size=batch_size,
        shuffle=True
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    netG = Generator().to(device)
    netD = Discriminator().to(device)

    optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
    criterion = nn.BCEWithLogitsLoss()

    fixed_noise = torch.randn(25, nz, 1, 1, device=device)  # Fixed noise for generating 25 images

    for epoch in range(num_epochs):
        for i, data in enumerate(dataloader, 0):
            # Update Discriminator
            netD.zero_grad()
            real_cpu, _ = data
            real_cpu = real_cpu.to(device)

            batch_size = real_cpu.size(0)
            real_label = torch.full((batch_size, 1), 1.0, dtype=torch.float, device=device)
            fake_label = torch.full((batch_size, 1), 0.0, dtype=torch.float, device=device)

            output = netD(real_cpu)
            errD_real = criterion(output, real_label)
            errD_real.backward()

            noise = torch.randn(batch_size, nz, 1, 1, device=device)
            fake = netG(noise)
            output = netD(fake.detach())
            errD_fake = criterion(output, fake_label)
            errD_fake.backward()
            optimizerD.step()

            # Update Generator
            netG.zero_grad()
            output = netD(fake)
            errG = criterion(output, real_label)
            errG.backward()
            optimizerG.step()

            # Print training stats
            if i % 100 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}] | Batch [{i+1}/{len(dataloader)}] | '
                      f'D Loss: {errD_real.item() + errD_fake.item()} | G Loss: {errG.item()}')

        # Save 25 images at the end of each epoch
        with torch.no_grad():
            fake_images = netG(fixed_noise).detach().cpu()
        vutils.save_image(fake_images, f"sngan_generated_images/epoch_{epoch+1}.png", normalize=True, nrow=5)

        # Compute FID and IS after each epoch
        real_cpu_norm = (real_cpu + 1) / 2.0
        fake_norm = (fake + 1) / 2.0

        real_cpu_uint8 = (real_cpu_norm * 255).clamp(0, 255).to(torch.uint8)
        fake_uint8 = (fake_norm * 255).clamp(0, 255).to(torch.uint8)

        fid.update(real_cpu_uint8, real=True)
        fid.update(fake_uint8, real=False)
        inception.update(fake_uint8)

        fid_score = fid.compute()
        inception_score = inception.compute()

        print(f'Epoch [{epoch+1}/{num_epochs}] - FID Score: {fid_score.item()}, Inception Score: {inception_score[0].item()}')

        # Reset FID and IS metrics after each epoch
        fid.reset()
        inception.reset()

if __name__ == '__main__':
    main()


Collecting torch-fidelity
  Downloading torch_fidelity-0.3.0-py3-none-any.whl.metadata (2.0 kB)
Collecting torchmetrics[image]
  Downloading torchmetrics-1.4.2-py3-none-any.whl.metadata (19 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics[image])
  Downloading lightning_utilities-0.11.7-py3-none-any.whl.metadata (5.2 kB)
Downloading torch_fidelity-0.3.0-py3-none-any.whl (37 kB)
Downloading lightning_utilities-0.11.7-py3-none-any.whl (26 kB)
Downloading torchmetrics-1.4.2-py3-none-any.whl (869 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m869.2/869.2 kB[0m [31m15.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: lightning-utilities, torchmetrics, torch-fidelity
Successfully installed lightning-utilities-0.11.7 torch-fidelity-0.3.0 torchmetrics-1.4.2


Downloading: "https://github.com/toshas/torch-fidelity/releases/download/v0.2.0/weights-inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/weights-inception-2015-12-05-6726825d.pth
100%|██████████| 91.2M/91.2M [00:00<00:00, 409MB/s]


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 83770407.24it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Epoch [1/5] | Batch [1/391] | D Loss: 1.3933030366897583 | G Loss: 0.7236005067825317
Epoch [1/5] | Batch [101/391] | D Loss: 0.8858745396137238 | G Loss: 1.714690923690796
Epoch [1/5] | Batch [201/391] | D Loss: 0.9719308018684387 | G Loss: 1.0440354347229004
Epoch [1/5] | Batch [301/391] | D Loss: 1.2604663968086243 | G Loss: 0.6766989827156067
Epoch [1/5] - FID Score: 364.4429931640625, Inception Score: 2.208373785018921
Epoch [2/5] | Batch [1/391] | D Loss: 1.23647540807724 | G Loss: 0.9379602074623108
Epoch [2/5] | Batch [101/391] | D Loss: 1.2632673978805542 | G Loss: 0.7409140467643738
Epoch [2/5] | Batch [201/391] | D Loss: 1.353330135345459 | G Loss: 0.8858590722084045
Epoch [2/5] | Batch [301/391] | D Loss: 1.3169527053833008 | G Loss: 0.858943521976471
Epoch [2/5] - FID Score: 334.59283447265625, Inception Score: 2.0675463676452637
Epoch [3/5] | Batch [1/391] | D Loss: 1.328259527683258 | G Loss: 0.8953539133071899
Epoch [3/