In [None]:
!pip install pytorch-msssim
!pip install faiss-cpu
!pip install wandb
!pip install sympy

In [None]:
from torchvision import datasets, transforms

from torchvision.utils import save_image
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

from torchsummary import summary
import torch.optim as optim
import logging

import wandb
import random

import  torch, time, os, pickle
import os
import torch
import numpy as np

from os import listdir
 

import torch
import torch.nn.functional as F

from pytorch_msssim import ssim  # For SSIM metric
import numpy as np
import torch.nn as nn
import torch.optim as optim


import matplotlib.image as mpimg


In [None]:
# Define device as cuda if available

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def save_images(epoch,y_,model_type):
    os.makedirs('outputs', exist_ok=True)
    z_ = torch.rand((batch_size, z_dim))
    y_vec_ = torch.zeros((batch_size, class_num)) \
            .scatter_(1, y_.type(torch.LongTensor).unsqueeze(1), 1)
    fake_images = G(z_,y_vec_).detach().cpu()
    save_image(fake_images, f'outputs/'+model_type+str(epoch)+'.png', nrow=8, normalize=True)

In [None]:
def generated_images():
  folder_dir = "/content/outputs/"

  for image in os.listdir(folder_dir):
      img = mpimg.imread(folder_dir+image)
      plt.imshow(img)
      plt.axis('off')  
      plt.show()

In [None]:
class GeneratorCGAN(nn.Module):
    def __init__(self, nz=100, nc=1, input_size=32, class_num=10):
        super(GeneratorCGAN, self).__init__()
        self.nz = nz
        self.nc = nc
        self.input_size = input_size
        self.class_num = class_num

      
        self.fc = nn.Sequential(
            nn.Linear(self.nz + self.class_num, 1024),
            nn.ReLU(),
            nn.Linear(1024, 128 * (self.input_size // 4) * (self.input_size // 4)),
            nn.ReLU(),
        )
    
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, self.nc, 4, 2, 1),
            nn.Tanh(),  
        )

        self.apply(self._initialize_weights)

    def forward(self, input, label):
        x = torch.cat([input, label], 1)
        

        x = self.fc(x)
        x = x.view(-1, 128, (self.input_size // 4), (self.input_size // 4))

    
        x = self.deconv(x)

        return x

    def _initialize_weights(self, m):
        if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, 0, 0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)


In [None]:
def data_set(batch_size):

    # Define transformations (convert to tensor + normalize)

    transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

    train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)

    dataloader = DataLoader(train_dataset,  batch_size)

    return dataloader


data_loader = data_set(batch_size)

In [None]:
# Define hyperparameters

batch_size=64

input_size = 28

z_dim = 62

class_num = 10

sample_num = class_num ** 2

In [None]:
def polynomial_mmd(x, y, degree=3, gamma=None, coef0=1):
    # Returns the polynomial MMD between x and y.

    # Reshape to 2D ensuring the same number of features
    x = x.view(x.size(0), -1)
    y = y.view(y.size(0), -1)

    # Check if feature dimensions match
    if x.shape[1] != y.shape[1]:
        raise ValueError(f"Feature dimensions do not match: x has {x.shape[1]} features, y has {y.shape[1]} features. Ensure real and generated images have the same number of channels.")

    if gamma is None:
        gamma = 1.0 / x.shape[1]
    kernel_xx = (gamma * x.mm(x.t()) + coef0) ** degree
    kernel_yy = (gamma * y.mm(y.t()) + coef0) ** degree
    kernel_xy = (gamma * x.mm(y.t()) + coef0) ** degree
    return kernel_xx.mean() + kernel_yy.mean() - 2 * kernel_xy.mean()



def kernel_inception_distance(real_features, generated_features):
    # Calculates the Kernel Inception Distance (KID) between real and generated features.

    # Ensure real and generated features have the same number of channels
    if real_features.shape[1] != generated_features.shape[1]:
         real_features = real_features[:, :generated_features.shape[1], :, :]
        

    # Calculate KID using polynomial MMD
    real_features, generated_features = torch.tensor(real_features), torch.tensor(generated_features)
    return polynomial_mmd(real_features, generated_features)

In [None]:
def inicialization(sample_num,z_dim,class_num):
    sample_z_ = torch.zeros((sample_num, z_dim))

    # Create fixed latent vectors (same noise for each class)
    for i in range(class_num):
        sample_z_[i * class_num] = torch.rand(1, z_dim)
        for j in range(1, class_num):
            sample_z_[i * class_num + j] = sample_z_[i * class_num]

    # Create class labels for one-hot encoding
    temp = torch.zeros((class_num, 1))
    for i in range(class_num):
        temp[i, 0] = i

    # Repeat class labels to match the sample size
    temp_y = torch.zeros((sample_num, 1))
    for i in range(class_num):
        temp_y[i * class_num: (i + 1) * class_num] = temp

    # Convert labels to one-hot encoding
    sample_y_ = torch.zeros((sample_num, class_num)) \
        .scatter_(1, temp_y.type(torch.LongTensor), 1)

    return sample_z_, sample_y_


In [None]:
def label_processing(batch_size,z_dim,class_num,input_size,y_):

  # Sample random noise and prepare label vectors

        z_ = torch.rand((batch_size, z_dim))
        y_vec_ = torch.zeros((batch_size, class_num)) \
            .scatter_(1, y_.type(torch.LongTensor).unsqueeze(1), 1)

        y_fill_ = y_vec_.unsqueeze(2).unsqueeze(3) \
            .expand(batch_size, class_num, input_size, input_size)

        return z_,y_vec_,y_fill_


In [None]:
def plot_losses(d_losses, g_losses,gan_type):
    plt.figure(figsize=(10, 5))
    plt.plot(d_losses, label='Discriminator Loss')
    plt.plot(g_losses, label='Generator Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Losses '+gan_type)
    plt.legend()
    plt.grid()
    os.makedirs('Metrics', exist_ok=True)
    plt.savefig('./Metrics/'+gan_type+'_generator_discriminator_losses.png')
    plt.show()

In [None]:
def ssim_plot(SSIM_SCORES_EPOCHS,gan_type):
    plt.title("SSIM for every Epoch")
    plt.plot(SSIM_SCORES_EPOCHS,color="green")
    plt.grid()
    os.makedirs('Metrics', exist_ok=True)
    plt.savefig('./Metrics/'+gan_type+'SSIM.png')
    plt.show()

In [None]:
def kid_plot(KID_SCORES_EPOCHS,gan_type):
  plt.title("KID in every Epoch")
  plt.plot(KID_SCORES_EPOCHS,color="yellow")
  plt.grid()
  plt.savefig('./Metrics/'+gan_type+'_KID_losses.png')
  plt.show()