# IMPORTS

In [None]:
import os
import numpy as np
import tarfile
import time
import scipy.io
import math
import pickle

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torch.optim as optim
from torchvision.utils import save_image
from torch.distributions.normal import Normal

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

import matplotlib.pyplot as plt
%matplotlib inline

# PROB FUNCTIONS

In [None]:
PI = torch.from_numpy(np.asarray(np.pi))
EPS = 1.e-5

def log_categorical(x, p, num_classes=256, reduction=None, dim=None):
    x_one_hot = F.one_hot(x.long(), num_classes=num_classes)

    log_p = x_one_hot * torch.log(torch.clamp(p, EPS, 1. - EPS))
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p

def log_bernoulli(x, p, reduction=None, dim=None):
    pp = torch.clamp(p, EPS, 1. - EPS)
    log_p = x * torch.log(pp) + (1. - x) * torch.log(1. - pp)
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p

def log_normal_diag(x, mu, log_var, reduction=None, dim=None):
    log_p = -0.5 * torch.log(2. * PI) - 0.5 * log_var - 0.5 * torch.exp(-log_var) * (x - mu)**2.
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p

def log_standard_normal(x, reduction=None, dim=None):
    log_p = -0.5 * torch.log(2. * PI) - 0.5 * x**2.
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p
    
def init_weights(m):
    if (type(m) == nn.Linear):
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.05)

# LOAD DATA

In [None]:
def load_data(dataset='MNIST', bs=64, n_instances=5000):
    if dataset == 'MNIST':
        transform = transforms.Compose([transforms.ToTensor(), transforms.Pad(2)])
        trainset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
        trainset = transforms.Pad(2)(trainset.data).unsqueeze(1) / 255
        
        trainset, valset = torch.utils.data.random_split(trainset, [n_instances, 60000 - n_instances])
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=bs, shuffle=True)
        valloader =  torch.utils.data.DataLoader(valset, batch_size=bs, shuffle=True)
        c, h, w = 1, 32, 32
        
    else: #dataset == 'SVHN'
        mat = scipy.io.loadmat('./data/train_32x32.mat')
        trainset = torch.utils.data.TensorDataset((torch.Tensor(mat['X']) / 255).permute(3, 2, 0, 1),
                                                 torch.Tensor(mat['y']))
        trainset, valset = torch.utils.data.random_split(trainset, [n_instances, 73257 - n_instances])
        trainloader =  torch.utils.data.DataLoader(trainset, batch_size=bs, shuffle=True)
        valloader =  torch.utils.data.DataLoader(valset, batch_size=bs, shuffle=True)
        
        c, h, w = 3, 32, 32
        
    return trainloader, valloader, (c, h, w)

# DEF VAE

