In [None]:
import load_split_mnist
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import random
from torchsummary import summary


In [None]:
BATCHSIZE = 256
LR = 0.01
MOMENTUM = 0.9
EPOCH = 10
DEVICE = 'mps'
GRAD_EST_BATCHSIZE = 32
GRAD_EST_EPOCHS = 1
EWC_LAMBDA = 1000000.
ROTATED = True 

SEED = 42

# seed all the things
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
random.seed(SEED)

In [None]:
# load datasets
train, test = load_split_mnist.load()

train_t1, train_t2 = train
test_t1, test_t2 = test

# dataloaders
train_loader_t1 = torch.utils.data.DataLoader(
    train_t1, batch_size=BATCHSIZE, shuffle=True
)
train_loader_subspace_est_t1 = torch.utils.data.DataLoader(
    train_t1, batch_size=GRAD_EST_BATCHSIZE, shuffle=False
)

train_loader_t2 = torch.utils.data.DataLoader(
    train_t2, batch_size=BATCHSIZE, shuffle=True
)
test_loader_t1 = torch.utils.data.DataLoader(
    test_t1, batch_size=BATCHSIZE, shuffle=False
)  # test data must not be shuffled!!! Otherwise embeddings are not paired
test_loader_t2 = torch.utils.data.DataLoader(
    test_t2, batch_size=BATCHSIZE, shuffle=False
)

# %%

# get dataset shapes for train sets
n_train_t1 = len(train_t1)
n_train_t2 = len(train_t2)

print(f"n_train_t1: {n_train_t1}")
print(f"n_train_t2: {n_train_t2}")



In [None]:

def label_smoothing(y, alpha):
    # convert y to one hot
    y = torch.eye(5).to(y.device)[y]
    return y * (1 - alpha) + alpha / y.size(1)


In [None]:
# define a two layer mlp with relu

class CNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, 1, bias=False)
        self.mpool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 16, 3, 1, bias=False)
        self.mpool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(400, hidden_dim, bias=False)

        self.readout1 = nn.Linear(hidden_dim, output_dim, bias=False)
        self.readout2 = nn.Linear(hidden_dim, output_dim, bias=False)
        self.readouts = nn.ModuleList([self.readout1, self.readout2])

    def forward(self, x, t):
        x = F.relu(self.conv1(x))
        x = self.mpool1(x)
        x = F.relu(self.conv2(x))
        x = self.mpool2(x)
        x = x.view(-1, 400)
        x = F.relu(self.fc1(x))
        x = self.readouts[t](x)
        return x
    
cnn = CNN(28*28, 128, 5)
summary(cnn)
cnn.to(DEVICE)

# optimizer and loss function
optimizer = torch.optim.SGD(cnn.parameters(), lr=LR, momentum=MOMENTUM)
loss_func = nn.CrossEntropyLoss()

In [None]:
# train step
def train_step(model, optimizer, loss_func, x, y, t):
    # x,y,t to model device
    x,y,t = x.to(DEVICE), y.to(DEVICE), t.to(DEVICE)
    model.train()
    optimizer.zero_grad()
    output = model(x, t)
    loss = loss_func(output, y)
    loss.backward()
    optimizer.step()
    return loss.item()

# test step
@torch.no_grad()
def test_step(model, loss_func, x, y, t):
    # x,y,t to model device
    x,y,t = x.to(DEVICE), y.to(DEVICE), t.to(DEVICE)
    model.eval()
    output = model(x, t)
    loss = loss_func(output, y)
    return loss.item()

def train_epoch(model, optimizer, loss_func, train_loader, t):
    loss = 0.0
    for x, y in tqdm(train_loader):
        loss += train_step(model, optimizer, loss_func, x, y, t)
    return loss / len(train_loader)

@torch.no_grad()
def test_epoch(model, loss_func, test_loader, t):
    loss = 0.0
    for x, y in test_loader:
        loss += test_step(model, loss_func, x, y, t)
    return loss / len(test_loader)

@torch.no_grad()
def compute_accuracy_for_dataset(model, test_loader, t):
    t = torch.tensor(t).to(DEVICE)
    correct = 0
    total = 0
    for x, y in test_loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        output = model(x, t)
        _, predicted = torch.max(output.data, 1)
        total += y.size(0)
        correct += (predicted == y).sum().item()
    return correct / total

In [None]:
# train for 10 epochs on first task
accs_t1 = []
accs_t2 = []

