## IMPORTS

In [1]:
import os
import numpy as np
import tarfile 
import time
import scipy.io
import math
import pickle

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torch.optim as optim
from torchvision.utils import save_image
from torch.distributions.normal import Normal

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

import matplotlib.pyplot as plt
%matplotlib inline

## DISTRIBUTIONS

In [2]:
PI = torch.from_numpy(np.asarray(np.pi))
EPS = 1.e-5

def log_categorical(x, p, num_classes=256, reduction=None, dim=None):
    x_one_hot = F.one_hot(x.long(), num_classes=num_classes)

    log_p = x_one_hot * torch.log(torch.clamp(p, EPS, 1. - EPS))
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p

def log_bernoulli(x, p, reduction=None, dim=None):
    pp = torch.clamp(p, EPS, 1. - EPS)
    log_p = x * torch.log(pp) + (1. - x) * torch.log(1. - pp)
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p

def log_normal_diag(x, mu, log_var, reduction=None, dim=None):
    log_p = -0.5 * torch.log(2. * PI) - 0.5 * log_var - 0.5 * torch.exp(-log_var) * (x - mu)**2.
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p

def log_standard_normal(x, reduction=None, dim=None):
    log_p = -0.5 * torch.log(2. * PI) - 0.5 * x**2.
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p

## LOAD DATA

In [3]:
def load_data(dataset='MNIST', bs=64, n_instances=5000):
    if dataset == 'MNIST':
        transform = transforms.Compose([transforms.ToTensor()])

        trainset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
        trainset, valset = torch.utils.data.random_split(trainset, [n_instances, 60000 - n_instances])
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=bs, shuffle=True)
        valloader =  torch.utils.data.DataLoader(valset, batch_size=bs, shuffle=True)
        
        c, h, w = 1, 28, 28
        
    else: #dataset == 'SVHN'
        mat = scipy.io.loadmat('./data/train_32x32.mat')
        trainset = torch.utils.data.TensorDataset((torch.Tensor(mat['X']) / 255).permute(3, 2, 0, 1),
                                                 torch.Tensor(mat['y']))
        trainset, valset = torch.utils.data.random_split(trainset, [n_instances, 73257 - n_instances])
        trainloader =  torch.utils.data.DataLoader(trainset, batch_size=bs, shuffle=True)
        valloader =  torch.utils.data.DataLoader(valset, batch_size=bs, shuffle=True)
        
        c, h, w = 3, 32, 32
        
    return trainloader, valloader, (c, h, w)

# VAE

### ENCODER

