<a href="https://colab.research.google.com/github/Stanfording/Trying-FID-LOSS/blob/main/FID_LOSS_implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>



# Goal:   Testing if FID Loss works on training GAN



### Get the preprocessed celebHD data from google drive

In [43]:
# #Download the dataset
# !wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1KqBRLsB0CJuQGycvaPINwaPgcGDUsAxN' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1KqBRLsB0CJuQGycvaPINwaPgcGDUsAxN" -O "proCeleba.zip" && rm -rf /tmp/cookies.txt

# #unzip the dataset
# !unzip "/content/proCeleba.zip"

# #remove unnecessary files
# !rm -rf /content/__MACOSX

### Import libraries

In [44]:
from tqdm import tqdm

from torch.autograd import Variable, grad
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
from PIL import Image
import os
import torch as t
import torch.nn as nn
from torchvision import datasets, transforms, utils
from torch.utils.data import Dataset, DataLoader
from skimage import io
import copy

### Set initial hyperparameters

In [45]:
batch_size = 64

resolution = 16

img_fold_dir_64_reso = f"/content/proCeleba/{resolution}"

iteration = 200

critic = 5          

eval_size = 25

laten_space = 100

device = t.device('cuda' if t.cuda.is_available() else 'cpu')

log_folder = "log"
!mkdir "log"
!mkdir "log/checkpoint"
!mkdir "log/sample"

mkdir: cannot create directory ‘log’: File exists
mkdir: cannot create directory ‘log/checkpoint’: File exists
mkdir: cannot create directory ‘log/sample’: File exists


### Show me which gpu I am using.

In [46]:
!nvidia-smi -L

GPU 0: Tesla P100-PCIE-16GB (UUID: GPU-07454f12-3b42-7271-d801-ce798afb5a0f)


### Pre loading the data

In [47]:
# Define a data class for load unclassfied data.
class Get_No_Classes_Img_Dataset(Dataset):
    
    def __init__(self, folder_dir, transform = None):
        self.folder_dir = os.path.join(folder_dir)
        self.transform = transform
        self.image_list = os.listdir(self.folder_dir)
        
    def __len__(self):
        return len(os.listdir(self.folder_dir))
    
    def __getitem__(self, index):
        
        image_name = self.image_list[index]
        
        image_dir = os.path.join(self.folder_dir, image_name)
        
        image = io.imread(image_dir)
        
        if (self.transform != None):
            image = self.transform(image)
        
        return image
    

transform = transforms.Compose([
    transforms.ToTensor()       #From Batch * Highth * Width * Channel to Batch * Channel * Highth * Width
                                  #Which is what pytorch CNN can work with.
]) 

dataset = Get_No_Classes_Img_Dataset(img_fold_dir_64_reso, transform = transform) 
                                                            # datasets[0].shape = (16,16,3)
                                                            # len(datasets) = 28000
total_data_len = len(dataset)
                                                                                            
#datasets_batched = DataLoader(datasets, batch_size = batch_size) #loader is renewed every epoch




```
loader = iter(loader)
print(next(loader).shape) 
```
will output


```
torch.Size([batch_size, 3, resolution, resolution])
```

So data loading is ready.

What's left is keep using 

```
next(loader)
```
to access each batch of data


### Visualize a picture

In [48]:
from IPython.display import Image, display




# Viewing one data sample function:
def showOneImge(img, i, shouldSave):
    
    img = img.squeeze()
    
    img = transforms.ToPILImage()(img)
    
    plt.figure(figsize = (10,10), dpi = 10)
    plt.axis('off')
    
    if shouldSave:
      saveDir = f'{log_folder}/sample/{str(i).zfill(6)}.png'
      plt.imshow(img)
      plt.savefig(saveDir, bbox_inches='tight', pad_inches = 0)
      img = Image(saveDir)
      display(img)
    else:
      deleteDir = f"{log_folder}/sample/Delete.png"
      plt.imshow(img)
      plt.savefig(deleteDir, bbox_inches='tight', pad_inches = 0)
      img = Image(deleteDir)
      display(img)
      !rm "/content/log/sample/Delete.png"
    return 

''' Testing showOneImage'''
# loader = DataLoader(datasets, batch_size = batch_size)

