In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10

from fjd.fjd_metric import FJDMetric
from fjd.embeddings import OneHotEmbedding, InceptionEmbedding

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [2]:
class SuspiciouslyGoodGAN(torch.nn.Module):
    def __init__(self):
        super(SuspiciouslyGoodGAN, self).__init__()

        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(32),
            transforms.Normalize(mean=[0.5],
                                 std=[0.5])])

        test_set = CIFAR10(root='./data',
                           train=False,
                           download=True,
                           transform=transform)

        test_loader = DataLoader(test_set,
                                 batch_size=128,
                                 shuffle=False,
                                 drop_last=True)

        self.test_loader = test_loader
        self.data_iter = iter(test_loader)

    def forward(self, z, y):
        # Normally a GAN would actually do something with z and y, but for this fake GAN we ignore them
        try:
            samples, _ = next(self.data_iter)
        except StopIteration:
            # Reset dataloader if it runs out of samples
            self.data_iter = iter(self.test_loader)
            samples, _ = next(self.data_iter)
        samples = samples.cuda()
        return samples
    
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        ## Decoding:
        self.deconv1 = nn.ConvTranspose2d(z_dim, 1024, 4, 1, 0, bias = False) # Not sure how this looks
        self.deconv1_bn = nn.BatchNorm2d(1024)
        self.deconv2 = nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias = False)
        self.deconv2_bn = nn.BatchNorm2d(512)
        self.deconv3 = nn.ConvTranspose2d(512, 256, 4, 2, 1, bias = False)
        self.deconv3_bn = nn.BatchNorm2d(256)
        self.deconv4 = nn.ConvTranspose2d(256, 128, 4, 2, 1, bias = False)
        self.deconv4_bn = nn.BatchNorm2d(128)
        #self.deconv5 = nn.ConvTranspose2d(128, 1, 4, 2, 1)
        self.deconv5 = nn.ConvTranspose2d(128, 1, 4, 2, 1)
    
    
    def weight_init(self):
        for m in self._modules:
            normal_init(self._modules[m])
      
    def forward(self, z):
        x = self.deconv1_bn(self.deconv1(z))
        x = F.relu(self.deconv2_bn(self.deconv2(x)))
        x = F.relu(self.deconv3_bn(self.deconv3(x)))
        x = F.relu(self.deconv4_bn(self.deconv4(x)))
        x = torch.tanh(self.deconv5(x))
        return x

In [3]:
c_dim = 10
v_dim = 100
z_dim = 100
class InverseAutoencoder(nn.Module):
    def __init__(self):
        super(InverseAutoencoder, self).__init__()
        ## Encoding: Unconditional samples
        self.conv1 = nn.Conv2d(1, 128, 4, 2, 1) # Input: (bs, 3, img_size, img_size)
        self.conv2 = nn.Conv2d(128, 256, 4, 2, 1, bias = False)
        self.conv2_bn = nn.BatchNorm2d(256)
        self.conv3 = nn.Conv2d(256, 512, 4, 2, 1, bias = False)
        self.conv3_bn = nn.BatchNorm2d(512)
        self.conv4 = nn.Conv2d(512, 1024, 4, 2, 1, bias = False)
        self.conv4_bn = nn.BatchNorm2d(1024)
        
        self.conv5v = nn.Conv2d(1024, v_dim, 4, 1, 0) # Output: (bs, c_dim, 1, 1)
        self.conv5c = nn.Conv2d(1024, c_dim, 4, 1, 0) # Output, same as above: but this one to condition-space
        
        ## Decoding:
        self.deconv1v = nn.ConvTranspose2d(v_dim, 1024, 4, 1, 0, bias = False) # Not sure how this looks
        self.deconv1c = nn.ConvTranspose2d(c_dim, 1024, 4, 1, 0, bias = False) # Input: (bs, cdim+v_dim, 1, 1)
        
        self.deconv1_bn = nn.BatchNorm2d(1024)
        self.deconv2 = nn.ConvTranspose2d(1024+1024, 512, 4, 2, 1, bias = False)
        self.deconv2_bn = nn.BatchNorm2d(512)
        self.deconv3 = nn.ConvTranspose2d(512, 256, 4, 2, 1, bias = False)
        self.deconv3_bn = nn.BatchNorm2d(256)
        self.deconv4 = nn.ConvTranspose2d(256, 128, 4, 2, 1, bias = False)
        self.deconv4_bn = nn.BatchNorm2d(128)
        self.deconv5 = nn.ConvTranspose2d(128, 1, 4, 2, 1)
    
    def weight_init(self):
        for m in self._modules:
            normal_init(self._modules[m])
            
    def encode(self, x):
        # Encode data x to 2 spaces: condition space and variance-space
        x = F.leaky_relu(self.conv1(x), 0.2)
        x = F.leaky_relu(self.conv2_bn(self.conv2(x)), 0.2)
        x = F.leaky_relu(self.conv3_bn(self.conv3(x)), 0.2)
        x = F.leaky_relu(self.conv4_bn(self.conv4(x)), 0.2)
        
        v = torch.sigmoid(self.conv5v(x)) # Variance-space unif~[0,1]
        c = torch.sigmoid(self.conv5c(x)) # this is softmax for CLASSIFICATION. Shapes3d is not 1-classif..
        
        return v, c
      
    def forward(self, v, c):
        # This is actually conditional generation // decoding.
        # It's beneficial to call this forward, though, for FJD calculation
        v = self.deconv1_bn(self.deconv1v(v))
        c = self.deconv1_bn(self.deconv1c(c))
        x = torch.cat((v, c), dim = 1) #stack on channel dim, should be (bs, vdim+cdim, 1, 1). Not sure here
        x = F.relu(self.deconv2_bn(self.deconv2(x)))
        x = F.relu(self.deconv3_bn(self.deconv3(x)))
        x = F.relu(self.deconv4_bn(self.deconv4(x)))
        x = torch.tanh(self.deconv5(x))
        return x
    
    def pass_thru(self, v, c):
        x = self.forward(v, c)
        return self.encode(x)
  
