**E9 333 Advanced Deep Representation Learning (2022)**

---


Assignment 01

---


Submission by: Dhruv Bhardwaj (SR 19280), Bhartendu Kumar (SR 19649), MTech AI

# Variational Auto Encoders (VAE)

## Standard Imports

In [None]:
import torch
torch.manual_seed(42)

from torch.nn import Conv2d, ConvTranspose2d, Linear, Embedding
from torch.nn import MaxPool2d, BatchNorm2d
from torch.nn import LeakyReLU, Tanh, ReLU, Sigmoid
from torch.nn import Module
from torch.nn import MSELoss
from torch import flatten
import numpy
import random
import os, os.path
from torch.utils.data import DataLoader, Dataset
from torchvision.io import read_image

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)
    return 

g = torch.Generator()
g.manual_seed(42)

## Configurations for different tasks

### Vanilla VAE - CelebA Dataset

In [None]:
DATA_PATH   =   './datasets/img_align_celeba_resampled/'
FILE_EXTN   =   '.jpg'
SAVE_PATH   =   './logs/'
BATCH_SIZE  =   100
LEARN_RATE  =   0.001
EPOCHS      =   100

INPUT_H     = 64
INPUT_W     = 64
INPUT_CH    = 3

LATENT_DIM  = 128
NUM_GENERATE_SAMPLES    =   100
DATA_VAR = 0.09

### Vanilla VAE - dsprites Dataset

In [None]:
DATA_PATH   =   './datasets/dsprites/'
FILE_EXTN   =   '.jpg'
SAVE_PATH   =   './logs/'
BATCH_SIZE  =   100
LEARN_RATE  =   0.001
EPOCHS      =   100

INPUT_H     = 64
INPUT_W     = 64
INPUT_CH    = 1

LATENT_DIM  = 128
NUM_GENERATE_SAMPLES    =   100
DATA_VAR = 0.037

### VQ VAE - tinyImage Dataset

In [None]:
DATA_PATH   =   './datasets/tiny_imagenet/'

FILE_EXTN   =   '.JPEG'
SAVE_PATH   =   './logs/'
BATCH_SIZE  =   128
LEARN_RATE  =   0.0002
EPOCHS      =   100

INPUT_H     = 64
INPUT_W     = 64
INPUT_CH    = 3

EMBED_H     = 4
EMBED_W     = 4
EMBED_CH    = 128

NUM_EMBEDDINGS = 128
BETA            = 0.25
NUM_GENERATE_SAMPLES    =   100
DATA_VAR = 0.0765

# FIT GMM ON LATENT SPACE - AFTER PCA
N_GMM_COMPONENTS = 32
GMM_FIT_NUM_SAMPLES = 10000
GMM_GEN_NUM_SAMPLES = 100
GMM_NUM_PCA_FEATURES = 256

# FIT VANILLA VAE ON LATENT SPACE
VAE_FIT_NUM_SAMPLES =   100 #samples in a batch
VAE_FIT_LEARN_RATE  =   0.001
VAE_FIT_EPOCHS      =   50

VAE_FIT_LATENT_DIM  = 32
VAE_GEN_NUM_SAMPLES    =   100
VAE_FIT_DATA_VAR = 0.077

## DataLoader Class and Utilities Definitions

In [None]:
class ImageDataset(Dataset):
    def __init__(self,img_folder, extn='.jpg'):
        self.img_folder=img_folder   
        self.extn = extn
        self.img_list = [name for name in os.listdir(self.img_folder) if name.endswith(self.extn)]
        return
    
    def __len__(self):
        return len(self.img_list)
    
    def __getitem__(self,index):     
        image=read_image(self.img_folder+'/'+self.img_list[index])          
        image=image.float()                
        image=image/255.0        
        return image

def getDataloader(data_path, batch_size, extn):
    print('[INFO] DATA_PATH={}, BATCH_SIZE={}'.format(data_path,batch_size))
    imgDataset = ImageDataset(data_path,extn)    
    print('[INFO] Found data set with {} samples'.format(len(imgDataset)))
    dl = DataLoader(imgDataset, batch_size,
                    shuffle=True,worker_init_fn=seed_worker,generator=g)
    return dl

if __name__ == '__main__':
    print(cfg.DATA_PATH)
    data = getDataloader(cfg.DATA_PATH, cfg.BATCH_SIZE, cfg.FILE_EXTN)
    for image_batch in data:        
        print(image_batch.size())
        print(torch.var(image_batch))

In [None]:
import datasets as DS
def save_image_to_file(epoch,image_tensor, save_path,ref_str=None):
    print(image_tensor.size())
    if ref_str is not None:
        filestr = save_path + ref_str +'SAMPLE_IMGS_E'+ str(epoch)  + '.jpg'
    else:
        filestr = save_path + 'SAMPLE_IMGS_E'+ str(epoch)  + '.jpg'
    save_image(image_tensor,filestr,nrow = 10) 
    return

def return_random_batch_from_dir(img_folder, file_extn, num_samples):
    img_list = [name for name in os.listdir(img_folder) if name.endswith(file_extn)]
    samples=[]
    if(len(img_list)>0):
        
        sample_names = random.sample(img_list, num_samples)
        for name in sample_names:
            img = read_image(img_folder+'/'+name).float()
            img = img/255.0
            samples.append((img.unsqueeze(0)))
        samples = torch.cat(samples)
        print(samples.size())
    return samples

## Vanilla VAE Class Definitions

In [None]:
### FOR USE WITH CELEBA/DSPRITES DATASET ###
class Encoder(Module):
    def __init__(self, in_channels, latent_dim):
        super(Encoder, self).__init__()

        self.latent_dim = latent_dim
        self.in_channels = in_channels
        self.output_dim = self.latent_dim

        self.conv1 = Conv2d(in_channels=self.in_channels, out_channels=32,kernel_size=3, stride=2, padding=1)
        self.bnorm1 = BatchNorm2d(32)
        self.relu1 = LeakyReLU()        
        
        self.conv2 = Conv2d(in_channels=32, out_channels=64,kernel_size=3, stride=2, padding=1)
        self.bnorm2 = BatchNorm2d(64)
        self.relu2 = LeakyReLU()

        self.conv3 = Conv2d(in_channels=64, out_channels=128,kernel_size=3, stride=2, padding=1)
        self.bnorm3 = BatchNorm2d(128)
        self.relu3 = LeakyReLU()

        self.conv4 = Conv2d(in_channels=128, out_channels=256,kernel_size=3, stride=2, padding=1)
        self.bnorm4 = BatchNorm2d(256)
        self.relu4 = LeakyReLU()
        
        #self.fc1 = Linear(in_features=256*4, out_features=256)	        

        self.fcMu = Linear(in_features=256*4*4, out_features=self.output_dim)
        self.fcCov = Linear(in_features=256*4*4, out_features=self.output_dim)

        self.relu5 = ReLU()

        return

    def forward(self, x):

        x = self.bnorm1(self.relu1((self.conv1(x))))
        x = self.bnorm2(self.relu2((self.conv2(x))))
        x = self.bnorm3(self.relu3((self.conv3(x))))
        x = self.bnorm4(self.relu4((self.conv4(x))))
        
        x = flatten(x, start_dim=1)                
        out_mu = self.fcMu(x)
        out_cov = self.relu5(self.fcCov(x))

        #out_cov = self.fcCov(x) # log variance
        out = torch.cat((out_mu, out_cov),dim=1)
        return out

class Decoder(Module):
    def __init__(self, latent_dim, output_dim, output_channels):
        super(Decoder, self).__init__()

        self.input_dim = latent_dim
        self.output_dim = output_dim
        self.output_channels = output_channels

        self.fc1 = Linear(self.input_dim, out_features=256*4*4)	        
        
        self.convT1 = ConvTranspose2d(in_channels=256, out_channels=128,kernel_size=3, stride=2, padding=1,output_padding=1)
        self.bnorm1 = BatchNorm2d(128)
        self.relu1 = LeakyReLU()

        self.convT2 = ConvTranspose2d(in_channels=128, out_channels=64,kernel_size=3, stride=2, padding=1,output_padding=1)
        self.bnorm2 = BatchNorm2d(64)
        self.relu2 = LeakyReLU()

        self.convT3 = ConvTranspose2d(in_channels=64, out_channels=32,kernel_size=3, stride=2, padding=1,output_padding=1)
        self.bnorm3 = BatchNorm2d(32)
        self.relu3 = LeakyReLU()

        self.convT31 = ConvTranspose2d(in_channels=32, out_channels=32,kernel_size=3, stride=2, padding=1,output_padding=1)
        self.bnorm31 = BatchNorm2d(32)
        self.relu31 = LeakyReLU()
        
        self.convT4 = ConvTranspose2d(in_channels=32, out_channels=self.output_channels,kernel_size=3, padding=1)
        #self.bnorm4 = BatchNorm2d(3)
        self.tanh = Tanh()            

        return
    
    def forward(self,x):
        
        x = self.fc1(x)        
        
        x = torch.reshape(x,(-1,256,4,4))
        
        x = self.bnorm1(self.relu1(self.convT1(x)))
        
        x = self.bnorm2(self.relu2(self.convT2(x)))
        
        x = self.bnorm3(self.relu3(self.convT3(x)))
        
        x = self.bnorm31(self.relu31(self.convT31(x)))
        
        out = self.tanh(self.convT4(x))     
         
        return 0.5 + (0.5*out)

