# How to Use FJD

This notebook provides a simple example of how to evaluate a conditional GAN using FJD.

In [1]:
%matplotlib inline
import numpy as np

import torch
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torch.nn as nn
import torch.nn.functional as F

from fjd.fjd_metric import FJDMetric
from fjd.embeddings import OneHotEmbedding, InceptionEmbedding

import os
os.getcwd()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [3]:
if False:
    AE = Autoencoder()
    print(AE)
    AE.load_state_dict(torch.load('models/caegan_mnist_G.pt'))
    with torch.no_grad():
            codes = one_hot_embedding(torch.tensor(list(range(9)), device = 'cpu')).view(9,c_dim,1,1).float()
            varis = torch.randn((9, v_dim,1,1), device = 'cpu') # walk from [0,...,0] to [1,...,1]
            #print(codes.shape, varis.shape)
            generated = .5*(AE.forward(varis, codes).cpu() + 1)
            generated = torch.squeeze(generated)
            #print(generated.shape)
            for i in range(9):
                plt.subplot(330 + 1 + i)
                # plot raw pixel data
                element = generated[i,:]
                plt.imshow(element, cmap = 'gray')
            plt.show()
        

In order to compute FJD we will need two data loaders: one to provide images and conditioning for the reference distribution, and a second one whose conditioning will be used to condition the GAN for creating the generated distribution. For this example we will use the CIFAR-10 dataset.

When loading in reference images, it is important to normalize them between [-1, 1].

In [4]:
def get_dataloaders():
    #transform = transforms.Compose(
    #    [transforms.ToTensor(),
    #     transforms.Normalize(mean=(0.5, 0.5, 0.5), 
    #                          std=(0.5, 0.5, 0.5))])
    
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(mean=[0.5], 
                              std=[0.5])])

    train_set = MNIST(root='./data',
                        train=True,
                        download=True,
                        transform=transform)

    test_set = MNIST(root='./data',
                       train=False,
                       download=True,
                       transform=transform)

    train_loader = DataLoader(train_set,
                              batch_size=128,
                              shuffle=True,
                              drop_last=False)

    test_loader = DataLoader(test_set,
                             batch_size=128,
                             shuffle=False,
                             drop_last=False)

    return train_loader, test_loader

In [5]:
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        ## Encoding: Unconditional samples
        self.conv1 = nn.Conv2d(1, 128, 4, 2, 1) # Input: (bs, 3, img_size, img_size)
        self.conv2 = nn.Conv2d(128, 256, 4, 2, 1, bias = False)
        self.conv2_bn = nn.BatchNorm2d(256)
        self.conv3 = nn.Conv2d(256, 512, 4, 2, 1, bias = False)
        self.conv3_bn = nn.BatchNorm2d(512)
        self.conv4 = nn.Conv2d(512, 1024, 4, 2, 1, bias = False)
        self.conv4_bn = nn.BatchNorm2d(1024)
        
        self.conv5v = nn.Conv2d(1024, v_dim, 4, 1, 0) # Output: (bs, c_dim, 1, 1)
        self.conv5c = nn.Conv2d(1024, c_dim, 4, 1, 0) # Output, same as above: but this one to condition-space
        
        ## Decoding:
        self.deconv1v = nn.ConvTranspose2d(v_dim, 1024, 4, 1, 0, bias = False) # Not sure how this looks
        self.deconv1c = nn.ConvTranspose2d(c_dim, 1024, 4, 1, 0, bias = False) # Input: (bs, cdim+v_dim, 1, 1)
        
        self.deconv1_bn = nn.BatchNorm2d(1024)
        self.deconv2 = nn.ConvTranspose2d(1024+1024, 512, 4, 2, 1, bias = False)
        self.deconv2_bn = nn.BatchNorm2d(512)
        self.deconv3 = nn.ConvTranspose2d(512, 256, 4, 2, 1, bias = False)
        self.deconv3_bn = nn.BatchNorm2d(256)
        self.deconv4 = nn.ConvTranspose2d(256, 128, 4, 2, 1, bias = False)
        self.deconv4_bn = nn.BatchNorm2d(128)
        self.deconv5 = nn.ConvTranspose2d(128, 1, 4, 2, 1)
    
    def weight_init(self):
        for m in self._modules:
            normal_init(self._modules[m])
            
    def encode(self, x):
        # Encode data x to 2 spaces: condition space and variance-space
        x = F.leaky_relu(self.conv1(x), 0.2)
        x = F.leaky_relu(self.conv2_bn(self.conv2(x)), 0.2)
        x = F.leaky_relu(self.conv3_bn(self.conv3(x)), 0.2)
        x = F.leaky_relu(self.conv4_bn(self.conv4(x)), 0.2)
        
        v = torch.sigmoid(self.conv5v(x)) # Variance-space unif~[0,1]
        c = torch.sigmoid(self.conv5c(x)) # this is softmax for CLASSIFICATION. Shapes3d is not 1-classif..
        
        return v, c
      
    def forward(self, v, c):
        # This is actually conditional generation // decoding.
        # It's beneficial to call this forward, though, for FJD calculation
        v = self.deconv1_bn(self.deconv1v(v))
        c = self.deconv1_bn(self.deconv1c(c))
        x = torch.cat((v, c), dim = 1) #stack on channel dim, should be (bs, vdim+cdim, 1, 1). Not sure here
        x = F.relu(self.deconv2_bn(self.deconv2(x)))
        x = F.relu(self.deconv3_bn(self.deconv3(x)))
        x = F.relu(self.deconv4_bn(self.deconv4(x)))
        x = torch.tanh(self.deconv5(x))
        return x
    
    def pass_thru(self, x):
        v, c = self.encode(x)
        return self.forward(v, c)

