In [None]:
!pip install ipynb

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import torchvision.utils as utils
import matplotlib.animation as animation
from IPython.display import HTML
import time
from torch.utils.data import Subset
import torchvision.models as models
import torch.nn.functional as F
from scipy import linalg
import pandas as pd
import os
from ipynb.fs.full.FID import calculate_fretchet


torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
learning_rate = 2e-4
batch_size = 128
image_size = 64
channels_img = 3  
noise_dim = 100
max_epochs = 50
disc_features = 64 
gen_features = 64 
beta = 0.5

dataset = datasets.CIFAR10(root="./dataset/CIFAR10", download=True,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize([0.5 for _ in range(channels_img)], [0.5 for _ in range(channels_img)]),
                           ]))
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)



class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            self.block(features_d, features_d * 2, 4, 2, 1),     
            self.block(features_d * 2, features_d * 4, 4, 2, 1),
            self.block(features_d * 4, features_d * 8, 4, 2, 1),  
            nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),
            nn.Sigmoid(),
        )

    def block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False,), 
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.disc(x)
    
   

class Generator(nn.Module):
    def __init__(self, channels_noise, channels_img, features_g):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            self.block(channels_noise, features_g * 16, 4, 1, 0), 
            self.block(features_g * 16, features_g * 8, 4, 2, 1), 
            self.block(features_g * 8, features_g * 4, 4, 2, 1), 
            self.block(features_g * 4, features_g * 2, 4, 2, 1),
            nn.ConvTranspose2d(features_g * 2, channels_img, kernel_size=4, stride=2, padding=1), 
            nn.Tanh(),
        )

    def block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False,),
            nn.BatchNorm2d(out_channels,momentum=0.9),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.net(x)
    

def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

            
gen = Generator(noise_dim, channels_img, gen_features).to(device)
disc = Discriminator(channels_img, disc_features).to(device)
initialize_weights(gen)
initialize_weights(disc)
optimGenerator = optim.Adam(gen.parameters(), lr=learning_rate, betas=(beta, 0.999))
optimDiscriminator = optim.Adam(disc.parameters(), lr=learning_rate, betas=(beta, 0.999))
criterion = nn.BCELoss()
fixed_noise = torch.randn(32, noise_dim, 1, 1).to(device)
step = 0




#training block
gen.train()
disc.train()
GenLoss = []
DiscLoss = []
img_list = []
FID_list = []
best_fid = float('inf')
iters = 0


for epoch in range(max_epochs):
    epoch = epoch+1

    for batch_idx, data in enumerate(dataloader,0):
        real = data[0].to(device)
        noise = torch.randn(batch_size, noise_dim, 1, 1).to(device)
        fake = gen(noise)

        #Discriminator Training
        disc_real = disc(real).reshape(-1)
        loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake.detach()).reshape(-1)
        loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        loss_disc = (loss_disc_real + loss_disc_fake) / 2
        disc.zero_grad()
        loss_disc.backward()
        optimDiscriminator.step()

        #Train Generator Training
        output = disc(fake).reshape(-1)
        loss_gen = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        loss_gen.backward()
        optimGenerator.step()
        GenLoss.append(loss_gen.detach().cpu())
        DiscLoss.append(loss_disc.detach().cpu())
        
        if (iters % 500 == 0) or ((epoch == max_epochs) and (batch_idx == len(dataloader)-1)):
            with torch.no_grad():
                fake = gen(fixed_noise).detach().cpu()
            img_list.append(utils.make_grid(fake, padding=2, normalize=True))
            
        iters += 1
    
    fretchet_dist=calculate_fretchet(real,fake)
    FID_list.append(fretchet_dist)
    
    if fretchet_dist < best_fid:
        best_fid = fretchet_dist
        # torchvision.utils.save_image(fake, f'DCGAN_best_fake_image.png', nrow=5, normalize=True)

    if epoch%5 == 0:
        print( f'\nEpoch [{epoch}/{max_epochs}] Batch {batch_idx+1}/{len(dataloader)} \
                  Loss Discriminator: {loss_disc:.3f}, loss Generator: {loss_gen:.3f} FID:{fretchet_dist:.3f} ')      


plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Training Loss")
plt.plot(GenLoss,label="Gen")
plt.plot(DiscLoss,label="Disc")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig('DCGAN_loss.pdf', format='pdf', bbox_inches='tight')
plt.show()


plt.figure(figsize=(10,5))
plt.title("FID Scores for DCGAN")
plt.plot(FID_list,label="DCGAN")
plt.xlabel("Epochs")
plt.ylabel("FID")
plt.legend()
plt.savefig('DCGAN_FID.pdf', format='pdf', bbox_inches='tight')
plt.show()

np.save('DC_FID', FID_list)