# Define VAE class
class Varn_AE(Module):
    def __init__(self, input_dim, latent_dim, in_channels, var_norm):
        super().__init__()

        self.eps = 0.000001
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.in_channels = in_channels
        self.var_norm = var_norm

        self.mse_loss = torch.nn.MSELoss(reduction='sum')        
        self.kld_loss = self.kld_gaussian

        self.mse_loss_value = 0.0
        self.kl_loss_value = 0.0

        self.mu = torch.zeros(self.latent_dim)
        self.dCov = torch.zeros(self.latent_dim)
        
        self.encoder = Encoder(self.in_channels,self.latent_dim)
        self.decoder = Decoder(self.latent_dim, self.input_dim, self.in_channels)
        
        self.encoder.float()
        self.decoder.float()

        print()
        print('-'*59)
        print('ENCODER MODEL')
        print('-'*59)
        print(self.encoder)

        print('-'*59)
        print('DECODER MODEL')
        print('-'*59)
        print(self.decoder)
        print()
        return

    def forward(self,x):        
        
        ## ENCODE ##
        x = self.encoder(x)

        ## SAMPLE ##
        self.mu = x[:,:self.latent_dim]
        self.dCov = x[:,self.latent_dim:] + self.eps
        
        eps = torch.randn_like(self.dCov)                
        z = self.mu + (eps*self.dCov)

        ## DECODE ##
        x_hat=self.decoder(z)

        return x_hat
    
    def kld_gaussian(self):   
        mu_sq = torch.square(self.mu)
        loss_j = (1.0 + torch.log(self.dCov)) - mu_sq - self.dCov
        kld_loss = -1*0.5*torch.sum(loss_j)
        return kld_loss

    def criterion(self, x, x_hat):            
                
        x_hat = torch.reshape(x_hat, (x.size(0),-1))        
        
        x = torch.reshape(x,(x.size(0),-1,))        
        
        self.mse_loss_value = self.mse_loss(x,x_hat)

        #self.mse_loss_value = self.mse_loss_value/(2*self.var_norm)
        self.kl_loss_value = self.kld_loss()        
        
        return self.kl_loss_value + self.mse_loss_value, self.mse_loss_value, self.kl_loss_value        
    
    def sample(self,num_samples=100, curr_device="cpu"):
        self.decoder.eval()
        z = torch.randn((num_samples,self.latent_dim),device=curr_device)
        samples = self.decoder(z)
        #print(samples.size())
        self.decoder.train()
        return samples


In [None]:
### FOR USE WITH VQ-VAE ###
class Encoder(Module):
    def __init__(self, in_channels, latent_dim):
        super(Encoder, self).__init__()

        self.latent_dim = latent_dim
        self.in_channels = in_channels
        self.output_dim = self.latent_dim

        self.conv1 = Conv2d(in_channels=self.in_channels, out_channels=32,kernel_size=3, stride=2, padding=1)
        self.bnorm1 = BatchNorm2d(32)
        self.relu1 = LeakyReLU()        
        
        self.conv2 = Conv2d(in_channels=32, out_channels=64,kernel_size=3, stride=2, padding=1)
        self.bnorm2 = BatchNorm2d(64)
        self.relu2 = LeakyReLU()

        self.conv3 = Conv2d(in_channels=64, out_channels=128,kernel_size=3, stride=2, padding=1)
        self.bnorm3 = BatchNorm2d(128)
        self.relu3 = LeakyReLU()

        self.conv4 = Conv2d(in_channels=128, out_channels=256,kernel_size=3, stride=2, padding=1)
        self.bnorm4 = BatchNorm2d(256)
        self.relu4 = LeakyReLU()
                
        self.fcMu = Linear(in_features=256, out_features=self.output_dim)
        self.fcCov = Linear(in_features=256, out_features=self.output_dim)
        self.relu5 = ReLU()

        return

    def forward(self, x):

        x = self.bnorm1(self.relu1((self.conv1(x))))
        x = self.bnorm2(self.relu2((self.conv2(x))))
        x = self.bnorm3(self.relu3((self.conv3(x))))
        x = self.bnorm4(self.relu4((self.conv4(x))))
        
        x = flatten(x, start_dim=1)                
        out_mu = self.fcMu(x)
        out_cov = self.relu5(self.fcCov(x))

        #out_cov = self.fcCov(x) # log variance
        out = torch.cat((out_mu, out_cov),dim=1)
        return out

class Decoder(Module):
    def __init__(self, latent_dim, output_dim, output_channels):
        super(Decoder, self).__init__()

        self.input_dim = latent_dim
        self.output_dim = output_dim
        self.output_channels = output_channels

        self.fc1 = Linear(self.input_dim, out_features=16*4*4)	        
        
        self.convT1 = ConvTranspose2d(in_channels=16, out_channels=8,kernel_size=3, stride=1, padding=1)
        self.bnorm1 = BatchNorm2d(8)
        self.relu1 = LeakyReLU()

        self.convT2 = ConvTranspose2d(in_channels=8, out_channels=4,kernel_size=3, stride=1, padding=1)
        self.bnorm2 = BatchNorm2d(4)
        self.relu2 = LeakyReLU()

        self.convT3 = ConvTranspose2d(in_channels=4, out_channels=4,kernel_size=3, stride=1, padding=1)
        self.bnorm3 = BatchNorm2d(4)
        self.relu3 = LeakyReLU()

        self.convT31 = ConvTranspose2d(in_channels=4, out_channels=4,kernel_size=3, stride=1, padding=1)
        self.bnorm31 = BatchNorm2d(4)
        self.relu31 = LeakyReLU()
        
        self.convT4 = ConvTranspose2d(in_channels=4, out_channels=self.output_channels,kernel_size=3, padding=1)
        #self.bnorm4 = BatchNorm2d(3)
        self.tanh = Tanh()            

        return
    
    def forward(self,x):
        
        x = self.fc1(x)        
        
        x = torch.reshape(x,(-1,16,4,4))
        
        x = self.bnorm1(self.relu1(self.convT1(x)))
        
        x = self.bnorm2(self.relu2(self.convT2(x)))
        
        x = self.bnorm3(self.relu3(self.convT3(x)))
        
        x = self.bnorm31(self.relu31(self.convT31(x)))
        
        out = self.tanh(self.convT4(x))     
         
        return 0.5 + (0.5*out)

# Define VAE class
class Varn_AE(Module):
    def __init__(self, input_dim, latent_dim, in_channels, var_norm):
        super().__init__()

        self.eps = 0.000001
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.in_channels = in_channels
        self.var_norm = var_norm

        self.mse_loss = torch.nn.MSELoss(reduction='sum')        
        self.kld_loss = self.kld_gaussian

        self.mse_loss_value = 0.0
        self.kl_loss_value = 0.0

        self.mu = torch.zeros(self.latent_dim)
        self.dCov = torch.zeros(self.latent_dim)
        
        self.encoder = Encoder(self.in_channels,self.latent_dim)
        self.decoder = Decoder(self.latent_dim, self.input_dim, self.in_channels)
        
        self.encoder.float()
        self.decoder.float()

        print()
        print('-'*59)
        print('ENCODER MODEL')
        print('-'*59)
        print(self.encoder)

        print('-'*59)
        print('DECODER MODEL')
        print('-'*59)
        print(self.decoder)
        print()
        return

    def forward(self,x):        
        
        ## ENCODE ##
        x = self.encoder(x)

        ## SAMPLE ##
        self.mu = x[:,:self.latent_dim]
        self.dCov = x[:,self.latent_dim:] + self.eps
        
        eps = torch.randn_like(self.dCov)                
        z = self.mu + (eps*self.dCov)

        ## DECODE ##
        x_hat=self.decoder(z)

        return x_hat
    
    def kld_gaussian(self):   
        mu_sq = torch.square(self.mu)
        loss_j = (1.0 + torch.log(self.dCov)) - mu_sq - self.dCov
        kld_loss = -1*0.5*torch.sum(loss_j)
        return kld_loss

    def criterion(self, x, x_hat):            
                
        x_hat = torch.reshape(x_hat, (x.size(0),-1))        
        
        x = torch.reshape(x,(x.size(0),-1,))        
        
        self.mse_loss_value = self.mse_loss(x,x_hat)

        self.kl_loss_value = self.kld_loss()        
        
        return self.kl_loss_value + self.mse_loss_value, self.mse_loss_value, self.kl_loss_value        
    
    def sample(self,num_samples=100, curr_device="cpu"):
        self.decoder.eval()
        z = torch.randn((num_samples,self.latent_dim),device=curr_device)
        samples = self.decoder(z)      
        self.decoder.train()
        return samples


## VQ-VAE Class Definitions