for epoch in (range(EPOCH)):
    train_loss = train_epoch(cnn, optimizer, loss_func, train_loader_t1, torch.tensor(0))
    test_loss = test_epoch(cnn, loss_func, test_loader_t1, torch.tensor(0))
    test_acc = compute_accuracy_for_dataset(cnn, test_loader_t1, torch.tensor(0))
    accs_t1.append(test_acc)
    accs_t2.append(np.nan)
    print(f"Epoch: {epoch}, train_loss: {train_loss}, test_loss: {test_loss}, test_acc: {test_acc}")

In [None]:
'''
estimate the subspace spanned by gradients for the first task at each layer
'''

layer_1_grads = []
layer_2_grads = []
layer_3_grads = []

loss_func_grad_est = nn.CrossEntropyLoss()

t = torch.tensor(0).to(DEVICE)

for _ in range(GRAD_EST_EPOCHS):
    for x, y in tqdm(train_loader_subspace_est_t1):
        x, y = x.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()
        model_output = cnn(x, t)
        loss = loss_func_grad_est(model_output, y)
        loss.backward()
        layer_1_grads.append(cnn.conv1.weight.grad.detach().cpu().numpy().reshape(-1))
        layer_2_grads.append(cnn.conv2.weight.grad.detach().cpu().numpy().reshape(-1))
        layer_3_grads.append(cnn.fc1.weight.grad.detach().cpu().numpy().reshape(-1))

# numpy array
layer_1_grads = np.array(layer_1_grads)
layer_2_grads = np.array(layer_2_grads)
layer_3_grads = np.array(layer_3_grads)


In [None]:
print(layer_1_grads.shape)
print(layer_2_grads.shape)
print(layer_3_grads.shape)

In [None]:
from copy import deepcopy
from torch.linalg import svd, svdvals
from scipy.linalg import svd as scipy_svd
from sklearn.utils.extmath import svd_flip
# import utilities for timing
import time

# pca as done by sklearn
from sklearn.decomposition import PCA
A = layer_2_grads
start = time.time()
pca = PCA()
pca.fit(A)
end = time.time()

print(f"sklearn PCA took {end-start} seconds")


# manual pca
M = deepcopy(A)

tick = time.time()
mean_ = np.mean(M, axis=0)
M -= mean_

U, S, Vt = scipy_svd(M, full_matrices=False)
# flip eigenvectors' sign to enforce deterministic output
U, Vt = svd_flip(U, Vt)
explained_variance_ = (S**2) / (len(M) - 1)
tock = time.time()
print(f"manual PCA took {tock-tick} seconds")


def svd_trick(M):
    sig = np.dot(M, M.T) / M.shape[0]
    u2, s, _ = scipy_svd(sig)
    u = np.dot(M.T, u2) / np.sqrt(s*M.shape[0])

# compute again with svd trick because number of samples is much smaller than number of features
M2 = deepcopy(A)
tick = time.time()
mean_ = np.mean(M2, axis=0)
M2 -= mean_



# plot eigenvalues for both methods
plt.figure()
plt.plot(pca.explained_variance_, label="sklearn")
plt.plot(explained_variance_, label="manual")
plt.legend()
plt.show()

# are the eigenvectors the same?
np.allclose(pca.components_, Vt)



In [None]:




A = torch.tensor(layer_1_grads)
start = time.time()
U, S, V = svd(A)
end = time.time()
print(f"torch.linalg.svd took {end-start} seconds")
print(U.shape)
print(V.shape)


In [None]:
# get the mean gradient at each layer for task 1
layer_1_grads_mean = np.mean(layer_1_grads, axis=0)
layer_2_grads_mean = np.mean(layer_2_grads, axis=0)
layer_3_grads_mean = np.mean(layer_3_grads, axis=0)

In [None]:
from sklearn.decomposition import PCA
pca_layer1 = PCA()
pca_layer1.fit(layer_1_grads)
pca_layer2 = PCA()
pca_layer2.fit(layer_2_grads)
pca_layer3 = PCA()
pca_layer3.fit(layer_3_grads)


In [None]:

# plot cdf of explained variance
plt.figure()
plt.plot(np.cumsum(pca_layer1.explained_variance_ratio_), label='layer 1')
plt.plot(np.cumsum(pca_layer2.explained_variance_ratio_), label='layer 2')
plt.plot(np.cumsum(pca_layer3.explained_variance_ratio_), label='layer 3')
plt.xlabel('number of components')
plt.ylabel('cumulative explained variance')
plt.legend()
plt.show()


In [None]:
def project_gradient(subspace, grad):
    '''
    project gradient onto subspace

    subspace        - numpy array of shape (subspace_dim, layer_dim)
    grad            - numpy array of shape (layer_dim,)
    '''
    subspace = subspace.to(grad.device)
    projection = ((subspace @ grad).T @ subspace).T
    return grad - projection