class C_Generator(nn.Module):
    def __init__(self):
        super(C_Generator, self).__init__()
        ## Decoding:
        self.deconv1v = nn.ConvTranspose2d(v_dim, 1024, 4, 1, 0, bias = False) # Not sure how this looks
        self.deconv1c = nn.ConvTranspose2d(c_dim, 1024, 4, 1, 0, bias = False) # Input: (bs, cdim+v_dim, 1, 1)
        
        self.deconv1_bn = nn.BatchNorm2d(1024)
        self.deconv2 = nn.ConvTranspose2d(1024+1024, 512, 4, 2, 1, bias = False)
        self.deconv2_bn = nn.BatchNorm2d(512)
        self.deconv3 = nn.ConvTranspose2d(512, 256, 4, 2, 1, bias = False)
        self.deconv3_bn = nn.BatchNorm2d(256)
        self.deconv4 = nn.ConvTranspose2d(256, 128, 4, 2, 1, bias = False)
        self.deconv4_bn = nn.BatchNorm2d(128)
        #self.deconv5 = nn.ConvTranspose2d(128, 3, 4, 2, 1)
        self.deconv5 = nn.ConvTranspose2d(128, 3, 1, 1, 0)
    
    
    def weight_init(self):
        for m in self._modules:
            normal_init(self._modules[m])
      
    def forward(self, v, c):
        v = self.deconv1_bn(self.deconv1v(v))
        c = self.deconv1_bn(self.deconv1c(c))
        x = torch.cat((v, c), dim = 1) #stack on channel dim, should be (bs, vdim+cdim, 1, 1). Not sure here
        x = F.relu(self.deconv2_bn(self.deconv2(x)))
        #print('after2', x.shape)
        x = F.relu(self.deconv3_bn(self.deconv3(x)))
        #print('after3', x.shape)
        x = F.relu(self.deconv4_bn(self.deconv4(x)))
        #print('after4', x.shape)
        x = torch.tanh(self.deconv5(x))
        #print('after5', x.shape)
        return x
    
def one_hot_embedding(labels):
    #y = torch.eye(num_classes)
    #return y[labels]
    #return torch.nn.functional.one_hot(labels)[:,1:]
    
    labels = torch.nn.functional.one_hot(torch.tensor(labels).to(torch.int64), num_classes = c_dim)
    return torch.squeeze(labels)

def print_g_sample():
    with torch.no_grad():
        codes = one_hot_embedding(torch.tensor(list(range(9)), device = device)).view(9,c_dim,1,1).float()
        varis = torch.randn((9, v_dim,1,1), device = device) # walk from [0,...,0] to [1,...,1]
        generated = .5*(IAE(varis, codes).cpu() + 1)
        generated = torch.squeeze(generated)
        #print(generated.shape)
        for i in range(9):
            plt.subplot(330 + 1 + i)
            # plot raw pixel data
            element = generated[i,:]
            plt.imshow(element, cmap = 'gray')
        plt.show()
        
#print_g_sample()