# data = iter(loader)

# oneSample = next(data)[0]

# showOneImge(oneSample, 9999, True)


def showMoreImages(img, num):

  subplot_x = int(num ** (1/2))
  subplot_y = num // subplot_x
  plt.figure(figsize = (2,2))
  for i in range(len(img)):
      aimg = transforms.ToPILImage()(img[i])
      plt.subplot(subplot_x, subplot_y, i+1)
      plt.imshow(aimg)
      plt.axis('off')


-------------------------------------------------

### Now Designing the simple GAN network

In [49]:
"""
Define the generator
"""

class G(nn.Module):
    
    def __init__(self):
        super().__init__()

        self.laten = nn.Sequential(
            nn.Linear(laten_space, 100),
            nn.Linear(100, 500),
            nn.Linear(500, 128 * resolution * resolution))

        self.model = nn.Sequential(
            nn.Conv2d(128, 64, (3, 3), padding = "same"),
            nn.LeakyReLU(0.2),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 32, (3, 3), padding = "same"),
            nn.LeakyReLU(0.2),
            nn.Dropout2d(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 3, (3, 3), padding = "same"),
            nn.Sigmoid()
        )
        
        
        
    def forward(self, theInput, batch_size):
        
        x = self.laten(theInput)
        
        x = t.reshape(x, (batch_size, 128, resolution, resolution))
         
        x = self.model(x)
            
            
        return x
    
class D(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.netWork = nn.Sequential(
            nn.Conv2d(3, 64, (3, 3), padding = "same"),
            nn.LeakyReLU(),
            nn.Dropout2d(0.3),
            nn.Conv2d(64, 128, (3, 3), stride = (3, 3)),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),
            nn.Conv2d(128, 1, (3, 3), stride = (3, 3)),
            nn.Flatten())
        
    def forward(self, theInput):
        
        return self.netWork(theInput)

### Testing the network

In [50]:
"""

Testing the net work:

"""

# oneImg = next(iter(datasets))[0]

# showOneImge(oneImg, 0)

# oneImg = oneImg.expand(1,3,64,64)

# print(oneImg.shape)

# print(oneImg)

# #img into G to test shape

# input_noise_example = t.randn((batch_size, 1, 1, 5))

# a = G()(input_noise_example, batch_size)
# print(a[0])
# showOneImge(a[0], 0)

# b = D()(next(loader))

# print(b.shape)

'\n\nTesting the net work:\n\n'

### Gradient Penalty from wGAN.

In [51]:
def compute_gradient_penalty(D, real_samples, fake_samples, current_batch_size):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
    alpha = t.randn((current_batch_size, 1, 1, 1)).to(device)
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    
    
    d_interpolates = D(interpolates)
    
    
    grad_x_hat = grad(
            outputs=d_interpolates.sum(), inputs=interpolates, create_graph=True)[0]
    grad_penalty = ((grad_x_hat.view(grad_x_hat.size(0), -1)
                      .norm(2, dim=1) - 1)**2).mean()
    grad_penalty = 10 * grad_penalty
    
    return grad_penalty

### Try FID LOSS

In [52]:
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models.feature_extraction import get_graph_node_names
from torch.autograd import Function

import numpy as np

import scipy.linalg as linalg

"""MatrixSquareRoot is from https://github.com/steveli/pytorch-sqrtm/blob/master/sqrtm.py"""

class MatrixSquareRoot(Function):
    """Square root of a positive definite matrix.
    NOTE: matrix square root is not differentiable for matrices with
          zero eigenvalues.
    """
    @staticmethod
    def forward(ctx, input):
        m = input.detach().cpu().numpy().astype(np.float_)
        sqrtm = t.from_numpy(linalg.sqrtm(m).real).to(input)
        ctx.save_for_backward(sqrtm)
        return sqrtm

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = None
        if ctx.needs_input_grad[0]:
            sqrtm, = ctx.saved_tensors
            sqrtm = sqrtm.data.cpu().numpy().astype(np.float_)
            gm = grad_output.data.cpu().numpy().astype(np.float_)

            # Given a positive semi-definite matrix X,
            # since X = X^{1/2}X^{1/2}, we can compute the gradient of the
            # matrix square root dX^{1/2} by solving the Sylvester equation:
            # dX = (d(X^{1/2})X^{1/2} + X^{1/2}(dX^{1/2}).
            grad_sqrtm = linalg.solve_sylvester(sqrtm, sqrtm, gm)

            grad_input = t.from_numpy(grad_sqrtm).to(grad_output)
        return grad_input


