# Variational Continual Learning

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import numpy as np
# visualization tools
import matplotlib
import matplotlib.pyplot as plt
# datasets
import torchvision.datasets as datasets

We begin by defining `VCLModel`, a Bayesian multi-head neural network which updates its parameters using variational inference.

In [0]:
class VCL_NN(nn.Module):
  def __init__(self, input_size: int, out_size: int, layer_width: int, n_hidden_layers: int):
    super(VCL_NN, self).__init__()
    self.input_size = input_size
    self.out_size = out_size
    self.n_hidden_layers = n_hidden_layers
    self.layer_width = layer_width
    self.prior, self.posterior = None, None
    init_prior()
  
  def forward(self, x):
    (w_means, w_vars), (b_means, b_vars) = self.posterior
    w_epsilons = torch.randn_like(w_means)
    b_epsilons = torch.randn_like(b_means)

    sampled_weights = w_means + w_epsilons * torch.sqrt(w_vars)
    sampled_bias = b_means + b_epsilons * torch.sqrt(b_vars)

    for i, layer in enumerated(sampled_weights):
      x = F.relu(layer @ x + sampled_bias[i])
      
    return x

  def prediction(self):
    pass
  
  def calculate_KL_term(self):
    '''Calculates and returns KL(posterior, prior). Formula from L3 slide 14.'''
    # Concatenate w and b statistics into one tensor for ease of calculation
    ((prior_w_means, prior_w_vars), (prior_b_means, prior_b_vars)) = self.prior
    prior_means = torch.cat((prior_w_means, prior_b_means), axis=0)
    prior_vars = torch.cat((prior_w_vars, prior_b_vars), axis=0)
    
    ((post_w_means, post_w_vars), (post_b_means, post_b_vars)) = self.posterior
    post_means = torch.cat((post_w_means, post_b_means), axis=0)
    post_vars = torch.cat((post_w_vars, post_b_vars), axis=0)
    
    # Calculate KL for individual normal distributions over parameters
    KL_elementwise = 
      post_vars / prior_vars +
      torch.pow(prior_means - post_means, 2) / prior_vars
      - 1 + torch.log(prior_vars / post_vars)
    
    # Sum KL over all parameters 
    return 0.5 * KL_elementwise.sum()
    
  
  def loss(self):
    pass
  
  def init_prior(self):
    if self.prior == None:
      w_means = torch.zeros(self.n_hidden_layers + 1, self.layer_width, self.layer_width)
      w_vars = torch.ones(self.n_hidden_layers + 1, self.layer_width, self.layer_width)
      b_means = torch.zeros(self.n_hidden_layers + 1, self.layer_width, self.layer_width)
      b_vars = torch.ones(self.n_hidden_layers + 1, self.layer_width, self.layer_width)
      self.prior = ((w_means, w_vars), (b_means, b_vars))
    else:
      self.prior = self.posterior

# EXPERIMENTS
In this section we replicate the experiments from the original paper. The experiments are divided into discriminative and generative cases.

## Discriminative Experiments
### 1)  Permuted MNIST
In the permuted MNIST experiment, each of the *T* datasets is a copy of the MNIST dataset where the pixels in the image have undergone a random fixed permutation. The original paper compares the performance (average test set accuracy) of VCL to EWC, SI, and diagonal LP.

In [0]:
# download MNIST
mnist_train = datasets.MNIST(root='./data/', train=True, download=True)
mnist_test = datasets.MNIST(root='./data/', train=False, download=True)

mnist_in_dim = 28*28
mnist_out_dim = 10

print(mnist_train[0][1].value())

# transformation that is applied to each new task
perm = torch.randperm(mnist_in_dim)

# dims are the same as in original paper
model = VCL_NN(mnist_in_dim, mnist_out_dim, 100, 2)

# each task is a random permutation of MNIST
for task in range(10):
  mnist_train = None

# test
for task in range(10):
  pass
  

AttributeError: ignored

### 2) Split MNIST
In the split MNIST experiment, the MNIST dataset is divided into five binary classification tasks. The tasks are (0/1), (2/3), (4/5), (6/7), (8/9). Again the original paper compares the *average test set accuracy* of VCL against EWC, SI, and diagonal LP.

In [0]:
# split MNIST
mnist_train = datas

### 3) Split notMNIST
In this experiment, the notMNIST dataset is used, which consists of images of 10 different characters (from A to J). The dataset is then split into five binary classification tasks as it was in the second experiment. The comparison is again *average test set accuracy* against EWC, SI, and diagonal LP.

In [0]:
# split notMNIST


## Generative Experiments