In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import tqdm.notebook as tqdm
import copy
import os

from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.optim.lr_scheduler import StepLR


import sys
sys.path.append('/Users/Matt/projects/sgpvae/')
sys.path.append('/Users/Matt/projects/entrovae/')

import sgpvae
import entrovae

## Data preparation.

In [2]:
transform=transforms.Compose([
        transforms.ToTensor(),
        ])

train_dataset = datasets.MNIST('../data', train=True, download=True,
                               transform=transform)
test_dataset = datasets.MNIST('../data', train=False,
                              transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)

## Classifier training.

In [3]:
def train_cls(model, train_loader, optimiser, epoch):
    model.train()
    batch_iter = tqdm.tqdm(enumerate(train_loader), desc='Batch')
    for batch_idx, (x, y) in batch_iter:
        optimiser.zero_grad()
        loss, _ = model.nll(x, y)
        loss.backward()
        optimiser.step()
        
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(x), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            
def test_cls(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for x, y in test_loader:
            loss, output = model.nll(x, y)
            test_loss += loss.item()
            
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(y.view_as(pred)).sum().item()
            
    test_loss /= len(test_loader.dataset)
    
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [4]:
cls = entrovae.classifiers.MNISTClassificationNet()
cls_optimiser = optim.Adadelta(cls.parameters(), lr=1.0)
cls_scheduler = StepLR(cls_optimiser, step_size=1, gamma=0.7)

In [5]:
for epoch in range(1, 14+1):
    train_cls(cls, train_loader, cls_optimiser, epoch)
    test_cls(cls, test_loader)
    cls_scheduler.step()

HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…




KeyboardInterrupt: 

## Set up EntroVAE and VAE datasets. 
* For ***EntroVAE*** use entropies as well as the data.
* For ***VAE***, dataset is simply the images.

In [6]:
pred_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=False)
probs = torch.zeros(len(train_dataset), 10)
with torch.no_grad():
    for batch_idx, (data, target) in enumerate(pred_loader):
        probs[batch_idx*64:(batch_idx+1)*64] = cls(data).detach().exp()
        
entropies = - (probs * probs.log()).sum(1)

In [5]:
class EntroVAEDataset(torch.utils.data.Dataset):
    def __init__(self, data, entropy):
        self.data = data
        self.entropy = entropy
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        return self.data[idx, ...], self.entropy[idx]
    
# entrovae_dataset = EntroVAEDataset(train_dataset.data, entropies)
    
# entrovae_loader = torch.utils.data.DataLoader(entrovae_dataset, batch_size=64)
vae_loader = torch.utils.data.DataLoader(train_dataset.data, batch_size=64)

In [6]:
def train_vae(model, train_loader, optimiser, epoch):
    model.train()
    train_loss = 0
    batch_iter = tqdm.tqdm(enumerate(train_loader), desc='Batch')
    
    for batch_idx, x in batch_iter:
        optimiser.zero_grad()
        
        if isinstance(x, tuple) or isinstance(x, list):
            x, h = x
            x = x.view(-1, 784).float() / 255
            loss = -model.elbo(x, h)
        else:
            x = x.view(-1, 784).float() / 255
            loss = -model.elbo(x)
            
        loss.backward()
        train_loss += loss.item()
        optimiser.step()
        
        if batch_idx % 10 == 0:
            batch_iter.set_postfix(loss=loss.item())
            
    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))

In [14]:
z_dim = 2
x_dim = 784

encoder = sgpvae.networks.LinearGaussian(x_dim, z_dim, [512, 256], min_sigma=1e-3)
loglikelihood = entrovae.loglikelihoods.NNBernoulli(z_dim, x_dim, [256, 512])

vae_model = entrovae.models.VAE(copy.deepcopy(encoder), copy.deepcopy(loglikelihood), z_dim)
entrovae_model = entrovae.models.EntroVAE(copy.deepcopy(encoder), copy.deepcopy(loglikelihood), z_dim)

## VAE training.

In [15]:
vae_optimiser = optim.Adam(vae_model.parameters())
for epoch in range(1, 14+1):
    train_vae(vae_model, vae_loader, vae_optimiser, epoch)

HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 1 Average loss: 2.8210


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 2 Average loss: 2.4883


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 3 Average loss: 2.3919


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 4 Average loss: 2.3362


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 5 Average loss: 2.2910


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 6 Average loss: 2.2609


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 7 Average loss: 2.2357


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 8 Average loss: 2.2194


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…




KeyboardInterrupt: 

In [17]:
with torch.no_grad():
    z = torch.zeros(100, 2)
    for i, x in enumerate(torch.linspace(-2, 2, 10, dtype=torch.float32)):
        for j, y in enumerate(torch.linspace(-2, 2, 10, dtype=torch.float32)):
            idx = 10 * i + j
            z[idx, :] = torch.tensor([x, y])
            
    sample = vae_model.loglikelihood.predict(z)
    
    filename = './samples/vae_sample'
    
    if os.path.exists(filename + '.png'):
        i = 1
        while os.path.exists(filename + '_' + str(i) + '.png'):
            i += 1
            
        filename = filename + '_' + str(i) + '.png'
        
    else:
        filename = filename + '.png'
    
    save_image(sample.view(100, 1, 28, 28), filename)

