# Semi-Supervised Deep Learning project
This notebook contains the Python code for the final project by the group "Uffe & Axel" in the 02456 Deep Learning course autumn 2020.

## Core

**IMPORTS**

In [None]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import random
import numpy as np
import torch
from torch import Tensor, nn
from torchvision.datasets import MNIST
from mpl_toolkits.axes_grid1 import ImageGrid
from torchvision.transforms import ToTensor
from torch.distributions import Distribution
from torch.distributions import Bernoulli
from typing import Dict, Any
from torchvision import transforms
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler, RandomSampler
from collections import defaultdict
from IPython.display import Image, display, clear_output
import statistics
import math
import os.path
import socket
from torch.distributions import Categorical
# Static random seed
np.random.seed(89)

**LAMBDAS**

In [None]:
# Need to convert data to Tensor, because the DataLoader iterator refuses to work with PIL image objects.
pil2tensor = lambda x: ToTensor()(x).squeeze()   # ToTensor return (64,1,28,28), the squeeze() call removes the 1 dimension

# Binarize method for binarized dataset
binarize = lambda x: torch.bernoulli(x)

**DATASETS**

In [None]:
mnist_train_data = MNIST("./temp/", transform=pil2tensor, download=True, train=True)
mnist_test_data = MNIST("./temp/", transform=pil2tensor, download=True, train=False)

binarized_mnist_train_data = MNIST("./temp/",
                                   download=True,
                                   train=True,
                                   transform=transforms.Compose([pil2tensor,
                                                                 binarize]))
binarized_mnist_test_data = MNIST("./temp/",
                                  download=True,
                                  train=False,
                                  transform=transforms.Compose([pil2tensor,
                                                                binarize]))

**DATALOADERS**

In [None]:
def label_indices(dataset, total_labels):
    # return random list of indicies into 'dataset' that point to an equal amount of each label 0 to 9
    idx_list = []
    for target in range(10):
        idx_list += random.sample(list(np.where([dataset.targets.numpy() == target])[1]), k=int(total_labels/10))
    return idx_list
for cnt, idx in enumerate(label_indices(binarized_mnist_train_data, 10)):
    assert cnt == binarized_mnist_train_data[idx][1]

In [None]:
def loader_setup(labelled_size = 10000,
                 unlabelled_size = None,    # None: 50k - lbl_size
                 validation_size = 10000,
                 test_size = None,          # None: use all 10k
                 batch_size = 64):
    '''Setup all data loaders, providing labelled, unlabelled, validation & test samples'''
    global binarized_mnist_train_loader_labelled
    global binarized_mnist_train_loader_unlabelled
    global binarized_mnist_train_loader_validation
    global binarized_mnist_test_loader
    if unlabelled_size == None:
        unlabelled_size = 50000 - labelled_size
    indices_train = np.arange(len(binarized_mnist_train_data)) # 60000

    labelled_idx = label_indices(binarized_mnist_train_data, labelled_size)
    unlabelled_idx = random.sample(list(np.setdiff1d(indices_train, labelled_idx)
                                       ), k=unlabelled_size)
    validation_idx = random.sample(list(np.setdiff1d(indices_train, 
                                                     np.concatenate((labelled_idx,unlabelled_idx)))
                                       ), k=validation_size)

    # Last, generate the dataloaders
    #
    # There is no need to add 'shuffle=True' to DataLoader. Test show that each re-run will shuffle the data
    # for the next epoch.
    binarized_mnist_train_loader_labelled = DataLoader(binarized_mnist_train_data, 
                                                       batch_size = batch_size,
                                                       sampler = SubsetRandomSampler(labelled_idx))
    if unlabelled_size == 0:
        # The M2 trainer skips unlabelled training if train_loader is 'None'
        binarized_mnist_train_loader_unlabelled = None
    else:
        binarized_mnist_train_loader_unlabelled = DataLoader(binarized_mnist_train_data, 
                                                             batch_size = batch_size, 
                                                             sampler = SubsetRandomSampler(unlabelled_idx))
    binarized_mnist_train_loader_validation = DataLoader(binarized_mnist_train_data, 
                                                         batch_size = batch_size, 
                                                         sampler = SubsetRandomSampler(validation_idx))
    if test_size == None:
        test_size = len(binarized_mnist_test_data)  # 10000
    indices_test = np.arange(test_size)
    binarized_mnist_test_loader = DataLoader(binarized_mnist_test_data, 
                                             batch_size = 64, 
                                             sampler = SubsetRandomSampler(indices_test))

In [None]:
loader_setup(labelled_size=10, batch_size=5)

In [None]:
# QUESTION: Can we rely on DataLoader to visit every contained sample over the course 
# of one epoch when using SubsetRandomSampler?

# ANSWER: yes. Here we check that restarting data loader will gives same pictures in new order, 
# and also that all pictures in the set are visited. Run this cell a couple of times to verify.
for images,labels in binarized_mnist_train_loader_labelled:
    fig, axs = plt.subplots(1, 5, figsize=(6, 3), squeeze=False)
    for i,ax in enumerate(axs.flat):
        ax.imshow(images[i], cmap='gray')
        ax.set_title("%s" % (labels[i].item()))
        ax.axis('off')
    plt.tight_layout()

In [None]:
loader_setup(labelled_size=10000, 
             unlabelled_size=40000, 
             validation_size=10000)

In [None]:
# Non-binarized dataloader
indices_train = list(range(len(mnist_train_data))) # 60000
mnist_train_loader = DataLoader(mnist_train_data, 
                                batch_size = 64, 
                                sampler = SubsetRandomSampler(indices_train))

**REPARAMETERIZED DIAGONAL GUASSIAN**

In [None]:
# Implement reparameterized diagonal gaussian
#from torch.distributions import Distribution

class ReparameterizedDiagonalGaussian(Distribution):
    def __init__(self, mu: Tensor, log_sigma: Tensor):
        assert mu.shape == log_sigma.shape, f"Tensors `mu` : {mu.shape} and ` log_sigma` : {log_sigma.shape} must be of the same shape"
        self.mu = mu
        self.sigma = log_sigma.exp()
        
    def sample_epsilon(self) -> Tensor:
        return torch.empty_like(self.mu).normal_()
        
    def sample(self) -> Tensor:
        with torch.no_grad():
            return self.rsample()
        
    def rsample(self) -> Tensor:
        return self.mu + self.sigma*self.sample_epsilon()
        
    def log_prob(self, z:Tensor) -> Tensor:
        from torch.distributions import Normal 
        return  Normal(loc=self.mu, scale=self.sigma).log_prob(z)

**NETWORK CONSTRUCTOR HELPER**

In [None]:
def FF_NetworkConstructor(layers: [],
                          pre_batchnorm: bool,
                          hidden_batchnorm: bool,
                          hidden_activation,
                          dropout_prob: float,
                          final_activation) -> nn.Sequential:
    constructor = []
    if pre_batchnorm:
        constructor.append(nn.BatchNorm1d(num_features = layers[0]))
    for i in range(len(layers) - 2):
        constructor.append(nn.Linear(in_features = layers[i],
                                     out_features = layers[i + 1]))
        constructor.append(hidden_activation)
        if hidden_batchnorm:
            constructor.append(nn.BatchNorm1d(num_features = layers[i + 1]))
        constructor.append(nn.Dropout(p=dropout_prob))
    
    constructor.append(nn.Linear(in_features = layers[-2], out_features = layers[-1]))
    if final_activation is not None:
        constructor.append(final_activation)
    result = nn.Sequential(*constructor)
    
    return result

**DATA ANALYSIS SUPPORT**

In [None]:
def confuse_matrix_update(predictions, labels, confusion_matrix):
    # Update 'confusion_matrix' according the the prodictions/labels vectors.
    # Confusion_matrix rows are actual labels, and columns predictions. E.g. row 0 
    # tells how the classifier predicted '0', (0,0) represents correct predictions, 
    # (0,1) is how many times a '0' was classified as a '1'
    for pre, lbl in zip(predictions, labels):
        confusion_matrix[lbl,pre] += 1
def confuse_matrix_accuracy(cm):
    return 0.0 if cm.sum() == 0 else cm.trace()/cm.sum()
if False:
    cm = np.zeros((10,10))
    pred=torch.tensor([2,9,8,6])
    lab =torch.tensor([2,9,8,4])
    confuse_matrix_update(pred,lab,cm)
    confuse_matrix_update(pred,lab,cm)
    print(cm, confuse_matrix_accuracy(cm))

In [None]:
def print_confuse_matrix(VAE, test_loader = None):
    # print confuse_matrix for 'VAE' which is assumed to have function 'classifier'
    # and variable 'device' indicating cpu or cuda.
    if test_loader == None:
        test_loader = binarized_mnist_test_loader
    confuse_matrix = np.zeros((10,10)).astype('int')
    for images_test,labels_test in test_loader:
        images_test = images_test.to(VAE.device)
        classifications = VAE.classifier(images_test.view(-1,28*28))
        preds = torch.argmax(classifications,1)
        confuse_matrix_update(preds, labels_test, confuse_matrix)
    print(confuse_matrix)
    print("Classifier accuracy: %.3f" % confuse_matrix_accuracy(confuse_matrix))