# for each layer, get the components
layer_1_components = pca_layer1.components_
layer_2_components = pca_layer2.components_
layer_3_components = pca_layer3.components_

# normalize
layer_1_components = layer_1_components / np.linalg.norm(layer_1_components, axis=1, keepdims=True)
layer_2_components = layer_2_components / np.linalg.norm(layer_2_components, axis=1, keepdims=True)
layer_3_components = layer_3_components / np.linalg.norm(layer_3_components, axis=1, keepdims=True)
layer_1_components = layer_1_components.T
layer_2_components = layer_2_components.T
layer_3_components = layer_3_components.T

print(layer_1_components.shape)
print(layer_2_components.shape)
print(layer_3_components.shape)


proj_l1 = torch.tensor(layer_1_components).to(DEVICE)
proj_l2 = torch.tensor(layer_2_components).to(DEVICE)
proj_l3 = torch.tensor(layer_3_components).to(DEVICE)

In [None]:
# print projection shapes
print(proj_l1.shape)
print(proj_l2.shape)
print(proj_l3.shape)

In [None]:

def compute_rotated_diagonal_fisher(model, layers, data, criterion, task, projections):
    infos = []
    for layer, projection in zip(layers, projections):
        fish_info = torch.zeros(projection.shape[1]).to(layer.weight.device)
        fish_info = fish_info.reshape(-1)
        N = 0
        for x, y in data:
            # make sure x and y live on the same device as the layer parameters
            x = x.to(layer.weight.device)
            y = y.to(layer.weight.device)
            task = torch.tensor(task).to(layer.weight.device)
            model.zero_grad()
            yhat = model(x, task)
            loss = criterion(yhat, y)
            loss.backward()
            grads = layer.weight.grad
            grads = grads.reshape(-1)
            grads = grads @ projection
            grads = grads**2
            fish_info += grads
            N += len(x)
        fish_info /= N
        infos.append(fish_info)
    return infos 

fisher_infos = compute_rotated_diagonal_fisher(cnn, [cnn.conv1, cnn.conv2, cnn.fc1], train_loader_t1, loss_func, 0, [proj_l1, proj_l2, proj_l3])



In [None]:
# plot distribution of fisher information
plt.figure()
plt.hist(fisher_infos[0].cpu().numpy(), label='layer 1')
plt.hist(fisher_infos[1].cpu().numpy(), label='layer 2', alpha=0.5)
plt.hist(fisher_infos[2].cpu().numpy(), label='layer 3', alpha=.5)
plt.xlabel('fisher information')
plt.ylabel('count')
plt.legend()
plt.show()



In [None]:
# recompute fisher info without projection
dim_l1 = np.prod(cnn.conv1.weight.shape)
dim_l2 = np.prod(cnn.conv2.weight.shape)
dim_l3 = np.prod(cnn.fc1.weight.shape)

fisher_infos_unrotated = compute_rotated_diagonal_fisher(cnn, [cnn.conv1, cnn.conv2, cnn.fc1], train_loader_t1, loss_func, 0, [torch.eye(dim_l1).to(DEVICE), torch.eye(dim_l2).to(DEVICE), torch.eye(dim_l3).to(DEVICE)])

In [None]:
plt.figure()
plt.hist(fisher_infos_unrotated[0].cpu().numpy(), bins=100, label='layer 1')
plt.hist(fisher_infos_unrotated[1].cpu().numpy(), bins=100, label='layer 2', alpha=0.5)
plt.hist(fisher_infos_unrotated[2].cpu().numpy(), bins=100, label='layer 3', alpha=.5)
plt.xlabel('fisher information')

In [None]:
# compare max fisher info for each layer between methods
print(f"layer 1: aligned: {torch.max(fisher_infos[0])}, unaligned: {torch.max(fisher_infos_unrotated[0])}")
print(f"layer 2: aligned: {torch.max(fisher_infos[1])}, unaligned: {torch.max(fisher_infos_unrotated[1])}")
print(f"layer 3: aligned: {torch.max(fisher_infos[2])}, unaligned: {torch.max(fisher_infos_unrotated[2])}")


In [None]:
def compute_regularisation_term(w_new, w_old, fisher_info, lambda_ewc, projection):
    w_old = w_old.detach().reshape(-1)
    w_new = w_new.reshape(-1)
    w_old_proj = w_old @ projection
    w_new_proj = w_new @ projection
    square_diff = (w_old_proj - w_new_proj)**2
    weighted_square_diff = fisher_info * square_diff
    return (lambda_ewc / 2) * weighted_square_diff.sum()