In [26]:
class Encoder(nn.Module):
    def __init__(self, c, h, w, variant='Normal', partial=True):
        super(Encoder, self).__init__()
        self.c, self.h, self.w, self.D = c, h, w, c*h*w
        self.variant = variant
        self.partial = partial
        self.variance = torch.tensor(1e-10)
        
        if variant != 'RKS':
            self.enc1 = nn.Conv2d(self.c, 4*mult, kernel_size=3, stride=1, padding=1)
            self.enc2 = nn.Conv2d(4*mult, 8*mult, kernel_size=3, stride=1, padding=1)
            self.enc3 = nn.Conv2d(8*mult, 16*mult, kernel_size=3, stride=1, padding=1)
            self.enc4 = nn.Conv2d(16*mult, 32*mult, kernel_size=3, stride=1)

            if self.variant != 'Normal':
                self.enc1.weight.requires_grad = False
                self.enc1.bias.requires_grad = False
                self.enc2.weight.requires_grad = False
                self.enc2.bias.requires_grad = False
                self.enc3.weight.requires_grad = False
                self.enc3.bias.requires_grad = False
                if not self.partial:
                    self.enc4.weight.requires_grad = False
                    self.enc4.bias.requires_grad = False
        else:
            #self.enc1 = nn.Linear(32*32*3, 32*2*2)
            self.enc1 = nn.Conv2d(self.c, 32, kernel_size=16, stride=16)
            self.enc1.weight.requires_grad = False
            self.enc1.bias.requires_grad = False
        
        
    def forward(self, x):
        if self.variant != 'RKS':
            x = F.max_pool2d(F.relu(self.enc1(x)), 2)
            x = F.max_pool2d(F.relu(self.enc2(x)), 2)
            x = F.max_pool2d(F.relu(self.enc3(x)), 2)
            x = self.enc4(x)
            return torch.chunk(x, 2, dim=1)
        else:
            #x = x.view(-1, 32*32*3)
            x = self.enc1(x)
            #x = x.view(-1, 32, 2, 2)
            return torch.chunk(x, 2, dim=1)
    
    
    def reparameterize(self, mu, log_var):
        if (not self.partial) and (self.variant != 'Normal'):
            mu = torch.cat((mu, log_var), dim=1)
            std = torch.sqrt(self.variance)
            eps = torch.randn_like(mu)
            return mu + (eps*std)
        else:
            std = torch.sqrt(log_var.exp())
            eps = torch.randn_like(std)
            return mu + (eps*std)
    
    
    def log_prob(self, z, mu, log_var):    
        if (not self.partial) and (self.variant != 'Normal'):
            mu = torch.cat((mu, log_var), dim=1)
            z, mu = z.view(-1, 32*2*2), mu.view(-1, 32*2*2)
            return log_normal_diag(z, mu, (torch.ones(z.shape)*torch.log(self.variance)).to(device))
        else:
            z, mu, log_var = z.view(-1, 32*2*2), mu.view(-1, 32*2*2), log_var.view(-1, 32*2*2)
            return log_normal_diag(z, mu, log_var)

### DECODER

In [27]:
class Decoder(nn.Module):
    def __init__(self, c, h, w):
        super(Decoder, self).__init__()
        self.c, self.h, self.w, self.D = c, h, w, c*h*w
        
        self.dec1 = nn.Conv2d(32, 24, 3, padding=1)
        self.dec2 = nn.Conv2d(24, 16, 3, padding=1)
        self.dec3 = nn.Conv2d(16, 12, 3, padding=1)
        self.dec4 = nn.Conv2d(12, 8, 3, padding=1)
        self.dec5 = nn.Conv2d(8, 3, 3, padding=1)

    def forward(self, z):
        x = F.upsample(F.leaky_relu(self.dec1(z), 0.2), scale_factor=2)
        x = F.upsample(F.leaky_relu(self.dec2(x), 0.2), scale_factor=2)
        x = F.upsample(F.leaky_relu(self.dec3(x), 0.2), scale_factor=2)
        x = F.upsample(F.leaky_relu(self.dec4(x), 0.2), scale_factor=2)
        x = self.dec5(x)
        return F.interpolate(x, size=(32, 32), mode='nearest')
    
    
    def log_prob(self, x_hat, x):
        RE = F.mse_loss(x_hat, x, reduction='none').sum(dim=(1,2,3))
        return RE

### PRIORS

In [28]:
class Prior(nn.Module):
    def __init__(self):
        super(Prior, self).__init__()

    def sample(self, batch_size):
        z = torch.randn((batch_size, 32, 2, 2))
        return z

    def log_prob(self, z):
        return log_standard_normal(z.view(-1, 32*2*2))