In [None]:
class Plotter:
    def __init__(self):
        self.train_losses = []
        self.train_accuracies = []
        self.valid_losses = []
        self.valid_accuracies = []
        self.test_losses = []
        self.test_accuracies = []
        
        self.train_loss_buffer = []
        self.train_labels_buffer = []
        self.train_preds_buffer = []
        self.valid_loss_buffer = []
        self.valid_labels_buffer = []
        self.valid_preds_buffer = []
        self.test_loss_buffer = []
        self.test_labels_buffer = []
        self.test_preds_buffer = []
    
    def append_train(self, loss: Tensor, preds: Tensor = None, targets: Tensor = None):
        self.train_loss_buffer = np.append(self.train_loss_buffer, loss.cpu().detach().numpy())
        if preds is not None and targets is not None:
            self.train_labels_buffer = np.append(self.train_labels_buffer, targets)
            self.train_preds_buffer = np.append(self.train_preds_buffer, preds.cpu().detach().numpy())
    
    def append_valid(self, loss: Tensor, preds: Tensor = None, targets: Tensor = None):
        self.valid_loss_buffer = np.append(self.valid_loss_buffer, loss.cpu().detach().numpy())
        if preds is not None and targets is not None:
            self.valid_labels_buffer = np.append(self.valid_labels_buffer, targets)
            self.valid_preds_buffer = np.append(self.valid_preds_buffer, preds.cpu().detach().numpy())
    
    def append_test(self, loss: Tensor, preds: Tensor = None, targets: Tensor = None):
        self.test_loss_buffer = np.append(self.test_loss_buffer, loss.cpu().detach().numpy())
        if preds is not None and targets is not None:
            self.test_labels_buffer = np.append(self.test_labels_buffer, targets)
            self.test_preds_buffer = np.append(self.test_preds_buffer, preds.cpu().detach().numpy())

    def plot(self):
        # Train processing
        self.train_losses = np.append(self.train_losses, np.mean(self.train_loss_buffer))
        self.train_loss_buffer = []
        train_acc = accuracy_score(self.train_labels_buffer, self.train_preds_buffer)
        self.train_preds_buffer = []
        self.train_labels_buffer = []
        self.train_accuracies = np.append(self.train_accuracies, train_acc)
        # Valid processing
        self.valid_losses = np.append(self.valid_losses, np.mean(self.valid_loss_buffer))
        self.valid_loss_buffer = []
        valid_acc = accuracy_score(self.valid_labels_buffer, self.valid_preds_buffer)
        self.valid_labels_buffer = []
        self.valid_preds_buffer = []
        self.valid_accuracies = np.append(self.valid_accuracies, valid_acc)
        # Test processing
        self.test_losses = np.append(self.test_losses, np.mean(self.test_loss_buffer))
        self.test_loss_buffer = []
        test_acc = accuracy_score(self.test_labels_buffer, self.test_preds_buffer)
        self.test_labels_buffer = []
        self.test_preds_buffer = []
        self.test_accuracies = np.append(self.test_accuracies, test_acc)
        
        # Display
        clear_output(wait=True)
        fig, axs = plt.subplots(1, 2, figsize=(10, 5), squeeze=False)
        ax = axs[0, 0]
        ax.set_title('Loss')
        ax.plot(self.train_losses, label = 'Training')
        ax.plot(self.valid_losses, label = 'Validation')
        ax.plot(self.test_losses, label = 'Test')
        ax.legend()
        
        ax = axs[0, 1]
        ax.set_title('Accuracy')
        ax.plot(self.train_accuracies, label = 'Training')
        ax.plot(self.valid_accuracies, label = 'Validation')
        ax.plot(self.test_accuracies, label = 'Test')
        ax.legend()
        
        plt.tight_layout()
        tmp_img = "tmp_ae_out.png"
        plt.savefig(tmp_img)
        plt.close(fig)
        display(Image(filename=tmp_img))
        
        print("Training Loss: %.3f, Validation Loss: %.3f, Test Loss: %.3f" %
              (self.train_losses[-1], self.valid_losses[-1], self.test_losses[-1]))
        print("Training Accuracy: %.3f, Validation Accuracy: %.3f, Test Accuracy: %.3f" %
              (self.train_accuracies[-1], self.valid_accuracies[-1], self.test_accuracies[-1]))

## Vanilla FFNN classifier

**FFNN IMPLEMENTATION**

In [None]:
# Define a classification model

class SimpleClassifier(nn.Module):
    def __init__(self, input_shape: torch.Size, num_classes: int, hidden_shape: [], dropout_prob:float = 0.5) -> None:
        super(SimpleClassifier, self).__init__()
        
        # Core params
        self.input_shape = input_shape
        self.observation_features = np.prod(input_shape)
        self.epoch = 0
        
        # Model construction
        model_shape = [self.observation_features]
        model_shape.extend(hidden_shape)
        model_shape.append(num_classes)
        self.model = FF_NetworkConstructor(layers = model_shape,
                                           pre_batchnorm = False,
                                           hidden_batchnorm = True,
                                           hidden_activation = nn.ReLU(),
                                           dropout_prob = dropout_prob,
                                           final_activation = nn.Sigmoid())
        
    def forward(self, x) -> Tensor:
        x = x.view(x.size(0), -1)
        x = self.model(x)
        return x

**FFNN TRAINER**

In [None]:
class FFNN_Trainer:
    def __init__(self,
                 network: SimpleClassifier,
                 train_set: DataLoader,
                 valid_set: DataLoader,
                 test_set: DataLoader):
        self.network = network
        self.train_set = train_set
        self.valid_set = valid_set
        self.test_set = test_set
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.network.parameters(), lr=0.001)
        self.plotter = Plotter()

    def train(self):
        self.network.train()
        for images, labels in self.train_set:
            self.optimizer.zero_grad()
            classifications = self.network(images)
            loss = self.criterion(classifications, labels)
            loss.backward()
            self.optimizer.step()
        self.network.epoch += 1
    
    def test(self):
        self.network.eval()
        
        for images, labels in self.train_set:
            classifications = self.network(images)
            loss = self.criterion(classifications, labels)
            self.plotter.append_train(loss, torch.argmax(classifications, 1), labels)
        
        for images, labels in self.valid_set:
            classifications = self.network(images)
            loss = self.criterion(classifications, labels)
            self.plotter.append_valid(loss, torch.argmax(classifications, 1), labels)
        
        for images, labels in self.test_set:
            classifications = self.network(images)
            loss = self.criterion(classifications, labels)
            self.plotter.append_test(loss, torch.argmax(classifications, 1), labels)
        
        self.plotter.plot()

## Standard VAE

**VAE IMPLEMENTATION**

In [None]:
# Summarize values per sample
def reduce(x: Tensor) -> Tensor:
    return x.view(x.size(0), -1).sum(dim=1)

In [None]:
# Define hidden layer topology - list of sizes of hidden layers

# Implement VAE
class VariationalAutoEncoder(nn.Module):
    def __init__(self, 
                 input_shape:torch.Size,
                 latent_features:int,
                 hidden_layers_encoder: [], 
                 hidden_layers_decoder: []) -> None:
        super(VariationalAutoEncoder, self).__init__()

        # Core parameters
        self.input_shape = input_shape
        self.latent_features = latent_features
        self.observation_features = np.prod(input_shape)
        self.register_buffer('prior_params', torch.zeros(torch.Size([1, 2*latent_features])))
        # Having epochs here makes it possible to continue training, with correct epochs counting
        self.epochs = 0
        
        # Dynamically constructing the encoder network
        encoder_shape = [self.observation_features]
        encoder_shape.extend(hidden_layers_encoder)
        encoder_shape.append(self.latent_features * 2)
        self.encoder = FF_NetworkConstructor(layers = encoder_shape,
                                             pre_batchnorm = False,
                                             hidden_batchnorm = True,
                                             hidden_activation = nn.ReLU(),
                                             dropout_prob = 0.5,
                                             final_activation = None)

        # Dynamically constructing the decoder network
        decoder_shape = [self.latent_features]
        decoder_shape.extend(hidden_layers_decoder)
        decoder_shape.append(self.observation_features)
        self.decoder = FF_NetworkConstructor(layers = decoder_shape,
                                             pre_batchnorm = False,
                                             hidden_batchnorm = True,
                                             hidden_activation = nn.ReLU(),
                                             dropout_prob = 0.5,
                                             final_activation = None)
        
        if torch.cuda.is_available():
            self.device = "cuda:0"
            self.encoder.cuda()
            self.decoder.cuda()
        else:
            self.device = "cpu"
            self.encoder.cpu()
            self.decoder.cpu()
        
    # Encode input into posterior distribution
    def encode(self, x: Tensor) -> Distribution:
        h_x = self.encoder(x)
        mu, log_sigma =  h_x.chunk(2, dim=-1)
        result = ReparameterizedDiagonalGaussian(mu, log_sigma)
        result.mu = result.mu.to(self.device)
        result.sigma = result.sigma.to(self.device)
        return result
    
    # Decode latent variables into reconstruction
    def decode(self, z: Tensor) -> Distribution:
        px_logits = self.decoder(z)
        px_logits = px_logits.view(-1, *self.input_shape)
        return Bernoulli(logits=px_logits)
    
    # Get the prior distribution
    def prior(self, batch_size: int = 1) -> Distribution:
        local_prior_params = self.prior_params.expand(batch_size, *self.prior_params.shape[-1:])
        mu, log_sigma = local_prior_params.chunk(2, dim=-1)
        result = ReparameterizedDiagonalGaussian(mu, log_sigma)
        result.mu = result.mu.to(self.device)
        result.sigma = result.sigma.to(self.device)
        return result
    
    # Sample from a provided distribution
    def sample(self, distribution: ReparameterizedDiagonalGaussian) -> Tensor:
        return distribution.rsample()
    
    # Compute the ELBO
    def elbo(self, prior: Distribution, posterior: Distribution, reconstruction: Distribution, x: Tensor, z: Tensor) -> float:
        tst = reconstruction.sample().to(self.device)
        x = x.to(self.device)
        z = z.to(self.device)
        x = x.view(x.size(0), -1)
        log_px = reduce(reconstruction.log_prob(x))
        log_pz = reduce(prior.log_prob(z))
        log_qz = reduce(posterior.log_prob(z))
        kl = log_qz - log_pz
        elbo = log_px - kl
        return elbo
    
    def forward(self, x: Tensor) -> Dict[str, Any]:
        # Figure out where model currently is located and update 'self.device'. 
        # Makes it possible to dynamically move model between cpu and cuda, after it has been initialized.
        if next(self.parameters()).is_cuda:
            self.device = "cuda:0"
        else:
            self.device = "cpu"
        x = x.to(self.device)
        x = x.view(x.size(0), -1) # flatten the input
        qz = self.encode(x) # define the posterior q(z|x) / encode x into q(z|x)
        pz = self.prior(batch_size=x.size(0)) # define the prior p(z)
        z = qz.rsample() # sample the posterior using the reparameterization trick: z ~ q(z | x)
        px = self.decode(z) # define the observation model p(x|z) = B(x | g(z))
        return {'px': px, 'pz': pz, 'qz': qz, 'z': z}

