<a href="https://colab.research.google.com/github/VatsalRaina/variational_continual_learning/blob/master/VCL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [75]:
!pip -q install blitz-bayesian-pytorch
!pip -q install transformers

In [76]:
!wget "https://github.com/VatsalRaina/variational_continual_learning/raw/main/discriminative/permutedMNIST/mnist.pkl.gz"

--2021-03-11 09:50:49--  https://github.com/VatsalRaina/variational_continual_learning/raw/main/discriminative/permutedMNIST/mnist.pkl.gz
Resolving github.com (github.com)... 140.82.114.4
Connecting to github.com (github.com)|140.82.114.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/VatsalRaina/variational_continual_learning/main/discriminative/permutedMNIST/mnist.pkl.gz [following]
--2021-03-11 09:50:49--  https://raw.githubusercontent.com/VatsalRaina/variational_continual_learning/main/discriminative/permutedMNIST/mnist.pkl.gz
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 16168813 (15M) [application/octet-stream]
Saving to: ‘mnist.pkl.gz.1’


2021-03-11 09:50:49 (37.7 MB/s) - ‘mnist.pkl.

In [127]:
import torch
import math
from torch.nn.parameter import Parameter
class VCL_layer(torch.nn.Module):
    def __init__(self, input_size: int, output_size: int, init_variance: float, previous_W_b:None):
        super().__init__()
        self.epsilon = 1e-8
        self.input_size = input_size
        self.output_size = output_size
        self.init_variance = init_variance

        if previous_W_b !=None:
            self.register_buffer('prior_W_mean', torch.from_numpy(previous_W_b[0])) # Reversed dimensions for torch.nn.functional.linear
            self.register_buffer('prior_b_mean', torch.from_numpy(previous_W_b[1]))
        else:
            self.register_buffer('prior_W_mean', torch.randn((output_size, input_size))/10) # Reversed dimensions for torch.nn.functional.linear
            self.register_buffer('prior_b_mean', torch.randn(output_size)/10)
        self.register_buffer('prior_W_logvar', torch.ones(self.output_size,self.input_size)*init_variance)
        self.register_buffer('prior_b_logvar', torch.ones(self.output_size)*init_variance)

        if previous_W_b !=None:
            self.posterior_W_mean = Parameter(torch.tensor(previous_W_b[0], requires_grad=True))
            self.posterior_b_mean = Parameter(torch.tensor(previous_W_b[1], requires_grad=True))
        else:
            self.posterior_W_mean = Parameter(torch.randn((output_size, input_size), requires_grad=True)/10)
            self.posterior_b_mean = Parameter(torch.randn((output_size), requires_grad=True)/10)
        self.posterior_W_logvar = Parameter(torch.nn.init.constant(torch.empty((self.output_size,self.input_size), requires_grad=True), math.log(self.init_variance)))
        self.posterior_b_logvar = Parameter(torch.nn.init.constant(torch.empty(self.output_size,requires_grad=True), math.log(self.init_variance)))
        self.register_parameter('posterior_W_mean', self.posterior_W_mean)
        self.register_parameter('posterior_b_mean', self.posterior_b_mean)
        self.register_parameter('posterior_W_logvar', self.posterior_W_logvar)
        self.register_parameter('posterior_b_logvar', self.posterior_b_logvar)

    def sample_parameters(self):
        epsilon_W = torch.randn_like(self.posterior_W_mean)
        epsilon_b = torch.randn_like(self.posterior_b_mean)
        W_sample = self.posterior_W_mean + epsilon_W * torch.exp(0.5 * self.posterior_W_logvar) # Element-wise multiplication of epsilon with variance
        b_sample = self.posterior_b_mean + epsilon_b * torch.exp(0.5 * self.posterior_b_logvar) 
        return W_sample, b_sample

    def forward(self, x):
        W, b = self.sample_parameters()
        return torch.nn.functional.linear(x, W, b) # No activation function here, will be managed in main model

    def kl_divergence(self):
        #TODO: redo the demonstration of this
        prior_means = torch.autograd.Variable(torch.cat(
            (torch.reshape(self.prior_W_mean, (-1,)),
             torch.reshape(self.prior_b_mean, (-1,)))),
            requires_grad=False
        )
        prior_logvars = torch.autograd.Variable(torch.cat(
            (torch.reshape(self.prior_W_logvar, (-1,)),
             torch.reshape(self.prior_b_logvar, (-1,)))),
            requires_grad=False
        )
        prior_vars = torch.exp(prior_logvars)

        posterior_means = torch.cat(
            (torch.reshape(self.posterior_W_mean, (-1,)),
             torch.reshape(self.posterior_b_mean, (-1,))),
        )
        posterior_logvars = torch.cat(
            (torch.reshape(self.posterior_W_logvar, (-1,)),
             torch.reshape(self.posterior_b_logvar, (-1,))),
        )
        posterior_vars = torch.exp(posterior_logvars)

        # compute kl divergence (this computation is valid for multivariate diagonal Gaussians)
        kl_elementwise = posterior_vars / (prior_vars + self.epsilon) + \
                         torch.pow(prior_means - posterior_means, 2) / (prior_vars + self.epsilon) - \
                         1 + prior_logvars - posterior_logvars
        return 0.5 * kl_elementwise.sum()
    
    def update_prior_posterior(self):
        """The previous posterior becomes the new prior"""
        self._buffers['prior_W_mean'].data.copy_(self.posterior_W_mean.data)
        self._buffers['prior_W_logvar'].data.copy_(self.posterior_W_logvar.data)
        self._buffers['prior_b_mean'].data.copy_(self.posterior_b_mean.data)
        self._buffers['prior_b_logvar'].data.copy_(self.posterior_b_logvar.data)

