In [2]:
import torch
import torch.nn as nn
import torchvision.utils as vutils
from torchvision.utils import save_image
from scipy import linalg
import torchvision.models as models
import torch.nn.functional as F
import json
import os
import numpy as np
import faiss

from ipynb.fs.full.definitions import SEED, MODELS_PATH, IMAGE_SIZE

In [None]:
def get_features(images, batch_size=32):
    model = models.inception_v3(pretrained=True).eval()
    model.fc = nn.Identity()
    model.cuda()
    
    features = []
    for i in range(0, len(images), batch_size):
        batch = images[i:i + batch_size].cuda()
        with torch.no_grad():
            batch_features = model(batch).cpu().numpy()
        features.append(batch_features)
    return np.concatenate(features)

In [10]:
# FID
def calculate_fretchet(images_real,images_fake):
    real_features = get_features(images_real)
    generated_features = get_features(images_fake)
    
    mu1, sigma1 = real_features.mean(axis=0), np.cov(real_features, rowvar=False)
    mu2, sigma2 = generated_features.mean(axis=0), np.cov(generated_features, rowvar=False)
    diff = mu1 - mu2
    covmean = linalg.sqrtm(sigma1.dot(sigma2))
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    
    fid = diff.dot(diff) + np.trace(sigma1 + sigma2 - 2 * covmean)
    return fid

In [None]:
def polynomial_mmd(x, y, degree=3, gamma=None, coef0=1):
    if gamma is None:
        gamma = 1.0 / x.shape[1]
    kernel_xx = (gamma * x.mm(x.t()) + coef0) ** degree
    kernel_yy = (gamma * y.mm(y.t()) + coef0) ** degree
    kernel_xy = (gamma * x.mm(y.t()) + coef0) ** degree
    return kernel_xx.mean() + kernel_yy.mean() - 2 * kernel_xy.mean()

In [None]:
# KID
def kernel_inception_distance(images_real, images_fake):
    real_features = get_features(images_real)
    generated_features = get_features(images_fake)
    
    real_features, generated_features = torch.tensor(real_features), torch.tensor(generated_features)
    return polynomial_mmd(real_features, generated_features)

In [None]:
# PRECISION AND RECALL
def compute_precision_recall(images_real, images_fake, k=5):
    real_features = get_features(images_real)
    generated_features = get_features(images_fake)
    
    # Initialize FAISS index for nearest neighbors
    index = faiss.IndexFlatL2(real_features.shape[1])
    index.add(real_features)
    
    # Precision: Nearest neighbors for generated samples in real set
    D, I = index.search(generated_features, k)
    precision = np.mean([i in I for i in range(len(real_features))])
    
    # Recall: Nearest neighbors for real samples in generated set
    index.reset()
    index.add(generated_features)
    D, I = index.search(real_features, k)
    recall = np.mean([i in I for i in range(len(generated_features))])
    
    return precision, recall

In [12]:
def seed_everything():
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True

In [14]:
# random seeds for reproducibility
seed_everything()

In [11]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [13]:
class LambdaLR:
    def __init__(self, n_epochs, offset, decay_start_epoch):
        assert(n_epochs - decay_start_epoch) > 0, "Decay must start before the training session ends!"
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch
        
    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch)

In [23]:
def training_loop(num_epochs, dataloader, netG, netD, device, criterion, nz, optimizerG, optimizerD, schedulerG, schedulerD, fixed_noise, out):

    img_list = []
    img_list_only = []
    G_losses = []
    D_losses = []
    iters = 0

    real_label = 0.9 # label smoothing
    fake_label = 0.
    
    print("Starting Training Loop...")
    for epoch in range(num_epochs):
        for i, data in enumerate(dataloader, 0):
            netD.zero_grad()
            real_cpu = data[0].to(device)
            b_size = real_cpu.size(0)
            label = torch.full((b_size,), real_label, dtype=torch.float, device=device)

            output = netD(real_cpu).view(-1)

            errD_real = criterion(output, label)

            errD_real.backward()
            D_x = output.mean().item()

            noise = torch.randn(b_size, nz, 1, 1, device=device)

            fake = netG(noise)
            label.fill_(fake_label)

            output = netD(fake.detach()).view(-1)

            errD_fake = criterion(output, label)

            errD_fake.backward()
            D_G_z1 = output.mean().item()

            errD = errD_real + errD_fake

            optimizerD.step()

            netG.zero_grad()
            label.fill_(real_label)

            output = netD(fake).view(-1)

            errG = criterion(output, label)

            errG.backward()
            D_G_z2 = output.mean().item()

            optimizerG.step()

            if i % 50 == 0:
                print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                      % (epoch, num_epochs, i, len(dataloader),
                         errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

            G_losses.append(errG.item())
            D_losses.append(errD.item())
            
            fretchet_dist=calculate_fretchet(real_cpu,fake) 
            if isinstance(fretchet_dist, torch.Tensor):
                fretchet_dist = fretchet_dist.cpu().numpy()
                
            kernel_dist=kernel_inception_distance(real_cpu,fake)
            if isinstance(kernel_dist, torch.Tensor):
                kernel_dist = kernel_dist.cpu().numpy()
                
            precision, recall = compute_precision_recall(real_cpu,fake)
            if isinstance(precision, torch.Tensor):
                precision = precision.cpu().numpy()
            
            if isinstance(recall, torch.Tensor):
                recall = recall.cpu().numpy()
            
            if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
                with torch.no_grad():
                    fake = netG(fixed_noise).detach().cpu()

                img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
                img_list_only.append(fake)
            
            if (iters % 20 == 0) and (epoch != 0):
                js = {"Iteration" : iters, "FID": str(fretchet_dist), "KID": str(kernel_dist), "PR": str(precision), "REC": str(recall)}
                
                f = open(os.path.join(out, 'metrics.json'), 'a')
                f.write(json.dumps(js, indent=4))
                f.close()
                # save the trained generator so far...
                torch.save(netG.state_dict(), os.path.join(MODELS_PATH, "generator_" + str(IMAGE_SIZE) + 'x' + str(IMAGE_SIZE) + "_" + str(epoch) + ".pth"))

                # ...as well as the trained discriminator...
                torch.save(netD.state_dict(), os.path.join(MODELS_PATH, "discriminator_" + str(IMAGE_SIZE) + 'x' + str(IMAGE_SIZE) + "_" + str(epoch) + ".pth"))
                
                
            if (epoch % 50 == 0) and (epoch != 0):
                # ...and the losses 
                np.save(os.path.join(out, "G_loss_" + str(epoch)), G_losses)
                np.save(os.path.join(out, "D_loss_" + str(epoch)), D_losses)
                
                for j in range(len(fake)-1):
                    save_image(fake[j], os.path.join(out, 'synthetics', 'e' + str(epoch) + '_' + str(j) + '_synthetic.png'))
            
            iters += 1
        # Update learning rates
        schedulerG.step()
        schedulerD.step()
        
    return G_losses, D_losses, img_list, img_list_only