In [None]:
class Encoder(nn.Module):
    def __init__(self, D, M, variant='Normal', partial=True):
        super(Encoder, self).__init__()
        self.D = D
        self.M = M
        self.variant = variant
        self.partial = partial
        self.std = torch.tensor(1e-3, requires_grad=False)
        self.logvar = torch.log(torch.ones(bs, M, requires_grad=False)*(self.std**2)).to(device)
        
        if self.variant != 'RKS':
            self.enc1 = nn.Linear(self.D, 256)
            self.enc2 = nn.Linear(self.enc1.out_features, self.enc1.out_features//2)
            self.enc3 = nn.Linear(self.enc2.out_features, self.enc2.out_features//2)
            self.enc4 = nn.Linear(self.enc3.out_features, self.M*mult)

            if self.variant != 'Normal':
                self.enc1.weight.requires_grad = False
                self.enc1.bias.requires_grad = False
                self.enc2.weight.requires_grad = False
                self.enc2.bias.requires_grad = False
                self.enc3.weight.requires_grad = False
                self.enc3.bias.requires_grad = False
                if not self.partial:
                    self.enc4.weight.requires_grad = False
                    self.enc4.bias.requires_grad = False
        
        else:
            self.enc1 = nn.Linear(self.D, self.M*mult)
            self.enc1.weight.requires_grad = False
            self.enc1.bias.requires_grad = False
    
        
    def forward(self, x):
        if self.variant != 'RKS':
            x = x.view(x.shape[0], -1)
            x = F.leaky_relu(self.enc1(x), 0.2)
            x = F.leaky_relu(self.enc2(x), 0.2)
            x = F.leaky_relu(self.enc3(x), 0.2)
            x = self.enc4(x)
            if self.partial:
                return torch.chunk(x, 2, dim=1)
            else:
                return x, None
        else:
            x = x.view(x.shape[0], -1)
            return self.enc1(x), None
    
    
    def reparameterize(self, mu, log_std):
        if (not self.partial) and (self.variant != 'Normal'):
            eps = torch.randn_like(mu)
            return mu + (eps*self.std)
        else:
            std = log_std.exp()
            eps = torch.randn_like(std)
            return mu + (eps*std)
    
    
    def log_prob(self, z, mu, log_std):
        if (not self.partial) and (self.variant != 'Normal'):
            return log_normal_diag(z, mu, self.logvar[:z.shape[0]])
        else:
            return log_normal_diag(z, mu, log_std)

In [None]:
class Decoder(nn.Module):
    def __init__(self, D, M):
        super(Decoder, self).__init__()
        self.D = D
        self.M = M
        
        self.dec1 = nn.Linear(self.M, 128) 
        self.dec2 = nn.Linear(self.dec1.out_features, self.dec1.out_features*2)
        self.dec3 = nn.Linear(self.dec2.out_features, self.dec2.out_features*2)
        self.dec4 = nn.Linear(self.dec3.out_features, self.dec3.out_features*2)        
        self.dec5 = nn.Linear(self.dec4.out_features, self.D)
    
    
    def forward(self, z):
        x = F.leaky_relu(self.dec1(z), 0.2)
        x = F.leaky_relu(self.dec2(x), 0.2)
        x = F.leaky_relu(self.dec3(x), 0.2)
        x = F.leaky_relu(self.dec4(x), 0.2)
        return (self.dec5(x))
    
    
    def log_prob(self, x_hat, x):
        x = x.view(x.shape[0], -1)
        RE = F.mse_loss(x, x_hat, reduction='none').sum(1)
        return RE

In [None]:
class Prior(nn.Module):
    def __init__(self, L):
        super(Prior, self).__init__()
        self.L = L

    def sample(self, batch_size):
        z = torch.randn((batch_size, self.L))
        return z

    def log_prob(self, z):
        return log_standard_normal(z)

In [None]:
class FlowPrior(nn.Module):
    def __init__(self, num_flows, D=50, M=64):
        super(FlowPrior, self).__init__()
        
        nets = lambda: nn.Sequential(nn.Linear(D // 2, M), nn.LeakyReLU(),
                             nn.Linear(M, M), nn.LeakyReLU(),
                             nn.Linear(M, D // 2), nn.Tanh())

        nett = lambda: nn.Sequential(nn.Linear(D // 2, M), nn.LeakyReLU(),
                             nn.Linear(M, M), nn.LeakyReLU(),
                             nn.Linear(M, D // 2))

        self.D = D
        self.t = torch.nn.ModuleList([nett() for _ in range(num_flows)])
        self.s = torch.nn.ModuleList([nets() for _ in range(num_flows)])
        self.num_flows = num_flows

    def coupling(self, x, index, forward=True):
        (xa, xb) = torch.chunk(x, 2, 1)
        xa, xb = xa.to(device), xb.to(device)
        s = self.s[index](xa)
        t = self.t[index](xa)

        if forward:
            #yb = f^{-1}(x)
            yb = (xb - t) * torch.exp(-s)
        else:
            #xb = f(y)
            yb = torch.exp(s) * xb + t

        return torch.cat((xa, yb), 1), s

    def permute(self, x):
        return x.flip(1)

    def f(self, x):
        log_det_J, z = x.new_zeros(x.shape[0]), x
        for i in range(self.num_flows):
            z, s = self.coupling(z, i, forward=True)
            z = self.permute(z)
            log_det_J = log_det_J - s.sum(dim=1)

        return z, log_det_J

    def f_inv(self, z):
        x = z
        for i in reversed(range(self.num_flows)):
            x = self.permute(x)
            x, _ = self.coupling(x, i, forward=False)

        return x

    def sample(self, batch_size):
        z = torch.randn(batch_size, self.D)
        x = self.f_inv(z)
        return x.view(-1, self.D)

    def log_prob(self, x):
        z, log_det_J = self.f(x)
        return log_standard_normal(z) + log_det_J.unsqueeze(1)


In [None]:
class VAE(nn.Module):
    def __init__(self, c, D, M, variant='Normal', partial=True):
        super(VAE, self).__init__()
        self.D = D
        self.M = M
        self.c = c
        self.variant = variant
        self.partial = partial
        
        self.encoder = Encoder(self.D, self.M, variant=self.variant, partial=self.partial)
        self.decoder = Decoder(self.D, self.M)
        if prior == 'Flow':
            self.prior = FlowPrior(flows, D = M)
        else:
            self.prior = Prior(M)
        
            
    def forward(self, x, reduction='avg'):
        mu, log_var = self.encoder.forward(x)
        z = self.encoder.reparameterize(mu, log_var)
        x_hat = self.decoder.forward(z)

        RE = self.decoder.log_prob(x_hat, x)
        KL = (self.prior.log_prob(z) - self.encoder.log_prob(z, mu, log_var)).mean(-1)
        return RE, KL
        
        
    def reconstruct(self, n=8):
        for batch_idx, (x, _) in enumerate(valloader):
            mu, log_std = self.encoder.forward(x[:n].to(device))
            z = self.encoder.reparameterize(mu, log_std)
            x_hat = self.decoder.forward(z)

            save_image(torch.cat((x[:n].view(n, c, h, w), 
                                   x_hat.view(n, c, h, w).cpu())), 
                                 f"{dataset}-{str(REPEAT)}/fc/{variant+str(partial)}/{prior+str(flows)}/{subs}_reconstructed_{epoch}.png", nrow=n)
            break
    
    
    def generate(self, n=8):
        z = self.prior.sample(n**2).to(device)
        x_hat = self.decoder(z)
        save_image(x_hat.view(n**2, c, h, w).cpu(), 
                   f"{dataset}-{str(REPEAT)}/fc/{variant+str(partial)}/{prior+str(flows)}/{subs}_generated_{epoch}.png", nrow=n)
        
    
    def interpolate(self):
        m = torch.zeros((8, 8, M))
        m[0] = self.prior.sample(8)
        m[7] = self.prior.sample(8)        
              
        for i in range(1, 7):
            m[i] = m[0] - (((m[0] - m[7]) / 7)*i)
            
        for row in m.permute(1,0,2):
            x_hat = self.decoder.forward(torch.Tensor(row).to(device)).reshape(8, self.c, 32, 32)
            
            if "result_" in dir():
                result_ = torch.cat((result_, x_hat))
            else:
                result_ = x_hat 
                
        save_image(result_.detach().cpu(), 
                   f"{dataset}-{str(REPEAT)}/fc/{variant+str(partial)}/{prior+str(flows)}/{subs}_interpolation_{epoch}.png", nrow=8)
        del result_
    

# Initialize + Hyperparameters

In [None]:
def get_subs():
    subs = 'F'
    if variant == 'RandomNet' and partial == True:
        subs += 'P'
    elif variant == 'RandomNet' and partial == False:
        subs += 'N'
    elif variant == 'RKS':
        subs += 'R'
    else:
        subs += 'L'

    if prior == 'Flow':
        subs += 'F'+str(flows)
    else:
        subs += 'N'
    return subs

In [None]:
def get_times():
    model = VAE(c, D, M, variant=variant, partial=partial)
    model.to(device)
    model.apply(init_weights)
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = lr)

    n_weights = 0
    for parameter in model.parameters():
        if parameter.requires_grad:
            n_weights += torch.tensor(parameter.shape).sum()

    print(f'There are {n_weights} weights in the VAE.')


    losses, RElosses, KLlosses = [], [], []
    timer = []
    for epoch in range(1, n_epochs+1):
        t = time.time()
        epochloss, REloss, stdnloss, KLloss = 0, 0, 0, 0
        for batch_idx, x in enumerate(trainloader):
            RE, KL = model.forward(x.to(device))
            loss = -(-RE + KL).mean()

            model.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()

            epochloss += loss.item()
            REloss += RE.mean().item()
            KLloss += KL.mean().item()

        losses.append(epochloss/len(trainloader))
        RElosses.append(REloss/len(trainloader))
        KLlosses.append(KLloss/len(trainloader))

        print('[%d/%d]: loss: %.3f || RE: %.3f || KL: %.3f' % ((epoch), n_epochs, epochloss/len(trainloader), 
                                                                                 REloss/len(trainloader), KLloss/len(trainloader)))
        timer.append(round(time.time() - t, 5))
        
    return timer

In [None]:
def experiment():
    global epoch
    model = VAE(c, D, M, variant=variant, partial=partial)
    model.to(device)
    model.apply(init_weights)
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = lr)

    n_weights = 0
    for parameter in model.parameters():
        if parameter.requires_grad:
            n_weights += torch.tensor(parameter.shape).sum()

    print(f'There are {n_weights} weights in the VAE.')
    
    
    losses, RElosses, KLlosses = [], [], []
    timer = []
    for epoch in range(1, n_epochs+1):
        t = time.time()
        epochloss, REloss, stdnloss, KLloss = 0, 0, 0, 0
        for batch_idx, (x, _) in enumerate(trainloader):
            RE, KL = model.forward(x.to(device))
            loss = -(-RE + KL).mean()


            model.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()

            epochloss += loss.item()
            REloss += RE.mean().item()
            KLloss += KL.mean().item()

        losses.append(epochloss/len(trainloader))
        RElosses.append(REloss/len(trainloader))
        KLlosses.append(KLloss/len(trainloader))

        print('[%d/%d]: loss: %.3f || RE: %.3f || KL: %.3f ' % ((epoch), n_epochs, epochloss/len(trainloader), 
                                                                             REloss/len(trainloader), KLloss/len(trainloader)))
        timer.append(round(time.time() - t, 5))

        if epoch%10 == 0:
            model.reconstruct()
            model.generate()
            model.interpolate()
    
    with torch.no_grad():
        test_loss, test_re, test_kl = 0, 0, 0
        
        for batch, _ in valloader:
            RE, KL = model.forward(batch.to(device))
            
            loss = -(-RE + KL).mean()
            test_loss += (loss.item() / len(valloader))
            test_re += (RE.mean().item() / len(valloader))
            test_kl += (KL.mean().item() / len(valloader))
    
    print('TEST: loss: %.3f || RE: %.3f || KL: %.3f ' % (test_loss, test_re, test_kl))
            
    info={'Dataset': dataset,
         'Variant': variant,
         'Partial': partial,
         'Prior': prior,
         'n_instances': n_instances,
         'batch_size': bs,
         'flows': flows,
         'n_epochs': n_epochs,
         'learning_rate': lr,
         'n_weights': n_weights,
         'losses': list(np.asarray(losses).round(4)),
         'RE': list(np.asarray(RElosses).round(4)),
         'KL': list(np.asarray(KLlosses).round(4)),
         'Times': list(np.asarray(timer).round(4)),
         'Test_loss': round(test_loss, 4),
         'Test_RE': round(test_re, 4),
         'Test_KL': round(test_kl, 4)}

    pickle.dump(info, open(f'{dataset}-{str(REPEAT)}/fc/{variant+str(partial)}/{prior+str(flows)}/info.p', 'wb'))
    torch.save(model, f'{dataset}-{str(REPEAT)}/fc/{variant+str(partial)}/{prior+str(flows)}/model.pt')

# Train Models

In [None]:
dataset = 'MNIST'
n_instances = 25000
bs = 256
n_epochs = 250
lr = 0.001
M = 32 if dataset == 'SVHN' else 16

for REPEAT in range(2, 6):
    print('ITERATION', REPEAT)
    trainloader, valloader, (c, h, w) = load_data(dataset=dataset, n_instances=n_instances, bs=bs)
    D=c*h*w
   
    for (variant, partial) in [('RandomNet', False), ('RandomNet', True), ('RKS', False), ('Normal', True)]:
        for (prior, flows) in [('Flow', 6), ('Flow', 2), ('Normal', 2)]:

            mult = 2 if partial else 1 
            subs = get_subs()
            print(subs)

            experiment()
            print()
            print()

## Compute FID-Scores

In [None]:
from pytorch_fid import fid_score
from pytorch_fid.inception import InceptionV3

In [None]:
dims = 2048

block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
inception = InceptionV3([block_idx]).to(device)

In [None]:
dataset = 'MNIST'
n_instances = 25000
bs = 256
n_epochs = 250
lr = 0.001
M = 8 if dataset == 'SVHN' else 4
n = 10000

trainloader, valloader, (c, h, w) = load_data(dataset=dataset, n_instances=n_instances, bs=n)
    
for batch in trainloader:
    break
            
for i, x in enumerate(batch):
    save_image(x, f"FID/Original/sample{i}.png")

og_m, og_s = fid_score.compute_statistics_of_path("FID/Original", model=inception, batch_size=100, 
                                                  dims=2048, device=device)

In [None]:
for REPEAT in range(1, 6):
    print('ITERATION', REPEAT)

    for (variant, partial) in [('RandomNet', False), ('RandomNet', True), ('RKS', False), ('Normal', True)]:
        for (prior, flows) in [('Flow', 6), ('Flow', 2), ('Normal', 2)]:
            model = torch.load(f'{dataset}-{str(REPEAT)}/fc/{variant+str(partial)}/{prior+str(flows)}/model.pt')
            info = pickle.load(open(f'{dataset}-{str(REPEAT)}/fc/{variant+str(partial)}/{prior+str(flows)}/info.p', 'rb'))
            subs=get_subs()
            
            with torch.no_grad():
                z = model.prior.sample(n).to(device)
                x_hat = model.decoder(z)

            for i, x in enumerate(x_hat):
                save_image(x.reshape(32, 32), f"FID/Generated/sample{i}.png")
            
            gen_m, gen_s = fid_score.compute_statistics_of_path("FID/Generated", model=inception, batch_size=100, 
                                                                dims=2048, device=device)
            fid = fid_score.calculate_frechet_distance(og_m, og_s, gen_m, gen_s)
            print(f'{subs}: {round(fid, 4)}')
            
            info['FID_score'] = fid
            pickle.dump(info, open(f'{dataset}-{str(REPEAT)}/fc/{variant+str(partial)}/{prior+str(flows)}/info.p', 'wb'))