In [None]:
class Encoder(Module):
    def __init__(self, in_channels, out_channels):
        super(Encoder, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.conv1 = Conv2d(in_channels=self.in_channels, out_channels=32,kernel_size=3, stride=2, padding=1)
        self.bnorm1 = BatchNorm2d(32)
        self.relu1 = LeakyReLU()        
        
        self.conv2 = Conv2d(in_channels=32, out_channels=64,kernel_size=3, stride=2, padding=1)
        self.bnorm2 = BatchNorm2d(64)
        self.relu2 = LeakyReLU()

        self.conv3 = Conv2d(in_channels=64, out_channels=128,kernel_size=3, stride=2, padding=1)
        self.bnorm3 = BatchNorm2d(128)
        self.relu3 = LeakyReLU()

        self.conv4 = Conv2d(in_channels=128, out_channels=self.out_channels,kernel_size=3, stride=2, padding=1)
        self.bnorm4 = BatchNorm2d(self.out_channels)
        self.relu4 = LeakyReLU()               

        return

    def forward(self, x):

        x = self.bnorm1(self.relu1((self.conv1(x))))
        
        x = self.bnorm2(self.relu2((self.conv2(x))))
        
        x = self.bnorm3(self.relu3((self.conv3(x))))
        
        out = self.bnorm4(self.relu4((self.conv4(x))))
                
        return out

class Decoder(Module):
    def __init__(self, latent_dim, output_dim, output_channels):
        super(Decoder, self).__init__()

        self.input_dim = latent_dim
        self.output_dim = output_dim
        self.output_channels = output_channels         
        
        self.convT1 = ConvTranspose2d(in_channels=128, out_channels=64,kernel_size=3, stride=2, padding=1,output_padding=1)
        self.bnorm1 = BatchNorm2d(64)
        self.relu1 = LeakyReLU()

        self.convT2 = ConvTranspose2d(in_channels=64, out_channels=32,kernel_size=3, stride=2, padding=1,output_padding=1)
        self.bnorm2 = BatchNorm2d(32)
        self.relu2 = LeakyReLU()

        self.convT3 = ConvTranspose2d(in_channels=32, out_channels=16,kernel_size=3, stride=2, padding=1,output_padding=1)
        self.bnorm3 = BatchNorm2d(16)
        self.relu3 = LeakyReLU()

        self.convT31 = ConvTranspose2d(in_channels=16, out_channels=16,kernel_size=3, stride=2, padding=1,output_padding=1)
        self.bnorm31 = BatchNorm2d(16)
        self.relu31 = LeakyReLU()
        
        self.convT4 = ConvTranspose2d(in_channels=16, out_channels=self.output_channels,kernel_size=3, padding=1)
        #self.bnorm4 = BatchNorm2d(3)
        self.tanh = Tanh()            

        return
    
    def forward(self,x):
                
        
        x = torch.reshape(x,(-1,128,4,4))
        
        x = self.bnorm1(self.relu1(self.convT1(x)))
        
        x = self.bnorm2(self.relu2(self.convT2(x)))
        
        x = self.bnorm3(self.relu3(self.convT3(x)))
        
        x = self.bnorm31(self.relu31(self.convT31(x)))
        
        out = self.tanh(self.convT4(x))     
         
        return 0.5 + (0.5*out)

class VQ(Module):
    def __init__(self, embedding_dim, num_embeddings, beta):
        super().__init__()
        
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings        

        self.mse_loss = MSELoss(reduction='sum')                
                
        self.embedding_loss_value = 0.0
        self.commitment_loss_value = 0.0            
        
        self.dictionary = Embedding(self.num_embeddings, self.embedding_dim)
        self.dictionary.weight.data.uniform_(-1/self.num_embeddings, 1/self.num_embeddings)
        
        self.beta = beta
        return        

    def findNearest(self,x):
        # x is BC X HW => BC x EMBED_DIM
        # xref is NUM_EMBED X EMBED_DIM

        xref = self.dictionary.weight
        x_norm2 = torch.linalg.norm(x, dim=1,keepdim=True)**2 #[BC x 1]
        x_norm2 = x_norm2.expand(-1,self.num_embeddings) #[BC x NUM_EMBED]

        xref_norm2 = torch.linalg.norm(xref, dim=1,keepdim=True)**2 #[NUM_EMBED x 1]
        xref_norm2 = xref_norm2.expand(-1, x.size(0)) #[NUM_EMBED x BC]
        xref_norm2 = torch.transpose(xref_norm2, 0, 1) #[BC X NUM_EMBED]        
        
        dist = x_norm2 + xref_norm2 - 2*torch.matmul(x, torch.transpose(xref, 0 , 1))        

        nearest_idxs =  torch.argmin(dist, dim=1).unsqueeze(1)
        
        return nearest_idxs

    def forward(self,z_e):        
        
        z_e = torch.reshape(z_e,(z_e.size(0)*z_e.size(1),z_e.size(2)*z_e.size(3))) #BC x HW => BCxEMBED_DIM
        ## Find nearest embeddings      
        nearest_idxs = self.findNearest(z_e)
        
        ## DECODE ##
        z_q = self.dictionary.weight[nearest_idxs,:].squeeze()   
        #print(z_e.size(), z_q.size())
        self.embedding_loss_value = self.mse_loss(z_e.detach(),z_q)
        self.commitment_loss_value = self.mse_loss(z_e,z_q.detach())

        self.vq_loss = self.embedding_loss_value + self.beta*self.commitment_loss_value
        
        z_q = z_e + (z_q - z_e).detach()

        return z_q, self.vq_loss    

class VQ_Varn_AE(Module):
    def __init__(self, input_dim, embedding_dim, in_channels, embed_channels, var_norm, num_embeddings, beta):
        super().__init__()
                
        self.input_dim = input_dim        
        self.in_channels = in_channels
        self.embed_channels = embed_channels
        self.var_norm = var_norm
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.beta = beta

        self.mse_loss = BCELoss(reduction='sum')                
        
        self.mse_loss_value = 0.0        
                
        #print(self.dictionary.weight.size())
        self.encoder = Encoder(self.in_channels,self.embed_channels)
        self.vq = VQ(self.embedding_dim, self.num_embeddings, self.beta)
        self.decoder = Decoder(self.embedding_dim, self.input_dim, self.in_channels)
        
        self.encoder.float()
        self.vq.float()
        self.decoder.float()

        print()
        print('-'*59)
        print('ENCODER MODEL')
        print('-'*59)
        print(self.encoder)

        print('-'*59)
        print('VQ LAYER')
        print('-'*59)
        print(self.vq)

        print('-'*59)
        print('DECODER MODEL')
        print('-'*59)
        print(self.decoder)
        print()

        return
    
    def forward(self,x):        
               
        z_e = self.encoder(x)            
        
        ## Find nearest embeddings              
        z_q, vq_loss = self.vq(z_e)      
        
        ## DECODE ##                               
        x_hat = self.decoder(z_q)        
        
        return x_hat, vq_loss
    
    def criterion(self, x, x_hat, vq_loss):            
                
        x_hat = torch.reshape(x_hat, (x.size(0),-1))                
        x = torch.reshape(x,(x.size(0),-1,))        
        
        self.mse_loss_value = self.mse_loss(x_hat,x)        
        
        total_loss = self.mse_loss_value + vq_loss

        return total_loss, self.mse_loss_value
    
    def sample(self, x_in=None, num_samples=100, curr_device="cpu"):
        self.decoder.eval()
        self.vq.eval()
        self.encoder.eval()

        if(x_in == None):
            return None
        else:
            z_e = self.encoder(x_in)              
            z_q, _ = self.vq(z_e)                
            samples = self.decoder(z_q)                    
        
        self.decoder.train()
        self.vq.train()
        self.encoder.train()
        
        return samples



## Training Runner - VAE (Question 1)

In [None]:
import time
import datetime

import config_celeba as cfg

import utils as util

from datasets import getDataloader
from vae_2_CNN import Varn_AE
#from torch.distributions.multivariate_normal import MultivariateNormal
#########################LOGGER#########################
import sys

class Logger(object):
    def __init__(self, filename="Default.log"):
        self.terminal = sys.stdout
        self.log = open(filename, "a")

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)
    
    def flush(self):
        pass

sys.stdout = Logger(cfg.SAVE_PATH + 'expt_1_celeba.txt')
#########################################################3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
#########################################################3
def train_VAE():
    print('-' * 59)
    torch.cuda.empty_cache()
    vae_model = Varn_AE(cfg.INPUT_H*cfg.INPUT_W, cfg.LATENT_DIM, cfg.INPUT_CH, cfg.DATA_VAR)
    vae_model.to(device)
    vae_model.train()
    
    optimizer = torch.optim.Adam(vae_model.parameters(), lr=cfg.LEARN_RATE)

    data = getDataloader(cfg.DATA_PATH,cfg.BATCH_SIZE, cfg.FILE_EXTN)
    N = len(data)*cfg.BATCH_SIZE
    print('-' * 59)
    print("Starting Training of model")
    epoch_times = []

    for epoch in range(1,cfg.EPOCHS+1):        
        start_time = time.process_time()        
        total_loss = 0.0
        mse_loss = 0.0
        kl_loss = 0.0
        counter = 0        
        for image_batch in data:
            if(image_batch.size(1)>cfg.INPUT_CH):
                image_batch = image_batch[:,0,:,:].unsqueeze(1)
            counter += 1            
            optimizer.zero_grad()           
            #print(image_batch.size())
            #print(aaa)
            x_hat = vae_model.forward(image_batch.to(device)) 
                    
            loss,m,k = vae_model.criterion(image_batch.to(device), x_hat)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            mse_loss += m.item()
            kl_loss += k.item()
            if counter%500 == 0:                
                print("Epoch {}......Step: {}/{}....... Loss={:12.5} (MSE Loss = {:12.5}, KL Loss= {:12.5})"
                .format(epoch, counter, len(data), total_loss/N,mse_loss/N,kl_loss/N))
        
        current_time = time.process_time()
        print("Epoch {}/{} Done, Loss = {:12.5} (MSE Loss = {:12.5}, KL Loss= {:12.5})"
        .format(epoch, cfg.EPOCHS, total_loss/N,mse_loss/N,kl_loss/N))

        print("Total Time Elapsed={:12.5} seconds".format(str(current_time-start_time)))

        if((mse_loss/N) < 500):            
            samples = vae_model.sample(cfg.NUM_GENERATE_SAMPLES,device)
            util.save_image_to_file(epoch,samples, cfg.SAVE_PATH)
        if((mse_loss/N) < 500):      
            torch.save(vae_model, cfg.SAVE_PATH + 'MODEL_E' + str(epoch) + datetime.date.today().strftime("%B %d, %Y") + '.pth')
        
        epoch_times.append(current_time-start_time)
        print('-' * 59)

    print("Total Training Time={:12.5} seconds".format(str(sum(epoch_times))))
    return vae_model

def compute_mv_normal(x,mu=None,var=1):
    d = x.shape[1]
    n = x.shape[0]
    if mu is None:
        mu = np.zeros((d,))

    K = -0.5*d*np.log(2*np.pi*var)
    logp = np.zeros((n,))
    for i in range(0,n):
        exponent = np.linalg.norm(x[i,:]-mu)**2
        exponent =-1.0*exponent/(2*var)
        logp[i] = exponent + K

    return logp

