<a href="https://colab.research.google.com/github/dchu1/AI_P2_cl/blob/master/SynapticIntelligence.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import Library

In [0]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torch.utils.data as data_utils
import numpy as np
import subprocess
import os
import random
import matplotlib.pyplot as plt

from torch.nn.parameter import Parameter
from torch.nn import init
from torch.nn import Module
from torch.nn import init
from torchvision import datasets, transforms
from PIL import Image
from IPython.core.debugger import set_trace

n_tasks = 20
n_epochs = 3
print_messages = False

# Download package

In [0]:
pip install ax-platform

Collecting ax-platform
[?25l  Downloading https://files.pythonhosted.org/packages/c3/e5/defa97540bf23447f15d142a644eed9a9d9fd1925cf1e3c4f47a49282ec0/ax_platform-0.1.9-py3-none-any.whl (499kB)
[K     |▋                               | 10kB 24.9MB/s eta 0:00:01[K     |█▎                              | 20kB 3.0MB/s eta 0:00:01[K     |██                              | 30kB 3.7MB/s eta 0:00:01[K     |██▋                             | 40kB 2.8MB/s eta 0:00:01[K     |███▎                            | 51kB 3.2MB/s eta 0:00:01[K     |████                            | 61kB 3.8MB/s eta 0:00:01[K     |████▋                           | 71kB 4.1MB/s eta 0:00:01[K     |█████▎                          | 81kB 3.8MB/s eta 0:00:01[K     |██████                          | 92kB 4.3MB/s eta 0:00:01[K     |██████▋                         | 102kB 4.4MB/s eta 0:00:01[K     |███████▏                        | 112kB 4.4MB/s eta 0:00:01[K     |███████▉                        | 122kB 4.4MB/

# Constructing Data Set

In [0]:
def rotate_dataset(d, rotation):
  result = torch.FloatTensor(d.size(0), 784)
  tensor = transforms.ToTensor()

  for i in range(d.size(0)):
    img = Image.fromarray(d[i].numpy(), mode="L")
    result[i] = tensor(img.rotate(rotation)).view(784)
  return result

mnist_path = "mnist.npz"
if not os.path.exists(os.path.join("/content", mnist_path)):
  subprocess.call("wget https://s3.amazonaws.com/img-datasets/mnist.npz", shell=True)

f = np.load(mnist_path)
x_tr = torch.from_numpy(f["x_train"])
y_tr = torch.from_numpy(f["y_train"]).long()
x_te = torch.from_numpy(f["x_test"])
y_te = torch.from_numpy(f["y_test"]).long()
f.close()

# Rotate Dataset
tasks_tr = []
tasks_te = []
mnist_rot_path = "mnist_rotations.pt"
if not os.path.exists(os.path.join("/content", mnist_rot_path)):
    torch.manual_seed(0)

    for t in range(n_tasks):
      min_rot = 1.0 * t / n_tasks * (180.0 - 0.0) + 0.0
      max_rot = 1.0 * (t + 1) / n_tasks * (180.0 - 0.0) + 0.0
      rot = random.random() * (max_rot - min_rot) + min_rot

      tasks_tr.append([rot, rotate_dataset(x_tr, rot), y_tr])
      tasks_te.append([rot, rotate_dataset(x_te, rot), y_te])

    torch.save([tasks_tr, tasks_te], 'mnist_rotations.pt')
else:
    tasks_tr, tasks_te = torch.load('/content/mnist_rotations.pt')

# Defining Synaptic Intelligence Model (Simple)

In [0]:
class SIModel(nn.Module):
    def __init__(self):
        super(SIModel, self).__init__()

        # SI Hyperparameters
        self.si_c = 0.          #-> hyperparam: how strong to weigh SI-loss ("regularisation strength")
        self.epsilon = 0.1      #-> dampening parameter: bounds 'omega' when squared parameter-change goes to 0
    
    def init_weights(self, m):
        if type(m) == nn.Linear:
            torch.nn.init.kaiming_uniform_(m.weight)
            m.bias.data.fill_(0.01)

    def init(self, n_neurons):
        # Our Network
        self.net = nn.Sequential(
            nn.Linear(28*28, n_neurons),
            nn.ReLU(),
            nn.Linear(n_neurons, n_neurons),
            nn.ReLU(),
            nn.Linear(n_neurons, 10)
        )
        self.net.apply(self.init_weights)

    def forward(self, x):
        return self.net(x)

    def update_omega(self, W):
        '''
        After completing training on a task, update the per-parameter regularization strength.
        [W] <dict> estimated parameter-specific contribution to changes in total loss of completed task
        '''

        # Loop over all parameters
        for n, p in self.named_parameters():
            if p.requires_grad:
                n = n.replace('.', '__')

                # Find/calculate new values for quadratic penalty on parameters
                p_prev = getattr(self, '{}_SI_prev_task'.format(n))
                p_current = p.detach().clone()
                p_change = p_current - p_prev
                
                omega_add = W[n]/(p_change**2 + self.epsilon)
                try:
                    omega = getattr(self, '{}_SI_omega'.format(n))
                except AttributeError:
                    omega = p.detach().clone().zero_()
                omega_new = omega + omega_add

                # Store these new values in the model
                self.register_buffer('{}_SI_prev_task'.format(n), p_current)
                self.register_buffer('{}_SI_omega'.format(n), omega_new)

    def surrogate_loss(self):
        '''
        Calculate SI's surrogate loss.
        '''
        try:
            losses = []
            for n, p in self.named_parameters():
                if p.requires_grad:
                    # Retrieve previous parameter values and their normalized path integral (i.e., omega)
                    n = n.replace('.', '__')
                    prev_values = getattr(self, '{}_SI_prev_task'.format(n))
                    omega = getattr(self, '{}_SI_omega'.format(n))
                    # Calculate SI's surrogate loss, sum over all parameters
                    losses.append((omega * (p-prev_values)**2).sum())
            return sum(losses)
        except AttributeError:
            # SI-loss is 0 if there is no stored omega yet
            return torch.tensor(0., device=self._device())

    def _device(self):
        return next(self.parameters()).device


# Running our experiment

In [0]:
def train_task(model, device, train_loader, optimizer, batch_log = 0):
    model.train()
    # Prepare <dicts> to store running importance estimates and param-values before update
    W = {}
    param_old = {}
    for name, param in model.named_parameters():
        if param.requires_grad:
            name = name.replace('.', '__')
            W[name] = param.data.clone().zero_()
            param_old[name] = param.data.clone()

    losses = []
    total_losses = []
    for k in range(n_epochs):
        if print_messages:
            print("----> Epoch {}:".format(k))
        for batch_idx, (x, y) in enumerate(train_loader):
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()

            # Get the prediction
            y_hat = model(x)

            # Calculate training-precision
            precision = (y == y_hat.max(1)[1]).sum().item() / x.size(0)

            # Calculate the loss using cross entropy
            # and the surrogate loss
            loss = F.cross_entropy(input=y_hat, target=y, reduction='mean')
            surrogate_loss = model.surrogate_loss()
            total_loss = loss + model.si_c * surrogate_loss

            # Backpropagate errors
            total_loss.backward()

            # Take optimization-step
            optimizer.step()

            # Update running parameter importance estimates in W
            # "In practice, we can approximate w as the running sum of the 
            # product of the gradient g(w) and the parameter update" 
            for name, param in model.named_parameters():
                if param.requires_grad:
                    name = name.replace('.', '__')
                    if param.grad is not None:
                        W[name].add_(-param.grad*(param.detach()-param_old[name]))
                    param_old[name] = param.detach().clone()

            # Print out a log
            if batch_idx % batch_log == 0:
                losses.append(loss.item())
                total_losses.append(total_loss.item())
                if print_messages:
                    print('---->[{}/{} ({:.0f}%)]\tPrecision: {:.6f}\tLoss: {:.6f}\tSurrogate Loss: {:.6f}\tTotal Loss: {:.6f}'.format(
                        batch_idx * len(x), len(train_loader.dataset),
                        100. * batch_idx / len(train_loader), 
                        precision, loss.item(), surrogate_loss.item(), total_loss.item()))
            
    # After finishing training on a task, update the omega value in the model
    model.update_omega(W)

    return losses, total_losses
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            test_loss += F.cross_entropy(input=y_hat, target=y, reduction='mean')
            pred = y_hat.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(y.view_as(pred)).sum().item()
    
    test_loss /= len(test_loader.dataset)
    return correct, test_loss

def eval_on_tasks(model, device, test_loaders):
    acc = []
    test_losses = []
    for j in range(n_tasks):
        correct, test_loss = test(model, device, test_loaders[j])
        acc.append(correct / len(test_loaders[j].dataset))
        test_losses.append(test_loss)
        if print_messages:
            print('---->Test set {}: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
                j, test_loss, correct, len(test_loaders[j].dataset),
                100. * correct / len(test_loaders[j].dataset)))
    return acc, test_losses

In [0]:
config={
    "lr": 0.003, 
    "si_c": 0.152, 
    "si_epsilon": 0.01,
    "optimizer": "adam",
    "batch_size": 64,
    "n_neurons": 100,
    "sample_size": 60000
    }

In [0]:
def main(config): 
    # Use cuda?
    cuda = torch.cuda.is_available()
    device = torch.device("cuda" if cuda else "cpu")

    # Create Model
    model = SIModel()
    model.init(config["n_neurons"])
    model.to(device)
    optim_list = [{'params': filter(lambda p: p.requires_grad, model.parameters()), 'lr': config['lr']}]
    if config['optimizer'] == "adam":
        optimizer = optim.Adam(optim_list, betas=(0.9, 0.999))
    else:
        optimizer = optim.SGD(optim_list)

    # SI Parameters
    model.si_c = config["si_c"]
    model.epsilon = config["si_epsilon"]

    for name, param in model.named_parameters():
        if param.requires_grad:
            name = name.replace('.', '__')
            model.register_buffer('{}_SI_prev_task'.format(name), param.data.clone())

    # Load our test data
    test_loaders = []
    for i in range(n_tasks):
        test_loaders.append(data_utils.DataLoader(data_utils.TensorDataset(tasks_te[i][1], tasks_te[i][2]), batch_size=1000, shuffle = False))

    # Training
    if print_messages:
        print("--> Training:")

    total_acc = []
    total_test_losses = []
    
    # Before we start training we will get a baseline by evaluating our tasks
    acc, test_losses = eval_on_tasks(model, device, test_loaders)
    total_acc.append(acc)
    total_test_losses.append(test_losses)

    for i in range(n_tasks):
        if print_messages:
            print("--> Training Task {}:".format(i))

        perm = np.random.permutation(tasks_tr[i][1].size(0))
        perm = perm[:config['sample_size']]
        train_data = data_utils.TensorDataset(tasks_tr[i][1], tasks_tr[i][2])
        train_loader = data_utils.DataLoader(train_data, batch_size=config["batch_size"], 
                                      sampler = data_utils.SubsetRandomSampler(perm), drop_last = True)
        
        train_losses, total_train_losses = train_task(model, device, train_loader, optimizer, 100)
        
        # Reset the optimizer (if using adam)
        if config['optimizer'] == "adam":
            model.optimizer = optim.Adam(optim_list, betas=(0.9, 0.999))

        if print_messages:
            print(train_losses)
            print(total_train_losses)
            print("--> Finished Training Task {}. Starting Test phase:".format(i))

        acc, test_losses = eval_on_tasks(model, device, test_loaders)
        total_acc.append(acc)
        total_test_losses.append(test_losses)
    
    return total_acc, total_test_losses

In [0]:
total_acc, total_test_losses = main(config)

# Get the accuracy metric as defined by Facebook paper: sum(R_Ti) 
# where T is the test set of the last Task and i is the current trained task
average_acc = np.mean(total_acc[n_tasks-1])
print("Accuracy:", average_acc)
print("Confusion matrix:")
print('\n'.join([','.join([str(item) for item in row]) for row in total_acc]))

Accuracy: 0.45984500000000006
Confusion matrix:
0.114,0.0904,0.0654,0.0872,0.101,0.1139,0.1074,0.1052,0.1008,0.0843,0.0725,0.0742,0.0696,0.0899,0.0954,0.1263,0.1184,0.1156,0.1087,0.0915
0.9702,0.9319,0.8998,0.7832,0.5837,0.4048,0.3202,0.1784,0.1414,0.1235,0.1045,0.1109,0.1192,0.1766,0.1999,0.2316,0.2382,0.2573,0.2818,0.3039
0.9184,0.9704,0.9641,0.9331,0.8081,0.6314,0.5259,0.2932,0.2136,0.1571,0.1234,0.1099,0.1001,0.1178,0.1322,0.1752,0.178,0.213,0.245,0.2782
0.8964,0.9735,0.9763,0.963,0.8918,0.747,0.6472,0.3599,0.2715,0.2095,0.1554,0.1361,0.1227,0.125,0.1364,0.164,0.1693,0.1966,0.2277,0.272
0.8207,0.9587,0.9708,0.9741,0.9529,0.874,0.7965,0.5234,0.3967,0.2768,0.1662,0.1392,0.1191,0.0956,0.1007,0.135,0.1435,0.1796,0.215,0.2823
0.6099,0.8901,0.9326,0.9693,0.976,0.9531,0.9255,0.6926,0.5328,0.3718,0.2206,0.1744,0.1388,0.1004,0.0969,0.1131,0.1164,0.1418,0.182,0.2375
0.4012,0.675,0.7693,0.8933,0.9637,0.9726,0.9682,0.8499,0.7326,0.5608,0.3774,0.3032,0.2562,0.1593,0.1401,0.125,0.1264,0.1225,0.1

# Tune Hyperparamters using Ax

In [0]:
def tune(config, objective):
    total_acc, total_loss = main(config)
    if objective == "accuracy":
        return np.mean(total_acc[n_tasks-1])
    elif objective == "loss":
        return np.mean(total_loss[n_tasks-1])
    else:
        return

from ax import optimize
best_parameters, values, experiment, model = optimize(
    parameters=[
        {
            "name": "lr",
            "type": "range",
            "bounds": [1e-4, 0.4], 
            "log_scale": True,
            "value_type": "float",
        },
        {  
            "name": "si_c",
            "type": "range",
            "bounds": [0.01, 0.5],
            "value_type": "float",
        },
        {  
            "name": "si_epsilon",
            "type": "fixed",
            "value": 0.01,
            "value_type": "float",
        },
        {  
            "name": "batch_size",
            "type": "fixed",
            "value": 64,
            "value_type": "int",
        },
        {  
            "name": "sample_size",
            "type": "fixed",
            "value": 10000,
            "value_type": "int",
        },
        {  
            "name": "n_neurons",
            "type": "fixed",
            "value": 100,
            "value_type": "int",
        },
        {  
            "name": "optimizer",
            "type": "fixed",
            "value": "adam",
            "value_type": "str",
        },
    ],
    evaluation_function=lambda p: tune(p, "accuracy"),
    objective_name='accuracy',
)
print(best_parameters)
print(values)
    #evaluation_function=lambda p: np.mean(main(p)[n_tasks-1]),
    #minimize=True,)

[INFO 04-05 20:41:10] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+GPEI', steps=[Sobol for 7 arms, GPEI for subsequent arms], generated 0 arm(s) so far). Iterations after 7 will take longer to generate due to model-fitting.
[INFO 04-05 20:41:10] ax.service.managed_loop: Started full optimization with 20 steps.
[INFO 04-05 20:41:10] ax.service.managed_loop: Running optimization trial 1...
[INFO 04-05 20:42:21] ax.service.managed_loop: Running optimization trial 2...
[INFO 04-05 20:43:32] ax.service.managed_loop: Running optimization trial 3...
[INFO 04-05 20:44:44] ax.service.managed_loop: Running optimization trial 4...
[INFO 04-05 20:45:55] ax.service.managed_loop: Running optimization trial 5...
[INFO 04-05 20:47:06] ax.service.managed_loop: Running optimization trial 6...
[INFO 04-05 20:48:17] ax.service.managed_loop: Running optimization trial 7...
[INFO 04-05 20:49:28] ax.service.managed_loop: Running optimization t

{'lr': 0.002969074196382516, 'si_c': 0.441531161522652, 'si_epsilon': 0.01, 'batch_size': 64, 'sample_size': 10000, 'n_neurons': 100, 'optimizer': 'adam'}
({'accuracy': 0.6400499142783531}, {'accuracy': {'accuracy': 1.1108064544778255e-09}})


In [0]:
best_parameters, values, experiment, model = optimize(
    parameters=[
        {
            "name": "lr",
            "type": "range",
            "bounds": [1e-4, 0.4], 
            "log_scale": True,
            "value_type": "float",
        },
        {  
            "name": "si_c",
            "type": "fixed",
            "value": 0.152,
            "value_type": "float",
        },
        {  
            "name": "si_epsilon",
            "type": "fixed",
            "value": 0.01,
            "value_type": "float",
        },
        {  
            "name": "batch_size",
            "type": "choice",
            "values": [64, 128, 256],
            "value_type": "int",
        },
        {  
            "name": "sample_size",
            "type": "choice",
            "values": [1000, 5000, 10000, 20000, 40000, 60000],
            "value_type": "int",
        },
        {  
            "name": "n_neurons",
            "type": "fixed",
            "value": 100,
            "value_type": "int",
        },
        {  
            "name": "optimizer",
            "type": "fixed",
            "value": "adam",
            "value_type": "str",
        },
    ],
    evaluation_function=lambda p: tune(p, "accuracy"),
    objective_name='accuracy',
)
print(best_parameters)
print(values)

[INFO 04-05 21:05:27] ax.modelbridge.dispatch_utils: Using Sobol generation strategy.
[INFO 04-05 21:05:27] ax.service.managed_loop: Started full optimization with 20 steps.
[INFO 04-05 21:05:27] ax.service.managed_loop: Running optimization trial 1...
[INFO 04-05 21:06:12] ax.service.managed_loop: Running optimization trial 2...
[INFO 04-05 21:06:51] ax.service.managed_loop: Running optimization trial 3...
[INFO 04-05 21:08:05] ax.service.managed_loop: Running optimization trial 4...
[INFO 04-05 21:08:57] ax.service.managed_loop: Running optimization trial 5...
[INFO 04-05 21:11:03] ax.service.managed_loop: Running optimization trial 6...
[INFO 04-05 21:13:53] ax.service.managed_loop: Running optimization trial 7...
[INFO 04-05 21:14:45] ax.service.managed_loop: Running optimization trial 8...
[INFO 04-05 21:15:29] ax.service.managed_loop: Running optimization trial 9...
[INFO 04-05 21:17:34] ax.service.managed_loop: Running optimization trial 10...
[INFO 04-05 21:20:38] ax.service.ma

{'lr': 0.004312738958492966, 'batch_size': 256, 'sample_size': 60000, 'si_c': 0.152, 'si_epsilon': 0.01, 'n_neurons': 100, 'optimizer': 'adam'}
({'accuracy': 0.551365}, {'accuracy': {'accuracy': 0.0}})