def train_step_ewc(model, old_model, optimizer, loss_func, x, y, t, fisher_infos, lambda_ewc, projections):
    '''
    train the network on a batch of data, projecting the gradient to be orthogonal to v
    '''
    x,y,t = x.to(DEVICE), y.to(DEVICE), t.to(DEVICE)
    
    model.train()
    optimizer.zero_grad()
    output = model(x, t)
    loss = loss_func(output, y)

    # ewc loss
    ewc_loss = 0.0
    ewc_loss += compute_regularisation_term(model.conv1.weight, old_model.conv1.weight, fisher_infos[0], lambda_ewc, projections[0])
    ewc_loss += compute_regularisation_term(model.conv2.weight, old_model.conv2.weight, fisher_infos[1], lambda_ewc, projections[1])
    ewc_loss += compute_regularisation_term(model.fc1.weight, old_model.fc1.weight, fisher_infos[2], lambda_ewc, projections[2])

    loss += ewc_loss
    loss.backward()
    optimizer.step()
    return loss.item()





In [None]:

def train_epoch_ewc(model, old_model, optimizer, loss_func, train_loader, t, fisher_infos, lambda_ewc, projections):
    loss = 0.0
    for x, y in tqdm(train_loader):
        loss += train_step_ewc(model, old_model, optimizer, loss_func, x, y, t, fisher_infos, lambda_ewc, projections)
    return loss / len(train_loader)



In [None]:
# evaluate task 1
t0 = torch.tensor(0).to(DEVICE)
test_acc_1 = compute_accuracy_for_dataset(cnn, test_loader_t1, t0)

# evaluate task 2
t1 = torch.tensor(1).to(DEVICE)
test_acc_2 = compute_accuracy_for_dataset(cnn, test_loader_t2, t1)

print(f"Task 1 accuracy: {test_acc_1}, Task 2 accuracy: {test_acc_2}")

In [None]:
from copy import deepcopy

# reset optimizer so that we do not carry over momentum from previous training
optimizer = torch.optim.SGD(cnn.parameters(), lr=LR, momentum=MOMENTUM)
old_model = deepcopy(cnn)

# train on task 2
for i in range(EPOCH):
    if ROTATED:
        train_loss = train_epoch_ewc(cnn, old_model, optimizer, loss_func, train_loader_t2, torch.tensor(1), fisher_infos, EWC_LAMBDA, [proj_l1, proj_l2, proj_l3])
    else:
        train_loss = train_epoch_ewc(cnn, old_model, optimizer, loss_func, train_loader_t2, torch.tensor(1), fisher_infos_unrotated, EWC_LAMBDA, [torch.eye(dim_l1).to(DEVICE), torch.eye(dim_l2).to(DEVICE), torch.eye(dim_l3).to(DEVICE)])
    test_acc_t1 = compute_accuracy_for_dataset(cnn, test_loader_t1, torch.tensor(0))
    test_acc_t2 = compute_accuracy_for_dataset(cnn, test_loader_t2, torch.tensor(1))
    accs_t1.append(test_acc_t1)
    accs_t2.append(test_acc_t2)
    print(f"Epoch: {i}, train_loss: {train_loss}, test_acc_t1: {test_acc_t1}, test_acc_t2: {test_acc_t2}")


In [None]:
# plot accuracies

# ticks style, talk context
sns.set_context('talk')
sns.set_style('ticks')

plt.figure()
plt.plot(accs_t1, label='task 1', marker='o')
plt.plot(accs_t2, label='task 2', marker='o')
# vertical dotted line at epochs + 0.5
plt.axvline(EPOCH - 0.5, linestyle='--', color='k')
plt.xlabel('epoch')
plt.ylabel('accuracy')
plt.legend()
# plt.ylim(0.5, 1)
sns.despine(trim=False)
plt.tight_layout()

# save plot to desktop
plt.savefig(f"/Users/daniel/Desktop/ewc_rotated_{ROTATED}_seed_{SEED}_lambda_{EWC_LAMBDA}_grad_est_epochs_{GRAD_EST_EPOCHS}_grad_est_bs_{GRAD_EST_BATCHSIZE}.pdf", dpi=300, bbox_inches='tight')


## NOTE

It seems that we are making errors in the nullspace projection. As we approach the optimum the errors become large relative to the gradient and thus Adam messes things up I guess?

Alternative explanation: Adam uses different momentum terms PER PARAMETER. This means the actual gradient updates can fall in a space that is forbidden by the projection, as momentum is applied after the projection (check this). If we want to use adam we would have to insert the projection step into the momentum computation done by adam