In [4]:
def get_dataloaders():
    #transform = transforms.Compose(
    #    [transforms.ToTensor(),
    #     transforms.Normalize(mean=(0.5, 0.5, 0.5), 
    #                          std=(0.5, 0.5, 0.5))])
    
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Resize(32), #64
         transforms.Normalize(mean=[0.5], 
                              std=[0.5])])

    train_set = CIFAR10(root='./data',
                        train=True,
                        download=True,
                        transform=transform)

    test_set = CIFAR10(root='./data',
                       train=False,
                       download=True,
                       transform=transform)

    train_loader = DataLoader(train_set,
                              batch_size=128,
                              shuffle=True,
                              drop_last=True)

    test_loader = DataLoader(test_set,
                             batch_size=128,
                             shuffle=True,
                             drop_last=True)
    
    x = next(iter(train_loader))
    print(x[0].shape)
    print(x[1].shape)
    print(torch.min(x[0]), torch.max(x[0]))

    return train_loader, test_loader

#print_g_sample()

In [5]:
## My GANWrapper
class GANWrapper:
    def __init__(self, model):
        self.model = model

    def __call__(self, y):
        
        y_nat = y
        if audit == True:
            print(y[0:9])

        batch_size = y.size(0)
        y = onehot_embedding(y).to(device).view(batch_size, c_dim, 1, 1)

        z = torch.randn((batch_size, v_dim, 1, 1), device = device)
        
        #samples = G(z, y)
        samples = self.model(z)
        
        if audit == True:
            for i in range(9):
                ind = y_nat[i]
                plt.subplot(330 + 1 + i).set_title(str(classes[ind]))
                # plot raw pixel data
                element = 0.5*(samples[i,:].permute(1,2,0).cpu() + 1)
                plt.imshow(element, cmap = 'gray')
            plt.show()
        
        return samples
    
#print_g_sample()

In [6]:
audit = False

  
#IAE = InverseAutoencoder().to(device)
#IAE.load_state_dict(torch.load('models/icaegan_mnist_G.pt'))

#CG = C_Generator().to(device)
#CG.load_state_dict(torch.load('models/cgan_cifar_G.pt'))

from supervised.g_arches import rgb_32_G

G = rgb_32_G(z_dim).to(device)
G.load_state_dict(torch.load('models/gan_cifar_60e_G.pt'))

train_loader, test_loader = get_dataloaders()
inception_embedding = InceptionEmbedding(parallel=False)
onehot_embedding = OneHotEmbedding(num_classes=10)
#gan = SuspiciouslyGoodGAN()
#gan = InverseAutoencoder()
#gan = CG
#params = 'models/gan_mnist_G.pt'
gan = GANWrapper(G)
#gan = GANWrapper(model = gan, model_checkpoint = params)
#gan.print_sample()


Files already downloaded and verified
Files already downloaded and verified
torch.Size([128, 3, 32, 32])
torch.Size([128])
tensor(-1.) tensor(1.)


Downloading: "https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth" to C:\Users\Alex/.cache\torch\hub\checkpoints\inception_v3_google-1a9a5a14.pth


HBox(children=(FloatProgress(value=0.0, max=108857766.0), HTML(value='')))




In [7]:
fjd_metric = FJDMetric(gan=gan,
                       reference_loader=train_loader, #cifar10 train
                       condition_loader=test_loader, #cifar10 test
                       image_embedding=inception_embedding, #dont change
                       condition_embedding=onehot_embedding,
                       reference_stats_path='datasets/cifar_train_stats.npz',
                       save_reference_stats=True,
                       samples_per_condition=10,
                       cuda=True)

In [8]:
fid = fjd_metric.get_fid()
fjd = fjd_metric.get_fjd()
print('FID: ', fid)
print('FJD: ', fjd)

Computing generated distribution:   0%|                                                         | 0/78 [00:00<?, ?it/s]

Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|████████████████████████████████████████████████| 78/78 [15:15<00:00, 11.74s/it]
Computing generated distribution:  41%|███████████████████▋                            | 32/78 [17:58<25:50, 33.70s/it]


KeyboardInterrupt: 

In [None]:
bce_loss = nn.BCELoss()
#IAE.to(device)

In [None]:
def check_accuracy(test_loader: DataLoader, model: nn.Module, device):
    total = 0
    i = 0
    model.eval()

    with torch.no_grad():
        for data, labels in test_loader:
            data = data.to(device=device)
            labels = onehot_embedding(labels.to(device=device))

            v, predictions = IAE.encode(data)
            loss = bce_loss(predictions.view(-1, c_dim), labels)
            total += loss.data.item()
            i+=1
            
        print(total / i)
        
check_accuracy(test_loader, IAE, device)