**VAE TRAINER**

In [None]:
class VAE_Trainer:
    def __init__(self,
                 network:VariationalAutoEncoder = None,
                 train_data:DataLoader               = None,
                 valid_data:DataLoader               = None,
                 test_data:DataLoader                = None):
        self.model            = network          if network          != None else TestM2
        self.train_data       = train_data       if train_data       != None else binarized_mnist_train_loader_labelled
        self.valid_data       = valid_data       if valid_data       != None else binarized_mnist_train_loader_validation
        self.test_data        = test_data        if test_data        != None else binarized_mnist_test_loader
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
        self.plotter = Plotter()
    
    def train(self):
        self.model.train()
        for images, labels in self.train_data:
            self.optimizer.zero_grad()
            outputs = self.model(images)
            loss = -self.model.elbo(outputs['pz'], outputs['qz'], outputs['px'], images, outputs['z']).mean()
            loss.backward()
            self.optimizer.step()
        self.model.epochs += 1
    
    def test(self):
        self.model.eval()
        losses = []
        i = 0
        for images, labels in self.train_data:
            outputs = self.model(images)
            loss = -self.model.elbo(outputs['pz'], outputs['qz'], outputs['px'], images, outputs['z']).mean()
            self.plotter.append_train(loss, None, None)

        for images, labels in self.valid_data:
            outputs = self.model(images)
            loss = -self.model.elbo(outputs['pz'], outputs['qz'], outputs['px'], images, outputs['z']).mean()
            self.plotter.append_valid(loss, None, None)

        for images, labels in self.test_data:
            outputs = self.model(images)
            loss = -self.model.elbo(outputs['pz'], outputs['qz'], outputs['px'], images, outputs['z']).mean()
            self.plotter.append_test(loss, None, None)

        self.plotter.plot()

## M2 VAE

**IMPLEMENTATION**

In [None]:
class M2_VAE(nn.Module):
    def __init__(self, 
                 input_shape: torch.Size,
                 latent_features: int,
                 classes: int, 
                 hidden_layers_preclass: [], 
                 hidden_layers_postclass: [], 
                 hidden_layers_classification: [],
                 hidden_layers_decoder: []):
        super(M2_VAE, self).__init__()
        
        # Core params
        self.input_shape = input_shape
        self.observation_features = np.prod(input_shape)
        self.register_buffer('prior_params', torch.zeros(torch.Size([1, 2*latent_features])))
        self.register_buffer('classification_prior_params', torch.full(torch.Size([1, classes]), 1 / classes))
        self.classes = classes
        # Having epochs here makes it possible to continue training, with correct epochs counting
        self.epochs = 0
        
        # Cuda enabling
        if torch.cuda.is_available():
            self.use_cuda = True
            self.device = "cuda:0"
        else:
            self.use_cuda = False
            self.device = "cpu"
        
        # Classifier construction
        classifier_shape = [self.observation_features]
        classifier_shape.extend(hidden_layers_classification)
        classifier_shape.append(classes)
        self.classifier = FF_NetworkConstructor(layers = classifier_shape,
                                                pre_batchnorm = False,
                                                hidden_batchnorm = True,
                                                hidden_activation = nn.ReLU(),
                                                dropout_prob = 0.5,
                                                final_activation = None)
        
        # Pre-classification network construction
        preclass_shape = [self.observation_features]
        preclass_shape.extend(hidden_layers_preclass)
        self.preclass_encoder = FF_NetworkConstructor(layers = preclass_shape,
                                                      pre_batchnorm = False,
                                                      hidden_batchnorm = True,
                                                      hidden_activation = nn.ReLU(),
                                                      dropout_prob = 0.5,
                                                      final_activation = nn.ReLU())
        
        # Post-classification network construction
        postclass_shape = [hidden_layers_preclass[-1] + classes]
        postclass_shape.extend(hidden_layers_postclass)
        postclass_shape.append(latent_features * 2)
        self.postclass_encoder = FF_NetworkConstructor(layers = postclass_shape,
                                                       pre_batchnorm = False,
                                                       hidden_batchnorm = True,
                                                       hidden_activation = nn.ReLU(),
                                                       dropout_prob = 0.5,
                                                       final_activation = None)
        
        # Decoder construction
        decoder_shape = [latent_features + classes]
        decoder_shape.extend(hidden_layers_decoder)
        decoder_shape.append(self.observation_features)
        self.decoder = FF_NetworkConstructor(layers = decoder_shape,
                                             pre_batchnorm = False,
                                             hidden_batchnorm = True,
                                             hidden_activation = nn.ReLU(),
                                             dropout_prob = 0.5,
                                             final_activation = None)
        
        # Move networks to cuda if available
        if self.use_cuda:
            self.classifier.cuda()
            self.preclass_encoder.cuda()
            self.postclass_encoder.cuda()
            self.decoder.cuda()

    def reset(self) -> None:
        for layers in self.children():
            for layer in layers:
                if hasattr(layer, 'reset_parameters'):
                    layer.reset_parameters()
        self.epochs = 0

    def prior(self, batch_size: int = 1) -> Distribution:
        local_prior_params = self.prior_params.expand(batch_size, *self.prior_params.shape[-1:])
        mu, log_sigma = local_prior_params.chunk(2, dim=-1)
        result = ReparameterizedDiagonalGaussian(mu, log_sigma)
        result.mu = result.mu.to(self.device)
        result.sigma = result.sigma.to(self.device)
        return result
    
    def classification_prior(self, batch_size: int = 1) -> Distribution:
        local_classification_prior_params = self.classification_prior_params.expand(batch_size, 
                                                                                    *self.classification_prior_params.shape[-1:])
        result = Categorical(probs = local_classification_prior_params)
        return result
    
    def classification_posterior(self, x: Tensor) -> Distribution:
        result = self.classifier(x)
        result = result.view(-1, self.classes)
        result = Categorical(logits = result)
        return result
    
    def classification_entropy(self, qy: Tensor) -> float:
        qy = qy * torch.log(qy)
        return -qy.sum(1)
        
    def encode(self, x: Tensor, y: Tensor = None) -> Tensor:
        # Classify if no classification is provided
        if y is None:
            y = self.classifier(x)
        # Encode input
        result = self.preclass_encoder(x)
        result = torch.cat((result, y), 1)
        result = self.postclass_encoder(result)
        mu, log_sigma =  result.chunk(2, dim=-1)
        return ReparameterizedDiagonalGaussian(mu, log_sigma)
    
    def decode(self, z: Tensor, y: Tensor) -> Distribution:
        px_logits = self.decoder(torch.cat((z, y), 1))
        px_logits = px_logits.view(-1, *self.input_shape)
        return Bernoulli(logits = px_logits)
    
    def onehot(self, y: int):
        result = torch.zeros(y.shape[0], self.classes).to(self.device)
        for i in range(len(y)):
            result[i][y[i]] = 1
        return result
    
    def loss(self,
             px: Distribution, 
             py: Distribution, 
             pz: Distribution, 
             qy: Distribution, 
             qz: Distribution, 
             x: Tensor,
             y: int, 
             z: Tensor,
             alpha: float,
             debug: bool = False) -> float:
        x = x.to(self.device)
        z = z.to(self.device)
        x = x.view(x.size(0), -1)
        
        py_logprob = py.logits.to(self.device)
        
        # If labels are not provided, sample from classification posterior
        if y is None:
            if True:
                # Paralel calcution giving 40% speedup compared to serial version below.
                x2 = x.repeat(10,1)
                y2 = torch.Tensor(0)
                for i in range(10):
                    y2 = torch.cat( (y2, torch.Tensor([i]).expand(x.shape[0])))
                y2 = y2.to(self.device)
                outputs = self.forward(x2, y2)
                px_logprob = outputs['px'].log_prob(x2.view(10*x.shape[0],-1)).sum(dim=1).view(-1,10)
                qz_logprob = outputs['qz'].log_prob(outputs['z']).sum(dim=1).view(-1, 10)
                pz_logprob = outputs['pz'].log_prob(outputs['z']).sum(dim=1).view(-1, 10)
            else:
                # Original serial calculation. Kept as fall-back until parallel version has been more tested
                px_logprob = torch.Tensor(0).to(self.device)
                qz_logprob = torch.Tensor(0).to(self.device)
                pz_logprob = torch.Tensor(0).to(self.device)
                for i in range(10):
                    _y = torch.Tensor([i]).expand(x.shape[0])
                    outputs = self.forward(x, _y)
                    px_logprob = torch.cat((px_logprob, outputs['px'].log_prob(x).sum(dim = 1).view(-1, 1)), 1)
                    qz_logprob = torch.cat((qz_logprob, outputs['qz'].log_prob(outputs['z']).sum(dim = 1).view(-1, 1)), 1)
                    pz_logprob = torch.cat((pz_logprob, outputs['pz'].log_prob(outputs['z']).sum(dim = 1).view(-1, 1)), 1)
            
            L = -(px_logprob + py_logprob + pz_logprob - qz_logprob)
            U = -(torch.mul(qy.probs, -L).sum(1) + self.classification_entropy(qy.probs))
            J = U
            return J
        else:
            y = y.to(self.device)
            y = y.view(-1, 1)
            px_logprob = px.log_prob(x).sum(dim = 1).view(-1, 1).repeat(1, 10)
            qz_logprob = qz.log_prob(z).sum(dim = 1).view(-1, 1).repeat(1, 10)
            pz_logprob = pz.log_prob(z).sum(dim = 1).view(-1, 1).repeat(1, 10)
            L = -(px_logprob + py_logprob + pz_logprob - qz_logprob)
            J = L.gather(1, y)
            J_alpha = J - (alpha * qy.logits.gather(1, y))
            return J_alpha
    
    def forward(self, x: Tensor, y: int = None, debug: bool = False) -> Dict[str, Any]:
        x = x.to(self.device) # Move to cuda if applicable
        x = x.view(x.size(0), -1) # Flatten image input
        qy = self.classification_posterior(x) # Classification posterior q(y|x)
        py = self.classification_prior(batch_size = x.size(0)) # Classification prior p(y)
        pz = self.prior(batch_size=x.size(0)) # Prior p(z)
        if y is None: # If labels are not provided, sample from classification posterior
            try:
                y = qy.sample()
            except:
                print(self.classifier(x))
                raise Exception("QY sample error")
        
        y = y.to(self.device)
        y = y.int()
        y = self.onehot(y)
        qz = self.encode(x, y) # Approximate posterior q(z|x, y)
        z = qz.rsample() # Sample the posterior
        px = self.decode(z, y) # Reconstruction p(x|z, y) = B(x | g(z, y))
        
        return {'px': px, 'py': py, 'pz': pz, 'qy': qy, 'qz': qz, 'z': z}

