In [6]:
import torch
import torch.optim as optim
import tqdm.notebook as tqdm

from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.optim.lr_scheduler import StepLR


import os
import sys
sys.path.append('/Users/Matt/projects/gmmvae/')
sys.path.append('/Users/Matt/projects/sgpvae/')

import gmmvae
import entrovae





## Data preparation.

In [7]:
transform = transforms.Compose([transforms.ToTensor()])

train_dataset = datasets.MNIST('../data', train=True, download=True,
                               transform=transform)
test_dataset = datasets.MNIST('../data', train=False,
                              transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)

## Classifier training.

In [8]:
def train_cls(model, train_loader, optimiser, epoch):
    model.train()
    batch_iter = tqdm.tqdm(enumerate(train_loader), desc='Batch')
    for batch_idx, (x, y) in batch_iter:
        optimiser.zero_grad()
        loss, _ = model.nll(x, y)
        loss.backward()
        optimiser.step()
        
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(x), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            
def test_cls(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for x, y in test_loader:
            loss, output = model.nll(x, y)
            test_loss += loss.item()
            
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(y.view_as(pred)).sum().item()
            
    test_loss /= len(test_loader.dataset)
    
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [9]:
cls = gmmvae.classifiers.MNISTClassificationNet()
cls_optimiser = optim.Adadelta(cls.parameters(), lr=1.0)
cls_scheduler = StepLR(cls_optimiser, step_size=1, gamma=0.7)

In [10]:
for epoch in range(1, 14+1):
    train_cls(cls, train_loader, cls_optimiser, epoch)
    test_cls(cls, test_loader)
    cls_scheduler.step()

HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…



Test set: Average loss: 0.0000, Accuracy: 9845/10000 (98%)



HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…




KeyboardInterrupt: 

## Set GMMVAE datasets. 

In [11]:
pred_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=False)
cls_output = torch.zeros(len(train_dataset), 10)
with torch.no_grad():
    for batch_idx, (x, y) in enumerate(pred_loader):
        cls_output[batch_idx*64:(batch_idx+1)*64] = cls(x).detach().exp()

In [12]:
class GMMVAEDataset(torch.utils.data.Dataset):
    def __init__(self, x, cls_output):
        self.x = x
        self.cls_output = cls_output
    
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        return self.x[idx, ...], self.cls_output[idx]
    
gmmvae_dataset = GMMVAEDataset(train_dataset.data, cls_output)
gmmvae_loader = torch.utils.data.DataLoader(gmmvae_dataset, batch_size=64)

In [13]:
from torch import autograd

def train_gmmvae(model, loader, optimiser, epoch):
    model.train()
    train_loss = 0
    batch_iter = tqdm.tqdm(enumerate(loader), desc='Batch')
    
    for batch_idx, (x, pi) in batch_iter:
        optimiser.zero_grad()
        
        x = x.view(-1, 784).float() / 255
        
        loss = -model.elbo(x, pi)
        loss.backward()
        
        for name, param in model.named_parameters():
            if (param.grad != param.grad).any():
                pdb.set_trace()
                print('wtf')
        
        optimiser.step()
        
        train_loss += loss.item()        
        batch_iter.set_postfix(loss=loss.item())
            
    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(loader.dataset)))

In [14]:
z_dim = 2
x_dim = 784

encoder = sgpvae.networks.LinearGaussian(x_dim, z_dim, [512, 256], min_sigma=1e-3)
loglikelihood = gmmvae.loglikelihoods.NNBernoulli(z_dim, x_dim, [256, 512])

gmmvae_model = gmmvae.models.GMMVAE(encoder, loglikelihood, z_dim, 10)

## GMMVAE training.

In [15]:
gmmvae_optimiser = optim.Adam(gmmvae_model.parameters())
for epoch in range(1, 14+1):
    train_gmmvae(gmmvae_model, gmmvae_loader, gmmvae_optimiser, epoch)

HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 1 Average loss: 2.8252


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 2 Average loss: 2.4922


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 3 Average loss: 2.3915


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 4 Average loss: 2.3333


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 5 Average loss: 2.2953


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 6 Average loss: 2.2684


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 7 Average loss: 2.2488


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…




KeyboardInterrupt: 

In [19]:
if not os.path.exists('./samples'):
    os.makedirs('./samples')

with torch.no_grad():            
    sample = gmmvae_model.sample(
        pi=torch.tensor([1, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
        num_samples=100)
    
    filename = './samples/gmmvae_sample'
    
    if os.path.exists(filename + '.png'):
        i = 1
        while os.path.exists(filename + '_' + str(i) + '.png'):
            i += 1
            
        filename = filename + '_' + str(i) + '.png'
        
    else:
        filename = filename + '.png'
    
    save_image(sample.view(100, 1, 28, 28), filename)