Class to store VCL models

In [128]:
import torch
import torchvision.models as models

from blitz.modules import BayesianLinear
from blitz.utils import variational_estimator

class Vanilla_NN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, n_hidden_layers=2):

        super(Vanilla_NN, self).__init__()

        self.relu = torch.nn.ReLU()
        self.input_layer = torch.nn.Linear(input_dim, hidden_dim)
        self.hidden_layers = []
        self.n_hidden_layers = n_hidden_layers
        for i in range(n_hidden_layers):
            self.hidden_layers.append(torch.nn.Linear(hidden_dim, hidden_dim))
        self.output_layer = torch.nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        h1 = self.relu(self.input_layer(x))
        for hidden_layer in self.hidden_layers:
            h2 = self.relu(hidden_layer(h1))
            h1=h2
        prediction_logits = self.output_layer(h2)
        
        return prediction_logits

    
    def get_parameters(self):
        """ Returns model weights and biases as a dictionnary: 
        dic['input'] = (W_mean_input, b_mean_input)
        """
        input_W_b = (self.input_layer.weight.detach().numpy(),self.input_layer.bias.detach().numpy())
        layers_W_b = [(self.hidden_layers[i].weight.detach().numpy(),self.hidden_layers[i].bias.detach().numpy()) for i in range(self.n_hidden_layers)]
        output_W_b = (self.output_layer.weight.detach().numpy(),self.output_layer.bias.detach().numpy())
        dic = {'input':input_W_b, 'layers':layers_W_b, 'output':output_W_b}
        return dic