def estimate_marginal_prob(vae_model, num_input_samples):
    vae_model.to(device)
    vae_model.eval()
    data = getDataloader(cfg.DATA_PATH,num_input_samples, cfg.FILE_EXTN)

    for image_batch in data:
        #logprob(vae_model, image_batch.to(device))
        z = vae_model.encoder(image_batch.to(device))
        
        n = int(z.size(1)/2)
        mu = z[:,:n].unsqueeze(1).expand(-1,cfg.NUM_LATENT_SAMPLES_PER_IP,-1)
        dCov = z[:,n:].unsqueeze(1).expand(-1,cfg.NUM_LATENT_SAMPLES_PER_IP,-1)
        
        eps = torch.randn_like(dCov)                
        z_1 = mu + (eps*dCov)
        
        eps = torch.randn_like(dCov)                
        z_2 = mu + (eps*dCov)
        
        marginal_p = np.zeros((num_input_samples,))
        for i in range(0,num_input_samples):
            x = image_batch[i,:,:,:].detach().numpy()

            z1_sample = z_1[i,:,:].squeeze().cpu().detach().numpy()
            q_z = KernelDensity(kernel='gaussian', bandwidth=0.2).fit(z1_sample)

            x_hat = vae_model.decoder(z_2[i,:,:].squeeze()).cpu().detach().numpy()         
            
            z2_sample = z_2[i,:,:].squeeze().cpu().detach().numpy()
            
            # all probabilities are in log
            q_prob = q_z.score_samples(z2_sample)
            p_prob = compute_mv_normal(z2_sample,None,0.584)            
            x_hat = np.reshape(x_hat,(-1,cfg.INPUT_H*cfg.INPUT_W*cfg.INPUT_CH))
            x = np.reshape(np.expand_dims(x, axis=0),(-1,cfg.INPUT_H*cfg.INPUT_W*cfg.INPUT_CH))
            
            p_x_z_prob = compute_mv_normal(x_hat,x,0.5)
            
            logp = q_prob - (p_prob + p_x_z_prob)
            
            p_inv = np.mean(np.exp(logp))
            marginal_p[i] = 1/p_inv
            
        break
    return marginal_p
    

## Training Runner - VQ VAE (Question 3)

In [None]:
import config_tiny_imagenet as cfg

import utils as util

from datasets import getDataloader
from vq_vae_2_CNN import VQ_Varn_AE
from vae_3_CNN import Varn_AE

from sklearn.mixture import GaussianMixture
from sklearn.decomposition import PCA
#########################LOGGER#########################
import sys

class Logger(object):
    def __init__(self, filename="Default.log"):
        self.terminal = sys.stdout
        self.log = open(filename, "a")

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)
    
    def flush(self):
        pass

sys.stdout = Logger(cfg.SAVE_PATH + 'expt_3_tinyimagenet.txt')
#########################################################3

device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
print(device)
#########################################################3
def train_VAE():
    print('-' * 59)
    torch.cuda.empty_cache()
    vae_model = VQ_Varn_AE(cfg.INPUT_H*cfg.INPUT_W, cfg.EMBED_H*cfg.EMBED_W, cfg.INPUT_CH, cfg.EMBED_CH ,cfg.DATA_VAR, cfg.NUM_EMBEDDINGS, cfg.BETA)
    #print(aaa)
    vae_model.to(device)    
    vae_model.train()
    
    optimizer = torch.optim.Adam(vae_model.parameters(), lr=cfg.LEARN_RATE)

    data = getDataloader(cfg.DATA_PATH,cfg.BATCH_SIZE, cfg.FILE_EXTN)
    N = len(data)*cfg.BATCH_SIZE
    print('-' * 59)
    print("Starting Training of model")
    epoch_times = []

    for epoch in range(1,cfg.EPOCHS+1):        
        start_time = time.process_time()        
        total_loss = 0.0
        mse_loss = 0.0
        vq = 0.0
        
        counter = 0        
        for image_batch in data:
            #print(image_batch.size())
            #print(torch.var(image_batch))            
            counter += 1            
            optimizer.zero_grad()
            x_hat, vq_loss,_ = vae_model.forward(image_batch.to(device))                     
            loss, m = vae_model.criterion(image_batch.to(device), x_hat, vq_loss)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            mse_loss += m.item()
            vq += vq_loss.item()            
            
            if counter%500 == 0:                
                print("Epoch {}......Step: {}/{}....... Loss={:12.5} (BCE Loss = {:12.5}, EMB + COMT Loss= {:12.5})"
                .format(epoch, counter, len(data), total_loss/N,mse_loss/N,vq/N))
        
        current_time = time.process_time()
        print("Epoch {}/{} Done, Loss = {:12.5} (BCE Loss = {:12.5}, EMB + COMT Loss= {:12.5})"
        .format(epoch, cfg.EPOCHS, total_loss/N,mse_loss/N,vq/N))

        print("Total Time Elapsed={:12.5} seconds".format(str(current_time-start_time)))

        if((mse_loss/N) < 10000):
            samples = util.return_random_batch_from_dir(cfg.DATA_PATH, cfg.FILE_EXTN, cfg.NUM_GENERATE_SAMPLES)
            r_samples = vae_model.sample(samples.to(device), cfg.NUM_GENERATE_SAMPLES,device)
            util.save_image_to_file(epoch,r_samples, cfg.SAVE_PATH,'RECON')
            util.save_image_to_file(epoch, samples, cfg.SAVE_PATH, 'ORIG')
        if((mse_loss/N) < 10000):
            torch.save(vae_model, cfg.SAVE_PATH + 'MODEL_E' + str(epoch) + datetime.date.today().strftime("%B %d, %Y") + '.pth')
        
        epoch_times.append(current_time-start_time)
        print('-' * 59)

    print("Total Training Time={:12.5} seconds".format(str(sum(epoch_times))))
    return vae_model

def fit_vae_training_data(vq_vae_model):
    vq_vae_model.eval()
    data = getDataloader(cfg.DATA_PATH,cfg.VAE_FIT_NUM_SAMPLES, cfg.FILE_EXTN)    

    vae_model = Varn_AE(cfg.EMBED_H*cfg.EMBED_W, cfg.VAE_FIT_LATENT_DIM, cfg.EMBED_CH, cfg.VAE_FIT_DATA_VAR)
    vae_model.to(device)
    vae_model.train()
    optimizer = torch.optim.Adam(vae_model.parameters(), lr=cfg.VAE_FIT_LEARN_RATE)

    N = len(data)*cfg.VAE_FIT_NUM_SAMPLES
    print('-' * 59)
    print("Starting Training of VAE model on latent space")
    epoch_times = []

    for epoch in range(1,cfg.VAE_FIT_EPOCHS+1):        
        start_time = time.process_time()        
        total_loss = 0.0
        mse_loss = 0.0
        kl_loss = 0.0
        counter = 0        
        for image_batch in data:
            
            z_q,_ = vq_vae_model.vq(vq_vae_model.encoder(image_batch.to(device)))
            z_q = torch.reshape(z_q,(-1,cfg.EMBED_CH,cfg.EMBED_H,cfg.EMBED_W))
            
            counter += 1            
            optimizer.zero_grad()
            x_hat = vae_model.forward(z_q) 
            #print(x_hat.size(),z_q.size())
            loss,m,k = vae_model.criterion(z_q.to(device), x_hat)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            mse_loss += m.item()
            kl_loss += k.item()
            if counter%500 == 0:                
                print("Epoch {}......Step: {}/{}....... Loss={:12.5} (MSE Loss = {:12.5}, KL Loss= {:12.5})"
                .format(epoch, counter, len(data), total_loss/N,mse_loss/N,kl_loss/N))
        
        current_time = time.process_time()
        print("Epoch {}/{} Done, Loss = {:12.5} (MSE Loss = {:12.5}, KL Loss= {:12.5})"
        .format(epoch, cfg.EPOCHS, total_loss/N,mse_loss/N,kl_loss/N))

        print("Total Time Elapsed={:12.5} seconds".format(str(current_time-start_time)))

        if((mse_loss/N) < 500):            
            samples = vae_model.sample(cfg.NUM_GENERATE_SAMPLES,device)
            util.save_image_to_file(epoch,samples, cfg.SAVE_PATH)
        if((mse_loss/N) < 500):      
            torch.save(vae_model, cfg.SAVE_PATH + 'MODEL_E' + str(epoch) + datetime.date.today().strftime("%B %d, %Y") + '.pth')
        
        epoch_times.append(current_time-start_time)
        print('-' * 59)

    print("Total Training Time={:12.5} seconds".format(str(sum(epoch_times))))

    return vae_model

def fit_gmm_training_data(vq_vae_model):
    vq_vae_model.eval()

    gmm = GaussianMixture(n_components=cfg.N_GMM_COMPONENTS,covariance_type='diag', 
                            max_iter=500,verbose=3,tol=0.0001, 
                            verbose_interval=10,
                            random_state=42)

    pca = PCA(n_components=cfg.GMM_NUM_PCA_FEATURES, whiten=True)

    data = getDataloader(cfg.DATA_PATH,cfg.GMM_FIT_NUM_SAMPLES, cfg.FILE_EXTN)    

    print('-' * 59)
    print("Generate latent space samples from training data")
    
    for image_batch in data:
        #_,_,idxs = vq_vae_model.vq(vq_vae_model.encoder(image_batch.to(device)))
        z_q = vq_vae_model.encoder(image_batch.to(device))
        #print(z_q.size()) 
        z_q = torch.reshape(z_q,(-1,cfg.EMBED_H*cfg.EMBED_W*cfg.EMBED_CH))
        #print(idxs.size()) 
        #print(aaa)             
        print("Reduce data dimensionality of latent samples:")
        reduced_z_q = pca.fit_transform(z_q.cpu().detach().numpy())  
        #reduced_z_q=idxs
        print("Fit GMM on latent samples of this batch:")
        gmm.fit(reduced_z_q)
        break

    return gmm, pca

