<a href="https://colab.research.google.com/github/dchu1/AI_P2_cl/blob/master/SynapticIntelligence.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
n_tasks = 20
n_epochs = 3
print_messages = False

# Imports

In [3]:
pip install ax-platform

Collecting ax-platform
[?25l  Downloading https://files.pythonhosted.org/packages/c3/e5/defa97540bf23447f15d142a644eed9a9d9fd1925cf1e3c4f47a49282ec0/ax_platform-0.1.9-py3-none-any.whl (499kB)
[K     |████████████████████████████████| 501kB 2.8MB/s 
Collecting botorch==0.2.1
[?25l  Downloading https://files.pythonhosted.org/packages/64/e4/d696b12a84d505e9592fb6f8458a968b19efc22e30cc517dd2d2817e27e4/botorch-0.2.1-py3-none-any.whl (221kB)
[K     |████████████████████████████████| 225kB 32.8MB/s 
Collecting gpytorch>=1.0.0
[?25l  Downloading https://files.pythonhosted.org/packages/9c/5f/ce79e35c1a36deb25a0eac0f67bfe85fb8350eb8e19223950c3d615e5e9a/gpytorch-1.0.1.tar.gz (229kB)
[K     |████████████████████████████████| 235kB 44.0MB/s 
Building wheels for collected packages: gpytorch
  Building wheel for gpytorch (setup.py) ... [?25l[?25hdone
  Created wheel for gpytorch: filename=gpytorch-1.0.1-py2.py3-none-any.whl size=390441 sha256=0995ea6764052cca65a0d4ccb506f7a25d1ae71fd843a3803b

In [0]:
import math

import torch
from torch.nn.parameter import Parameter
from torch.nn import init
from torch.nn import Module

import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import torch.optim as optim

import torchvision
from torchvision import datasets, transforms

import torch.utils.data as data_utils

import numpy as np
import subprocess
import os
import random
from PIL import Image

import matplotlib.pyplot as plt
from IPython.core.debugger import set_trace

# Constructing Data Set

In [0]:
def rotate_dataset(d, rotation):
  result = torch.FloatTensor(d.size(0), 784)
  tensor = transforms.ToTensor()

  for i in range(d.size(0)):
    img = Image.fromarray(d[i].numpy(), mode="L")
    result[i] = tensor(img.rotate(rotation)).view(784)
  return result

mnist_path = "mnist.npz"
if not os.path.exists(os.path.join("/content", mnist_path)):
  subprocess.call("wget https://s3.amazonaws.com/img-datasets/mnist.npz", shell=True)

f = np.load(mnist_path)
x_tr = torch.from_numpy(f["x_train"])
y_tr = torch.from_numpy(f["y_train"]).long()
x_te = torch.from_numpy(f["x_test"])
y_te = torch.from_numpy(f["y_test"]).long()
f.close()

# Rotate Dataset
tasks_tr = []
tasks_te = []
mnist_rot_path = "mnist_rotations.pt"
if not os.path.exists(os.path.join("/content", mnist_rot_path)):
    torch.manual_seed(0)

    for t in range(n_tasks):
      min_rot = 1.0 * t / n_tasks * (180.0 - 0.0) + 0.0
      max_rot = 1.0 * (t + 1) / n_tasks * (180.0 - 0.0) + 0.0
      rot = random.random() * (max_rot - min_rot) + min_rot

      tasks_tr.append([rot, rotate_dataset(x_tr, rot), y_tr])
      tasks_te.append([rot, rotate_dataset(x_te, rot), y_te])

    torch.save([tasks_tr, tasks_te], 'mnist_rotations.pt')
else:
    tasks_tr, tasks_te = torch.load('/content/mnist_rotations.pt')

# Defining Synaptic Intelligence Model (Simple)

In [0]:
class SIModel(nn.Module):
    def __init__(self):
        super(SIModel, self).__init__()

        # SI Hyperparameters
        self.si_c = 0.           #-> hyperparam: how strong to weigh SI-loss ("regularisation strength")
        self.epsilon = 0.1      #-> dampening parameter: bounds 'omega' when squared parameter-change goes to 0
    
    def init_weights(self, m):
        if type(m) == nn.Linear:
            torch.nn.init.kaiming_uniform_(m.weight)
            m.bias.data.fill_(0.01)

    def init(self, n_neurons):
        # Our Network
        self.net = nn.Sequential(
            nn.Linear(28*28, n_neurons),
            nn.ReLU(),
            nn.Linear(n_neurons, n_neurons),
            nn.ReLU(),
            nn.Linear(n_neurons, 10)
        )
        self.net.apply(self.init_weights)
        # self.fc1 = nn.Linear(28*28, n_neurons)
        # self.fc2 = nn.Linear(n_neurons, n_neurons)
        # self.fc3 = nn.Linear(n_neurons, 10)


    def forward(self, x):
        return self.net(x)

        # x = self.fc1(x)
        # x = F.relu(x)
        # x = self.fc2(x)
        # x = F.relu(x)
        # y = self.fc3(x)
        # return y

    def update_omega(self, W):
        '''After completing training on a task, update the per-parameter regularization strength.

        [W]         <dict> estimated parameter-specific contribution to changes in total loss of completed task
        '''

        # Loop over all parameters
        for n, p in self.named_parameters():
            if p.requires_grad:
                n = n.replace('.', '__')

                # Find/calculate new values for quadratic penalty on parameters
                p_prev = getattr(self, '{}_SI_prev_task'.format(n))
                p_current = p.detach().clone()
                p_change = p_current - p_prev
                
                omega_add = W[n]/(p_change**2 + self.epsilon)
                try:
                    omega = getattr(self, '{}_SI_omega'.format(n))
                except AttributeError:
                    omega = p.detach().clone().zero_()
                omega_new = omega + omega_add

                # Store these new values in the model
                self.register_buffer('{}_SI_prev_task'.format(n), p_current)
                self.register_buffer('{}_SI_omega'.format(n), omega_new)

    def surrogate_loss(self):
        '''Calculate SI's surrogate loss.'''
        try:
            losses = []
            for n, p in self.named_parameters():
                if p.requires_grad:
                    # Retrieve previous parameter values and their normalized path integral (i.e., omega)
                    n = n.replace('.', '__')
                    prev_values = getattr(self, '{}_SI_prev_task'.format(n))
                    omega = getattr(self, '{}_SI_omega'.format(n))
                    # Calculate SI's surrogate loss, sum over all parameters
                    losses.append((omega * (p-prev_values)**2).sum())
            return sum(losses)
        except AttributeError:
            # SI-loss is 0 if there is no stored omega yet
            return torch.tensor(0., device=self._device())

    def _device(self):
        return next(self.parameters()).device


# Running our experiment

In [0]:
def train_task(model, device, train_loader, optimizer, batch_log = 0):
    model.train()
    # Prepare <dicts> to store running importance estimates and param-values before update
    W = {}
    param_old = {}
    for name, param in model.named_parameters():
        if param.requires_grad:
            name = name.replace('.', '__')
            W[name] = param.data.clone().zero_()
            param_old[name] = param.data.clone()

    losses = []
    total_losses = []
    for k in range(n_epochs):
        if print_messages:
            print("----> Epoch {}:".format(k))
        for batch_idx, (x, y) in enumerate(train_loader):
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()

            # Get the prediction
            y_hat = model(x)

            # Calculate training-precision
            precision = (y == y_hat.max(1)[1]).sum().item() / x.size(0)

            # Calculate the loss using cross entropy
            # and the surrogate loss
            loss = F.cross_entropy(input=y_hat, target=y, reduction='mean')
            surrogate_loss = model.surrogate_loss()
            total_loss = loss + model.si_c * surrogate_loss

            # Backpropagate errors
            total_loss.backward()

            # Take optimization-step
            optimizer.step()

            # Update running parameter importance estimates in W
            for name, param in model.named_parameters():
                if param.requires_grad:
                    name = name.replace('.', '__')
                    if param.grad is not None:
                        W[name].add_(-param.grad*(param.detach()-param_old[name]))
                    param_old[name] = param.detach().clone()

            # Print out a log
            if batch_idx % batch_log == 0:
                losses.append(loss.item())
                total_losses.append(total_loss.item())
                if print_messages:
                    print('---->[{}/{} ({:.0f}%)]\tPrecision: {:.6f}\tLoss: {:.6f}\tSurrogate Loss: {:.6f}\tTotal Loss: {:.6f}'.format(
                        batch_idx * len(x), len(train_loader.dataset),
                        100. * batch_idx / len(train_loader), 
                        precision, loss.item(), surrogate_loss.item(), total_loss.item()))
            
    # After finishing training on a task, update the omega value in the model
    model.update_omega(W)

    return losses, total_losses
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            test_loss += F.cross_entropy(input=y_hat, target=y, reduction='mean')
            pred = y_hat.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(y.view_as(pred)).sum().item()
    
    test_loss /= len(test_loader.dataset)
    return correct, test_loss

In [0]:
config={
    "lr": 0.02, 
    "si_c": 0.9, 
    "si_epsilon": 0.01,
    "optimizer": "adam",
    "batch_size": 64,
    "n_neurons": 100,
    "momentum": 0.4,
    "sample_size": 20000
    }

In [0]:
def main(config): 
    # Use cuda?
    cuda = torch.cuda.is_available()
    device = torch.device("cuda" if cuda else "cpu")

    # Create Model
    model = SIModel()
    model.init(config["n_neurons"])
    model.to(device)
    optim_list = [{'params': filter(lambda p: p.requires_grad, model.parameters()), 'lr': config['lr'], 'momentum': config['momentum']}]
    if config['optimizer'] == "adam":
        optimizer = optim.Adam(optim_list, betas=(0.9, 0.999))
    else:
        optimizer = optim.SGD(optim_list)

    # SI Parameters
    model.si_c = config["si_c"]
    model.epsilon = config["si_epsilon"]

    for name, param in model.named_parameters():
        if param.requires_grad:
            name = name.replace('.', '__')
            model.register_buffer('{}_SI_prev_task'.format(name), param.data.clone())

    # Load our test data
    test_loaders = []
    for i in range(n_tasks):
        test_loaders.append(data_utils.DataLoader(data_utils.TensorDataset(tasks_te[i][1], tasks_te[i][2]), batch_size=1000, shuffle = False))

    # Training
    if print_messages:
        print("--> Training:")

    total_acc = []
    total_test_losses = []

    for i in range(n_tasks):
        if print_messages:
            print("--> Training Task {}:".format(i))

        perm = np.random.permutation(tasks_tr[i][1].size(0))
        perm = perm[:config['sample_size']]
        train_data = data_utils.TensorDataset(tasks_tr[i][1], tasks_tr[i][2])
        train_loader = data_utils.DataLoader(train_data, batch_size=config["batch_size"], 
                                      sampler = data_utils.SubsetRandomSampler(perm), drop_last = True)
        # train_loader = data_utils.DataLoader(train_data, batch_size=config["batch_size"], 
        #                                      shuffle = True, drop_last = True)
        
        train_losses, total_train_losses = train_task(model, device, train_loader, optimizer, 1000)
        
        # Reset the optimizer (if using adam)
        if config['optimizer'] == "adam":
            model.optimizer = optim.Adam(optim_list, betas=(0.9, 0.999))

        if print_messages:
            print(train_losses)
            print(total_train_losses)
            print("--> Finished Training Task {}. Starting Test phase:".format(i))

        # Get our accuracy metrics on all test sets
        acc = []
        test_losses = []
        for j in range(n_tasks):
            correct, test_loss = test(model, device, test_loaders[j])
            acc.append(100. * correct / len(test_loaders[j].dataset))
            test_losses.append(test_loss)
            if print_messages:
                print('---->Test set {}: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
                    j, test_loss, correct, len(test_loaders[j].dataset),
                    100. * correct / len(test_loaders[j].dataset)))
        total_acc.append(acc)
        total_test_losses.append(test_losses)
    
    # Get the accuracy metric as defined by Facebook paper: sum(R_Ti) 
    # where T is the test set of the last Task and i is the current trained task
    return total_acc, total_test_losses

In [73]:
total_acc, total_test_losses = main(config)
average_acc = np.mean(total_acc[n_tasks-1])
print("Accuracy:", average_acc)
print("Confusion matrix:")
print('\n'.join([','.join([str(item) for item in row]) for row in total_acc]))

Accuracy: 37.027
Confusion matrix:
93.5,89.8,72.96,64.45,51.03,28.24,24.0,13.58,11.06,9.54,9.65,10.16,11.21,14.64,16.55,19.04,19.43,19.46,21.14,20.1
94.21,92.2,78.85,70.62,56.53,32.0,27.1,14.53,11.55,9.47,9.54,9.64,10.6,14.14,16.79,21.08,22.23,22.57,24.28,23.08
94.12,92.05,79.3,71.56,57.71,33.36,28.03,14.98,11.77,9.64,9.96,9.87,10.78,14.45,16.96,21.05,22.53,22.72,24.82,23.37
94.21,92.18,79.52,71.57,58.07,33.57,28.16,15.02,11.95,9.88,9.95,9.9,10.87,14.75,17.11,21.05,22.5,22.83,24.66,23.18
94.19,92.08,79.55,71.72,58.31,33.81,28.47,14.99,11.86,9.8,9.91,10.0,10.93,14.76,17.08,20.97,22.33,22.66,24.52,23.06
87.89,85.22,75.41,72.73,66.02,56.09,52.0,36.03,29.57,22.64,20.56,17.03,15.77,16.27,18.37,20.82,20.92,21.25,22.57,21.15
86.54,83.93,74.86,72.41,66.48,57.89,54.43,38.7,32.56,24.8,22.54,17.9,16.37,16.26,18.23,20.38,20.26,20.97,22.42,20.92
78.96,73.26,61.48,59.14,56.88,60.6,61.8,59.43,53.63,41.36,38.64,26.82,22.9,18.74,18.75,20.26,18.98,20.3,22.28,20.46
78.36,72.74,60.57,58.4,56.4,60.14,61.55

# Tune Hyperparamters using Ax

In [69]:
def tune(config, objective):
    total_acc, total_loss = main(config)
    if objective == "accuracy":
        return np.mean(total_acc[n_tasks-1])
    elif objective == "loss":
        return np.mean(total_loss[n_tasks-1])
    else:
        return

from ax import optimize
best_parameters, values, experiment, model = optimize(
    parameters=[
        {
            "name": "lr",
            "type": "range",
            "bounds": [1e-6, 0.4], 
            "log_scale": True,
            "value_type": "float",
        },
        {  
            "name": "si_c",
            "type": "range",
            "bounds": [0.01, 1.0],
            "value_type": "float",
        },
        {  
            "name": "si_epsilon",
            "type": "fixed",
            "value": 0.01,
            "value_type": "float",
        },
        {  
            "name": "batch_size",
            "type": "choice",
            "values": [10, 64],
            "value_type": "int",
        },
        {  
            "name": "sample_size",
            "type": "fixed",
            "value": 10000,
            "value_type": "int",
        },
        {  
            "name": "n_neurons",
            "type": "fixed",
            "value": 100,
            "value_type": "int",
        },
        {  
            "name": "momentum",
            "type": "range",
            "bounds": [0., 1.],
            "value_type": "float",
        },
        {  
            "name": "optimizer",
            "type": "fixed",
            "value": "adam",
            "value_type": "str",
        },
    ],
    evaluation_function=lambda p: tune(p, "accuracy"),
    objective_name='accuracy',
)
print(best_parameters)
print(values)
    #evaluation_function=lambda p: np.mean(main(p)[n_tasks-1]),
    #minimize=True,)

[INFO 04-04 22:38:33] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+GPEI', steps=[Sobol for 8 arms, GPEI for subsequent arms], generated 0 arm(s) so far). Iterations after 8 will take longer to generate due to model-fitting.
[INFO 04-04 22:38:33] ax.service.managed_loop: Started full optimization with 20 steps.
[INFO 04-04 22:38:33] ax.service.managed_loop: Running optimization trial 1...


ERROR! Session/line number was not unique in database. History logging moved to new session 60


[INFO 04-04 22:40:09] ax.service.managed_loop: Running optimization trial 2...
[INFO 04-04 22:45:52] ax.service.managed_loop: Running optimization trial 3...
[INFO 04-04 22:47:27] ax.service.managed_loop: Running optimization trial 4...
[INFO 04-04 22:49:02] ax.service.managed_loop: Running optimization trial 5...
[INFO 04-04 22:54:45] ax.service.managed_loop: Running optimization trial 6...
[INFO 04-04 22:56:21] ax.service.managed_loop: Running optimization trial 7...
[INFO 04-04 23:02:05] ax.service.managed_loop: Running optimization trial 8...
[INFO 04-04 23:07:45] ax.service.managed_loop: Running optimization trial 9...
[INFO 04-04 23:09:20] ax.service.managed_loop: Running optimization trial 10...
[INFO 04-04 23:10:55] ax.service.managed_loop: Running optimization trial 11...
[INFO 04-04 23:12:31] ax.service.managed_loop: Running optimization trial 12...
[INFO 04-04 23:14:06] ax.service.managed_loop: Running optimization trial 13...
[INFO 04-04 23:15:42] ax.service.managed_loop: R

{'lr': 0.002161103326179421, 'si_c': 0.9999999999999994, 'momentum': 0.3974544156588803, 'batch_size': 64, 'si_epsilon': 0.01, 'sample_size': 10000, 'n_neurons': 100, 'optimizer': 'adam'}
({'accuracy': 67.21499463717565}, {'accuracy': {'accuracy': 2.4184398641919375e-05}})


In [62]:
values

({'accuracy': 54.37497428669556},
 {'accuracy': {'accuracy': 2.8317326613905426e-05}})