sqrtm = MatrixSquareRoot.apply

""" FID is modified from official pytorch implementation """

class FID_Loss(nn.Module):
    def __init__(self):
      super().__init__()
      self.model = t.hub.load('pytorch/vision:v0.10.0', 'inception_v3', pretrained=True)
    
      self.model.eval()
    
      self.model = create_feature_extractor(self.model, {'avgpool': 'feat'})


    def cov(self, tensor, rowvar=True, bias=False):
      """Estimate a covariance matrix (np.cov)
      https://gist.github.com/ModarTensai/5ab449acba9df1a26c12060240773110
      """
      tensor = tensor if rowvar else tensor.transpose(-1, -2)
      tensor = tensor - tensor.mean(dim=-1, keepdim=True)
      factor = 1 / (tensor.shape[-1] - int(not bool(bias)))
      return factor * tensor @ tensor.transpose(-1, -2).conj()


    def sqrtm(self, a):
      # Computing diagonalization
          evalues, evectors = t.linalg.eigh(a)
          # Ensuring square root matrix exists
          #assert (evalues >= 0).all()
          covmean = evectors @ t.diag(t.sqrt(evalues)) @ t.linalg.inv(evectors)
          return covmean

  # calculate frechet inception distance
  # A faster FID calculation modified from https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py
    def calculate_frechet_distance(self, ac1, ac2, eps=1e-6):
        """Numpy implementation of the Frechet Distance.
        The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
        and X_2 ~ N(mu_2, C_2) is
                d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
        Stable version by Dougal J. Sutherland.
        Params:
        -- mu1   : Numpy array containing the activations of a layer of the
                  inception net (like returned by the function 'get_predictions')
                  for generated samples.
        -- mu2   : The sample mean over activations, precalculated on an
                  representative data set.
        -- sigma1: The covariance matrix over activations for generated samples.
        -- sigma2: The covariance matrix over activations, precalculated on an
                  representative data set.
        Returns:
        --   : The Frechet Distance.
        """
        
        
        mu1 = t.mean(ac1)
        sigma1 = self.cov(ac1)
        
        mu2 = t.mean(ac2)
        sigma2 = self.cov(ac2)
        
        mu1 = t.atleast_1d(mu1)
        mu2 = t.atleast_1d(mu2)
        
        sigma1 = t.atleast_2d(sigma1)
        sigma2 = t.atleast_2d(sigma2)
        
        assert mu1.shape == mu2.shape, \
            'Training and test mean vectors have different lengths'
        assert sigma1.shape == sigma2.shape, \
            'Training and test covariances have different dimensions'
        
        diff = mu1 - mu2
        
        # Product might be almost singular
        #print("sigma1 shape: ", sigma1.shape)
        #print("sigma2 shape: ", sigma2.shape)
        
        
        # Computing diagonalization
        #covmean = self.sqrtm(sigma1.T@sigma2)
        covmean = sqrtm(sigma1@sigma2.T)
        #print(covmean.shape)
        if not t.isfinite(covmean).all():
            msg = ('fid calculation produces singular product; '
                  'adding %s to diagonal of cov estimates') % eps
            #print(msg)
            offset = t.eye(sigma1.shape[0]) * eps
            covmean = sqrtm((sigma1 + offset).dot(sigma2 + offset))
        
        #Numerical error might give slight imaginary component
        if t.is_complex(covmean):
            if not t.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
                m = t.max(np.abs(covmean.imag))
                raise ValueError('Imaginary component {}'.format(m))
            covmean = covmean.real
        
        tr_covmean = t.trace(covmean)
        
        return  (diff.dot(diff) + t.trace(sigma1) + t.trace(sigma2) - 2 * tr_covmean)
      



    def forward(self, fake, real):
      
        upsamle_layer = t.nn.UpsamplingBilinear2d(size=299)
    
        transform = transforms.Compose([
                                      transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])
    
        fake = upsamle_layer(fake)
        fake = transform(fake)
    
        fake_feature = self.model(fake)['feat']
        fake_feature = fake_feature.reshape((fake_feature.shape[0], 2048))
    
        real = upsamle_layer(real)
        real = transform(real)
    
        real_feature = self.model(real)['feat']
        real_feature = real_feature.reshape((real_feature.shape[0], 2048))
        #print("real_feature: ", real_feature.shape)
        FID_score = self.calculate_frechet_distance(real_feature, fake_feature)
    
        return FID_score