**TRAINER**

In [None]:
class M2_VAE_Trainer:
    def __init__(self,
                 network:M2_VAE              = None,
                 train_labelled:DataLoader   = None,
                 train_unlabelled:DataLoader = None,
                 valid:DataLoader            = None,
                 test:DataLoader             = None,
                 labelled_ratio:int          = 1,
                 alpha:float                 = 0.1,
                 verbose:int                 = 0):
        self.model            = network          if network          != None else TestM2
        self.train_labelled   = train_labelled   if train_labelled   != None else binarized_mnist_train_loader_labelled
        self.train_unlabelled = train_unlabelled if train_unlabelled != None else binarized_mnist_train_loader_unlabelled
        self.valid_data       = valid            if valid            != None else binarized_mnist_train_loader_validation
        self.test_data        = test             if test             != None else binarized_mnist_test_loader
        self.alpha = alpha
        self.training_data = defaultdict(list)
        self.validation_data = defaultdict(list)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
        self.labelled_ratio = labelled_ratio
        self.unlbl_bcnt = 0
        self.lbl_bcnt = 0
        self.tst_bcnt = 0
        self.verbose = verbose
        self.plotter = Plotter()
    
    def train(self):
        self.model.train()
        if self.train_unlabelled is not None:
            labelled_iter = iter(self.train_unlabelled)
            for images, labels in self.train_unlabelled:
                self.unlbl_bcnt += 1
                self.optimizer.zero_grad()
                outputs = self.model(images, None)
                loss = self.model.loss(outputs['px'], 
                                       outputs['py'], 
                                       outputs['pz'], 
                                       outputs['qy'], 
                                       outputs['qz'],
                                       images, 
                                       None,
                                       outputs['z'],
                                       self.alpha).mean()
                for i in range(self.labelled_ratio):
                    try:
                        img_lbl, lbl_lbl = next(labelled_iter)
                    except StopIteration:
                        labelled_iter = iter(self.train_unlabelled)
                        img_lbl, lbl_lbl = next(labelled_iter)
                    outputs = self.model(img_lbl, lbl_lbl)
                    loss += self.model.loss(outputs['px'], 
                                            outputs['py'], 
                                            outputs['pz'], 
                                            outputs['qy'], 
                                            outputs['qz'],
                                            img_lbl, 
                                            lbl_lbl,
                                            outputs['z'],
                                            self.alpha).mean()
                loss.backward()
                self.optimizer.step()
        else:
            for images, labels in self.train_labelled:
                self.lbl_bcnt += 1
                self.optimizer.zero_grad()
                outputs = self.model(images, labels)
                loss = self.model.loss(outputs['px'], 
                                       outputs['py'], 
                                       outputs['pz'], 
                                       outputs['qy'], 
                                       outputs['qz'],
                                       images, 
                                       labels,
                                       outputs['z'],
                                       self.alpha).mean()
                loss.backward()
                self.optimizer.step()
        self.model.epochs += 1
    
    def test(self):
        self.model.eval()
        losses = []
        i = 0
        epoch_data = defaultdict(list)
        for images, labels in self.train_labelled:
            self.tst_bcnt += 1
            outputs = self.model(images, labels)
            loss = self.model.loss(outputs['px'], 
                                   outputs['py'], 
                                   outputs['pz'], 
                                   outputs['qy'], 
                                   outputs['qz'],
                                   images, 
                                   labels,
                                   outputs['z'],
                                   self.alpha).mean()
            epoch_data['loss'] += [loss.item()]
            classifications = self.model.classifier(images.view(-1,28*28).to(self.model.device))
            preds = torch.argmax(classifications,1)
            self.plotter.append_train(loss, preds, labels)

        for images, labels in self.valid_data:
            self.tst_bcnt += 1
            outputs = self.model(images, labels)
            loss = self.model.loss(outputs['px'], 
                                   outputs['py'], 
                                   outputs['pz'], 
                                   outputs['qy'], 
                                   outputs['qz'],
                                   images, 
                                   labels,
                                   outputs['z'],
                                   self.alpha).mean()
            epoch_data['loss'] += [loss.item()]
            classifications = self.model.classifier(images.view(-1,28*28).to(self.model.device))
            preds = torch.argmax(classifications,1)
            self.plotter.append_valid(loss, preds, labels)

        for images, labels in self.test_data:
            self.tst_bcnt += 1
            outputs = self.model(images, labels)
            loss = self.model.loss(outputs['px'], 
                                   outputs['py'], 
                                   outputs['pz'], 
                                   outputs['qy'], 
                                   outputs['qz'],
                                   images, 
                                   labels,
                                   outputs['z'],
                                   self.alpha).mean()
            epoch_data['loss'] += [loss.item()]
            classifications = self.model.classifier(images.view(-1,28*28).to(self.model.device))
            preds = torch.argmax(classifications,1)
            self.plotter.append_test(loss, preds, labels)

        self.plotter.plot()
        #for k, v in epoch_data.items():
        #    self.validation_data[k] += [np.mean(v)]
        #self.validation_data['class_accu'] += [confuse_matrix_accuracy(confuse_matrix)]

        #print("epoch=%d  test_loss=%.2f  class_accu=%.3f" % (self.model.epochs, 
        #                                                     self.validation_data['loss'][-1], 
        #                                                     self.validation_data['class_accu'][-1]))
        #print("         unlbl_bcnt=%d, lbl_bcnt=%d, tst_bcnt=%d" % (self.unlbl_bcnt,
        #                                                            self.lbl_bcnt,
        #                                                            self.tst_bcnt)) if self.verbose > 0 else None

# Execution

## Simple FF classifier

**100 OBSERVATIONS**

In [None]:
sample = binarized_mnist_train_data.__getitem__(22)[0]
ffnn = SimpleClassifier(sample.flatten().shape, 10, [500, 250], dropout_prob=0.0)
loader_setup(labelled_size=100)
trainer = FFNN_Trainer(ffnn,
                       binarized_mnist_train_loader_labelled,
                       binarized_mnist_train_loader_validation,
                       binarized_mnist_test_loader)

In [None]:
%%time
epochs = 51
fn = "ff_100_0p73.pt"
if os.path.isfile(fn):
    ffnn.load_state_dict(torch.load(fn))
    trainer.test()
    print("Model loaded from ''%s'" % fn)
else:
    trainer.test()
    for i in range(epochs):
        trainer.train()
        trainer.test()

In [None]:
# torch.save(ffnn.state_dict(), fn)

dropout_prob = 0.0
![image.png](attachment:image.png)

dropout_prob = 0.95
![image.png](attachment:image.png)

**1000 OBSERVATIONS**

In [None]:
sample = binarized_mnist_train_data.__getitem__(22)[0]
ffnn = SimpleClassifier(sample.flatten().shape, 10, [500, 250], dropout_prob=0.0)
loader_setup(labelled_size=1000)
trainer = FFNN_Trainer(ffnn,
                       binarized_mnist_train_loader_labelled,
                       binarized_mnist_train_loader_validation,
                       binarized_mnist_test_loader)

In [None]:
%%time
epochs = 65
fn = "ff_1000_0p92.pt"
if os.path.isfile(fn):
    ffnn.load_state_dict(torch.load(fn))
    trainer.test()
    print("Loaded model from %s" % fn)
else:
    trainer.test()
    for i in range(epochs):
        trainer.train()
        trainer.test()

dropout_prob = 0.0
![image.png](attachment:image.png)

**10000 OBSERVATIONS**

In [None]:
sample = binarized_mnist_train_data.__getitem__(22)[0]
ffnn = SimpleClassifier(sample.flatten().shape, 10, [500, 250], dropout_prob=0.0)
loader_setup(labelled_size=10000)
trainer = FFNN_Trainer(ffnn,
                       binarized_mnist_train_loader_labelled,
                       binarized_mnist_train_loader_validation,
                       binarized_mnist_test_loader)

In [None]:
%%time
epochs = 16
fn = "ff_10000_0p96.pt"
if os.path.isfile(fn):
    ffnn.load_state_dict(torch.load(fn))
    trainer.test()
    print("Loaded model from %s" % fn)