# See explanation at:
# https://towardsdatascience.com/blitz-a-bayesian-neural-network-library-for-pytorch-82f9998916c7
# Original code at:
# https://github.com/piEsposito/blitz-bayesian-deep-learning 
@variational_estimator
class MFVI_NN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, n_heads, prev_weights, n_hidden_layers=2):

        super(MFVI_NN, self).__init__()

        self.relu = torch.nn.ReLU()
        self.input_layer = BayesianLinear(input_dim, hidden_dim)
        self.hidden_layers = []
        for i in range(n_hidden_layers):
            self.hidden_layers.append(BayesianLinear(hidden_dim, hidden_dim))
        self.heads = []
        for i in range(n_heads):
            self.heads.append(BayesianLinear(hidden_dim, output_dim))

        # Initialise using the Vanilla neural network weights when the model is first initialised
        # self.init_weights(prev_weights)

    def init_weights(self, prev_weights):
        """
        Initialise using Vanilla neural netwrok parameters for the means and a pre-decided variance
        """
        # Instead of initialising using Vanilla NN, this model can be trained with all the variances for the first task
        pass

    def forward(self, x, task):

        h1 = self.relu(self.input_layer(x))
        for hidden_layer in self.hidden_layers:
            h2 = self.relu(hidden_layer(h1))
            h1=h2
        prediction_logits = self.heads[task](h2)
        
        return prediction_logits




class VCL_discriminative(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, n_heads, prev_weights, n_hidden_layers=3):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.n_hidden_layers = n_hidden_layers
        self.n_heads = n_heads
        self.init_variance = 1e-3

        self.relu = torch.nn.ReLU()
        self.input_layer = VCL_layer(input_dim, hidden_dim, self.init_variance, previous_W_b=prev_weights['input'])
        self.hidden_layers = torch.nn.ModuleList([VCL_layer(hidden_dim, hidden_dim, self.init_variance,previous_W_b=prev_weights['layers'][i]) for i in range(n_hidden_layers)])

        self.heads = torch.nn.ModuleList([VCL_layer(hidden_dim, output_dim, self.init_variance, previous_W_b=prev_weights['output']) for _ in range(n_heads)])

        self.softmax = torch.nn.Softmax(dim=1)
        return

    def forward(self, x, task:int):
        x = self.relu(self.input_layer(x))
        for layer in self.hidden_layers:
            x = self.relu(layer(x))
        x = self.softmax(self.heads[task](x))
        return x

    def vcl_loss(self, x, y, task):
        return self.kl_divergence(task) - torch.nn.NLLLoss()(self(x, task), y)

    def kl_divergence(self, task:int):
        div = torch.zeros(1, requires_grad=False)
        for layer in self.hidden_layers:
            div = torch.add(div, layer.kl_divergence())
        div = torch.add(div, self.heads[task].kl_divergence())
        return div

    def update_prior_posterior(self, head:int):
        for layer in self.hidden_layers:
            layer.update_prior_posterior()
        self.heads[head].update_prior_posterior()
        return

    def prediction(self, x, head:int):
        return torch.argmax(self(x, task))


In [129]:
import gzip
import pickle 
from copy import deepcopy
from PIL import Image
import numpy as np

class PermutedMnistGenerator():
    def __init__(self, num_tasks=10):
        #Unzipping and reading Compressex MNIST DATA
        f = gzip.open('mnist.pkl.gz', 'rb')
        u = pickle._Unpickler( f )
        u.encoding = 'latin1'
        train_set, valid_set, test_set = u.load()
        f.close()

        self.X_train = np.vstack((train_set[0], valid_set[0]))
        self.Y_train = np.hstack((train_set[1], valid_set[1]))
        self.X_test = test_set[0]
        self.Y_test = test_set[1]

    def create_tasks(self, num_tasks=10):
        np.random.seed(0)

        X_train, Y_train, X_test, Y_test = [], [], [] ,[]

        for i in range(num_tasks):
            x_train, y_train, x_test, y_test = self.generate_new_task()
            X_train.append(x_train)
            Y_train.append(y_train)
            X_test.append(x_test)
            Y_test.append(y_test)

        return (X_train, Y_train, X_test, Y_test)

    def print_example(self, examples=[0]):
        for example in examples:
            array = self.X_train[example]
            array_2D = np.reshape(array, (28, 28))
            img = Image.fromarray(np.uint8(array_2D * 255) , 'L')
            img.show()

    def generate_new_task(self):
        perm_inds = list(range(self.X_train.shape[1]))
        np.random.shuffle(perm_inds)

        # Retrieve train data
        x_train = deepcopy(self.X_train)
        x_train = x_train[:,perm_inds]
        # y_train = np.eye(10)[self.Y_train]   #One hot encodes labels
        y_train = self.Y_train

        # Retrieve test data
        x_test = deepcopy(self.X_test)
        x_test = x_test[:,perm_inds]
        # y_test = np.eye(10)[self.Y_test]
        y_test = self.Y_test

        return x_train, y_train, x_test, y_test