def generate_samples_from_gmm_fit(gmm,pca,vq_vae_model, num_samples):
    vq_vae_model.eval()
    s_reduced_data = gmm.sample(num_samples)[0]
    latent_samples = torch.from_numpy(pca.inverse_transform(s_reduced_data))

    latent_samples = torch.reshape(latent_samples,(num_samples*cfg.EMBED_CH,cfg.EMBED_H*cfg.EMBED_W))
    #latent_samples = torch.reshape(latent_samples,(num_samples,cfg.EMBED_CH,cfg.EMBED_H,cfg.EMBED_W))
    print(latent_samples.size())
    
    r_samples = vq_vae_model.decoder(latent_samples.float().to(device))
    #r_samples = vq_vae_model.decoder(vq_vae_model.vq(latent_samples.float().to(device))[0])
    util.save_image_to_file(99,r_samples, cfg.SAVE_PATH,'GMM_FIT_')
    return

def generate_samples_from_vae_fit(vq_vae_model, vae_model, num_samples):
    vq_vae_model.eval()
    vae_model.eval()
    latent_samples = vae_model.sample(num_samples, device)
    print(latent_samples.size())

    latent_samples = torch.reshape(latent_samples,(num_samples*cfg.EMBED_CH,cfg.EMBED_H*cfg.EMBED_W))
    print(latent_samples.size())
    
    r_samples = vq_vae_model.decoder(latent_samples.float().to(device))
    util.save_image_to_file(99,r_samples, cfg.SAVE_PATH,'VAE_FIT_')
    return

if __name__ == '__main__':
    vq_vae_model = train_VAE()
    #vq_vae_model = torch.load('./logs/MODEL_E100September 25, 2022.pth')
    
    #gmm, pca = fit_gmm_training_data(vq_vae_model)
    #generate_samples_from_gmm_fit(gmm,pca,vq_vae_model,cfg.GMM_GEN_NUM_SAMPLES)

    #vae_model =fit_vae_training_data(vq_vae_model)
    #generate_samples_from_vae_fit(vq_vae_model, vae_model, cfg.VAE_GEN_NUM_SAMPLES)

# Generative Adversarial Networks (GAN)


## DCGAN


In [None]:
# we will implement DCGAN on SVHN dataset. The dataset is available at http://ufldl.stanford.edu/housenumbers/. The dataset is a collection of 32x32 color images of house numbers. The dataset is split into 3 parts: train, test and extra. We will use the train and test set for training and testing respectively. The extra set is not used in this tutorial.
model_name = "dcgan1"
#check model saving path is there
# Imports

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.transforms as transformations
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import PIL.Image as Image
import torchvision.models as models
#numpy
import numpy as np
import tqdm
from ignite.metrics.gan import FID
# Hyperparameters and constants: for dataset and training

# Hyperparameters etc.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 2e-4  # could also use two lrs, one for gen and one for disc
BATCH_SIZE = 3200
IMAGE_SIZE = 64
CHANNELS_IMG = 3
NOISE_DIM = 100
NUM_EPOCHS = 5
FEATURES_DISC = 64
FEATURES_GEN = 64
# preparing Dataset
### we will use SVHN dataset for this example
### we will combine the train, test and extra datasets to make a bigger dataset

transforms = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
        ),
    ]
)
#get the dataset

dataset = dataset = datasets.ImageFolder(root=os.path.join(os.getcwd(), "bitemojis_dataset"), transform=transform)
loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
)

#print the total number of images in the dataset
print(len(dataset))
# print the shape of the images
print(dataset[0][0].shape)
# print the label of the image
print(dataset[0][1])
# Model
## generator


class Generator(nn.Module):
    def __init__(self, channels_noise, channels_img, features_g):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
           
            self.generator_block_architecture(channels_noise, features_g * 16, 4, 1, 0),  # img: 4x4
            self.generator_block_architecture(features_g * 16, features_g * 8, 4, 2, 1),  # img: 8x8
            self.generator_block_architecture(features_g * 8, features_g * 4, 4, 2, 1),  # img: 16x16
            self.generator_block_architecture(features_g * 4, features_g * 2, 4, 2, 1),  # img: 32x32
            nn.ConvTranspose2d(
                features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
            ),
     
            nn.Tanh(),
        )

    def generator_block_architecture(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.net(x)

## discriminator
class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(

            nn.Conv2d(
                channels_img, features_d, kernel_size=4, stride=2, padding=1
            ),
            nn.LeakyReLU(0.2),
            # _block(in_channels, out_channels, kernel_size, stride, padding)
            self.dicriminator_block_architecture(features_d, features_d * 2, 4, 2, 1),
            self.dicriminator_block_architecture(features_d * 2, features_d * 4, 4, 2, 1),
            self.dicriminator_block_architecture(features_d * 4, features_d * 8, 4, 2, 1),
       
            nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),
            nn.Sigmoid(),
        )

    def dicriminator_block_architecture(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.disc(x)

--------
# Initialization : Model , Loss , Optimizer, data loader
### model
def initialize_weights(model):
    # Initializes weights according to the DCGAN paper
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)
### data loader
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
initialize_weights(gen)
initialize_weights(disc)
### optimizer, loss
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
criterion = nn.BCELoss()
### tensorboard

fixed_noise = torch.randn(100, NOISE_DIM, 1, 1).to(device)
#plot loss of generator and critic
writer_loss = SummaryWriter(f"runs/"+model_name+"/loss")
writer_real = SummaryWriter(f"logs/"+model_name+"/real")
writer_fake = SummaryWriter(f"logs/"+model_name+"/fake")
----
### initialize FID wrapper
fid_score = FID()
#interpolate function to resize images to 299,299,3  which is the input size of inception network
def interpolate(batch):
    arr = []
    for img in batch:
        pil_img = transformations.ToPILImage()(img)
        resized_img = pil_img.resize((299,299), Image.BILINEAR)
        arr.append(transformations.ToTensor()(resized_img))
    return torch.stack(arr)
# Training
gen.train()
disc.train()
step = 0

for epoch in range(NUM_EPOCHS):
    
    
    #we will track the total loss of the generator and critic for each epoch over the entire dataset
    #initialize the total loss of the generator and critic for each epoch to 0
    total_loss_gen = 0
    total_loss_disc = 0
    #move these to device
    
    
    # Target labels not needed! <3 unsupervised
    for batch_idx, (real, _) in enumerate(dataloader):
        batch_step = 0
        real = real.to(device)
        noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device)
        fake = gen(noise)

        ### Train Discriminator
        disc_real = disc(real).reshape(-1)
        loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake.detach()).reshape(-1)
        loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        loss_disc = (loss_disc_real + loss_disc_fake) / 2
        disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()

        ### Train Generator:
        output = disc(fake).reshape(-1)
        loss_gen = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()
        
        with torch.no_grad():
            total_loss_gen += loss_gen.item()
            total_loss_disc += loss_disc.item()
            
        

        # Print losses occasionally and print to tensorboard
        if batch_idx % 10 == 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \
                  Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
            )
            
            with torch.no_grad():
                               
                 #BATCH LOSS---
        
                #write loss to tensorboard
                writer_loss.add_scalar("Generator loss Batch", loss_gen, global_step=batch_step)
                writer_loss.add_scalar("Discriminator loss Batch", loss_disc, global_step=batch_step)         
                
                #FID--
                #calculate FID score of this batch
                #update the fid_score with real and fake images
                real_images_fid = interpolate(real)
                fake_images_fid = interpolate(fake)
                fid_score.update((real_images_fid, fake_images_fid))
                computed_fid_score = fid_score.compute()
                print("FID score: ", computed_fid_score)
                writer_loss.add_scalar("FID Score DCGAN", computed_fid_score, global_step=batch_step)
                #reset the fid score
                fid_score.reset()
                ##FID--
                
                batch_step += 1 
            
            
            
            
            

    with torch.no_grad():
        fake = gen(fixed_noise)
        # take out upto 100 examples
        img_grid_real = torchvision.utils.make_grid(
            real[:100], normalize=True
        )
        img_grid_fake = torchvision.utils.make_grid(
            fake[:100], normalize=True
        )

        writer_real.add_image("Real", img_grid_real, global_step=step)
        writer_fake.add_image("Fake", img_grid_fake, global_step=step)
        
        
        #AVERAGE LOSS---

        #get average loss of generator and critic for each epoch
        avg_loss_gen = total_loss_gen / len(loader)
        avg_loss_disc= total_loss_disc / len(loader)
        #write loss to tensorboard
        writer_loss.add_scalar("Generator loss Epoch", avg_loss_gen, global_step=batch_step)
        writer_loss.add_scalar("Discriminator loss Epoch", avg_loss_disc, global_step=batch_step)
        
        #AVERAGE LOSS----
        
        #we will plot the gradient of disc output with respect to the input image
        #get the gradient of the disc output with respect to the input image
        gradient = torch.autograd.grad(
        inputs=real,
        outputs=disc_real,
        grad_outputs=torch.ones_like(disc_real),
        create_graph=True,
        retain_graph=True,
        )[0]
        #flatten the gradient
        gradient = gradient.view(gradient.shape[0], -1)
        #get the norm of the gradient
        gradient_norm = gradient.norm(2, dim=1)
        #write gradient norm to tensorboard
        writer_loss.add_scalar("Gradient norm Disc Real DCGAN", gradient_norm.mean(), global_step=step)
        
        #----------------
        #we will plot the gradient of critic output with respect to the input image
        #get the gradient of the critic output with respect to the input image
        gradient = torch.autograd.grad(
        inputs=fake,
        outputs=disc_fake,
        grad_outputs=torch.ones_like(disc_fake),
        create_graph=True,
        retain_graph=True,
        )[0]
        #flatten the gradient
        gradient = gradient.view(gradient.shape[0], -1)
        #get the norm of the gradient
        gradient_norm = gradient.norm(2, dim=1)
        #write gradient norm to tensorboard
        writer_loss.add_scalar("Gradient norm Disc Fake DCGAN", gradient_norm.mean(), global_step=step)
        
        #----------------
        #we will plot the gradient of genrator output with respect to the input 
        #we will plot the gradient of genrator output with respect to the input 
        #get the gradient of the generator output with respect to the input noise
        gradient = torch.autograd.grad(
        inputs=noise,
        outputs=output,
        grad_outputs=torch.ones_like(output),
        create_graph=True,
        retain_graph=True,
        )[0]
        #flatten the gradient
        gradient = gradient.view(gradient.shape[0], -1)
        #get the norm of the gradient
        gradient_norm = gradient.norm(2, dim=1)
        #write gradient norm to tensorboard
        writer_loss.add_scalar("Gradient norm Generator DCGAN", gradient_norm.mean(), global_step=step)
        
        #----------------
        
        #get the gradient of the disc for the parameters weights of first layer
        #we will write the norm of the gardient of weights of the first layer of the disc
        for name, param in critic.named_parameters():
            if name == "disc.0.weight":
                writer_loss.add_scalar("Disc Gradient w.r.t 1st layer DCGAN", param.grad.norm(), global_step=step)
            #also plot the norm of gradient of 2nd layer
            elif name == "disc.2.0.weight":
                writer_loss.add_scalar("Disc Gradient w.r.t 2nd layer DCGAN", param.grad.norm(), global_step=step)
                
                
       

    step += 1
    
    #save the trained model
        #check if trained_model folder exists
    if not os.path.exists("trained_models"):
        os.mkdir("trained_models")
    
    #now trained_model folder exists
    if not os.path.exists("trained_models/"+model_name):
        os.mkdir("trained_models/"+model_name)
    #check if "trained_models/"+model_name     
    torch.save(gen.state_dict(), "trained_models/"+model_name+"/gen.pth")
    torch.save(critic.state_dict(), "trained_models/"+model_name+"/disc.pth")
    