To simulate a GAN we will use samples from the test set. Note that the labels from the test set are shuffled, so although the image distribution should match the reference distribution well, the conditional consistency will be very bad since most of the "generated" images will not match the requested conditions.

In [6]:
## My GANWrapper
class GANWrapper:
    def __init__(self, model, model_checkpoint=None):
        self.model = model
        
        if model_checkpoint is not None:
            self.model_checkpoint = model_checkpoint
            self.load_model()
         
        # model weights are loaded correctly...
        #print(self.model.deconv1v.weight)

    def load_model(self):
        self.model.eval()  # uncomment to put in eval mode if desired
        self.model = self.model.cuda()
        
        #state_dict = torch.load(self.model_checkpoint)
        self.model.load_state_dict(torch.load('models/caegan_mnist_G.pt'))
        
        print('State dict freshly loaded. Now print:')
        with torch.no_grad():
            codes = one_hot_embedding(torch.tensor(list(range(9)), device = device)).view(9,c_dim,1,1).float()
            varis = torch.randn((9, v_dim,1,1), device = device) # walk from [0,...,0] to [1,...,1]
            print(codes.shape, varis.shape)
            generated = .5*(self.model.forward(varis, codes).cpu() + 1)
            generated = torch.squeeze(generated)
            print(generated.shape)
            for i in range(9):
                plt.subplot(330 + 1 + i)
                # plot raw pixel data
                element = generated[i,:]
                plt.imshow(element, cmap = 'gray')
            plt.show()

    def get_noise(self, batch_size):
        # change the noise dimension as required
        z = torch.randn((batch_size, v_dim, 1, 1), device = device)
        return z

    def __call__(self, y):
        #print(y.shape) # Might need to use my one-hot, etc on this
        #print(y[0])
        batch_size = y.size(0)
        #print(y)
        #print(y.shape)
        
        

        y = one_hot_embedding(y).view(batch_size, c_dim, 1, 1).to(device).float()
        z = self.get_noise(batch_size)
        #print(y.shape)
        #print(y[0])
        samples = self.model(z, y)
        #print(samples.shape)
        #print(torch.min(samples), torch.max(samples))
        #plt.imshow(samples[0].cpu().squeeze(), cmap = 'gray')
        #plt.show()
        #samples = self.model(z)
        return samples

In order to be able to accomodate a wide variety of model configurations, we use a GAN wrapper to standardize model inputs and outputs. Each model is expected to take as input a set of conditions _y_, and return a corresponding set of generated samples.

In [7]:
def one_hot_embedding(labels):
    #y = torch.eye(num_classes)
    #return y[labels]
    #return torch.nn.functional.one_hot(labels)[:,1:]
    
    labels = torch.nn.functional.one_hot(torch.tensor(labels).to(torch.int64), num_classes = c_dim)
    return torch.squeeze(labels)