else:
    trainer.test()
    for i in range(epochs):
        trainer.train()
        trainer.test()

0.960 test_accur,  dropout_prob=0.00,   very similar results with 0.30 & 0.60 dropout
![image-2.png](attachment:image-2.png)

## Standard VAE (M1)

In [None]:
# Instantiate a VAE
sample = binarized_mnist_train_data.__getitem__(0)[0]
testVAE = VariationalAutoEncoder(sample.flatten().shape, 5, [512, 256, 128], [128, 256, 512])

In [None]:
%%time
# Train the VAE
loader_setup(labelled_size=50000)
trainer = VAE_Trainer(testVAE,
                      binarized_mnist_train_loader_labelled,
                      binarized_mnist_train_loader_validation,
                      binarized_mnist_test_loader)

testVAE_fn = "./M1_119ls_base_20ep.pt"
if False and os.path.isfile(testVAE_fn):
    testVAE.load_state_dict(torch.load(testVAE_fn))
    print("Loaded testVAE model from %s" % testVAE_fn)
    trainer.test()
else:
    epochs = 20
    for i in range(epochs):
        trainer.train()
        trainer.test()
    torch.save(testVAE.state_dict(), testVAE_fn)

## M2 - Classification

**100/40K LABELLED/UNLABELLED OBSERVATIONS**

In [None]:
sample = binarized_mnist_train_data.__getitem__(22)[0]
TestM2 = M2_VAE(sample.flatten().shape, # Input shape
                5, # Latent features
                10, # Classes
                [200], # Network dimensions for encoder before adding classifications
                [200], # Network dimensions for encoder after adding classification
                [400, 200], # Network dimensions for classification network
                [200, 400, 600]) # Network dimensions for decoder

In [None]:
%%time
loader_setup(labelled_size=100, unlabelled_size=40000)
trainer = M2_VAE_Trainer(TestM2,
                         binarized_mnist_train_loader_labelled,
                         binarized_mnist_train_loader_unlabelled,
                         binarized_mnist_train_loader_validation,
                         binarized_mnist_test_loader,
                         10,
                         0.1)
epochs = 20
fn = "m2_class_100_labelled_%dep.pt" % epochs
if os.path.isfile(fn):
    TestM2.load_state_dict(torch.load(fn))
    print("Loaded model from %s" % fn)
    trainer.test()
else:
    trainer.test()
    for i in range(epochs):
        trainer.train()
        trainer.test()
    torch.save(TestM2.state_dict(), fn)

Training Loss: 122.325, Validation Loss: 120.551, Test Loss: 119.317

Training Accuracy: 0.780, Validation Accuracy: 0.792, Test Accuracy: 0.791

![image.png](attachment:image.png)

**1K/40K LABELLED/UNLABELLED OBSERVATIONS**

In [None]:
sample = binarized_mnist_train_data.__getitem__(22)[0]
TestM2 = M2_VAE(sample.flatten().shape, # Input shape
                5, # Latent features
                10, # Classes
                [200], # Network dimensions for encoder before adding classifications
                [200], # Network dimensions for encoder after adding classification
                [400, 200], # Network dimensions for classification network
                [200, 400, 600]) # Network dimensions for decoder

In [None]:
%%time
loader_setup(labelled_size=1000, unlabelled_size=40000)
trainer = M2_VAE_Trainer(TestM2,
                         binarized_mnist_train_loader_labelled,
                         binarized_mnist_train_loader_unlabelled,
                         binarized_mnist_train_loader_validation,
                         binarized_mnist_test_loader,
                         10,
                         0.1)
epochs = 20
fn = "m2_class_1k_labelled_%dep.pt" % epochs
if os.path.isfile(fn):
    TestM2.load_state_dict(torch.load(fn))
    print("Loaded model from %s" % fn)
    trainer.test()
else:
    trainer.test()
    for i in range(epochs):
        trainer.train()
        trainer.test()
    torch.save(TestM2.state_dict(), fn)

Training Loss: 121.014, Validation Loss: 119.316, Test Loss: 118.913

Training Accuracy: 0.783, Validation Accuracy: 0.779, Test Accuracy: 0.791

![image.png](attachment:image.png)

**10K/40K LABELLED/UNLABELLED OBSERVATIONS**

In [None]:
sample = binarized_mnist_train_data.__getitem__(22)[0]
TestM2 = M2_VAE(sample.flatten().shape, # Input shape
                5, # Latent features
                10, # Classes
                [200], # Network dimensions for encoder before adding classifications
                [200], # Network dimensions for encoder after adding classification
                [400, 200], # Network dimensions for classification network
                [200, 400, 600]) # Network dimensions for decoder

In [None]:
%%time
loader_setup(labelled_size=10000, unlabelled_size=40000)
trainer = M2_VAE_Trainer(TestM2,
                         binarized_mnist_train_loader_labelled,
                         binarized_mnist_train_loader_unlabelled,
                         binarized_mnist_train_loader_validation,
                         binarized_mnist_test_loader,
                         10,
                         0.1)
epochs = 20
fn = "m2_class_10k_labelled_%dep.pt" % epochs
if os.path.isfile(fn):
    TestM2.load_state_dict(torch.load(fn))
    print("Loaded model from %s" % fn)
    trainer.test()
else:
    trainer.test()
    for i in range(epochs):
        trainer.train()
        trainer.test()
    torch.save(TestM2.state_dict(), fn)

Training Loss: 121.500, Validation Loss: 120.390, Test Loss: 119.620

Training Accuracy: 0.712, Validation Accuracy: 0.709, Test Accuracy: 0.709

![image.png](attachment:image.png)

**50K LABELLED OBSERVATIONS**

In [None]:
sample = binarized_mnist_train_data.__getitem__(22)[0]
TestM2 = M2_VAE(sample.flatten().shape, # Input shape
                5, # Latent features
                10, # Classes
                [200], # Network dimensions for encoder before adding classifications
                [200], # Network dimensions for encoder after adding classification
                [400, 200], # Network dimensions for classification network
                [200, 400, 600]) # Network dimensions for decoder

In [None]:
%%time
loader_setup(labelled_size=50000)
trainer = M2_VAE_Trainer(TestM2,
                         binarized_mnist_train_loader_labelled,
                         binarized_mnist_train_loader_unlabelled,
                         binarized_mnist_train_loader_validation,
                         binarized_mnist_test_loader,
                         10,
                         0.1)
epochs = 10
fn = "m2_class_50k_labelled_%dep.pt" % epochs
if os.path.isfile(fn):
    TestM2.load_state_dict(torch.load(fn))
    print("Loaded model from %s" % fn)
    trainer.test()
else:
    trainer.test()
    for i in range(epochs):
        trainer.train()
        trainer.test()
    torch.save(TestM2.state_dict(), fn)

Training Loss: 125.851, Validation Loss: 119.831, Test Loss: 124.206

Training Accuracy: 0.983, Validation Accuracy: 0.971, Test Accuracy: 0.976

![image.png](attachment:image.png)

## M2 - Style Transfer

**5-DIMENSIONAL LATENT SPACE**

In [None]:
sample = binarized_mnist_train_data.__getitem__(22)[0]
TestM2 = M2_VAE(sample.flatten().shape, # Input shape
                5, # Latent features
                10, # Classes
                [200], # Network dimensions for encoder before adding classifications
                [200], # Network dimensions for encoder after adding classification
                [400, 200], # Network dimensions for classification network
                [200, 400, 600]) # Network dimensions for decoder

In [None]:
%%time
loader_setup(labelled_size=10000, unlabelled_size=40000)
trainer = M2_VAE_Trainer(TestM2,
                         binarized_mnist_train_loader_labelled,
                         binarized_mnist_train_loader_unlabelled,
                         binarized_mnist_train_loader_validation,
                         binarized_mnist_test_loader,
                         10,
                         0.1)
epochs = 10
fn = "./style_transfer_5D_%dep.pt" % epochs
if os.path.isfile(fn):
    TestM2.load_state_dict(torch.load(fn))
    print("Loaded model from %s" % fn)
    trainer.test()
else:
    trainer.test()
    for i in range(epochs):
        trainer.train()
        trainer.test()
    torch.save(TestM2.state_dict(), fn)

In [None]:
TestM2.eval()
a = random.choices(binarized_mnist_train_data,k=10)
b = [x[0] for x in a]
img = torch.stack(b)
b = [x[1] for x in a]
lbl = torch.Tensor(b)
fig, axs = plt.subplots(10, 11, figsize=(10, 10), squeeze=False)

outputs = TestM2(img, lbl)
z = outputs['z']

for i in range(len(img)):
    axs[i, 0].imshow(img[i], cmap='gray')
    axs[i, 0].axis('off')
    for j in range(10):
        image = TestM2.decode(z[i].reshape(1, -1), TestM2.onehot(torch.Tensor([j]).int())).sample().view(28,28).cpu()
        axs[i, j + 1].imshow(image, cmap='gray')
        axs[i, j + 1].axis('off')
plt.tight_layout()

![image.png](attachment:image.png)

**50-DIMENSIONAL LATENT SPACE**

In [None]:
sample = binarized_mnist_train_data.__getitem__(22)[0]
TestM2 = M2_VAE(sample.flatten().shape, # Input shape
                50, # Latent features
                10, # Classes
                [200], # Network dimensions for encoder before adding classifications
                [200], # Network dimensions for encoder after adding classification
                [400, 200], # Network dimensions for classification network
                [200, 400, 600]) # Network dimensions for decoder

In [None]:
%%time
loader_setup(labelled_size=10000, unlabelled_size=40000)
trainer = M2_VAE_Trainer(TestM2,
                         binarized_mnist_train_loader_labelled,
                         binarized_mnist_train_loader_unlabelled,
                         binarized_mnist_train_loader_validation,
                         binarized_mnist_test_loader,
                         10,
                         0.1)
