# **Task 1**

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
import torchvision.utils as vutils
import os

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(100, 256, 4, 1, 0, bias=False),  # Output: (256, 4, 4)
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),  # Output: (128, 8, 8)
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),  # Output: (64, 16, 16)
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),  # Output: (3, 32, 32), change 1 -> 3 for RGB
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),  # Input: (3, 32, 32) -> Output: (64, 16, 16)
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),  # Output: (128, 8, 8)
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),  # Output: (256, 4, 4)
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),  # Output: (512, 2, 2)
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),  # Flatten to (512*2*2,)
            nn.Linear(512 * 2 * 2, 1)  # Fully connected layer
        )

    def forward(self, input):
        return self.main(input)

def main():
    batch_size = 64
    nz = 100
    lr = 0.0002
    beta1 = 0.5
    num_epochs = 5

    # Create directory for saving images
    os.makedirs("generated_images", exist_ok=True)

    # Initialize FID and IS
    fid = FrechetInceptionDistance(normalize=True).to("cpu")
    inception = InceptionScore().to("cpu")

    # DataLoader for MNIST
    dataloader = DataLoader(
        datasets.MNIST('./data', download=True, transform=transforms.Compose([
            transforms.Resize(32),  # Resize images to 32x32
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])),
        batch_size=batch_size,
        shuffle=True
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    netG = Generator().to(device)
    netD = Discriminator().to(device)

    optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
    criterion = nn.BCEWithLogitsLoss()

    fixed_noise = torch.randn(25, nz, 1, 1, device=device)  # Noise vector for generating 25 images

    for epoch in range(num_epochs):
        for i, data in enumerate(dataloader, 0):
            # Update Discriminator
            netD.zero_grad()
            real_cpu, _ = data
            real_cpu = real_cpu.to(device)

            # Convert grayscale images to 3-channel RGB
            real_cpu = real_cpu.repeat(1, 3, 1, 1)

            batch_size = real_cpu.size(0)

            real_label = torch.full((batch_size, 1), 1.0, dtype=torch.float, device=device)
            fake_label = torch.full((batch_size, 1), 0.0, dtype=torch.float, device=device)

            output = netD(real_cpu)
            errD_real = criterion(output, real_label)
            errD_real.backward()
            D_x = output.mean().item()

            noise = torch.randn(batch_size, nz, 1, 1, device=device)
            fake = netG(noise)
            output = netD(fake.detach())
            errD_fake = criterion(output, fake_label)
            errD_fake.backward()
            optimizerD.step()

            # Update Generator
            netG.zero_grad()
            output = netD(fake)
            errG = criterion(output, real_label)
            errG.backward()
            optimizerG.step()

            # Save 25 images at the end of each epoch
            if i % 100 == 0:
                # Normalize real and fake images to range [0, 1]
                real_cpu_norm = (real_cpu + 1) / 2.0
                fake_norm = (fake + 1) / 2.0

                # Convert to uint8 format for FID and IS (scale from [0, 1] to [0, 255])
                real_cpu_uint8 = (real_cpu_norm * 255).clamp(0, 255).to(torch.uint8)
                fake_uint8 = (fake_norm * 255).clamp(0, 255).to(torch.uint8)

                # Update FID and IS metrics
                fid.update(real_cpu_uint8, real=True)
                fid.update(fake_uint8, real=False)
                inception.update(fake_uint8)

            print(f'Epoch [{epoch+1}/{num_epochs}] | '
                  f'Batch [{i+1}/{len(dataloader)}] | '
                  f'D Loss: {errD_real.item() + errD_fake.item()} | '
                  f'G Loss: {errG.item()} | '
                  f'D(x): {D_x}')

        # Save 25 images at the end of each epoch
        with torch.no_grad():
            fake_images = netG(fixed_noise).detach().cpu()
        vutils.save_image(fake_images, f"generated_images/epoch_{epoch+1}.png", normalize=True, nrow=5)

        # Compute FID and IS after each epoch
        fid_score = fid.compute()
        inception_score = inception.compute()

        print(f'Epoch [{epoch+1}/{num_epochs}] - FID Score: {fid_score.item()}, Inception Score: {inception_score[0].item()}')

        # Reset FID and Inception metrics after each epoch
        fid.reset()
        inception.reset()

if __name__ == '__main__':
    main()


Downloading: "https://github.com/toshas/torch-fidelity/releases/download/v0.2.0/weights-inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/weights-inception-2015-12-05-6726825d.pth
100%|██████████| 91.2M/91.2M [00:00<00:00, 343MB/s]


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 15978422.39it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 488584.34it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 4402059.88it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 2770180.13it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Epoch [1/5] | Batch [1/938] | D Loss: 1.3813143372535706 | G Loss: 0.6884726285934448 | D(x): 0.019158557057380676
Epoch [1/5] | Batch [2/938] | D Loss: 1.2815299034118652 | G Loss: 0.6875267028808594 | D(x): 0.23314513266086578
Epoch [1/5] | Batch [3/938] | D Loss: 1.1368731260299683 | G Loss: 0.6838809847831726 | D(x): 0.6024007797241211
Epoch [1/5] | Batch [4/938] | D Loss: 0.9364640116691589 | G Loss: 0.6716883778572083 | D(x): 1.3589485883712769
Epoch [1/5] | Batch [5/938] | D Loss: 0.7877803407609463 | G Loss: 0.6472336649894714 | D(x): 2.816690444946289
Epoch [1/5] | Batch [6/938] | D Loss: 0.7804920524358749 | G Loss: 0.621026873588562 | D(x): 4.528801441192627
Epoch [1/5] | Batch [7/938] | D Loss: 0.8234195876866579 | G Loss: 0.6062020659446716 | D(x): 5.660635948181152
Epoch [1/5] | Batch [8/938] | D Loss: 0.8476154105737805 | G Loss: 0.613395094871521 | D(x): 5.590493679046631
Epoch [1/5] | Batch [9/9