#class GANWrapper:
#    def __init__(self, model, model_checkpoint=None):
#        self.model = model
#        
#        if model_checkpoint is not None:
#            self.model_checkpoint = model_checkpoint
#            self.load_model()#
#
#    def load_model(self):
#        # self.model.eval()  # uncomment to put in eval mode if desired
#        self.model = self.model.cuda()#
#
#        state_dict = torch.load(self.model_checkpoint)
#        self.model.load_state_dict(state_dict)##

#    def get_noise(self, batch_size):
#        # change the noise dimension as required
#        z = torch.cuda.FloatTensor(batch_size, 128).normal_()
#        return z

#    def __call__(self, y):
#        #y = one_hot_embedding(y)
#        batch_size = y.size(0)
#        z = self.get_noise(batch_size)
#        samples = self.model(z, y)
#        return samples

The FJDMetric object handles embedding the images and conditioning, the computation of the reference distribution and generated distribution statistics, the scaling of the conditioning component with alpha, and the calculation of FJD. It requires several inputs:

1. **gan** - A GAN model which takes as input conditioning and yields image samples as output.  
2. **reference_loader** - A data loader for the reference distribution, which yields image-condition pairs.  
3. **condition_loader** - A data loader for the generated distribution, which yields image-condition pairs. Images are ignored, and the conditioning is used as input to the GAN.  
4. **image_embedding** - An image embedding function. This will almost always be the InceptionEmbedding.  
5. **condition_embedding** - A conditioning embedding function. As we are dealing with class conditioning in this example, we will use one-hot encoding.

Other options:
* **save_reference_stats** - Indicates whether the statistics of the reference distribution should be saved to the path provided in **reference_stats_path**. This can speed up computation of FJD if the same reference set is used for multiple evaluations.
* **samples_per_condition** - Indicates the number of images that will be generated for each condition drawn from the condition loader. This may be useful if there are very few samples in the conditioning dataset, or to emphasize intra-conditioning diversity when calculating FJD.
* **cuda** - If True, indicates that the GPU accelerated version of FJD should be used. This version should be considerably faster than the CPU version, but may be slightly more unstable.

In [8]:
z_dim = 100
v_dim = 100
c_dim = 10

train_loader, test_loader = get_dataloaders()
inception_embedding = InceptionEmbedding(parallel=False)
onehot_embedding = OneHotEmbedding(num_classes=10)
gan = SuspiciouslyGoodGAN()
#gan = Autoencoder()
#params = 'models/cgan_mnist_G.pt'
params = 'models/caegan_mnist_G.pt'
gan = GANWrapper(gan, params)

fjd_metric = FJDMetric(gan=gan,
                       reference_loader=train_loader, #mnist train
                       condition_loader=test_loader, #mnist test
                       image_embedding=inception_embedding, #dont change
                       condition_embedding=onehot_embedding,
                       reference_stats_path='datasets/cifar_train_stats.npz',
                       save_reference_stats=True,
                       samples_per_condition=6,
                       cuda=True)

NameError: name 'SuspiciouslyGoodGAN' is not defined

Once the FJD object is initialized, FID and FJD can be calculated by calling **get_fid** or **get_fjd**.  By default, the alpha value used to weight the conditional component of FJD is selected to be the ratio between the average L2 norm of the image embedding and conditioning embedding.  

We see in this example that even though our "GAN" gets a very good FID score due to the generated image distribution being very close to the reference image distribution, its FJD score is very bad, as the model lacks any conditional consistency.

In [None]:
fid = fjd_metric.get_fid()
fjd = fjd_metric.get_fjd()
print('FID: ', fid)
print('FJD: ', fjd)

To visualize how FJD changes as we increase the weighting on the conditional component, we can evaluate it at a range of alpha values using the **sweep_alpha** function.

In [None]:
alpha = fjd_metric.alpha
alphas = [0, 1, 2, 4, 8, 16, 32]
fjds = fjd_metric.sweep_alpha(alphas)

plt.plot(alphas, fjds, label='FJD', linewidth=3)
plt.plot(alphas, [fid]*len(alphas), label='FID', linewidth=3)
plt.axvline(x=alpha, c='black', label=r'Suggested $\alpha$', linewidth=2)
plt.xlabel(r'$\alpha$')
plt.ylabel('Distance')
plt.legend()