#save the tensorboard
writer_real.close()
writer_fake.close()
writer_loss.close()

## Conditional WGAN

In [None]:
# we will train a conditional wgan on svhn dataset
# we will use the gradient penalty to stabilize the training


model_name = "c_wgan7"
#check model saving path is there
# imports


import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.transforms as transformations
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import PIL.Image as Image
import torchvision.models as models
#numpy
import numpy as np
import tqdm
from ignite.metrics.gan import FID
# hyperparameters for the dataset and the MOdel
# Hyperparameters etc.
device = "cuda:1" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 1e-4
BATCH_SIZE = 64000
IMAGE_SIZE = 64
CHANNELS_IMG = 3
NUM_CLASSES = 10
GEN_EMBEDDING = 100
Z_DIM = 100
NUM_EPOCHS = 50
FEATURES_CRITIC = 16
FEATURES_GEN = 16
CRITIC_ITERATIONS = 5
LAMBDA_GP = 10

# prepare the dataset
### we will use SVHN dataset for this example
### we will combine the train, test and extra datasets to make a bigger dataset
transforms = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]),
    ]
)
#get the dataset
#train part of svhn
train_dataset = datasets.SVHN(root="dataset_svhm/", split='train', transform=transforms, download=True)
#test part of svhn
test_dataset = datasets.SVHN(root="dataset_svhm/", split='test', transform=transforms, download=True)
#extra part of svhn
extra_dataset = datasets.SVHN(root="dataset_svhm/", split='extra', transform=transforms, download=True)
#concatenate the train, test and extra dataset
dataset = torch.utils.data.ConcatDataset([train_dataset, test_dataset, extra_dataset])
loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
)

#print the total number of images in the dataset
print(len(dataset))
# print the shape of the images
print(dataset[0][0].shape)
# print the label of the image
print(dataset[0][1])
# Model

## generator
class Generator(nn.Module):
    def __init__(self, channels_noise, channels_img, features_g , num_classes, img_size, embed_size):
        super(Generator, self).__init__()
        self.img_size = img_size
        self.net = nn.Sequential(
            # Input: N x channels_noise x 1 x 1
            self.block_architecture_generator(channels_noise+embed_size, features_g * 16, 4, 1, 0),  # img: 4x4
            self.block_architecture_generator(features_g * 16, features_g * 8, 4, 2, 1),  # img: 8x8
            self.block_architecture_generator(features_g * 8, features_g * 4, 4, 2, 1),  # img: 16x16
            self.block_architecture_generator(features_g * 4, features_g * 2, 4, 2, 1),  # img: 32x32
            nn.ConvTranspose2d(
                features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
            ),
            # Output: N x channels_img x 64 x 64
            nn.Tanh(),
        )
        
        self.embed = nn.Embedding(num_classes, embed_size)

    def block_architecture_generator(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels, out_channels, kernel_size, stride, padding, bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x, labels):
        embedding = self.embed(labels).unsqueeze(2).unsqueeze(3)
        x = torch.cat([x, embedding], 1)
        return self.net(x)
class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d, num_classes, img_size):
        super(Discriminator, self).__init__()
        self.img_size = img_size
        self.disc = nn.Sequential(
            
            nn.Conv2d(channels_img+1, features_d, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            # _block(in_channels, out_channels, kernel_size, stride, padding)
            self.block_architecture_critic(features_d, features_d * 2, 4, 2, 1),
            self.block_architecture_critic(features_d * 2, features_d * 4, 4, 2, 1),
            self.block_architecture_critic(features_d * 4, features_d * 8, 4, 2, 1),
            # After all _block img output is 4x4 (Conv2d below makes into 1x1)
            nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),
        )
        
        #embedding for conditionning
        self.embed = nn.Embedding(num_classes, img_size * img_size)

    def block_architecture_critic(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels, out_channels, kernel_size, stride, padding, bias=False,
            ),
            nn.InstanceNorm2d(out_channels, affine=True),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x, labels):
        embedding = self.embed(labels).view(labels.shape[0], 1, self.img_size, self.img_size)
        x = torch.cat([x, embedding], dim=1)
        return self.disc(x)
--------------------------------
# model initialization

def initialize_weights(model):
    # Initializes weights according to the DCGAN paper
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)


# initialize gen and critic
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN, NUM_CLASSES, IMAGE_SIZE, GEN_EMBEDDING).to(device)
critic = Discriminator(CHANNELS_IMG, FEATURES_CRITIC, NUM_CLASSES, IMAGE_SIZE).to(device)
initialize_weights(gen)
initialize_weights(critic)
### initialize optimizer
# initializate optimizer
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
### inintialize tensorboard
# for tensorboard plotting
fixed_noise = torch.randn(100, Z_DIM, 1, 1).to(device)
#fixed labels for tensorboard plotting
# we will have fixed labels of integers between 0 and 9 for the 10 classes
fixed_labels = torch.randint(0, 10, (100,)).to(device)

#plot loss of generator and critic
writer_loss = SummaryWriter(f"runs/"+model_name+"/loss")
writer_real = SummaryWriter(f"logs/"+model_name+"/real")
writer_fake = SummaryWriter(f"logs/"+model_name+"/fake")
#
--------------------------------
### initialize FID wrapper
fid_score = FID()
#interpolate function to resize images to 299,299,3  which is the input size of inception network
def interpolate(batch):
    arr = []
    for img in batch:
        pil_img = transformations.ToPILImage()(img)
        resized_img = pil_img.resize((299,299), Image.BILINEAR)
        arr.append(transformations.ToTensor()(resized_img))
    return torch.stack(arr)
#TEST FID 

# y_pred, y = torch.rand(100, 3, 64, 64), torch.rand(100, 3, 64, 64)
# y_pred = interpolate(y_pred)
# y = interpolate(y)
# # m = FID()
# fid_score.update((y_pred, y))

# print('ignite batch FID', fid_score.compute())  # 8.98434072559458e-05
# #reset the fid score
# fid_score.reset()
## start training


gen.train()
critic.train()
#### gradient penalty function for WGAN-GP
def gradient_penalty(critic,labels, real, fake, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * alpha + fake * (1 - alpha)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images, labels)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty
step = 0