epochs = 10
fn = "./style_transfer_50D_%dep.pt" % epochs
if os.path.isfile(fn):
    TestM2.load_state_dict(torch.load(fn))
    print("Loaded model from %s" % fn)
    trainer.test()
else:
    trainer.test()
    for i in range(epochs):
        trainer.train()
        trainer.test()
    torch.save(TestM2.state_dict(), fn)

In [None]:
TestM2.eval()
a = random.choices(binarized_mnist_train_data,k=10)
b = [x[0] for x in a]
img = torch.stack(b)
b = [x[1] for x in a]
lbl = torch.Tensor(b)
fig, axs = plt.subplots(10, 11, figsize=(10, 10), squeeze=False)

outputs = TestM2(img, lbl)
z = outputs['z']

for i in range(len(img)):
    axs[i, 0].imshow(img[i], cmap='gray')
    axs[i, 0].axis('off')
    for j in range(10):
        image = TestM2.decode(z[i].reshape(1, -1), TestM2.onehot(torch.Tensor([j]).int())).sample().view(28,28).cpu()
        axs[i, j + 1].imshow(image, cmap='gray')
        axs[i, j + 1].axis('off')
plt.tight_layout()

![image.png](attachment:image.png)

# Subtasks

**SUBTASK 2.1.2: Plot 8x8 random samples from MNIST data**

In [None]:
images, labels = next(iter(mnist_train_loader))

fig, axs = plt.subplots(8, 8, figsize=(10, 10), squeeze=False)
i = 0
for ax in axs.flat:
    ax.imshow(images[i], cmap='gray')
    ax.set_title("%s" % (labels[i].item()))
    ax.axis('off')
    i += 1
plt.tight_layout()

**SUBTASK 2.1.4: Plot binarized MNIST samples**

In [None]:
# QUESTION: are we doing statistical sampling of the greytones in the picture, so it picture looks slightly
#           different whenever DataLoader delivers it again ?
#
# ANSWER: yes, below plot same image a couple of times
fig, axs = plt.subplots(1, 4, figsize=(12, 3), squeeze=False)
for ax in axs.flat:
    sample = binarized_mnist_train_data.__getitem__(22)[0]
    assert torch.max(sample) == 1.0
    assert torch.min(sample) == 0.0
    ax.imshow(sample, cmap='gray')
    ax.axis('off')
plt.tight_layout()

In [None]:
images, labels = next(iter(binarized_mnist_train_loader_labelled))
fig, axs = plt.subplots(8, 8, figsize=(10, 10), squeeze=False)
i = 0
for ax in axs.flat:
    ax.imshow(images[i], cmap='gray')
    ax.set_title("%s" % (labels[i].item()))
    ax.axis('off')
    i += 1
plt.tight_layout()

**SUBTASK 2.2.1.2: Print samples from untrained VAE**

In [None]:
# Instantiate a VAE
sample = binarized_mnist_train_data.__getitem__(0)[0]
testVAE = VariationalAutoEncoder(sample.flatten().shape, 5)

In [None]:
# Method 1: Decoding sample from prior
testVAE.eval()

prior = testVAE.prior(64)
prior_sample = testVAE.sample(prior)
decoded_prior_sample = testVAE.decode(prior_sample)
sampled_decode_content = decoded_prior_sample.sample().view(64, 28, 28)

sampled_decode_content = sampled_decode_content.cpu()  # note that .cpu() works differently for tensor & model
fig, axs = plt.subplots(8, 8, figsize=(10, 10), squeeze=False)
i = 0
for ax in axs.flat:
    ax.imshow(sampled_decode_content[i], cmap='gray')
    ax.axis('off') 
    i += 1
plt.tight_layout()

In [None]:
# Method 2: Reconstruction of input from binarized MNIST
a = random.choices(binarized_mnist_train_data,k=64)
b = [x[0] for x in a]
img = torch.stack(b).to(testVAE.device)
sampled_decode_content = testVAE(img)['px'].sample().view(-1,28,28)

sampled_decode_content = sampled_decode_content.cpu()
fig, axs = plt.subplots(8, 8, figsize=(10, 10), squeeze=False)
i = 0
for ax in axs.flat:
    ax.imshow(sampled_decode_content[i], cmap='gray')
    ax.axis('off') 
    i += 1
plt.tight_layout()

**2.2.1.3: Compute ELBO of 64 samples**

In [None]:
sample_cnt = 64
samples = np.zeros(shape=(sample_cnt, 784))
labels = np.zeros(shape=(sample_cnt, 1))
for i in range(sample_cnt):
    sample = random.choice(binarized_mnist_train_data)
    samples[i] = sample[0].view(1, -1).numpy()
    labels[i] = sample[1]

prior = testVAE.prior(sample_cnt)
samples_tensor = Tensor(samples).to(testVAE.device)
posterior = testVAE.encode(samples_tensor)
z = testVAE.sample(posterior) # Random sampling
reconstruction = testVAE.decode(z)
elbo = testVAE.elbo(prior, posterior, reconstruction, samples_tensor, z) 

# 'float64' required because 'stdev' chokes on 'float32' which is the default type when detaching from GPU
elbo_ary = elbo.cpu().detach().numpy().astype('float64')
elbo_stddev = statistics.stdev(elbo_ary)
elbo_mean = statistics.mean(elbo_ary)
print("ELBO on %d train data: %.1f +/-%.1f" % (sample_cnt, elbo_mean,elbo_stddev))

**2.2.2.3: Training the network**


In [None]:
%%time
#binarized_mnist_test_loader = DataLoader(binarized_mnist_test_data, batch_size = 64)
trainer = VAE_Trainer(testVAE, binarized_mnist_train_loader_unlabelled, binarized_mnist_test_loader)

testVAE_fn = "./M1_119ls_base_200ep.pt"
if os.path.isfile(testVAE_fn):
    testVAE.load_state_dict(torch.load(testVAE_fn))
    print("Loaded testVAE model from %s" % testVAE_fn)
    trainer.test()
else:
    epochs = 200
    for i in range(epochs):
        print("Training epoch ", i)
        trainer.train()
        print("Testing epoch ", i)
        trainer.test()
    torch.save(testVAE.state_dict(), testVAE_fn)

**2.2.2.4: Generating samples from trained model**

In [None]:
#testVAE.cpu()
#testVAE.cuda()
testVAE.eval()

prior = testVAE.prior(64)
prior_sample = testVAE.sample(prior)
decoded_prior_sample = testVAE.decode(prior_sample)
sampled_decode_content = decoded_prior_sample.sample().view(64, 28, 28)

sampled_decode_content = sampled_decode_content.cpu()
fig, axs = plt.subplots(8, 8, figsize=(10, 10), squeeze=False)
i = 0
for ax in axs.flat:
    ax.imshow(sampled_decode_content[i], cmap='gray')
    ax.axis('off') 
    i += 1
plt.tight_layout()

In [None]:
sample_cnt = binarized_mnist_test_data.data.shape[0]
#sample_cnt = 100
samples_tensor = torch.empty(sample_cnt,28,28)
for idx in range(samples_tensor.shape[0]):
    samples_tensor[idx] = binarized_mnist_test_data[idx][0]

prior = testVAE.prior(sample_cnt)
samples_tensor = samples_tensor.view(sample_cnt,-1).to(testVAE.device)
samples_tensor.shape
posterior = testVAE.encode(samples_tensor)
z = testVAE.sample(posterior)
reconstruction = testVAE.decode(z)
elbo = testVAE.elbo(prior, posterior, reconstruction, samples_tensor, z) 

# 'float64' required because 'stdev' chokes on 'float32' which is the default type when detaching from GPU
elbo_ary = elbo.cpu().detach().numpy().astype('float64')
elbo_stddev = statistics.stdev(elbo_ary)
elbo_mean = statistics.mean(elbo_ary)
print("ELBO on %d test data: %.1f +/-%.1f" % (sample_cnt, elbo_mean,elbo_stddev))

**2.3.1: Extracting 10 samples per class for classification training**

**ToDo:** Consider more elegant solution for classification_sampler

In [None]:
from torch.utils.data.sampler import SubsetRandomSampler
import functools

def classification_sampler(labels):
    indices = []
    for i in range(10):
        #(tmp_indices,) = np.where(functools.reduce(lambda x, y: x | y, [labels.numpy() == i]))
        tmp_indices = np.where(labels.numpy() == i)[0]
        indices.append(random.choices(tmp_indices, k=10))
    indices = torch.Tensor(indices)
    indices = indices.view(1, -1).squeeze().int()
    return SubsetRandomSampler(indices)
    
classification_loader = DataLoader(binarized_mnist_train_data, batch_size=25,
                                   sampler=classification_sampler(binarized_mnist_train_data.train_labels))
# Accuracy of Test Accuracy estimates based on batch_size
#  10k: baseline
# 2000: 0.5% 0.9% 1.6%
# 5000: 0.2% 0.9% 1.2%
# Adding 2.5 sec by using all test data instead of just 1000. Worth the price.
classification_loader_test = DataLoader(binarized_mnist_test_data, shuffle=True, batch_size=10000)

In [None]:
# visual check that the same 100 pictures are printed, and only permutated. Evaluate this cell a couple of times.
for images, labels in classification_loader:
    fig, axs = plt.subplots(1, 25, figsize=(20, 25), squeeze=False)
    for i,ax in enumerate(axs.flat):
        ax.imshow(images[i], cmap='gray')
        ax.axis('off') 
    plt.tight_layout()

**2.3.2: Training classifier on latent representation**

In [None]:
# Define a classification model