In [130]:
import time
import numpy as np
import random
import datetime

def format_time(elapsed):
    elapsed_rounded = int(round((elapsed)))
    return str(datetime.timedelta(seconds=elapsed_rounded))

def get_default_device():
    if torch.cuda.is_available():
        print("Got CUDA!")
        return torch.device('cuda')
    else:
        return torch.device('cpu')
############################ Handle to use fron scrach model or not ######################
###########################################################################################
use_from_scratch_model = use_from_scratch_model
# Set the seed value all over the place to make this reproducible.
seed_val = seed
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

# Choose device
device = get_default_device()


In [131]:
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

data_processor = PermutedMnistGenerator(num_tasks)
X_train, Y_train, X_test, Y_test = data_processor.create_tasks(num_tasks)
x_train, y_train = torch.tensor(X_train[0]).to(device), torch.tensor(Y_train[0]).long().to(device)
train_data = TensorDataset(x_train, y_train)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

In [132]:
hidden_size=100
n_hidden=2
batch_size=256
no_epochs=100
num_tasks=5
coreset_size=0
adam_epsilon=1e-8
learning_rate=2e-3
use_from_scratch_model=True
seed = 2

In [137]:
from transformers import AdamW, get_linear_schedule_with_warmup

################## Train Vanilla_NN using data for first task ######################

vanilla_model = Vanilla_NN(input_dim=x_train.size()[1], hidden_dim=hidden_size, output_dim=10, n_hidden_layers=n_hidden).to(device)
mf_weights = vanilla_model.get_parameters()

optimizer = AdamW(vanilla_model.parameters(),
                lr = learning_rate,
                eps = adam_epsilon
                )
loss_values = []
criterion = torch.nn.CrossEntropyLoss()

for epoch in range(5):
    print(f"\n ======== Epoch {epoch + 1} / {no_epochs} ========")
    print('Training...')
    t0 = time.time()
    total_loss = 0
    vanilla_model.train()
    vanilla_model.zero_grad()
    for step, batch in enumerate(train_dataloader):
        # Progress update every 40 batches.
        if step % 40 == 0 and not step == 0:
            elapsed = format_time(time.time() - t0)
            print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(train_dataloader), elapsed))
        b_x = batch[0].to(device)
        b_y = batch[1].to(device)
        vanilla_model.zero_grad()
        logits = vanilla_model(b_x)
        loss = criterion(logits, b_y)
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        # Clip the norm of the gradients to 1.0.
        # This is to help prevent the "exploding gradients" problem.
        torch.nn.utils.clip_grad_norm_(vanilla_model.parameters(), 1.0)
        optimizer.step()
    # Calculate the average loss over the training data.
    avg_train_loss = total_loss / len(train_dataloader)

    # Store the loss value for plotting the learning curve.
    loss_values.append(avg_train_loss)

    print("")
    print("  Average training loss: {0:.2f}".format(avg_train_loss))
    print("  Training epoch took: {:}".format(format_time(time.time() - t0)))

vanilla_weights = vanilla_model.get_parameters()


Training...
  Batch    40  of    235.    Elapsed: 0:00:00.
  Batch    80  of    235.    Elapsed: 0:00:01.
  Batch   120  of    235.    Elapsed: 0:00:01.
  Batch   160  of    235.    Elapsed: 0:00:02.
  Batch   200  of    235.    Elapsed: 0:00:02.

  Average training loss: 0.64
  Training epoch took: 0:00:02