for epoch in range(NUM_EPOCHS):
    #we will track the total loss of the generator and critic for each epoch over the entire dataset
    #initialize the total loss of the generator and critic for each epoch to 0
    total_loss_gen = 0
    total_loss_critic = 0
    
    #move these to device

    #have no gradient for these losse
    # total_loss_gen = torch.no_grad()
    # total_loss_critic = torch.no_grad()
    
    # Target labels 
    for batch_idx, (real, labels) in enumerate(loader):
        #send labels to device
        labels = labels.to(device)
        batch_step = 0
        real = real.to(device)
        cur_batch_size = real.shape[0]
        
        

        # Train Critic: 
        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
            if len(noise) != len(labels):
                noise = noise[:len(labels)]
            fake = gen(noise, labels)
            critic_real = critic(real, labels).reshape(-1)
            critic_fake = critic(fake, labels).reshape(-1)
            gp = gradient_penalty(critic, real, fake, device=device)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp
            )
            
            
            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()
            
        #trained critic
        

        # Train Generator: 
        gen_fake = critic(fake, labels).reshape(-1)
        loss_gen = -torch.mean(gen_fake)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()
        
        
        
        
        
        #update the total loss of the generator and critic for each batch in the epoch
        #add just the value no gradients
        #have no gradient for these losse
        #add just the value no gradientsfrom the loss_gen tensor
      
        with torch.no_grad():
            total_loss_gen += loss_gen.item()
            total_loss_critic += loss_critic.item()
        
        # Print losses occasionally and print to tensorboard in a batch
        if batch_idx % 10 == 0 and batch_idx > 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)} \
                  Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
            )
            
            with torch.no_grad():
                
                #BATCH LOSS-----
                
                #write gen_loss and critic_loss to tensorboard
                writer_loss.add_scalar("Generator loss Batch", loss_gen, global_step=step)
                writer_loss.add_scalar("Critic loss Batch", loss_critic, global_step=step)
                
                #BATCH LOSS-------

                
                #FID--
                #calculate FID score of this batch
                #update the fid_score with real and fake images
                real_images_fid = interpolate(real)
                fake_images_fid = interpolate(fake)
                fid_score.update((real_images_fid, fake_images_fid))
                computed_fid_score = fid_score.compute()
                print("FID score: ", computed_fid_score)
                writer_loss.add_scalar("FID Score WGAN", computed_fid_score, global_step=batch_step)
                #reset the fid score
                fid_score.reset()
                ##FID--
                
                batch_step += 1
            
            
        
        

        # Print losses occasionally and print to tensorboard per epoch
        # if batch_idx % 100 == 0 and batch_idx > 0:
        
            ## PRINT for few epochs %10
            # print(
            #     f"Epoch [{epoch}/{NUM_EPOCHS}]  \
            #         Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
            # )
            
    #calculate the Fréchet Inception Distance (FID) to evaluate the performance of the generator
    

    with torch.no_grad():
        fake = gen(fixed_noise, fixed_labels)
        # take out (up to) 32 examples
        img_grid_real = torchvision.utils.make_grid(real[:100], normalize=True)
        img_grid_fake = torchvision.utils.make_grid(fake[:100], normalize=True)

        writer_real.add_image("Real", img_grid_real, global_step=step)
        writer_fake.add_image("Fake", img_grid_fake, global_step=step)
        
        
                        
        #AVERAGE LOSS---

        #get average loss of generator and critic for each epoch
        avg_loss_gen = total_loss_gen / len(loader)
        avg_loss_critic = total_loss_critic / len(loader)
        #write loss to tensorboard
        writer_loss.add_scalar("Generator loss Epoch", avg_loss_gen, global_step=batch_step)
        writer_loss.add_scalar("Critic loss Epoch", avg_loss_critic, global_step=batch_step)
        
        #AVERAGE LOSS----
        
        
        
        #we will plot the gradient of critic output with respect to the input image
        #get the gradient of the critic output with respect to the input image
        gradient = torch.autograd.grad(
        inputs=real,
        outputs=critic_real,
        grad_outputs=torch.ones_like(critic_real),
        create_graph=True,
        retain_graph=True,
        )[0]
        #flatten the gradient
        gradient = gradient.view(gradient.shape[0], -1)
        #get the norm of the gradient
        gradient_norm = gradient.norm(2, dim=1)
        #write gradient norm to tensorboard
        writer_loss.add_scalar("Gradient norm Critic Real", gradient_norm.mean(), global_step=step)
        
        #----------------
        #we will plot the gradient of critic output with respect to the input image
        #get the gradient of the critic output with respect to the input image
        gradient = torch.autograd.grad(
        inputs=fake,
        outputs=critic_fake,
        grad_outputs=torch.ones_like(critic_fake),
        create_graph=True,
        retain_graph=True,
        )[0]
        #flatten the gradient
        gradient = gradient.view(gradient.shape[0], -1)
        #get the norm of the gradient
        gradient_norm = gradient.norm(2, dim=1)
        #write gradient norm to tensorboard
        writer_loss.add_scalar("Gradient norm Critic Fake", gradient_norm.mean(), global_step=step)
        
        #----------------
        #we will plot the gradient of genrator output with respect to the input 
        #we will plot the gradient of genrator output with respect to the input 
        #get the gradient of the generator output with respect to the input noise
        gradient = torch.autograd.grad(
        inputs=noise,
        outputs=gen_fake,
        grad_outputs=torch.ones_like(gen_fake),
        create_graph=True,
        retain_graph=True,
        )[0]
        #flatten the gradient
        gradient = gradient.view(gradient.shape[0], -1)
        #get the norm of the gradient
        gradient_norm = gradient.norm(2, dim=1)
        #write gradient norm to tensorboard
        writer_loss.add_scalar("Gradient norm Generator", gradient_norm.mean(), global_step=step)
        
        #----------------

        
        
        # we will plot the gradient penalty
        writer_loss.add_scalar("GP", gp, global_step=step)
        #we will analyze for vanishing gradient
        writer_loss.add_scalar("Critic Real", critic_real.mean(), global_step=step)
        writer_loss.add_scalar("Critic Fake", critic_fake.mean(), global_step=step)
        
        #get the gradient of the critic for the parameters weights of first layer
        #we will write the norm of the gardient of weights of the first layer of the critic
        for name, param in critic.named_parameters():
            if name == "disc.0.weight":
                writer_loss.add_scalar("Critic Gradient w.r.t 1st layer WGAN", param.grad.norm(), global_step=step)
            #also plot the norm of gradient of 2nd layer
            elif name == "disc.2.0.weight":
                writer_loss.add_scalar("Critic Gradient w.r.t 2nd layer WGAN", param.grad.norm(), global_step=step)
        
       
       
        
        
        

    step += 1
        
        #save the trained model
        #check if trained_model folder exists
    if not os.path.exists("trained_models"):
        os.mkdir("trained_models")
    
    #now trained_model folder exists
    if not os.path.exists("trained_models/"+model_name):
        os.mkdir("trained_models/"+model_name)
    #check if "trained_models/"+model_name     
    torch.save(gen.state_dict(), "trained_models/"+model_name+"/gen.pth")
    torch.save(critic.state_dict(), "trained_models/"+model_name+"/critic.pth")
    
    
    

#save the tensorboard
writer_real.close()
writer_fake.close()
writer_loss.close()
#calculate the Fréchet Inception Distance (FID) to evaluate the performance of the generator

#(https://github.com/mseitzer/pytorch-fid)
#imports for FID

#print generator model architecture
print(gen)


## Conditional DCGAN

In [None]:
# we will train a conditional wgan on svhn dataset
# we will use the gradient penalty to stabilize the training


model_name = "c_wgan7"
#check model saving path is there
# imports


import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.transforms as transformations
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import PIL.Image as Image
import torchvision.models as models
#numpy
import numpy as np
import tqdm
from ignite.metrics.gan import FID
# hyperparameters for the dataset and the MOdel
# Hyperparameters etc.
device = "cuda:1" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 1e-4
BATCH_SIZE = 64000
IMAGE_SIZE = 64
CHANNELS_IMG = 3
NUM_CLASSES = 10
GEN_EMBEDDING = 100
Z_DIM = 100
NUM_EPOCHS = 50
FEATURES_CRITIC = 16
FEATURES_GEN = 16
CRITIC_ITERATIONS = 5
LAMBDA_GP = 10

# prepare the dataset
### we will use SVHN dataset for this example
### we will combine the train, test and extra datasets to make a bigger dataset
transforms = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]),
    ]
)
#get the dataset
#train part of svhn
train_dataset = datasets.SVHN(root="dataset_svhm/", split='train', transform=transforms, download=True)
#test part of svhn
test_dataset = datasets.SVHN(root="dataset_svhm/", split='test', transform=transforms, download=True)
#extra part of svhn
extra_dataset = datasets.SVHN(root="dataset_svhm/", split='extra', transform=transforms, download=True)
#concatenate the train, test and extra dataset
dataset = torch.utils.data.ConcatDataset([train_dataset, test_dataset, extra_dataset])
loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
)

#print the total number of images in the dataset
print(len(dataset))
# print the shape of the images
print(dataset[0][0].shape)
# print the label of the image
print(dataset[0][1])
# Model

## generator
class Generator(nn.Module):
    def __init__(self, channels_noise, channels_img, features_g , num_classes, img_size, embed_size):
        super(Generator, self).__init__()
        self.img_size = img_size
        self.net = nn.Sequential(
            # Input: N x channels_noise x 1 x 1
            self.block_architecture_generator(channels_noise+embed_size, features_g * 16, 4, 1, 0),  # img: 4x4
            self.block_architecture_generator(features_g * 16, features_g * 8, 4, 2, 1),  # img: 8x8
            self.block_architecture_generator(features_g * 8, features_g * 4, 4, 2, 1),  # img: 16x16
            self.block_architecture_generator(features_g * 4, features_g * 2, 4, 2, 1),  # img: 32x32
            nn.ConvTranspose2d(
                features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
            ),
            # Output: N x channels_img x 64 x 64
            nn.Tanh(),
        )
        
        self.embed = nn.Embedding(num_classes, embed_size)

    def block_architecture_generator(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels, out_channels, kernel_size, stride, padding, bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x, labels):
        embedding = self.embed(labels).unsqueeze(2).unsqueeze(3)
        x = torch.cat([x, embedding], 1)
        return self.net(x)
class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d, num_classes, img_size):
        super(Discriminator, self).__init__()
        self.img_size = img_size
        self.disc = nn.Sequential(
            
            nn.Conv2d(channels_img+1, features_d, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            # _block(in_channels, out_channels, kernel_size, stride, padding)
            self.block_architecture_critic(features_d, features_d * 2, 4, 2, 1),
            self.block_architecture_critic(features_d * 2, features_d * 4, 4, 2, 1),
            self.block_architecture_critic(features_d * 4, features_d * 8, 4, 2, 1),
            # After all _block img output is 4x4 (Conv2d below makes into 1x1)
            nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),
        )
        
        #embedding for conditionning
        self.embed = nn.Embedding(num_classes, img_size * img_size)

    def block_architecture_critic(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels, out_channels, kernel_size, stride, padding, bias=False,
            ),
            nn.InstanceNorm2d(out_channels, affine=True),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x, labels):
        embedding = self.embed(labels).view(labels.shape[0], 1, self.img_size, self.img_size)
        x = torch.cat([x, embedding], dim=1)
        return self.disc(x)