In [53]:
# "Testing FID Loss"
# fake = t.randn((64, 3, 16, 16)).to(device)
# real = t.randn((64, 3, 16, 16)).to(device)

# #same = FID_loss(fake, fake)

# dif = FID_loss(fake, real)

# dif = dif.reshape((1,1))

# dif = t.asarray(dif)

# a = nn.Sequential(
#     nn.Linear(1, 10),
#     nn.Linear(10, 1)
# )

# a_opt = t.optim.Adam(a.parameters(), lr = 0.001)

# loss = a(dif.float())

# loss.backward()

# print(a.state_dict())


# #print("same: ", same)
# print("dif", dif)


### Initializing the generator, discriminator, optimizer, labels, and loss.

In [54]:
generator = G().to(device)
discriminator = D().to(device)
FID_loss = FID_Loss().to(device)

G_optimizer = t.optim.Adam(generator.parameters(), lr = 0.00001)
D_optimizer = t.optim.Adam(discriminator.parameters(), lr = 0.00001)

label_real = t.ones((batch_size, 1)).to(device)
label_fake = t.zeros((batch_size, 1)).to(device)

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


### Train decider

In [None]:
for i in range(iteration):

    p = tqdm(range(total_data_len // batch_size + 1)) # This is a progress bar run on each epoch
    
    datasets_batched = DataLoader(dataset, batch_size = batch_size)
    
    loader = iter(datasets_batched)

    for j in p:
        

        batchNum = str(i+1)
        
        #Training the D
        #real data
        real = next(loader).to(device)

        #current_batch size (the last batch is different than others)
        current_batch_size, c, h, w = real.shape
        #labels
        label_real = 0.1 * t.randint(7,10,(current_batch_size,1)).type(t.half)
        label_fake = 0.1 * t.randint(0,3,(current_batch_size,1)).type(t.half)
        
        input_noise = t.normal(0, 1, size = (current_batch_size, 1, 1, laten_space)).to(device)
        fake = generator(input_noise, current_batch_size)
        
        total_loss = FID_loss(real, fake)

        generator.zero_grad()
        FID_loss.zero_grad()
        total_loss.backward()
        G_optimizer.step()


          

        
        mse = "Epoch: " + batchNum
        
        p.set_description(mse)
            
        p.set_postfix(FID_loss = total_loss)
        
    if i == 0:
        showOneImge(real[0], 99999, True)   
    
    if i % 2 == 0:
      
      showOneImge(real[0], 99999, False)
      showOneImge(fake[0], i, False)
      print("epoch = ", i + 1)  
      #print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'.format(i, iteration, j+1, total_data_len // batch_size + 1, total_loss.item(), g_fake_loss.item(), real_score.mean().item(), fake_score.mean().item()))      
      

    if i % 50 == 0:
      t.save(generator.state_dict(), f'{log_folder}/checkpoint/{str(i + 1).zfill(6)}_g.model')
      t.save(discriminator.state_dict(), f'{log_folder}/checkpoint/{str(i + 1).zfill(6)}_d.model')


### Show Result

In [None]:

#Generated
generator.eval()

with t.no_grad():
  input_noise = t.normal(0, 1, size = (eval_size, 1, 1, laten_space)).to(device)
  generated = generator(input_noise, eval_size)
  showMoreImages(generated, eval_size)

#Real
datasets_batched = DataLoader(dataset, batch_size = eval_size)
loader = iter(datasets_batched)
real = next(loader)
showMoreImages(real, eval_size)

### remove log file when necessarry

In [None]:
# !rm -rf /content/log/checkpoint

# !mkdir /content/log/checkpoint

# !rm -rf /content/log/sample
# !mkdir /content/log/sample

# from google.colab import files
# files.download('/content/log') 