Training...
  Batch    40  of    235.    Elapsed: 0:00:00.
  Batch    80  of    235.    Elapsed: 0:00:01.
  Batch   120  of    235.    Elapsed: 0:00:01.
  Batch   160  of    235.    Elapsed: 0:00:02.
  Batch   200  of    235.    Elapsed: 0:00:02.

  Average training loss: 0.23
  Training epoch took: 0:00:02

Training...
  Batch    40  of    235.    Elapsed: 0:00:00.
  Batch    80  of    235.    Elapsed: 0:00:01.
  Batch   120  of    235.    Elapsed: 0:00:01.
  Batch   160  of    235.    Elapsed: 0:00:02.
  Batch   200  of    235.    Elapsed: 0:00:02.

  Average training loss: 0.18
  Training epoch took: 0:00:02

Training...
  Batch    40  of    235.    Elapsed: 0:00:00.
  Batch   

In [139]:
# Now we are at a stage where we can extract the weights from the above trained model and call them the W_means and b_means
init_weights = vanilla_weights
init_weights = {'input':None, 'layers':[None]*n_hidden, 'output':None}

################## Train MFVI NN #######################

if use_from_scratch_model:
    model = VCL_discriminative(input_dim = x_train.size()[1], hidden_dim=hidden_size, output_dim=10, \
                               n_heads=num_tasks, prev_weights=init_weights, n_hidden_layers=n_hidden).to(device)
    optimizer = AdamW(model.parameters(), lr = learning_rate, eps = adam_epsilon)

else:
    model = MFVI_NN(in_dim=x_train.size()[1], hidden_dim=hidden_size, out_dim=10, num_tasks=num_tasks, \
                    prev_weights=vanilla_weights, n_hidden_layers=n_hidden).to(device)
    criterion = torch.nn.CrossEntropyLoss()


for task_id in range(num_tasks):
    # Extract task specific data
    x_train, y_train = torch.tensor(X_train[task_id]).to(device), torch.tensor(Y_train[task_id]).long().to(device)
    train_data = TensorDataset(x_train, y_train)
    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)
    # Set optimizer to be equal to all the shared parameters and the task specific head's parameters
    if not use_from_scratch_model:
        parameters = []
        parameters.extend(model.inputLayer.parameters())
        for hiddenlayer in model.hiddenLayers:
            parameters.extend(hiddenlayer.parameters())

        parameters.extend(model.outputHeads[task_id].parameters())
        optimizer = AdamW(parameters, lr = learning_rate, eps = adam_epsilon)
    loss_values = []

    model.train()
    for epoch in range(no_epochs):
        total_loss = 0
        for step, batch in enumerate(train_dataloader):
            b_x = batch[0].to(device)
            b_y = batch[1].to(device)
            model.zero_grad()
            optimizer.zero_grad()

            if not use_from_scratch_model:
                prediction_logits = model(b_x, task_id)
                fit_loss = criterion(prediction_logits, b_y)
                # This is an inbuilt function for the imported BNN
                # However, this KL term is finding the KL divergence between the setting of parameters in the current and previous mini-batch
                # We are actually interested in finding the KL divergence between the setting of the parameters in the current mini-batch 
                # and the the final setting of the parameters from the previous TASK
                # So we will need to write our own KL divergence function which finds KL only for the shared parameters
                complexity_loss = model.nn_kl_divergence()  
                loss = fit_loss + complexity_loss
            
            if use_from_scratch_model:
                loss = model.vcl_loss(b_x, b_y, task_id)
            total_loss += loss.item()
            loss.backward()
            optimizer.step()
        avg_train_loss = total_loss / len(train_dataloader)
        loss_values.append(avg_train_loss)
        print(f"\n  Average training loss: {avg_train_loss:.2f}")

    # Now perform evaluation on the test data
    x_test, y_test = X_test[task_id], Y_test[task_id]
    #TODO 




  Average training loss: 67539.39


KeyboardInterrupt: ignored