--------------------------------
# model initialization

def initialize_weights(model):
    # Initializes weights according to the DCGAN paper
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)


# initialize gen and critic
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN, NUM_CLASSES, IMAGE_SIZE, GEN_EMBEDDING).to(device)
critic = Discriminator(CHANNELS_IMG, FEATURES_CRITIC, NUM_CLASSES, IMAGE_SIZE).to(device)
initialize_weights(gen)
initialize_weights(critic)
### initialize optimizer
# initializate optimizer
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
### inintialize tensorboard
# for tensorboard plotting
fixed_noise = torch.randn(100, Z_DIM, 1, 1).to(device)
#fixed labels for tensorboard plotting
# we will have fixed labels of integers between 0 and 9 for the 10 classes
fixed_labels = torch.randint(0, 10, (100,)).to(device)

#plot loss of generator and critic
writer_loss = SummaryWriter(f"runs/"+model_name+"/loss")
writer_real = SummaryWriter(f"logs/"+model_name+"/real")
writer_fake = SummaryWriter(f"logs/"+model_name+"/fake")
#
--------------------------------
### initialize FID wrapper
fid_score = FID()
#interpolate function to resize images to 299,299,3  which is the input size of inception network
def interpolate(batch):
    arr = []
    for img in batch:
        pil_img = transformations.ToPILImage()(img)
        resized_img = pil_img.resize((299,299), Image.BILINEAR)
        arr.append(transformations.ToTensor()(resized_img))
    return torch.stack(arr)
#TEST FID 

# y_pred, y = torch.rand(100, 3, 64, 64), torch.rand(100, 3, 64, 64)
# y_pred = interpolate(y_pred)
# y = interpolate(y)
# # m = FID()
# fid_score.update((y_pred, y))

# print('ignite batch FID', fid_score.compute())  # 8.98434072559458e-05
# #reset the fid score
# fid_score.reset()
## start training


gen.train()
critic.train()
#### gradient penalty function for WGAN-GP
def gradient_penalty(critic,labels, real, fake, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * alpha + fake * (1 - alpha)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images, labels)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty
step = 0

for epoch in range(NUM_EPOCHS):
    #we will track the total loss of the generator and critic for each epoch over the entire dataset
    #initialize the total loss of the generator and critic for each epoch to 0
    total_loss_gen = 0
    total_loss_critic = 0
    
    #move these to device

    #have no gradient for these losse
    # total_loss_gen = torch.no_grad()
    # total_loss_critic = torch.no_grad()
    
    # Target labels 
    for batch_idx, (real, labels) in enumerate(loader):
        #send labels to device
        labels = labels.to(device)
        batch_step = 0
        real = real.to(device)
        cur_batch_size = real.shape[0]
        
        

        # Train Critic: 
        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
            if len(noise) != len(labels):
                noise = noise[:len(labels)]
            fake = gen(noise, labels)
            critic_real = critic(real, labels).reshape(-1)
            critic_fake = critic(fake, labels).reshape(-1)
            gp = gradient_penalty(critic, real, fake, device=device)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp
            )
            
            
            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()
            
        #trained critic
        

        # Train Generator: 
        gen_fake = critic(fake, labels).reshape(-1)
        loss_gen = -torch.mean(gen_fake)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()
        
        
        
        
        
        #update the total loss of the generator and critic for each batch in the epoch
        #add just the value no gradients
        #have no gradient for these losse
        #add just the value no gradientsfrom the loss_gen tensor
      
        with torch.no_grad():
            total_loss_gen += loss_gen.item()
            total_loss_critic += loss_critic.item()
        
        # Print losses occasionally and print to tensorboard in a batch
        if batch_idx % 10 == 0 and batch_idx > 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)} \
                  Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
            )
            
            with torch.no_grad():
                
                #BATCH LOSS-----
                
                #write gen_loss and critic_loss to tensorboard
                writer_loss.add_scalar("Generator loss Batch", loss_gen, global_step=step)
                writer_loss.add_scalar("Critic loss Batch", loss_critic, global_step=step)
                
                #BATCH LOSS-------

                
                #FID--
                #calculate FID score of this batch
                #update the fid_score with real and fake images
                real_images_fid = interpolate(real)
                fake_images_fid = interpolate(fake)
                fid_score.update((real_images_fid, fake_images_fid))
                computed_fid_score = fid_score.compute()
                print("FID score: ", computed_fid_score)
                writer_loss.add_scalar("FID Score WGAN", computed_fid_score, global_step=batch_step)
                #reset the fid score
                fid_score.reset()
                ##FID--
                
                batch_step += 1
            
            
        
        

        # Print losses occasionally and print to tensorboard per epoch
        # if batch_idx % 100 == 0 and batch_idx > 0:
        
            ## PRINT for few epochs %10
            # print(
            #     f"Epoch [{epoch}/{NUM_EPOCHS}]  \
            #         Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
            # )
            
    #calculate the Fréchet Inception Distance (FID) to evaluate the performance of the generator
    

    with torch.no_grad():
        fake = gen(fixed_noise, fixed_labels)
        # take out (up to) 32 examples
        img_grid_real = torchvision.utils.make_grid(real[:100], normalize=True)
        img_grid_fake = torchvision.utils.make_grid(fake[:100], normalize=True)

        writer_real.add_image("Real", img_grid_real, global_step=step)
        writer_fake.add_image("Fake", img_grid_fake, global_step=step)
        
        
                        
        #AVERAGE LOSS---

        #get average loss of generator and critic for each epoch
        avg_loss_gen = total_loss_gen / len(loader)
        avg_loss_critic = total_loss_critic / len(loader)
        #write loss to tensorboard
        writer_loss.add_scalar("Generator loss Epoch", avg_loss_gen, global_step=batch_step)
        writer_loss.add_scalar("Critic loss Epoch", avg_loss_critic, global_step=batch_step)
        
        #AVERAGE LOSS----
        
        
        
        #we will plot the gradient of critic output with respect to the input image
        #get the gradient of the critic output with respect to the input image
        gradient = torch.autograd.grad(
        inputs=real,
        outputs=critic_real,
        grad_outputs=torch.ones_like(critic_real),
        create_graph=True,
        retain_graph=True,
        )[0]
        #flatten the gradient
        gradient = gradient.view(gradient.shape[0], -1)
        #get the norm of the gradient
        gradient_norm = gradient.norm(2, dim=1)
        #write gradient norm to tensorboard
        writer_loss.add_scalar("Gradient norm Critic Real", gradient_norm.mean(), global_step=step)
        
        #----------------
        #we will plot the gradient of critic output with respect to the input image
        #get the gradient of the critic output with respect to the input image
        gradient = torch.autograd.grad(
        inputs=fake,
        outputs=critic_fake,
        grad_outputs=torch.ones_like(critic_fake),
        create_graph=True,
        retain_graph=True,
        )[0]
        #flatten the gradient
        gradient = gradient.view(gradient.shape[0], -1)
        #get the norm of the gradient
        gradient_norm = gradient.norm(2, dim=1)
        #write gradient norm to tensorboard
        writer_loss.add_scalar("Gradient norm Critic Fake", gradient_norm.mean(), global_step=step)
        
        #----------------
        #we will plot the gradient of genrator output with respect to the input 
        #we will plot the gradient of genrator output with respect to the input 
        #get the gradient of the generator output with respect to the input noise
        gradient = torch.autograd.grad(
        inputs=noise,
        outputs=gen_fake,
        grad_outputs=torch.ones_like(gen_fake),
        create_graph=True,
        retain_graph=True,
        )[0]
        #flatten the gradient
        gradient = gradient.view(gradient.shape[0], -1)
        #get the norm of the gradient
        gradient_norm = gradient.norm(2, dim=1)
        #write gradient norm to tensorboard
        writer_loss.add_scalar("Gradient norm Generator", gradient_norm.mean(), global_step=step)
        
        #----------------

        
        
        # we will plot the gradient penalty
        writer_loss.add_scalar("GP", gp, global_step=step)
        #we will analyze for vanishing gradient
        writer_loss.add_scalar("Critic Real", critic_real.mean(), global_step=step)
        writer_loss.add_scalar("Critic Fake", critic_fake.mean(), global_step=step)
        
        #get the gradient of the critic for the parameters weights of first layer
        #we will write the norm of the gardient of weights of the first layer of the critic
        for name, param in critic.named_parameters():
            if name == "disc.0.weight":
                writer_loss.add_scalar("Critic Gradient w.r.t 1st layer WGAN", param.grad.norm(), global_step=step)
            #also plot the norm of gradient of 2nd layer
            elif name == "disc.2.0.weight":
                writer_loss.add_scalar("Critic Gradient w.r.t 2nd layer WGAN", param.grad.norm(), global_step=step)
        
       
       
        
        
        

    step += 1
        
        #save the trained model
        #check if trained_model folder exists
    if not os.path.exists("trained_models"):
        os.mkdir("trained_models")
    
    #now trained_model folder exists
    if not os.path.exists("trained_models/"+model_name):
        os.mkdir("trained_models/"+model_name)
    #check if "trained_models/"+model_name     
    torch.save(gen.state_dict(), "trained_models/"+model_name+"/gen.pth")
    torch.save(critic.state_dict(), "trained_models/"+model_name+"/critic.pth")
    
    
    

#save the tensorboard
writer_real.close()
writer_fake.close()
writer_loss.close()
#calculate the Fréchet Inception Distance (FID) to evaluate the performance of the generator

#(https://github.com/mseitzer/pytorch-fid)
#imports for FID

#print generator model architecture
print(gen)