## EntroVAE training.

In [288]:
entrovae_optimiser = optim.Adam(entrovae.parameters())
for epoch in range(1, 14+1):
    train_vae(entrovae, entrovae_loader, entrovae_optimiser, epoch)

HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 1 Average loss: 2.8167


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 2 Average loss: 2.5384


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 3 Average loss: 2.4656


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 4 Average loss: 2.4169


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 5 Average loss: 2.3847


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 6 Average loss: 2.3670


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 7 Average loss: 2.3474


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 8 Average loss: 2.3300


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 9 Average loss: 2.3171


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 10 Average loss: 2.3078


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 11 Average loss: 2.2985


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 12 Average loss: 2.2943


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 13 Average loss: 2.2823


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 14 Average loss: 2.2796


In [289]:
with torch.no_grad():
    z = torch.zeros(100, 2)
    for i, x in enumerate(torch.linspace(-2, 2, 10, dtype=torch.float32)):
        for j, y in enumerate(torch.linspace(-2, 2, 10, dtype=torch.float32)):
            idx = 10 * i + j
            z[idx, :] = torch.tensor([x, y])
            
    sample = entrovae.decoder(z)
    
    filename = './samples/entrovae_sample'
    
    if os.path.exists(filename + '.png'):
        i = 1
        while os.path.exists(filename + '_' + str(i) + '.png'):
            i += 1
            
        filename = filename + '_' + str(i) + '.png'
        
    else:
        filename = filename + '.png'
    
    save_image(sample.view(100, 1, 28, 28), filename)

## Compare disentalglement of latent space?
* Not sure what to use here tbh.

In [124]:
class ClassificationNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

In [286]:
class VAE(nn.Module):
    def __init__(self, encoder, decoder, latent_dim):
        super().__init__()

        self.encoder = encoder
        self.decoder = decoder
        self.latent_dim = latent_dim

    def latent_posterior(self, x):
        qz_mu, qz_sigma = self.encoder(x)

        # Reshape.
        qz_mu = qz_mu.transpose(0, 1)
        qz_sigma = qz_sigma.transpose(0, 1)

        return qz_mu, qz_sigma

    def sample_prior(self, num_samples=1):
        """Sample latent prior."""
        z = torch.randn(num_samples, self.latent_dim)

        return self.decoder(z)

    def sample_posterior(self, x, num_samples=1):
        """Sample latent posterior."""
        qz_mu, qz_sigma = self.latent_posterior(x)

        samples = qz_mu + qz_sigma * torch.randn(num_samples, *qz_mu.shape)

        # samples = [qz_mu + qz_sigma * torch.randn_like(qz_mu)
        #            for _ in range(num_samples)]

        return samples


class EntroVAE(VAE):
    def __init__(self, encoder, decoder, latent_dim, init_scale=1.):
        super().__init__(encoder, decoder, latent_dim)

        self.log_scale = nn.Parameter(torch.ones(latent_dim) * np.log(init_scale))

    def latent_posterior(self, x, h):
        qz_mu = self.encoder(x)[0]

        # TODO: there is no bloody way this is right.
        qz_sigma = h.unsqueeze(1).matmul(self.log_scale.exp().unsqueeze(0))
        
        # Add minimum sigma to prevent numerical instabilities.
        qz_sigma += 1e-3

        # Reshape.
        qz_mu = qz_mu.transpose(0, 1)
        qz_sigma = qz_sigma.transpose(0, 1)

        return qz_mu, qz_sigma

In [157]:
def loss_function(recon_x, x, mu, log_var):
    bce = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    kl = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    
    return bce + kl

In [265]:
def vae_elbo(model, x, h=None, num_samples=1):
    elbo = 0
    
    if h is None:
        qz_mu, qz_sigma = model.latent_posterior(x)
    else:
        qz_mu, qz_sigma = model.latent_posterior(x, h)

    for _ in range(num_samples):
        z = qz_mu + qz_sigma * torch.randn_like(qz_mu)
        
        # log p(x|z) term.
        recon_x = F.sigmoid(model.decoder(z.T))
        px_z_term = - F.binary_cross_entropy(recon_x, x, reduction='sum')
        elbo += px_z_term

    # Inner summation over samples from q(z).
    elbo /= num_samples

    # KL(q(z) | p(z))
    pf_mu = torch.zeros_like(qz_mu)
    pf_sigma = torch.ones_like(qz_sigma)
    kl_term = gaussian_diagonal_kl(qz_mu, qz_sigma.pow(2), pf_mu,
                                   pf_sigma.pow(2))
    kl_term = kl_term.sum()
    elbo += - kl_term

    return elbo

In [209]:
def gaussian_diagonal_kl(m1, v1, m2, v2):
    kl = 0.5 * ((v2 / v1).log() + (v1 + (m1 - m2) ** 2) / v2 - 1)

    # Sum over dimensions.
    kl = kl.sum(1)

    return kl