class LatentClassifier(nn.Module):
    def __init__(self, latent_features:int) -> None:
        super(LatentClassifier, self).__init__()
        self.model = nn.Sequential(nn.BatchNorm1d(latent_features*2),    # MAGIC! raises accurcay from 50% to 75%
                                   nn.Linear(in_features=latent_features*2, out_features=10), 
                                   nn.Sigmoid(),  # Initial network used ReLU in output layer, however this was prone to give
                                                  # dead outputs, eg. classifier would often train so some classes would 
                                                  # never be guessed
                                   nn.Dropout(p=0.5)        # raises accuracy from 75% to 83%
                                  )
        
    def forward(self, x) -> Tensor:
        x = self.model(x)
        return x

In [None]:
# This class builds a M1 classifier. Output is logits tensor. UNDER CONSTRUCTION.
class M1Classifier(nn.Module):
    def __init__(self, VAE: nn.Module, LatentClassifier: nn.Module) -> None:
        super(M1Classifier, self).__init__()
        self.VAE = VAE
        self.LatentClassifier = LatentClassifier
        

In [None]:
def do_test_eval(cnt, epochs, num_of_evals):
    # Used for doing occasional print of test evaluation data.
    # Return true 'num_of_evals' times when running range(epochs) training
    # Will also trigger at start and end
    modulu = math.ceil(epochs/(num_of_evals-1))
    return cnt==0 or cnt==epochs-1 or cnt % modulu == modulu-1

In [None]:
%%time
testVAE.eval()

testLC = LatentClassifier(5)
testLC.to(testVAE.device)

criterion = nn.CrossEntropyLoss()
#optimizer = optim.SGD(testLC.parameters(), lr=0.001, momentum=0.9)
optimizer = optim.Adam(testLC.parameters(), lr=0.005)

epochs = 200
images_test, labels_test = next(iter(classification_loader_test))

for epoch in range(epochs):
    # Training
    testLC.train()
    for images, labels in classification_loader:
        images = images.to(testVAE.device)
        labels = labels.to(testVAE.device)
        outputs = testVAE(images)
        classifier_input = torch.cat((outputs['qz'].mu, outputs['qz'].sigma), 1).to(testVAE.device)
        optimizer.zero_grad()
        classifications = testLC(classifier_input).to(testVAE.device)
        loss = criterion(classifications, labels)
        loss.backward()
        optimizer.step()
    
    # Evaluating
    if do_test_eval(epoch,epochs,10):
        testLC.eval()
        confuse_matrix = np.zeros((10,10)).astype('int')
        outputs = testVAE(images_test)
        classifier_input = torch.cat((outputs['qz'].mu, outputs['qz'].sigma), 1)
        classifications = testLC(classifier_input)
        preds = torch.max(classifications, 1)[1]
        confuse_matrix_update(preds, labels_test, confuse_matrix)
        accuracy = confuse_matrix_accuracy(confuse_matrix)
        print("Epoch %3d, train loss %.3f, test accuracy: %.3f" % (epoch, loss.item(), accuracy))
print()
print(confuse_matrix)

In [None]:
# Continue training with x10 reduced 'lr' (learning rate)
classification_loader = DataLoader(binarized_mnist_train_data, batch_size=100,
                                   sampler=classification_sampler(binarized_mnist_train_data.train_labels))
optimizer = optim.Adam(testLC.parameters(), lr=0.0005)
for epoch in range(epochs):
    # Training
    testLC.train()
    for images, labels in classification_loader:
        images = images.to(testVAE.device)
        labels = labels.to(testVAE.device)
        outputs = testVAE(images)
        classifier_input = torch.cat((outputs['qz'].mu, outputs['qz'].sigma), 1).to(testVAE.device)
        optimizer.zero_grad()
        classifications = testLC(classifier_input).to(testVAE.device)
        loss = criterion(classifications, labels)
        loss.backward()
        optimizer.step()
    
    # Evaluating
    if do_test_eval(epoch,epochs,10):
        testLC.eval()
        confuse_matrix = np.zeros((10,10)).astype('int')
        outputs = testVAE(images_test)
        classifier_input = torch.cat((outputs['qz'].mu, outputs['qz'].sigma), 1)
        classifications = testLC(classifier_input)
        preds = torch.max(classifications, 1)[1]
        confuse_matrix_update(preds, labels_test, confuse_matrix)
        accuracy = confuse_matrix_accuracy(confuse_matrix)
        print("Epoch %3d, train loss %.3f, test accuracy: %.3f" % (epoch, loss.item(), accuracy))
print()
print(confuse_matrix)

**2.3.3: Classifying MNIST using simple FFNN**

Notes: training results are noise, rerunning training can give from 64% to 69%. Thus at least 3 restarts must be done to determine is a change was beneficial. With only single-layer FFNN I have difficulty getting above 69%. The dropout layer on the output improves results from 65% to 69%. This is surprising for me, because in the output layer each node is necessary, because it is needed to identify one digit.

**2.4: M2 implementation**

In [None]:
class M2_VAE(nn.Module):
    def __init__(self, 
                 input_shape: torch.Size,
                 latent_features: int,
                 classes: int, 
                 hidden_layers_preclass: [], 
                 hidden_layers_postclass: [], 
                 hidden_layers_classification: [],
                 hidden_layers_decoder: []):
        super(M2_VAE, self).__init__()
        
        # Core params
        self.input_shape = input_shape
        self.observation_features = np.prod(input_shape)
        self.register_buffer('prior_params', torch.zeros(torch.Size([1, 2*latent_features])))
        self.register_buffer('classification_prior_params', torch.full(torch.Size([1, classes]), 1 / classes))
        self.classes = classes
        # Having epochs here makes it possible to continue training, with correct epochs counting
        self.epochs = 0
        
        # Cuda enabling
        if torch.cuda.is_available():
            self.use_cuda = True
            self.device = "cuda:0"
        else:
            self.use_cuda = False
            self.device = "cpu"
        
        # Classifier construction
        classifier_shape = [self.observation_features]
        classifier_shape.extend(hidden_layers_classification)
        classifier_shape.append(classes)
        self.classifier = FF_NetworkConstructor(layers = classifier_shape,
                                                pre_batchnorm = False,
                                                hidden_batchnorm = True,
                                                hidden_activation = nn.ReLU(),
                                                final_activation = None)
        
        # Pre-classification network construction
        preclass_shape = [self.observation_features]
        preclass_shape.extend(hidden_layers_preclass)
        self.preclass_encoder = FF_NetworkConstructor(layers = preclass_shape,
                                                      pre_batchnorm = False,
                                                      hidden_batchnorm = True,
                                                      hidden_activation = nn.ReLU(),
                                                      final_activation = nn.ReLU())
        
        # Post-classification network construction
        postclass_shape = [hidden_layers_preclass[-1] + classes]
        postclass_shape.extend(hidden_layers_postclass)
        postclass_shape.append(latent_features * 2)
        self.postclass_encoder = FF_NetworkConstructor(layers = postclass_shape,
                                                       pre_batchnorm = False,
                                                       hidden_batchnorm = True,
                                                       hidden_activation = nn.ReLU(),
                                                       final_activation = None)
        
        # Decoder construction
        decoder_shape = [latent_features + classes]
        decoder_shape.extend(hidden_layers_decoder)
        decoder_shape.append(self.observation_features)
        self.decoder = FF_NetworkConstructor(layers = decoder_shape,
                                             pre_batchnorm = False,
                                             hidden_batchnorm = True,
                                             hidden_activation = nn.ReLU(),
                                             final_activation = None)
        
        # Move networks to cuda if available
        if self.use_cuda:
            self.classifier.cuda()
            self.preclass_encoder.cuda()
            self.postclass_encoder.cuda()
            self.decoder.cuda()

    def reset(self) -> None:
        for layers in self.children():
            for layer in layers:
                if hasattr(layer, 'reset_parameters'):
                    layer.reset_parameters()
        self.epochs = 0

    def prior(self, batch_size: int = 1) -> Distribution:
        local_prior_params = self.prior_params.expand(batch_size, *self.prior_params.shape[-1:])
        mu, log_sigma = local_prior_params.chunk(2, dim=-1)
        result = ReparameterizedDiagonalGaussian(mu, log_sigma)
        result.mu = result.mu.to(self.device)
        result.sigma = result.sigma.to(self.device)
        return result
    
    def classification_prior(self, batch_size: int = 1) -> Distribution:
        local_classification_prior_params = self.classification_prior_params.expand(batch_size, 
                                                                                    *self.classification_prior_params.shape[-1:])
        result = Categorical(probs = local_classification_prior_params)
        return result
    
    def classification_posterior(self, x: Tensor) -> Distribution:
        result = self.classifier(x)
        result = result.view(-1, self.classes)
        result = Categorical(logits = result)
        return result
    
    def classification_entropy(self, qy: Tensor) -> float:
        qy = qy * torch.log(qy)
        return -qy.sum(1)
        
    def encode(self, x: Tensor, y: Tensor = None) -> Tensor:
        # Classify if no classification is provided
        if y is None:
            y = self.classifier(x)
        # Encode input
        result = self.preclass_encoder(x)
        result = torch.cat((result, y), 1)
        result = self.postclass_encoder(result)
        mu, log_sigma =  result.chunk(2, dim=-1)
        return ReparameterizedDiagonalGaussian(mu, log_sigma)
    
    def decode(self, z: Tensor, y: Tensor) -> Distribution:
        px_logits = self.decoder(torch.cat((z, y), 1))
        px_logits = px_logits.view(-1, *self.input_shape)
        return Bernoulli(logits = px_logits)
    
    def onehot(self, y: int):
        result = torch.zeros(y.shape[0], self.classes).to(self.device)
        for i in range(len(y)):
            result[i][y[i]] = 1
        return result
    
    def loss(self,
             px: Distribution, 
             py: Distribution, 
             pz: Distribution, 
             qy: Distribution, 
             qz: Distribution, 
             x: Tensor,
             y: int, 
             z: Tensor,
             alpha: float,
             debug: bool = False) -> float:
        x = x.to(self.device)
        z = z.to(self.device)
        x = x.view(x.size(0), -1)
        
        py_logprob = py.logits.to(self.device)
        
        # If labels are not provided, sample from classification posterior
        if y is None:
            if True:
                # Paralel calcution giving 40% speedup compared to serial version below.
                x2 = x.repeat(10,1)
                y2 = torch.Tensor(0)
                for i in range(10):
                    y2 = torch.cat( (y2, torch.Tensor([i]).expand(x.shape[0])))
                y2 = y2.to(self.device)
                outputs = self.forward(x2, y2)
                px_logprob = outputs['px'].log_prob(x2.view(10*x.shape[0],-1)).sum(dim=1).view(-1,10)
                qz_logprob = outputs['qz'].log_prob(outputs['z']).sum(dim=1).view(-1, 10)
                pz_logprob = outputs['pz'].log_prob(outputs['z']).sum(dim=1).view(-1, 10)
            else:
                # Original serial calculation. Kept as fall-back until parallel version has been more tested
                px_logprob = torch.Tensor(0).to(self.device)
                qz_logprob = torch.Tensor(0).to(self.device)
                pz_logprob = torch.Tensor(0).to(self.device)
                for i in range(10):
                    _y = torch.Tensor([i]).expand(x.shape[0])
                    outputs = self.forward(x, _y)
                    px_logprob = torch.cat((px_logprob, outputs['px'].log_prob(x).sum(dim = 1).view(-1, 1)), 1)
                    qz_logprob = torch.cat((qz_logprob, outputs['qz'].log_prob(outputs['z']).sum(dim = 1).view(-1, 1)), 1)
                    pz_logprob = torch.cat((pz_logprob, outputs['pz'].log_prob(outputs['z']).sum(dim = 1).view(-1, 1)), 1)
            L = -(px_logprob + py_logprob + pz_logprob - qz_logprob)
            U = -(torch.mul(qy.probs, -L).sum(1) + self.classification_entropy(qy.probs))
            J = U
            return J
        
        else:
            y = y.to(self.device)
            y = y.view(-1, 1)
            px_logprob = px.log_prob(x).sum(dim = 1).view(-1, 1).repeat(1, 10)
            qz_logprob = qz.log_prob(z).sum(dim = 1).view(-1, 1).repeat(1, 10)
            pz_logprob = pz.log_prob(z).sum(dim = 1).view(-1, 1).repeat(1, 10)
            L = -(px_logprob + py_logprob + pz_logprob - qz_logprob)
            J = L.gather(1, y)
            J_alpha = J - (alpha * qy.logits.gather(1, y))
            return J_alpha
    
    def forward(self, x: Tensor, y: int = None, debug: bool = False) -> Dict[str, Any]:
        x = x.to(self.device) # Move to cuda if applicable
        x = x.view(x.size(0), -1) # Flatten image input
        qy = self.classification_posterior(x) # Classification posterior q(y|x)
        py = self.classification_prior(batch_size = x.size(0)) # Classification prior p(y)
        pz = self.prior(batch_size=x.size(0)) # Prior p(z)
        if y is None: # If labels are not provided, sample from classification posterior
            try:
                y = qy.sample()
            except:
                print(self.classifier(x))
                raise Exception("QY sample error")
        
        y = y.to(self.device)
        y = y.int()
        y = self.onehot(y)
        qz = self.encode(x, y) # Approximate posterior q(z|x, y)
        z = qz.rsample() # Sample the posterior
        px = self.decode(z, y) # Reconstruction p(x|z, y) = B(x | g(z, y))
        
        return {'px': px, 'py': py, 'pz': pz, 'qy': qy, 'qz': qz, 'z': z}