In [29]:
class FlowPrior(nn.Module):
    def __init__(self, num_flows, D=50, M=256):
        super(FlowPrior, self).__init__()
        
        nets = lambda: nn.Sequential(nn.Linear(D // 2, M), nn.LeakyReLU(),
                             nn.Linear(M, M), nn.LeakyReLU(),
                             nn.Linear(M, D // 2), nn.Tanh())

        nett = lambda: nn.Sequential(nn.Linear(D // 2, M), nn.LeakyReLU(),
                             nn.Linear(M, M), nn.LeakyReLU(),
                             nn.Linear(M, D // 2))

        self.D = D
        self.t = torch.nn.ModuleList([nett() for _ in range(num_flows)])
        self.s = torch.nn.ModuleList([nets() for _ in range(num_flows)])
        self.num_flows = num_flows

        
    def coupling(self, x, index, forward=True):
        (xa, xb) = torch.chunk(x, 2, 1)
        xa, xb = xa.to(device), xb.to(device)
        s = self.s[index](xa)
        t = self.t[index](xa)

        if forward:
            #yb = f^{-1}(x)
            yb = (xb - t) * torch.exp(-s)
        else:
            #xb = f(y)
            yb = torch.exp(s) * xb + t

        return torch.cat((xa, yb), 1), s

    
    def permute(self, x):
        return x.flip(1)

    
    def f(self, x):
        log_det_J, z = x.new_zeros(x.shape[0]), x
        for i in range(self.num_flows):
            z, s = self.coupling(z, i, forward=True)
            z = self.permute(z)
            log_det_J = log_det_J - s.sum(1)

        return z, log_det_J

    
    def f_inv(self, z):
        x = z
        for i in reversed(range(self.num_flows)):
            x = self.permute(x)
            x, _ = self.coupling(x, i, forward=False)

        return x

    
    def sample(self, batch_size):
        z = torch.randn(batch_size, self.D)
        x = self.f_inv(z)
        return x.view(-1, 32, 2, 2)

    
    def log_prob(self, x):
        z, log_det_J = self.f(x.view(-1, 32*2*2))
        return log_standard_normal(z), log_det_J.unsqueeze(1)

### VAE

In [30]:
class VAE(nn.Module):
    def __init__(self, c, h, w, variant='Normal', partial=True):
        super(VAE, self).__init__()
        self.variant = variant
        self.partial = partial
            
        self.encoder = Encoder(c, h, w, variant=self.variant, partial=self.partial)
        self.decoder = Decoder(c, h, w)
        if prior == 'Flow':
            self.prior = FlowPrior(flows, D = 32*2*2)
        else:
            self.prior = Prior()
        
            
    def forward(self, x):
        mu, log_var = self.encoder.forward(x)
        z = self.encoder.reparameterize(mu, log_var)
        x_hat = self.decoder.forward(z)

        RE = self.decoder.log_prob(x_hat, x)
        ENC = self.encoder.log_prob(z, mu, log_var)
        if prior == 'Flow':
            stdn, log_det_J = self.prior.log_prob(z)
            return (RE, ENC, stdn, log_det_J)
        else:
            stdn = self.prior.log_prob(z)
            return (RE, ENC, stdn)
        
        
    def reconstruct(self, n=8):
        for batch_idx, (x, _) in enumerate(valloader):
            mu, log_var = self.encoder.forward(x[:n].to(device))
            z = self.encoder.reparameterize(mu, log_var)
            x_hat = self.decoder.forward(z)

            save_image(torch.cat((x[:n].view(n, c, h, w), 
                                   x_hat.view(n, c, h, w).cpu())), 
                                 f"{dataset}/conv/{variant+str(partial)}/{prior+str(flows)}/{subs}_reconstructed_{epoch}.png", nrow=n)
            break
    
    
    def generate(self, n=8):
        z = self.prior.sample(n**2).to(device)
        x_hat = model.decoder(z)
        save_image(x_hat.view(n**2, c, h, w).cpu(), 
                   f"{dataset}/conv/{variant+str(partial)}/{prior+str(flows)}/{subs}_generated_{epoch}.png", nrow=n)  

# Initialize + Hyperparameters

In [31]:
dataset = 'SVHN'
variant = 'RKS' #'RandomNet'
partial = False
mult = 2 if partial else 1
prior = 'Flow' #'Flow'
n_instances = 25000
bs = 256
flows = 10
n_epochs = 500
lr = 0.001


subs = 'C'
if variant == 'RandomNet' and partial == True:
    subs += 'P'
elif variant == 'RandomNet' and partial == False:
    subs += 'N'
elif variant == 'RKS' and partial == False:
    subs += 'R'
else:
    subs += 'L'
    
if prior == 'Flow':
    subs += 'F'+str(flows)
else:
    subs += 'N'

In [32]:
trainloader, valloader, (c, h, w) = load_data(dataset=dataset, n_instances=n_instances, bs=bs)

model = VAE(c, h, w, variant=variant, partial=partial)
model.to(device)
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = lr)

n_weights = 0
for parameter in model.parameters():
    if parameter.requires_grad:
        n_weights += torch.tensor(parameter.shape).sum()

print(f'There are {n_weights} weights in the VAE.')

There are 34808 weights in the VAE.


In [33]:
def init_weights(m, distr = 'Xavier', uniform=True):
    if type(m) == nn.Conv2d:
        if distr == 'Kaiming' and uniform:
            torch.nn.init.kaiming_uniform(m.weight)
        elif distr == 'Kaiming' and not uniform:
            torch.nn.init.kaiming_normal(m.weight)
        elif distr == 'Xavier' and uniform:
            torch.nn.init.xavier_uniform(m.weight)
        else:  # distr == 'Xavier' and not uniform:
            torch.nn.init.xavier_normal(m.weight)
        
        m.bias.data.fill_(0.01)

model.apply(init_weights)

  


VAE(
  (encoder): Encoder(
    (enc1): Conv2d(3, 32, kernel_size=(16, 16), stride=(16, 16))
  )
  (decoder): Decoder(
    (dec1): Conv2d(32, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (dec2): Conv2d(24, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (dec3): Conv2d(16, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (dec4): Conv2d(12, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (dec5): Conv2d(8, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (prior): FlowPrior(
    (t): ModuleList(
      (0): Sequential(
        (0): Linear(in_features=64, out_features=256, bias=True)
        (1): LeakyReLU(negative_slope=0.01)
        (2): Linear(in_features=256, out_features=256, bias=True)
        (3): LeakyReLU(negative_slope=0.01)
        (4): Linear(in_features=256, out_features=64, bias=True)
      )
      (1): Sequential(
        (0): Linear(in_features=64, out_features=256, bias=True)
        (1): LeakyReLU(negative_slope=0.01)
 

In [34]:
losses, RElosses, stdnlosses, KLlosses = [], [], [], []
timer = []
for epoch in range(1, n_epochs+1):
    t = time.time()
    epochloss, REloss, stdnloss, KLloss = 0, 0, 0, 0
    for batch_idx, (x, _) in enumerate(trainloader):
        if prior == 'Flow':
            RE, ENC, stdn, log_det_J = model.forward(x.to(device))
            KL = (stdn + log_det_J - ENC).mean(-1)
            loss = -(-RE + KL).mean()
        else:
            RE, ENC, stdn = model.forward(x.to(device))
            KL =  (stdn - ENC).sum(-1)
            loss = -(-RE + KL).mean()
            
        
        model.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()
        
        epochloss += loss.item()
        REloss += RE.mean().item()
        stdnloss += stdn.mean().item()
        KLloss += KL.mean().item()
        
    losses.append(epochloss/len(trainloader))
    RElosses.append(REloss/len(trainloader))
    stdnlosses.append(stdnloss/len(trainloader))
    KLlosses.append(KLloss/len(trainloader))
    
    print('[%d/%d]: loss: %.3f || RE: %.3f || KL: %.3f || stdn: %.3f' % ((epoch), n_epochs, epochloss/len(trainloader), 
                                                                         REloss/len(trainloader), KLloss/len(trainloader),
                                                                         stdnloss/len(trainloader)))
    timer.append(round(time.time() - t, 5))
    
    if epoch%10 == 0:
        model.reconstruct()
        model.generate()

[1/500]: loss: -384.289 || RE: 141.411 || KL: 525.700 || stdn: -25.556
[2/500]: loss: -569.895 || RE: 42.597 || KL: 612.492 || stdn: -14.821
[3/500]: loss: -585.089 || RE: 31.204 || KL: 616.293 || stdn: -11.544
[4/500]: loss: -590.544 || RE: 26.925 || KL: 617.468 || stdn: -10.409
[5/500]: loss: -593.730 || RE: 24.392 || KL: 618.122 || stdn: -9.772
[6/500]: loss: -595.942 || RE: 22.603 || KL: 618.545 || stdn: -9.356
[7/500]: loss: -597.382 || RE: 21.372 || KL: 618.754 || stdn: -9.614
[8/500]: loss: -600.210 || RE: 20.400 || KL: 620.610 || stdn: -9.292
[9/500]: loss: -601.371 || RE: 19.559 || KL: 620.930 || stdn: -8.974
[10/500]: loss: -602.127 || RE: 18.974 || KL: 621.100 || stdn: -8.804
[11/500]: loss: -602.976 || RE: 18.244 || KL: 621.220 || stdn: -8.684
[12/500]: loss: -603.466 || RE: 17.881 || KL: 621.347 || stdn: -8.557
[13/500]: loss: -604.190 || RE: 17.219 || KL: 621.409 || stdn: -8.496
[14/500]: loss: -604.776 || RE: 16.767 || KL: 621.544 || stdn: -8.362
[15/500]: loss: -605.190

KeyboardInterrupt: 

In [None]:
info={'Dataset': dataset,
     'Variant': variant,
     'Partial': partial,
     'Prior': prior,
     'n_instances': n_instances,
     'batch_size': bs,
     'flows': flows,
     'n_epochs': n_epochs,
     'learning_rate': lr,
     'n_weights': n_weights,
     'losses': list(np.asarray(losses).round(2)),
     'RE': list(np.asarray(RElosses).round(2)),
     'stdn': list(np.asarray(stdnlosses).round(2)),
     'KL': list(np.asarray(KLlosses).round(2)),
     'Times': list(np.asarray(timer).round(4))}

pickle.dump(info, open(f'{dataset}/conv/{variant+str(partial)}/{prior+str(flows)}/info.p', 'wb'))
torch.save(model, f'{dataset}/conv/{variant+str(partial)}/{prior+str(flows)}/model.pt')

In [None]:
dataset = 'SVHN'
n_instances = 25000
bs = 256
n_epochs = 500
lr = 0.001

In [None]:
def get_subs():
    subs = 'C'
    if variant == 'RandomNet' and partial == True:
        subs += 'P'
    elif variant == 'RandomNet' and partial == False:
        subs += 'N'
    elif variant == 'RKS' and partial == False:
        subs += 'R'
    else:
        subs += 'L'

    if prior == 'Flow':
        subs += 'F'+str(flows)
    else:
        subs += 'N'
    return subs

In [None]:
for variant, partial in [('Normal', True), ('RandomNet', True), ('RandomNet', False), ('RKS', False)]:
    for prior, flows in [('Flow', 10)]:
        model = torch.load(f'{dataset}/conv/{variant+str(partial)}/{prior+str(flows)}/model.pt')
        subs = get_subs()
        
        m = torch.zeros((8, 8, 32, 2, 2))
        m[0] = model.prior.sample(8)
        m[7] = model.prior.sample(8)
              
        for i in range(1, 7):
            m[i] = m[0] - (((m[0] - m[7]) / 7)*i)
            
        for row in m.permute(1,0,2,3,4):
            x_hat = model.decoder.forward(torch.Tensor(row).to(device))
            
            if "result_" in dir():
                result_ = torch.cat((result_, x_hat))
            else:
                result_ = x_hat 
                
        save_image(result_.detach().cpu(), 
                   f"SVHN_{subs}_interpolation.png", nrow=8)
        del result_