# IMPORTS

In [1]:
import os
import numpy as np
import tarfile
import time
import scipy.io
import math
import pickle

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torch.optim as optim
from torchvision.utils import save_image
from torch.distributions.normal import Normal

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

import matplotlib.pyplot as plt
%matplotlib inline

# PROB FUNCTIONS

In [2]:
PI = torch.from_numpy(np.asarray(np.pi))
EPS = 1.e-5

def log_categorical(x, p, num_classes=256, reduction=None, dim=None):
    x_one_hot = F.one_hot(x.long(), num_classes=num_classes)

    log_p = x_one_hot * torch.log(torch.clamp(p, EPS, 1. - EPS))
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p

def log_bernoulli(x, p, reduction=None, dim=None):
    pp = torch.clamp(p, EPS, 1. - EPS)
    log_p = x * torch.log(pp) + (1. - x) * torch.log(1. - pp)
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p

def log_normal_diag(x, mu, log_var, reduction=None, dim=None):
    log_p = -0.5 * torch.log(2. * PI) - 0.5 * log_var - 0.5 * torch.exp(-log_var) * (x - mu)**2.
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p

def log_standard_normal(x, reduction=None, dim=None):
    log_p = -0.5 * torch.log(2. * PI) - 0.5 * x**2.
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p
    
def init_weights(m):
    if (type(m) == nn.Linear):
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)

# LOAD DATA

In [3]:
def load_data(dataset='MNIST', bs=64, n_instances=5000):
    if dataset == 'MNIST':
        transform = transforms.Compose([transforms.ToTensor()])

        trainset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
        trainset, valset = torch.utils.data.random_split(trainset, [n_instances, 60000 - n_instances])
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=bs, shuffle=True)
        valloader =  torch.utils.data.DataLoader(valset, batch_size=bs, shuffle=True)
        
        c, h, w = 1, 28, 28
        
    else: #dataset == 'SVHN'
        mat = scipy.io.loadmat('./data/train_32x32.mat')
        trainset = torch.utils.data.TensorDataset((torch.Tensor(mat['X']) / 255).permute(3, 2, 0, 1),
                                                 torch.Tensor(mat['y']))
        trainset, valset = torch.utils.data.random_split(trainset, [n_instances, 73257 - n_instances])
        trainloader =  torch.utils.data.DataLoader(trainset, batch_size=bs, shuffle=True)
        valloader =  torch.utils.data.DataLoader(valset, batch_size=bs, shuffle=True)
        
        c, h, w = 3, 32, 32
        
    return trainloader, valloader, (c, h, w)

# DEF VAE