In [None]:
class M2_VAE_Trainer:
    def __init__(self,
                 network:M2_VAE              = None,
                 train_labelled:DataLoader   = None,
                 train_unlabelled:DataLoader = None,
                 test:DataLoader             = None,
                 labelled_passes:int         = 1,
                 alpha:float                 = 0.1,
                 verbose:int                 = 0):
        self.model            = network          if network          != None else TestM2
        self.train_labelled   = train_labelled   if train_labelled   != None else binarized_mnist_train_loader_labelled
        self.train_unlabelled = train_unlabelled if train_unlabelled != None else binarized_mnist_train_loader_unlabelled
        self.test_data        = test             if test             != None else binarized_mnist_test_loader
        self.alpha = alpha
        self.training_data = defaultdict(list)
        self.validation_data = defaultdict(list)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
        self.labelled_passes = labelled_passes
        self.unlbl_bcnt = 0
        self.lbl_bcnt = 0
        self.tst_bcnt = 0
        self.verbose = verbose
    
    def train(self):
        self.model.train()
        if self.train_unlabelled is not None:
            for images, labels in self.train_unlabelled:
                self.unlbl_bcnt += 1
                self.optimizer.zero_grad()
                outputs = self.model(images, None)
                loss = self.model.loss(outputs['px'], 
                                       outputs['py'], 
                                       outputs['pz'], 
                                       outputs['qy'], 
                                       outputs['qz'],
                                       images, 
                                       None,
                                       outputs['z'],
                                       self.alpha).mean()
                loss.backward()
                self.optimizer.step()
        for i in range(self.labelled_passes):
            for images, labels in self.train_labelled:
                self.lbl_bcnt += 1
                self.optimizer.zero_grad()
                outputs = self.model(images, labels)
                loss = self.model.loss(outputs['px'], 
                                       outputs['py'], 
                                       outputs['pz'], 
                                       outputs['qy'], 
                                       outputs['qz'],
                                       images, 
                                       labels,
                                       outputs['z'],
                                       self.alpha).mean()
                loss.backward()
                self.optimizer.step()
        self.model.epochs += 1
    
    def test(self):
        self.model.eval()
        losses = []
        i = 0
        epoch_data = defaultdict(list)
        confuse_matrix = np.zeros((10,10)).astype('int')
        for images, labels in self.test_data:
            self.tst_bcnt += 1
            outputs = self.model(images, labels)
            loss = self.model.loss(outputs['px'], 
                                   outputs['py'], 
                                   outputs['pz'], 
                                   outputs['qy'], 
                                   outputs['qz'],
                                   images, 
                                   labels,
                                   outputs['z'],
                                   self.alpha).mean()
            epoch_data['loss'] += [loss.item()]

            classifications = self.model.classifier(images.view(-1,28*28).to(self.model.device))
            preds = torch.argmax(classifications,1)
            confuse_matrix_update(preds, labels, confuse_matrix)

        for k, v in epoch_data.items():
            self.validation_data[k] += [np.mean(v)]
        self.validation_data['class_accu'] += [confuse_matrix_accuracy(confuse_matrix)]

        print("epoch=%d  test_loss=%.2f  class_accu=%.3f" % (self.model.epochs, 
                                                             self.validation_data['loss'][-1], 
                                                             self.validation_data['class_accu'][-1]))
        print("         unlbl_bcnt=%d, lbl_bcnt=%d, tst_bcnt=%d" % (self.unlbl_bcnt,
                                                                    self.lbl_bcnt,
                                                                    self.tst_bcnt)) if self.verbose > 0 else None

In [None]:
sample = binarized_mnist_train_data.__getitem__(22)[0]
TestM2 = M2_VAE(sample.flatten().shape, # Input shape
                5, # Latent features
                10, # Classes
                [200], # Network dimensions for encoder before adding classifications
                [200], # Network dimensions for encoder after adding classification
                [200, 100], # Network dimensions for classification network
                [200, 400, 600]) # Network dimensions for decoder
#print(TestM2)

In [None]:
# Saved models. Enable one of these lines to load model
fn = "M2_axbr1_0p76_10ep.pt"    # 2020-12-04 11:00, 100 labelled, 0 unlabelled, 10epochs*10passes, 56 sec training
fn = ""

In [None]:
loader_setup(labelled_size=100)                     # mixed, unlabelled will be 50k - 100
loader_setup(labelled_size=100, unlabelled_size=0)  # skip unlabelled training
#loader_setup(labelled_size=0, unlabelled_size=1000, test_size=10)  # debugging

In [None]:
TestM2.reset()

In [None]:
%%time
trainer = M2_VAE_Trainer(labelled_passes = 10)

epochs = 10
if fn != "" and os.path.isfile(fn):
    TestM2.load_state_dict(torch.load(fn))
    print("Loaded model from %s" % fn)
    trainer.test()
else:
    trainer.test()
    for i in range(epochs):
        trainer.train()
        trainer.test()

In [None]:
# Execution speed measurements
# measure    calc.   
# sec        sec   
#  8.4       4.2    10k test
# 44         36     10k unlbl train
# 80,76,75    6.8   10k lbl train

In [None]:
fn = ""   # fill-out to save model you trained above, and document the file-name is the cell higher up
if fn != "":
    torch.save(TestM2.state_dict(), fn)

In [None]:
print_confuse_matrix(TestM2)

**STYLE TRANSFER**

In [None]:
TestM2.eval()
a = random.choices(binarized_mnist_train_data,k=10)
b = [x[0] for x in a]
img = torch.stack(b)
b = [x[1] for x in a]
lbl = torch.Tensor(b)
fig, axs = plt.subplots(10, 11, figsize=(10, 10), squeeze=False)

outputs = TestM2(img, lbl)
z = outputs['z']

for i in range(len(img)):
    axs[i, 0].imshow(img[i], cmap='gray')
    axs[i, 0].axis('off')
    for j in range(10):
        image = TestM2.decode(z[i].reshape(1, -1), TestM2.onehot(torch.Tensor([j]).int())).sample().view(28,28).cpu()
        axs[i, j + 1].imshow(image, cmap='gray')
        axs[i, j + 1].axis('off')
plt.tight_layout()