In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import tqdm.notebook as tqdm
import copy
import os

from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.optim.lr_scheduler import StepLR


import sys
sys.path.append('/Users/Matt/projects/entrovae/')
sys.path.append('/Users/Matt/projects/sgpvae/')

import sgpvae
import entrovae

## Data preparation.

In [2]:
transform=transforms.Compose([
        transforms.ToTensor(),
        ])

train_dataset = datasets.MNIST('../data', train=True, download=True,
                               transform=transform)
test_dataset = datasets.MNIST('../data', train=False,
                              transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)

## Classifier training.

In [3]:
def train_cls(model, train_loader, optimiser, epoch):
    model.train()
    batch_iter = tqdm.tqdm(enumerate(train_loader), desc='Batch')
    for batch_idx, (x, y) in batch_iter:
        optimiser.zero_grad()
        loss, _ = model.nll(x, y)
        loss.backward()
        optimiser.step()
        
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(x), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            
def test_cls(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for x, y in test_loader:
            loss, output = model.nll(x, y)
            test_loss += loss.item()
            
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(y.view_as(pred)).sum().item()
            
    test_loss /= len(test_loader.dataset)
    
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [4]:
cls = entrovae.classifiers.MNISTClassificationNet()
cls_optimiser = optim.Adadelta(cls.parameters(), lr=1.0)
cls_scheduler = StepLR(cls_optimiser, step_size=1, gamma=0.7)

In [5]:
for epoch in range(1, 14+1):
    train_cls(cls, train_loader, cls_optimiser, epoch)
    test_cls(cls, test_loader)
    cls_scheduler.step()

HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…



Test set: Average loss: 0.0001, Accuracy: 9836/10000 (98%)



HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…



Test set: Average loss: 0.0000, Accuracy: 9871/10000 (99%)



HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…




KeyboardInterrupt: 

## Set up EntroVAE and VAE datasets. 
* For ***EntroVAE*** use entropies as well as the data.
* For ***VAE***, dataset is simply the images.

In [6]:
pred_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=False)
probs = torch.zeros(len(train_dataset), 10)
with torch.no_grad():
    for batch_idx, (data, target) in enumerate(pred_loader):
        probs[batch_idx*64:(batch_idx+1)*64] = cls(data).detach().exp()
        
entropies = - (probs * probs.log()).sum(1)

In [7]:
class EntroVAEDataset(torch.utils.data.Dataset):
    def __init__(self, data, entropy):
        self.data = data
        self.entropy = entropy
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        return self.data[idx, ...], self.entropy[idx]
    
entrovae_dataset = EntroVAEDataset(train_dataset.data, entropies)
entrovae_loader = torch.utils.data.DataLoader(entrovae_dataset, batch_size=64)
vae_loader = torch.utils.data.DataLoader(train_dataset.data, batch_size=64)

In [8]:
def train_vae(model, train_loader, optimiser, epoch):
    model.train()
    train_loss = 0
    batch_iter = tqdm.tqdm(enumerate(train_loader), desc='Batch')
    
    for batch_idx, x in batch_iter:
        optimiser.zero_grad()
        
        if isinstance(x, tuple) or isinstance(x, list):
            x, h = x
            x = x.view(-1, 784).float() / 255
            loss = -model.elbo(x, h)
        else:
            x = x.view(-1, 784).float() / 255
            loss = -model.elbo(x)
            
        loss.backward()
        train_loss += loss.item()
        optimiser.step()
        
        if batch_idx % 10 == 0:
            batch_iter.set_postfix(loss=loss.item())
            
    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))

In [9]:
z_dim = 2
x_dim = 784

encoder = sgpvae.networks.LinearGaussian(x_dim, z_dim, [512, 256], min_sigma=1e-3)
loglikelihood = entrovae.loglikelihoods.NNBernoulli(z_dim, x_dim, [256, 512])

vae_model = entrovae.models.VAE(copy.deepcopy(encoder), copy.deepcopy(loglikelihood), z_dim)
entrovae_model = entrovae.models.EntroVAE(copy.deepcopy(encoder), copy.deepcopy(loglikelihood), z_dim)

## VAE training.

In [10]:
vae_optimiser = optim.Adam(vae_model.parameters())
for epoch in range(1, 14+1):
    train_vae(vae_model, vae_loader, vae_optimiser, epoch)

HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 1 Average loss: 2.8252


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 2 Average loss: 2.4959


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 3 Average loss: 2.3994


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 4 Average loss: 2.3392


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 5 Average loss: 2.2988


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 6 Average loss: 2.2699


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 7 Average loss: 2.2503


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 8 Average loss: 2.2311


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 9 Average loss: 2.2147


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…




KeyboardInterrupt: 

In [11]:
with torch.no_grad():
    z = torch.zeros(100, 2)
    for i, x in enumerate(torch.linspace(-2, 2, 10, dtype=torch.float32)):
        for j, y in enumerate(torch.linspace(-2, 2, 10, dtype=torch.float32)):
            idx = 10 * i + j
            z[idx, :] = torch.tensor([x, y])
            
    sample = vae_model.loglikelihood.predict(z)
    
    filename = './samples/vae_sample'
    
    if os.path.exists(filename + '.png'):
        i = 1
        while os.path.exists(filename + '_' + str(i) + '.png'):
            i += 1
            
        filename = filename + '_' + str(i) + '.png'
        
    else:
        filename = filename + '.png'
    
    save_image(sample.view(100, 1, 28, 28), filename)

## EntroVAE training.

In [12]:
entrovae_optimiser = optim.Adam(entrovae_model.parameters())
for epoch in range(1, 14+1):
    train_vae(entrovae_model, entrovae_loader, entrovae_optimiser, epoch)

HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 1 Average loss: 2.8848


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 2 Average loss: 2.5583


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 3 Average loss: 2.4653


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 4 Average loss: 2.4169


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 5 Average loss: 2.3779


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 6 Average loss: 2.3509


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 7 Average loss: 2.3289


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 8 Average loss: 2.3121


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 9 Average loss: 2.2967


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 10 Average loss: 2.2898


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 11 Average loss: 2.2825


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…




KeyboardInterrupt: 

In [14]:
with torch.no_grad():
    z = torch.zeros(100, 2)
    for i, x in enumerate(torch.linspace(-2, 2, 10, dtype=torch.float32)):
        for j, y in enumerate(torch.linspace(-2, 2, 10, dtype=torch.float32)):
            idx = 10 * i + j
            z[idx, :] = torch.tensor([x, y])
            
    sample = entrovae_model.loglikelihood.predict(z)
    
    filename = './samples/entrovae_sample'
    
    if os.path.exists(filename + '.png'):
        i = 1
        while os.path.exists(filename + '_' + str(i) + '.png'):
            i += 1
            
        filename = filename + '_' + str(i) + '.png'
        
    else:
        filename = filename + '.png'
    
    save_image(sample.view(100, 1, 28, 28), filename)

## Compare disentalglement of latent space?
* Not sure what to use here tbh.