In [4]:
class Encoder(nn.Module):
    def __init__(self, D, M, variant='Normal', partial=True):
        super(Encoder, self).__init__()
        self.D = D
        self.M = M
        self.variant = variant
        self.partial = partial
        
        self.M2 = self.M*2 if self.partial else self.M   
        
        self.enc1 = nn.Linear(self.D, 256)
        self.enc2 = nn.Linear(self.enc1.out_features, self.enc1.out_features//2)
        self.enc3 = nn.Linear(self.enc2.out_features, self.enc2.out_features//2)
        self.enc4 = nn.Linear(self.enc3.out_features, self.M2)

        if self.variant != 'Normal':
            self.enc1.weight.requires_grad = False
            self.enc1.bias.requires_grad = False
            self.enc2.weight.requires_grad = False
            self.enc2.bias.requires_grad = False
            self.enc3.weight.requires_grad = False
            self.enc3.bias.requires_grad = False
            if not self.partial:
                self.enc4.weight.requires_grad = False
                self.enc4.bias.requires_grad = False
        
        
    def forward(self, x):
        x = x.view(x.shape[0], -1)
        x = F.leaky_relu(self.enc1(x), 0.2)
        x = F.leaky_relu(self.enc2(x), 0.2)
        x = F.leaky_relu(self.enc3(x), 0.2)
        x = self.enc4(x)
        return torch.chunk(x, 2, dim=1)
    
    
    def reparameterize(self, mu, log_std):
        if (not self.partial) and (self.variant != 'Normal'):
            mu = torch.cat((mu, log_std), dim=1)
            std = 0.001
            eps = torch.randn_like(mu)
            return mu + (eps*std)
        else:
            std = log_std.exp()
            eps = torch.randn_like(std)
            return mu + (eps*std)
    
    
    def log_prob(self, z, mu, log_std):
        if (not self.partial) and (self.variant != 'Normal'):
            mu = torch.cat((mu, log_std), dim=1)
            return log_normal_diag(z, mu, (torch.ones(z.shape)*torch.log(torch.tensor(0.001**2))).to(device))
        else:
            return log_normal_diag(z, mu, log_std)

In [5]:
class Decoder(nn.Module):
    def __init__(self, D, M):
        super(Decoder, self).__init__()
        self.D = D
        self.M = M
        
        self.dec1 = nn.Linear(self.M, 128) 
        self.dec2 = nn.Linear(self.dec1.out_features, self.dec1.out_features*2)
        self.dec3 = nn.Linear(self.dec2.out_features, self.dec2.out_features*2)
        self.dec4 = nn.Linear(self.dec3.out_features, self.dec3.out_features*2)        
        self.dec5 = nn.Linear(self.dec4.out_features, self.D)
    
    
    def forward(self, z):
        x = F.leaky_relu(self.dec1(z), 0.2)
        x = F.leaky_relu(self.dec2(x), 0.2)
        x = F.leaky_relu(self.dec3(x), 0.2)
        x = F.leaky_relu(self.dec4(x), 0.2)
        return (self.dec5(x))
    
    
    def log_prob(self, x_hat, x):
        x = x.view(x.shape[0], -1)
        RE = F.mse_loss(x, x_hat, reduction='none').sum(1)
        return RE

In [6]:
class Prior(nn.Module):
    def __init__(self, L):
        super(Prior, self).__init__()
        self.L = L

    def sample(self, batch_size):
        z = torch.randn((batch_size, self.L))
        return z

    def log_prob(self, z):
        return log_standard_normal(z)

In [7]:
class FlowPrior(nn.Module):
    def __init__(self, num_flows, D=50, M=256):
        super(FlowPrior, self).__init__()
        
        nets = lambda: nn.Sequential(nn.Linear(D // 2, M), nn.LeakyReLU(),
                             nn.Linear(M, M), nn.LeakyReLU(),
                             nn.Linear(M, D // 2), nn.Tanh())

        nett = lambda: nn.Sequential(nn.Linear(D // 2, M), nn.LeakyReLU(),
                             nn.Linear(M, M), nn.LeakyReLU(),
                             nn.Linear(M, D // 2))

        self.D = D
        self.t = torch.nn.ModuleList([nett() for _ in range(num_flows)])
        self.s = torch.nn.ModuleList([nets() for _ in range(num_flows)])
        self.num_flows = num_flows

    def coupling(self, x, index, forward=True):
        (xa, xb) = torch.chunk(x, 2, 1)
        xa, xb = xa.to(device), xb.to(device)
        s = self.s[index](xa)
        t = self.t[index](xa)

        if forward:
            #yb = f^{-1}(x)
            yb = (xb - t) * torch.exp(-s)
        else:
            #xb = f(y)
            yb = torch.exp(s) * xb + t

        return torch.cat((xa, yb), 1), s

    def permute(self, x):
        return x.flip(1)

    def f(self, x):
        log_det_J, z = x.new_zeros(x.shape[0]), x
        for i in range(self.num_flows):
            z, s = self.coupling(z, i, forward=True)
            z = self.permute(z)
            log_det_J = log_det_J - s.sum(dim=1)

        return z, log_det_J

    def f_inv(self, z):
        x = z
        for i in reversed(range(self.num_flows)):
            x = self.permute(x)
            x, _ = self.coupling(x, i, forward=False)

        return x

    def sample(self, batch_size):
        z = torch.randn(batch_size, self.D)
        x = self.f_inv(z)
        return x.view(-1, self.D)

    def log_prob(self, x):
        z, log_det_J = self.f(x)
        return log_standard_normal(z), log_det_J.unsqueeze(1)


In [8]:
class VAE(nn.Module):
    def __init__(self, D, M, variant='Normal', partial=True):
        super(VAE, self).__init__()
        self.D = D
        self.M = M
        self.variant = variant
        self.partial = partial
        
        self.encoder = Encoder(self.D, self.M, variant=self.variant, partial=self.partial)
        self.decoder = Decoder(self.D, self.M)
        if prior == 'Flow':
            self.prior = FlowPrior(flows, D = M)
        else:
            self.prior = Prior(M)
        
            
    def forward(self, x, reduction='avg'):
        mu, log_var = self.encoder.forward(x)
        z = self.encoder.reparameterize(mu, log_var)
        x_hat = self.decoder.forward(z)

        RE = self.decoder.log_prob(x_hat, x)
        ENC = self.encoder.log_prob(z, mu, log_var)
        if prior == 'Flow':
            stdn, log_det_J = self.prior.log_prob(z)
            return (RE, ENC, stdn, log_det_J)
        else:
            stdn = self.prior.log_prob(z)
            return (RE, ENC, stdn)
        
        
    def reconstruct(self, n=8):
        for batch_idx, (x, _) in enumerate(valloader):
            mu, log_std = self.encoder.forward(x[:n].to(device))
            z = self.encoder.reparameterize(mu, log_std)
            x_hat = self.decoder.forward(z)

            save_image(torch.cat((x[:n].view(n, c, h, w), 
                                   x_hat.view(n, c, h, w).cpu())), 
                                 f"{dataset}/fc/{variant+str(partial)}/{prior+str(flows)}/{subs}_reconstructed_{epoch}.png", nrow=n)
            break
    
    
    def generate(self, n=8):
        z = self.prior.sample(n**2).to(device)
        x_hat = model.decoder(z)
        save_image(x_hat.view(n**2, c, h, w).cpu(), 
                   f"{dataset}/fc/{variant+str(partial)}/{prior+str(flows)}/{subs}_generated_{epoch}.png", nrow=n)  
    

# Initialize + Hyperparameters

In [None]:
def get_subs():
    subs = 'F'
    if variant == 'RandomNet' and partial == True:
        subs += 'P'
    elif variant == 'RandomNet' and partial == False:
        subs += 'N'
    else:
        subs += 'L'

    if prior == 'Flow':
        subs += 'F'+str(flows)
    else:
        subs += 'N'
    return subs

In [None]:
dataset = 'MNIST'
variant = 'RandomNet'
partial = False
mult = 2 if partial else 1
prior = 'Flow'
n_instances = 25000
bs = 256
flows = 10
n_epochs = 500
lr = 0.001
m=32

subs = get_subs()

In [None]:
trainloader, valloader, (c, h, w) = load_data(dataset=dataset, n_instances=n_instances, bs=bs)

model = VAE(c*h*w, m, variant=variant, partial=partial)
model.to(device)
#model.apply(init_weights)
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = lr)

n_weights = 0
for parameter in model.parameters():
    if parameter.requires_grad:
        n_weights += torch.tensor(parameter.shape).sum()

print(f'There are {n_weights} weights in the VAE.')

# TRAIN MODEL

In [None]:
losses, RElosses, stdnlosses, KLlosses = [], [], [], []
timer = []
for epoch in range(1, n_epochs+1):
    t = time.time()
    epochloss, REloss, stdnloss, KLloss = 0, 0, 0, 0
    for batch_idx, (x, _) in enumerate(trainloader):
        if prior == 'Flow':
            RE, ENC, stdn, log_det_J = model.forward(x.to(device))
            KL = (stdn + log_det_J - ENC).mean(-1)
            loss = -(-RE + KL).mean()
        else:
            RE, ENC, stdn = model.forward(x.to(device))
            KL =  (stdn - ENC).sum(-1)
            loss = -(-RE + KL).mean()
            
        
        model.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()
        
        epochloss += loss.item()
        REloss += RE.mean().item()
        stdnloss += stdn.mean().item()
        KLloss += KL.mean().item()
        
    losses.append(epochloss/len(trainloader))
    RElosses.append(REloss/len(trainloader))
    stdnlosses.append(stdnloss/len(trainloader))
    KLlosses.append(KLloss/len(trainloader))
    
    print('[%d/%d]: loss: %.3f || RE: %.3f || KL: %.3f || stdn: %.3f' % ((epoch), n_epochs, epochloss/len(trainloader), 
                                                                         REloss/len(trainloader), KLloss/len(trainloader),
                                                                         stdnloss/len(trainloader)))
    timer.append(round(time.time() - t, 5))
    
    if epoch%10 == 0:
        model.reconstruct()
        model.generate()

In [None]:
info={'Dataset': dataset,
     'Variant': variant,
     'Partial': partial,
     'Prior': prior,
     'n_instances': n_instances,
     'batch_size': bs,
     'flows': flows,
     'n_epochs': n_epochs,
     'learning_rate': lr,
     'n_weights': n_weights,
     'losses': list(np.asarray(losses).round(2)),
     'RE': list(np.asarray(RElosses).round(2)),
     'stdn': list(np.asarray(stdnlosses).round(2)),
     'KL': list(np.asarray(KLlosses).round(2)),
     'Times': list(np.asarray(timer).round(4))}

pickle.dump(info, open(f'{dataset}/fc/{variant+str(partial)}/{prior+str(flows)}/info.p', 'wb'))
torch.save(model, f'{dataset}/fc/{variant+str(partial)}/{prior+str(flows)}/model.pt')

In [21]:
dataset = 'MNIST'
n_instances = 25000
bs = 256
n_epochs = 500
lr = 0.001
m = 32

In [22]:
def get_subs():
    subs = 'F'
    if variant == 'RandomNet' and partial == True:
        subs += 'P'
    elif variant == 'RandomNet' and partial == False:
        subs += 'N'
    else:
        subs += 'L'

    if prior == 'Flow':
        subs += 'F'+str(flows)
    else:
        subs += 'N'
    return subs

In [23]:
for variant, partial in [('Normal', True), ('RandomNet', True), ('RandomNet', False)]:
    for prior, flows in [('Normal', 5), ('Flow', 5), ('Flow', 10)]:
        model = torch.load(f'{dataset}/fc/{variant+str(partial)}/{prior+str(flows)}/model.pt')
        subs = get_subs()
        
        m = torch.zeros((8, 8, 32))
        m[0] = model.prior.sample(8)
        m[7] = model.prior.sample(8)        
              
        for i in range(1, 7):
            m[i] = m[0] - (((m[0] - m[7]) / 7)*i)
            
        for row in m.permute(1,0,2):
            x_hat = model.decoder.forward(torch.Tensor(row).to(device)).reshape(8, 28, 28)
            
            if "result_" in dir():
                result_ = torch.cat((result_, x_hat))
            else:
                result_ = x_hat 
                
        save_image(result_.unsqueeze(1).detach().cpu(), 
                   f"{subs}_interpolation.png", nrow=8)
        del result_