# Experiment with CNN via Various Non-VCL Approaches (e.g. MAP, LP, EWC, SI) on Split-CIFAR100 Task

The models are configured in almost the same way (in terms of widths and depths) as in the VCL experiments, namely CNN-4 and ResNet-4.

## Model Definition and Data Preparation

In [1]:
import os
import json
from datetime import datetime
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torch.utils.tensorboard import SummaryWriter
import copy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
# Define non-vcl CNN & ResNet with task heads 
class Cifar10CNN(nn.Module):
    def __init__(self, in_channels, num_tasks=5, num_classes_per_task=2):
        super(Cifar10CNN, self).__init__()
        self.shared_conv_layers = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        
        self.fc_input_dim = 256 * 4 * 4  
        
        self.task_heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(self.fc_input_dim, 128),
                nn.ReLU(),
                nn.Linear(128, num_classes_per_task)
            ) for _ in range(num_tasks)
        ])
        
    def forward(self, x, task_idx):
        x = self.shared_conv_layers(x)
        x = x.view(-1, self.fc_input_dim)  # Flatten
        task_output = self.task_heads[task_idx](x)
        return F.log_softmax(task_output, dim=1)

class Cifar100CNN(nn.Module):
    def __init__(self, in_channels=3, num_tasks=5, num_classes_per_task=100):
        super(Cifar100CNN, self).__init__()
        self.shared_conv_layers = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(256, 512, kernel_size=3, padding=1),  # Added layer
            nn.ReLU(),
            nn.BatchNorm2d(512),
            nn.MaxPool2d(2, 2)
        )
        
        self.fc_input_dim = 512 * 2 * 2  
        
        self.task_heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(self.fc_input_dim, 256),  # Increased capacity
                nn.ReLU(),
                nn.Linear(256, num_classes_per_task) 
            ) for _ in range(num_tasks)
        ])
        
    def forward(self, x, task_idx):
        x = self.shared_conv_layers(x)
        x = x.view(-1, self.fc_input_dim)  # Flatten
        task_output = self.task_heads[task_idx](x)
        return F.log_softmax(task_output, dim=1)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNetCIFAR100(nn.Module):
    def __init__(self, block=BasicBlock, num_blocks=[2, 2, 2, 2], num_tasks=5, num_classes_per_task=100, in_channels=3):
        super(ResNetCIFAR100, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)

        self.task_heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(512 * block.expansion, 256),
                nn.ReLU(),
                nn.Linear(256, num_classes_per_task)
            ) for _ in range(num_tasks)
        ])

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, task_idx):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        task_output = self.task_heads[task_idx](out)
        return F.log_softmax(task_output, dim=1)
        
# model = ResNetCIFAR100(num_tasks=10, num_classes_per_task=10, in_channels=3)

In [3]:
# Get split CIFAR100 dataset
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import CIFAR100
from torchvision.transforms import Compose, ToTensor, Normalize

# permutation used for each task (add flatten for mlp)
transform = Compose([
    ToTensor(),
    Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

cifar_train = CIFAR100(root=f"/scratch-ssd/oatml/data", train=True, download=False, transform=transform)
cifar_test = CIFAR100(root=f"/scratch-ssd/oatml/data", train=False, download=False, transform=transform)

def split_cifar100_into_tasks(dataset, num_tasks=10):
    class_indices = [[] for _ in range(num_tasks)]
    for idx, (_, label) in enumerate(dataset):
        task = label // (100 // num_tasks)  # Determine the task based on the label
        class_indices[task].append(idx)
    return [Subset(dataset, indices) for indices in class_indices]

# Split training and testing datasets into 10 tasks (10-way classification each)
train_tasks = split_cifar100_into_tasks(cifar_train)
test_tasks = split_cifar100_into_tasks(cifar_test)

## Joint Accuracy

The baseline (upper-bound of performances) should be to create a separate model (one head each) and train on each task to the furthest stretch.

In [7]:
# Train on Split CIFAR10 without any CL strategies.
def remap_labels(labels, task_idx, num_classes_per_task=10):
    """
    Remaps labels to ensure they start from 0 for each task.
    
    Args:
    - labels: The original labels.
    - task_idx: The current task index.
    - num_classes_per_task: The number of classes per task.
    
    Returns:
    - Remapped labels starting from 0 for the current task.
    """
    return labels % num_classes_per_task

def test_model(model, dataloader, device, task_idx):  # it gets redefined in later experiments 
    model.eval()
    model = model.to(device)
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), remap_labels(labels, task_idx).to(device)
            outputs = model(images, 0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

def train_cifar_split(cls, log_name, train_tasks, test_tasks, device, epochs=100, batch_size=256):
    summary_logdir = os.path.join("logs", log_name, datetime.now().strftime('%b%d_%H-%M-%S'))
    writer = SummaryWriter(summary_logdir)
    # os.makedirs("out/models/", exist_ok=True)
    experiment_path = f"out/experiments/{log_name}_{datetime.now().strftime('%b%d_%H-%M-%S')}"
    os.makedirs(experiment_path, exist_ok=True)
    accuracies = {}

    # create a new model for each task and train from scratch
    for current_task in range(len(train_tasks)):
        print(f"Training on task {current_task}")
        model = cls(in_channels=3, num_tasks=1, num_classes_per_task=10).to(device)
        print("New model created")
        model.train()
        train_loader = DataLoader(train_tasks[current_task], batch_size=batch_size, shuffle=True)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

        for epoch in tqdm(range(epochs), desc=f"Task {current_task} Epoch"):
            for images, labels in train_loader:
                images, labels = images.to(device), labels.to(device)
                labels = remap_labels(labels, current_task)
                optimizer.zero_grad()
                outputs = model(images, 0)
                loss = F.nll_loss(outputs, labels)
                loss.backward()
                optimizer.step()

            writer.add_scalar(f'Task_{current_task}/Train_Loss', loss.item(), epoch)

        # Test on the current task
        current_task_accuracy = test_model(model, DataLoader(test_tasks[current_task], batch_size=batch_size), device, current_task)
        print(f"Test accuracy on current task {current_task}: {current_task_accuracy}%")
        writer.add_scalar(f"Accuracy/Current_Task_{current_task}", current_task_accuracy, epoch)

        accuracies[f"Task_{current_task}_current_max"] = current_task_accuracy

        # Save final model
        # model_save_path = os.path.join("out/models/", f"{log_name}_task{current_task}_model_final.pth")
        # torch.save(model.state_dict(), model_save_path)
        # print(f"Model saved to {model_save_path}")

    # Save accuracies to file
    accuracies_file = os.path.join(experiment_path, "final_accuracies.json")
    with open(accuracies_file, 'w') as f:
        json.dump(accuracies, f)
    print(f"Accuracies saved to {accuracies_file}")

    writer.close()

In [None]:
train_cifar_split(
    Cifar100CNN, 
    "joint_disc_conv_s_cifar100_upperbound", 
    train_tasks,
    test_tasks, 
    device,
    epochs=100
)

Training on task 0
New model created


Task 0 Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on current task 0: 78.9%
Training on task 1
New model created


Task 1 Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on current task 1: 76.7%
Training on task 2
New model created


Task 2 Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on current task 2: 80.1%
Training on task 3
New model created


Task 3 Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on current task 3: 76.9%
Training on task 4
New model created


Task 4 Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on current task 4: 78.6%
Training on task 5
New model created


Task 5 Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on current task 5: 80.5%
Training on task 6
New model created


Task 6 Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on current task 6: 78.9%
Training on task 7
New model created


Task 7 Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

In [9]:
train_cifar_split(
    ResNetCIFAR100, 
    "joint_disc_resnet_s_cifar100_upperbound", 
    train_tasks,
    test_tasks, 
    device,
    epochs=200
)

Training on task 0
New model created


Task 0 Epoch:   0%|          | 0/200 [00:00<?, ?it/s]

Test accuracy on current task 0: 76.1%
Model saved to out/models/joint_disc_resnet_s_cifar100_upperbound_task0_model_final.pth
Training on task 1
New model created


Task 1 Epoch:   0%|          | 0/200 [00:00<?, ?it/s]

Test accuracy on current task 1: 74.1%
Model saved to out/models/joint_disc_resnet_s_cifar100_upperbound_task1_model_final.pth
Training on task 2
New model created


Task 2 Epoch:   0%|          | 0/200 [00:00<?, ?it/s]

Test accuracy on current task 2: 79.4%
Model saved to out/models/joint_disc_resnet_s_cifar100_upperbound_task2_model_final.pth
Training on task 3
New model created


Task 3 Epoch:   0%|          | 0/200 [00:00<?, ?it/s]

Test accuracy on current task 3: 74.3%
Model saved to out/models/joint_disc_resnet_s_cifar100_upperbound_task3_model_final.pth
Training on task 4
New model created


Task 4 Epoch:   0%|          | 0/200 [00:00<?, ?it/s]

Test accuracy on current task 4: 78.3%
Model saved to out/models/joint_disc_resnet_s_cifar100_upperbound_task4_model_final.pth
Training on task 5
New model created


Task 5 Epoch:   0%|          | 0/200 [00:00<?, ?it/s]

Test accuracy on current task 5: 76.7%
Model saved to out/models/joint_disc_resnet_s_cifar100_upperbound_task5_model_final.pth
Training on task 6
New model created


Task 6 Epoch:   0%|          | 0/200 [00:00<?, ?it/s]

Test accuracy on current task 6: 77.5%
Model saved to out/models/joint_disc_resnet_s_cifar100_upperbound_task6_model_final.pth
Training on task 7
New model created


Task 7 Epoch:   0%|          | 0/200 [00:00<?, ?it/s]

Test accuracy on current task 7: 72.2%
Model saved to out/models/joint_disc_resnet_s_cifar100_upperbound_task7_model_final.pth
Training on task 8
New model created


Task 8 Epoch:   0%|          | 0/200 [00:00<?, ?it/s]

Test accuracy on current task 8: 81.8%
Model saved to out/models/joint_disc_resnet_s_cifar100_upperbound_task8_model_final.pth
Training on task 9
New model created


Task 9 Epoch:   0%|          | 0/200 [00:00<?, ?it/s]

Test accuracy on current task 9: 80.8%
Model saved to out/models/joint_disc_resnet_s_cifar100_upperbound_task9_model_final.pth
Accuracies saved to out/experiments/joint_disc_resnet_s_cifar100_upperbound_Mar26_17-53-46/final_accuracies.json


## MLE / MAP for Split CIFAR100 with CNN/ResNet

In [9]:
# Train on Split CIFAR10 without any CL strategies.
def test_model(model, dataloader, device, task_idx):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in dataloader:
            # Class-incremental Learning, as defined in https://github.com/GMvandeVen/continual-learning/
            images, labels = images.to(device), labels.to(device)
            labels = remap_labels(labels, task_idx)
            outputs = model(images, task_idx)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

def train_cifar_split(model, log_name, train_tasks, test_tasks, device, epochs=100, batch_size=256):
    summary_logdir = os.path.join("logs", log_name, datetime.now().strftime('%b%d_%H-%M-%S'))
    writer = SummaryWriter(summary_logdir)
    os.makedirs("out/models/", exist_ok=True)
    experiment_path = f"out/experiments/{log_name}_{datetime.now().strftime('%b%d_%H-%M-%S')}"
    os.makedirs(experiment_path, exist_ok=True)
    accuracies = {}

    for current_task in range(len(train_tasks)):
        print(f"Training on task {current_task}")
        model.train()
        train_loader = DataLoader(train_tasks[current_task], batch_size=batch_size, shuffle=True)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

        for epoch in tqdm(range(epochs), desc=f"Task {current_task} Epoch"):
            for images, labels in train_loader:
                images, labels = images.to(device), labels.to(device)
                labels = remap_labels(labels, current_task)
                optimizer.zero_grad()
                outputs = model(images, current_task)
                loss = F.nll_loss(outputs, labels)
                loss.backward()
                optimizer.step()

            writer.add_scalar(f'Task_{current_task}/Train_Loss', loss.item(), epoch)

        # Test on the current task
        current_task_accuracy = test_model(model, DataLoader(test_tasks[current_task], batch_size=batch_size), device, current_task)
        print(f"Test accuracy on current task {current_task}: {current_task_accuracy}%")
        writer.add_scalar(f"Accuracy/Current_Task_{current_task}", current_task_accuracy, epoch)

        accuracies[f"Task_{current_task}_current"] = current_task_accuracy

        # Test on all previous tasks
        for previous_task in range(current_task + 1):
            prev_task_accuracy = test_model(model, DataLoader(test_tasks[previous_task], batch_size=batch_size), device, current_task)
            print(f"Test accuracy on previous task {previous_task}: {prev_task_accuracy}%")
            writer.add_scalar(f"Accuracy/Previous_Task_{previous_task}_after_learning_{current_task}", prev_task_accuracy, epoch)
            accuracies[f"Task_{previous_task}_after_{current_task}"] = prev_task_accuracy

    # Save final model
    model_save_path = os.path.join("out/models/", f"{log_name}_model_final.pth")
    torch.save(model.state_dict(), model_save_path)
    print(f"Model saved to {model_save_path}")

    # Save accuracies to file
    accuracies_file = os.path.join(experiment_path, "final_accuracies.json")
    with open(accuracies_file, 'w') as f:
        json.dump(accuracies, f)
    print(f"Accuracies saved to {accuracies_file}")

    writer.close()

In [12]:
model = ResNetCIFAR100(in_channels=3, num_tasks=10, num_classes_per_task=10).to(device)
train_cifar_split(model, "mle_disc_resnet_s_cifar100", train_tasks, test_tasks, device, epochs=100)

Training on task 0


Task 0 Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on current task 0: 77.2%
Test accuracy on previous task 0: 77.2%
Training on task 1


Task 1 Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on current task 1: 74.8%
Test accuracy on previous task 0: 10.7%
Test accuracy on previous task 1: 74.8%
Training on task 2


Task 2 Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on current task 2: 80.7%
Test accuracy on previous task 0: 15.4%
Test accuracy on previous task 1: 9.8%
Test accuracy on previous task 2: 80.7%
Training on task 3


Task 3 Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on current task 3: 76.9%
Test accuracy on previous task 0: 6.6%
Test accuracy on previous task 1: 14.6%
Test accuracy on previous task 2: 11.3%
Test accuracy on previous task 3: 76.9%
Training on task 4


Task 4 Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on current task 4: 81.9%
Test accuracy on previous task 0: 16.7%
Test accuracy on previous task 1: 12.6%
Test accuracy on previous task 2: 11.3%
Test accuracy on previous task 3: 8.6%
Test accuracy on previous task 4: 81.9%
Training on task 5


Task 5 Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on current task 5: 79.0%
Test accuracy on previous task 0: 9.5%
Test accuracy on previous task 1: 9.8%
Test accuracy on previous task 2: 8.3%
Test accuracy on previous task 3: 4.2%
Test accuracy on previous task 4: 15.4%
Test accuracy on previous task 5: 79.0%
Training on task 6


Task 6 Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on current task 6: 79.2%
Test accuracy on previous task 0: 18.3%
Test accuracy on previous task 1: 6.5%
Test accuracy on previous task 2: 9.9%
Test accuracy on previous task 3: 8.3%
Test accuracy on previous task 4: 9.2%
Test accuracy on previous task 5: 8.7%
Test accuracy on previous task 6: 79.2%
Training on task 7


Task 7 Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on current task 7: 77.5%
Test accuracy on previous task 0: 20.6%
Test accuracy on previous task 1: 10.6%
Test accuracy on previous task 2: 10.2%
Test accuracy on previous task 3: 7.4%
Test accuracy on previous task 4: 11.1%
Test accuracy on previous task 5: 10.0%
Test accuracy on previous task 6: 5.1%
Test accuracy on previous task 7: 77.5%
Training on task 8


Task 8 Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on current task 8: 81.5%
Test accuracy on previous task 0: 4.0%
Test accuracy on previous task 1: 11.4%
Test accuracy on previous task 2: 5.6%
Test accuracy on previous task 3: 10.4%
Test accuracy on previous task 4: 6.6%
Test accuracy on previous task 5: 19.7%
Test accuracy on previous task 6: 3.8%
Test accuracy on previous task 7: 15.2%
Test accuracy on previous task 8: 81.5%
Training on task 9


Task 9 Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on current task 9: 86.2%
Test accuracy on previous task 0: 9.1%
Test accuracy on previous task 1: 2.4%
Test accuracy on previous task 2: 9.3%
Test accuracy on previous task 3: 5.7%
Test accuracy on previous task 4: 4.5%
Test accuracy on previous task 5: 8.3%
Test accuracy on previous task 6: 19.1%
Test accuracy on previous task 7: 7.9%
Test accuracy on previous task 8: 10.9%
Test accuracy on previous task 9: 86.2%
Model saved to out/models/mle_disc_resnet_s_cifar100_model_final.pth
Accuracies saved to out/experiments/mle_disc_resnet_s_cifar100_Mar26_18-50-12/final_accuracies.json


## EWC Method for Split CIFAR100 with CNN

In [13]:
# Test models with EWC method on split cifar100
def remap_labels(labels, task_idx, num_classes_per_task=10):
    return labels % num_classes_per_task

def compute_fisher_information(model, dataloader, task_idx, device):
    model.eval()
    fisher_information = {}
    with torch.no_grad():
        for name, param in model.named_parameters():
            fisher_information[name] = torch.zeros_like(param.data)

        for data, target in dataloader:
            data, target = data.to(device), remap_labels(target, task_idx).to(device)
            with torch.enable_grad():
                model.zero_grad()
                output = model(data, task_idx)
                loss = F.nll_loss(output, target)
                loss.backward()
                for name, param in model.named_parameters():
                    if param.grad is not None:
                        fisher_information[name] += param.grad.data ** 2 / len(dataloader.dataset)
    return fisher_information

def ewc_loss_function(model, lambda_ewc, fisher_matrices, optimal_params):
    ewc_loss = 0
    for name, param in model.named_parameters():
        if name in fisher_matrices:
            fisher_matrix = fisher_matrices[name]
            optimal_param = optimal_params[name].to(device)
            ewc_loss += (fisher_matrix * (param - optimal_param) ** 2).sum()
    return lambda_ewc / 2 * ewc_loss

def run_task_ewc(model, log_name, train_tasks, test_tasks, device, epochs, batch_size, lambda_ewc):
    summary_writer = SummaryWriter(log_dir=os.path.join("logs", log_name, datetime.now().strftime('%Y%m%d_%H%M%S')))
    experiment_path = f"out/experiments/{log_name}"
    os.makedirs(experiment_path, exist_ok=True)  # Ensure output directory exists
    accuracies = {}
    
    previous_fisher_matrices = {}
    previous_optimal_params = {}

    # Assume there are as many tasks as there are train_tasks/test_tasks
    for task_idx in range(len(train_tasks)):
        print(f"Training on task {task_idx}")
        train_loader = DataLoader(train_tasks[task_idx], batch_size=batch_size, shuffle=True)
        test_loader = DataLoader(test_tasks[task_idx], batch_size=batch_size, shuffle=False)

        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

        model.train()
        for epoch in tqdm(range(epochs), desc=f"Training Task {task_idx}"):
            for data, target in train_loader:
                data, target = data.to(device), remap_labels(target, task_idx).to(device)
                optimizer.zero_grad()
                output = model(data, task_idx)
                loss = F.nll_loss(output, target)
                if task_idx > 0:
                    ewc_loss = ewc_loss_function(model, lambda_ewc, previous_fisher_matrices, previous_optimal_params)
                    loss += ewc_loss
                loss.backward()
                optimizer.step()
        
            summary_writer.add_scalar(f'Task_{task_idx}/Train_Loss', loss.item(), epoch)

        test_accuracy = test_model(model, test_loader, device, task_idx)
        print(f"Test accuracy on task {task_idx}: {test_accuracy}%")
        summary_writer.add_scalar(f'Task_{task_idx}/Test_Accuracy', test_accuracy, epoch)
        accuracies[f"Task_{task_idx}"] = test_accuracy

        # Test on all previous tasks
        for previous_task in range(task_idx + 1):
            prev_task_accuracy = test_model(model, DataLoader(test_tasks[previous_task], batch_size=batch_size), device, task_idx)
            print(f"Test accuracy on previous task {previous_task}: {prev_task_accuracy}%")
            summary_writer.add_scalar(f"Accuracy/Previous_Task_{previous_task}_after_learning_{task_idx}", prev_task_accuracy, epoch)
            accuracies[f"Task_{previous_task}_after_{task_idx}"] = prev_task_accuracy

        # Update Fisher Information and Optimal Parameters
        model.eval()
        fisher_information = compute_fisher_information(model, train_loader, task_idx, device)
        optimal_params = copy.deepcopy(model.state_dict())

    # Save accuracies and close the summary writer
    import json
    accuracies_file = os.path.join(experiment_path, "final_accuracies.json")
    with open(accuracies_file, 'w') as f:
        json.dump(accuracies, f)
    print(f"Accuracies saved to {accuracies_file}")

    summary_writer.close()

In [None]:
lambda_ewc = 1

model = Cifar100CNN(in_channels=3, num_tasks=10, num_classes_per_task=10).to(device)
run_task_ewc(model, f"ewc_conv_CIFAR100_lambda_ewc_{lambda_ewc}", train_tasks, test_tasks, device, epochs=200, batch_size=256, lambda_ewc=lambda_ewc)

Training on task 0


Training Task 0:   0%|          | 0/200 [00:00<?, ?it/s]

Test accuracy on task 0: 78.7%
Test accuracy on previous task 0: 78.7%
Training on task 1


Training Task 1:   0%|          | 0/200 [00:00<?, ?it/s]

Test accuracy on task 1: 76.5%
Test accuracy on previous task 0: 11.1%
Test accuracy on previous task 1: 76.5%
Training on task 2


Training Task 2:   0%|          | 0/200 [00:00<?, ?it/s]

In [14]:
lambda_ewc = 1

model = ResNetCIFAR100(in_channels=3, num_tasks=10, num_classes_per_task=10).to(device)
run_task_ewc(
    model, 
    f"ewc_resnet_CIFAR100_lambda_ewc_{lambda_ewc}", 
    train_tasks, 
    test_tasks, 
    device, 
    epochs=100, 
    batch_size=256, 
    lambda_ewc=lambda_ewc
)

Training on task 0


Training Task 0:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 0: 78.3%
Test accuracy on previous task 0: 78.3%
Training on task 1


Training Task 1:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 1: 76.5%
Test accuracy on previous task 0: 11.4%
Test accuracy on previous task 1: 76.5%
Training on task 2


Training Task 2:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 2: 83.1%
Test accuracy on previous task 0: 16.5%
Test accuracy on previous task 1: 10.6%
Test accuracy on previous task 2: 83.1%
Training on task 3


Training Task 3:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 3: 78.6%
Test accuracy on previous task 0: 7.1%
Test accuracy on previous task 1: 16.7%
Test accuracy on previous task 2: 14.8%
Test accuracy on previous task 3: 78.6%
Training on task 4


Training Task 4:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 4: 81.2%
Test accuracy on previous task 0: 18.2%
Test accuracy on previous task 1: 10.8%
Test accuracy on previous task 2: 12.0%
Test accuracy on previous task 3: 11.1%
Test accuracy on previous task 4: 81.2%
Training on task 5


Training Task 5:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 5: 81.1%
Test accuracy on previous task 0: 6.1%
Test accuracy on previous task 1: 9.3%
Test accuracy on previous task 2: 8.3%
Test accuracy on previous task 3: 6.4%
Test accuracy on previous task 4: 16.5%
Test accuracy on previous task 5: 81.1%
Training on task 6


Training Task 6:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 6: 80.8%
Test accuracy on previous task 0: 19.1%
Test accuracy on previous task 1: 9.0%
Test accuracy on previous task 2: 10.6%
Test accuracy on previous task 3: 8.2%
Test accuracy on previous task 4: 8.9%
Test accuracy on previous task 5: 10.4%
Test accuracy on previous task 6: 80.8%
Training on task 7


Training Task 7:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 7: 78.6%
Test accuracy on previous task 0: 19.5%
Test accuracy on previous task 1: 7.7%
Test accuracy on previous task 2: 9.3%
Test accuracy on previous task 3: 8.2%
Test accuracy on previous task 4: 10.1%
Test accuracy on previous task 5: 14.7%
Test accuracy on previous task 6: 6.1%
Test accuracy on previous task 7: 78.6%
Training on task 8


Training Task 8:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 8: 82.4%
Test accuracy on previous task 0: 2.9%
Test accuracy on previous task 1: 11.8%
Test accuracy on previous task 2: 7.0%
Test accuracy on previous task 3: 8.5%
Test accuracy on previous task 4: 7.0%
Test accuracy on previous task 5: 19.0%
Test accuracy on previous task 6: 5.6%
Test accuracy on previous task 7: 10.0%
Test accuracy on previous task 8: 82.4%
Training on task 9


Training Task 9:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 9: 86.2%
Test accuracy on previous task 0: 10.2%
Test accuracy on previous task 1: 2.1%
Test accuracy on previous task 2: 8.1%
Test accuracy on previous task 3: 5.4%
Test accuracy on previous task 4: 5.4%
Test accuracy on previous task 5: 7.8%
Test accuracy on previous task 6: 20.3%
Test accuracy on previous task 7: 7.4%
Test accuracy on previous task 8: 11.0%
Test accuracy on previous task 9: 86.2%
Accuracies saved to out/experiments/ewc_resnet_CIFAR100_lambda_ewc_1/final_accuracies.json


## Synaptic Intelligence (SI) for Split CIFAR100 with CNN

In [11]:
def run_task_si(model, log_name, train_tasks, test_tasks, device, epochs, batch_size, c_si):
    """
    Trains a given model on CIFAR10 split tasks using the Synaptic Intelligence (SI) method.
    
    Args:
    - model: The model to train.
    - log_name: Name for the TensorBoard logs.
    - cifar_train: Training dataset.
    - cifar_test: Test dataset.
    - train_task_ids: Task IDs for the training data.
    - test_task_ids: Task IDs for the test data.
    - device: Torch device to use for training.
    - epochs: Number of epochs to train each task.
    - batch_size: Batch size for training and testing.
    - c_si: The SI regularization term coefficient.
    """
    summary_writer = SummaryWriter(log_dir=os.path.join("logs", log_name, datetime.now().strftime('%Y%m%d_%H%M%S')))
    experiment_path = f"out/experiments/{log_name}"
    os.makedirs(experiment_path, exist_ok=True)  # Ensure output directory exists
    accuracies = {}
    
    # Initialize importance and previous parameters dictionaries
    importance = {}
    prev_params = {}
    
    for name, param in model.named_parameters():
        if param.requires_grad:
            importance[name] = torch.zeros_like(param, device=device)
            prev_params[name] = param.clone().detach()

    for task_idx in range(5):
        print(f"Training on task {task_idx}")
        train_loader = DataLoader(train_tasks[task_idx], batch_size=batch_size, shuffle=True)

        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

        for epoch in tqdm(range(epochs), desc=f"Training Task {task_idx}"):
            model.train()
            for data, target in train_loader:
                data, target = data.to(device), target.to(device)
                target = remap_labels(target, task_idx)
                optimizer.zero_grad()
                output = model(data, task_idx)
                loss = F.nll_loss(output, target)
                
                # Backpropagate to compute gradients
                loss.backward()

                # Calculate SI regularization term and add it to the loss before optimizer step
                si_reg_loss = sum((importance[name] * (param - prev_params[name]) ** 2).sum() for name, param in model.named_parameters() if param.requires_grad)
                (c_si * si_reg_loss).backward() 
                
                optimizer.step()

                # Update the importance weights after optimizer step
                for name, param in model.named_parameters():
                    if param.requires_grad:
                        delta_param = param.detach() - prev_params[name]
                        # Assuming a zero-initialized importance, this needs accumulation of gradient information over tasks
                        importance[name] += (param.grad.detach() ** 2) * delta_param.abs()
            summary_writer.add_scalar(f'Task_{task_idx}/Train_Loss', loss.item(), epoch)
        
        # Evaluate the model on the current task's test set
        test_accuracy = test_model(model, DataLoader(test_tasks[task_idx], batch_size=batch_size), device, task_idx)
        print(f"Test accuracy on task {task_idx}: {test_accuracy}%")
        summary_writer.add_scalar(f'Task_{task_idx}/Test_Accuracy', test_accuracy, epoch)
        accuracies[f"TASK {task_idx}"] = test_accuracy
        
        # Evaluate the model on all previous tasks' test sets to measure forgetting
        for previous_task_idx in range(task_idx + 1):
            accuracy = test_model(model, DataLoader(test_tasks[task_idx], batch_size=batch_size), device, previous_task_idx)
            print(f"Test accuracy on previous task {previous_task_idx}: {accuracy}%")
            summary_writer.add_scalar(f"Cross_Task_Accuracy/task_{task_idx}_on_{previous_task_idx}", accuracy, global_step=task_idx)
            accuracies[f"TASK {previous_task_idx}"] = accuracy
        
        # After task training, update previous parameters for the next task
        prev_params = {name: param.clone().detach() for name, param in model.named_parameters() if param.requires_grad}

    # Save accuracies to file
    accuracies_file = os.path.join(experiment_path, "final_accuracies.json")
    with open(accuracies_file, 'w') as f:
        json.dump(accuracies, f)
    print(f"Accuracies saved to {accuracies_file}")

    summary_writer.close()

In [12]:
epochs = 100
batch_size = 256
c_si = 1
log_name = f"si_resnet_c_cifar100_c_si_{c_si}" 
model = ResNetCIFAR100(in_channels=3, num_tasks=10, num_classes_per_task=10).to(device)

run_task_si(
    model=model,
    log_name=log_name,
    train_tasks=train_tasks,
    test_tasks=test_tasks,
    device=device,
    epochs=epochs,
    batch_size=batch_size,
    c_si=c_si
)

Training on task 0


Training Task 0:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 0: 78.4%
Test accuracy on previous task 0: 78.4%
Training on task 1


Training Task 1:   0%|          | 0/100 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [22]:
epochs = 100
batch_size = 256
c_si = 1
log_name = f"si_conv_c_CIFAR100_c_si_{c_si}" 
model = CNN(in_channels=3, num_tasks=10, num_classes_per_task=10).to(device)

run_task_si(
    model=model,
    log_name=log_name,
    train_tasks=train_tasks,
    test_tasks=test_tasks,
    device=device,
    epochs=epochs,
    batch_size=batch_size,
    c_si=c_si
)

Training on task 0


Training Task 0:   0%|          | 0/200 [00:00<?, ?it/s]

Test accuracy on task 0: 39.5%
Test accuracy on task 0: 49.0%
Test accuracy on task 0: 55.2%
Test accuracy on task 0: 54.9%
Test accuracy on task 0: 61.2%
Test accuracy on task 0: 64.3%
Test accuracy on task 0: 64.0%
Test accuracy on task 0: 66.8%
Test accuracy on task 0: 68.8%
Test accuracy on task 0: 65.2%
Test accuracy on task 0: 69.5%
Test accuracy on task 0: 68.9%
Test accuracy on task 0: 69.8%
Test accuracy on task 0: 69.7%
Test accuracy on task 0: 70.9%
Test accuracy on task 0: 69.8%
Test accuracy on task 0: 69.6%
Test accuracy on task 0: 70.1%
Test accuracy on task 0: 71.5%
Test accuracy on task 0: 70.8%
Test accuracy on task 0: 70.6%
Test accuracy on task 0: 69.1%
Test accuracy on task 0: 69.7%
Test accuracy on task 0: 69.7%
Test accuracy on task 0: 70.9%
Test accuracy on task 0: 69.8%
Test accuracy on task 0: 71.9%
Test accuracy on task 0: 71.7%
Test accuracy on task 0: 71.9%
Test accuracy on task 0: 72.0%
Test accuracy on task 0: 72.0%
Test accuracy on task 0: 72.2%
Test acc

Training Task 1:   0%|          | 0/200 [00:00<?, ?it/s]

Test accuracy on task 1: 33.2%
Test accuracy on previous task 0: 13.8%
Test accuracy on task 1: 49.8%
Test accuracy on previous task 0: 13.2%
Test accuracy on task 1: 56.0%
Test accuracy on previous task 0: 12.9%
Test accuracy on task 1: 60.9%
Test accuracy on previous task 0: 12.5%
Test accuracy on task 1: 59.6%
Test accuracy on previous task 0: 11.2%
Test accuracy on task 1: 61.6%
Test accuracy on previous task 0: 12.6%
Test accuracy on task 1: 65.3%
Test accuracy on previous task 0: 13.5%
Test accuracy on task 1: 66.5%
Test accuracy on previous task 0: 11.5%
Test accuracy on task 1: 64.4%
Test accuracy on previous task 0: 11.3%
Test accuracy on task 1: 64.0%
Test accuracy on previous task 0: 12.6%
Test accuracy on task 1: 65.2%
Test accuracy on previous task 0: 11.2%
Test accuracy on task 1: 64.1%
Test accuracy on previous task 0: 11.0%
Test accuracy on task 1: 62.9%
Test accuracy on previous task 0: 12.3%
Test accuracy on task 1: 64.6%
Test accuracy on previous task 0: 12.5%
Test a

Training Task 2:   0%|          | 0/200 [00:00<?, ?it/s]

Test accuracy on task 2: 47.0%
Test accuracy on previous task 0: 12.0%
Test accuracy on previous task 1: 8.5%
Test accuracy on task 2: 62.6%
Test accuracy on previous task 0: 13.9%
Test accuracy on previous task 1: 6.7%
Test accuracy on task 2: 66.3%
Test accuracy on previous task 0: 14.2%
Test accuracy on previous task 1: 7.6%
Test accuracy on task 2: 69.5%
Test accuracy on previous task 0: 12.9%
Test accuracy on previous task 1: 8.0%
Test accuracy on task 2: 69.3%
Test accuracy on previous task 0: 12.4%
Test accuracy on previous task 1: 9.9%
Test accuracy on task 2: 68.6%
Test accuracy on previous task 0: 11.5%
Test accuracy on previous task 1: 11.1%
Test accuracy on task 2: 72.0%
Test accuracy on previous task 0: 11.9%
Test accuracy on previous task 1: 9.4%
Test accuracy on task 2: 72.4%
Test accuracy on previous task 0: 10.8%
Test accuracy on previous task 1: 11.1%
Test accuracy on task 2: 71.2%
Test accuracy on previous task 0: 11.8%
Test accuracy on previous task 1: 10.2%
Test ac

Training Task 3:   0%|          | 0/200 [00:00<?, ?it/s]

Test accuracy on task 3: 37.6%
Test accuracy on previous task 0: 8.3%
Test accuracy on previous task 1: 13.7%
Test accuracy on previous task 2: 9.4%
Test accuracy on task 3: 55.1%
Test accuracy on previous task 0: 5.5%
Test accuracy on previous task 1: 14.2%
Test accuracy on previous task 2: 14.0%
Test accuracy on task 3: 61.3%
Test accuracy on previous task 0: 5.5%
Test accuracy on previous task 1: 15.1%
Test accuracy on previous task 2: 13.8%
Test accuracy on task 3: 62.8%
Test accuracy on previous task 0: 5.7%
Test accuracy on previous task 1: 16.7%
Test accuracy on previous task 2: 10.8%
Test accuracy on task 3: 63.5%
Test accuracy on previous task 0: 5.5%
Test accuracy on previous task 1: 16.5%
Test accuracy on previous task 2: 13.0%
Test accuracy on task 3: 65.0%
Test accuracy on previous task 0: 5.1%
Test accuracy on previous task 1: 15.3%
Test accuracy on previous task 2: 13.0%
Test accuracy on task 3: 64.6%
Test accuracy on previous task 0: 5.1%
Test accuracy on previous task 

Training Task 4:   0%|          | 0/200 [00:00<?, ?it/s]

Test accuracy on task 4: 43.1%
Test accuracy on previous task 0: 9.2%
Test accuracy on previous task 1: 6.9%
Test accuracy on previous task 2: 8.1%
Test accuracy on previous task 3: 13.2%
Test accuracy on task 4: 58.5%
Test accuracy on previous task 0: 8.9%
Test accuracy on previous task 1: 10.7%
Test accuracy on previous task 2: 9.7%
Test accuracy on previous task 3: 13.4%
Test accuracy on task 4: 62.8%
Test accuracy on previous task 0: 10.7%
Test accuracy on previous task 1: 10.6%
Test accuracy on previous task 2: 10.6%
Test accuracy on previous task 3: 11.6%
Test accuracy on task 4: 66.3%
Test accuracy on previous task 0: 10.9%
Test accuracy on previous task 1: 10.3%
Test accuracy on previous task 2: 9.9%
Test accuracy on previous task 3: 11.6%
Test accuracy on task 4: 66.2%
Test accuracy on previous task 0: 11.3%
Test accuracy on previous task 1: 8.7%
Test accuracy on previous task 2: 11.1%
Test accuracy on previous task 3: 11.5%
Test accuracy on task 4: 65.6%
Test accuracy on prev

Training Task 5:   0%|          | 0/200 [00:00<?, ?it/s]

Test accuracy on task 5: 50.0%
Test accuracy on previous task 0: 8.1%
Test accuracy on previous task 1: 8.5%
Test accuracy on previous task 2: 6.0%
Test accuracy on previous task 3: 5.5%
Test accuracy on previous task 4: 13.4%
Test accuracy on task 5: 58.5%
Test accuracy on previous task 0: 6.9%
Test accuracy on previous task 1: 6.6%
Test accuracy on previous task 2: 7.2%
Test accuracy on previous task 3: 5.4%
Test accuracy on previous task 4: 14.1%
Test accuracy on task 5: 64.9%
Test accuracy on previous task 0: 6.9%
Test accuracy on previous task 1: 7.8%
Test accuracy on previous task 2: 7.5%
Test accuracy on previous task 3: 4.9%
Test accuracy on previous task 4: 14.6%
Test accuracy on task 5: 66.9%
Test accuracy on previous task 0: 5.9%
Test accuracy on previous task 1: 6.6%
Test accuracy on previous task 2: 8.8%
Test accuracy on previous task 3: 5.7%
Test accuracy on previous task 4: 14.1%
Test accuracy on task 5: 65.8%
Test accuracy on previous task 0: 8.0%
Test accuracy on previ

Training Task 6:   0%|          | 0/200 [00:00<?, ?it/s]

Test accuracy on task 6: 57.4%
Test accuracy on previous task 0: 13.8%
Test accuracy on previous task 1: 8.8%
Test accuracy on previous task 2: 7.5%
Test accuracy on previous task 3: 10.8%
Test accuracy on previous task 4: 8.9%
Test accuracy on previous task 5: 8.3%
Test accuracy on task 6: 65.8%
Test accuracy on previous task 0: 14.4%
Test accuracy on previous task 1: 8.7%
Test accuracy on previous task 2: 7.3%
Test accuracy on previous task 3: 9.8%
Test accuracy on previous task 4: 11.2%
Test accuracy on previous task 5: 8.9%
Test accuracy on task 6: 68.9%
Test accuracy on previous task 0: 17.9%
Test accuracy on previous task 1: 10.6%
Test accuracy on previous task 2: 6.4%
Test accuracy on previous task 3: 11.5%
Test accuracy on previous task 4: 11.2%
Test accuracy on previous task 5: 9.5%
Test accuracy on task 6: 68.1%
Test accuracy on previous task 0: 16.3%
Test accuracy on previous task 1: 10.0%
Test accuracy on previous task 2: 7.8%
Test accuracy on previous task 3: 10.0%
Test ac

Training Task 7:   0%|          | 0/200 [00:00<?, ?it/s]

Test accuracy on task 7: 40.5%
Test accuracy on previous task 0: 14.9%
Test accuracy on previous task 1: 11.3%
Test accuracy on previous task 2: 14.0%
Test accuracy on previous task 3: 8.4%
Test accuracy on previous task 4: 6.9%
Test accuracy on previous task 5: 12.9%
Test accuracy on previous task 6: 8.9%
Test accuracy on task 7: 56.1%
Test accuracy on previous task 0: 17.8%
Test accuracy on previous task 1: 9.3%
Test accuracy on previous task 2: 15.9%
Test accuracy on previous task 3: 6.3%
Test accuracy on previous task 4: 10.1%
Test accuracy on previous task 5: 14.2%
Test accuracy on previous task 6: 5.7%
Test accuracy on task 7: 61.4%
Test accuracy on previous task 0: 19.7%
Test accuracy on previous task 1: 11.9%
Test accuracy on previous task 2: 12.9%
Test accuracy on previous task 3: 6.9%
Test accuracy on previous task 4: 9.2%
Test accuracy on previous task 5: 13.8%
Test accuracy on previous task 6: 5.0%
Test accuracy on task 7: 64.2%
Test accuracy on previous task 0: 20.3%
Test 

Training Task 8:   0%|          | 0/200 [00:00<?, ?it/s]

Test accuracy on task 8: 44.3%
Test accuracy on previous task 0: 6.3%
Test accuracy on previous task 1: 9.2%
Test accuracy on previous task 2: 8.1%
Test accuracy on previous task 3: 8.3%
Test accuracy on previous task 4: 8.6%
Test accuracy on previous task 5: 16.6%
Test accuracy on previous task 6: 11.0%
Test accuracy on previous task 7: 11.6%
Test accuracy on task 8: 56.0%
Test accuracy on previous task 0: 5.2%
Test accuracy on previous task 1: 8.6%
Test accuracy on previous task 2: 6.9%
Test accuracy on previous task 3: 8.1%
Test accuracy on previous task 4: 9.8%
Test accuracy on previous task 5: 14.9%
Test accuracy on previous task 6: 9.8%
Test accuracy on previous task 7: 13.2%
Test accuracy on task 8: 61.5%
Test accuracy on previous task 0: 6.8%
Test accuracy on previous task 1: 8.0%
Test accuracy on previous task 2: 6.6%
Test accuracy on previous task 3: 7.2%
Test accuracy on previous task 4: 9.2%
Test accuracy on previous task 5: 14.3%
Test accuracy on previous task 6: 8.8%
Test

Training Task 9:   0%|          | 0/200 [00:00<?, ?it/s]

Test accuracy on task 9: 54.6%
Test accuracy on previous task 0: 10.5%
Test accuracy on previous task 1: 5.5%
Test accuracy on previous task 2: 6.1%
Test accuracy on previous task 3: 5.2%
Test accuracy on previous task 4: 4.4%
Test accuracy on previous task 5: 8.9%
Test accuracy on previous task 6: 19.3%
Test accuracy on previous task 7: 9.3%
Test accuracy on previous task 8: 13.1%
Test accuracy on task 9: 65.8%
Test accuracy on previous task 0: 9.7%
Test accuracy on previous task 1: 5.5%
Test accuracy on previous task 2: 7.4%
Test accuracy on previous task 3: 7.7%
Test accuracy on previous task 4: 4.1%
Test accuracy on previous task 5: 10.3%
Test accuracy on previous task 6: 18.3%
Test accuracy on previous task 7: 12.1%
Test accuracy on previous task 8: 11.4%
Test accuracy on task 9: 66.7%
Test accuracy on previous task 0: 9.2%
Test accuracy on previous task 1: 5.4%
Test accuracy on previous task 2: 7.8%
Test accuracy on previous task 3: 7.3%
Test accuracy on previous task 4: 3.8%
Tes

## Laplace Propagation Method for Split CIFAR100 Task with CNN

In [26]:
def compute_hessian_diag(model, dataloader, device, task_idx):
    model.eval()
    hessian_diag = {}
    for name, param in model.named_parameters():
        hessian_diag[name] = torch.zeros_like(param)

    for data, target in dataloader:
        data, target = data.to(device), remap_labels(target, task_idx).to(device)
        model.zero_grad()
        output = model(data, task_idx) 
        loss = F.nll_loss(output, target)
        grad_params = torch.autograd.grad(loss, model.parameters(), create_graph=True, allow_unused=True)

        for grad, (name, param) in zip(grad_params, model.named_parameters()):
            if grad is not None:
                grad2 = torch.autograd.grad(grad.sum(), param, retain_graph=True, allow_unused=True)[0]
                if grad2 is not None:
                    hessian_diag[name] += grad2.data / len(dataloader.dataset)

    return hessian_diag

def run_task_lp(model, log_name, train_tasks, test_tasks, device, epochs, batch_size, gamma_lp):
    summary_writer = SummaryWriter(log_dir=os.path.join("logs", log_name, datetime.now().strftime('%Y%m%d_%H%M%S')))
    experiment_path = f"out/experiments/{log_name}"
    os.makedirs(experiment_path, exist_ok=True)
    accuracies = {}
    
    hessian_diag = {}
    prev_params = {}

    for task_idx, (train_task, test_task) in enumerate(zip(train_tasks, test_tasks)):
        print(f"Training on task {task_idx}")
        model.train()
        train_loader = DataLoader(train_tasks[task_idx], batch_size=batch_size, shuffle=True)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

        for epoch in tqdm(range(epochs), desc=f"Task {task_idx} Epoch"):
            for images, labels in train_loader:
                images, labels = images.to(device), labels.to(device)
                labels = remap_labels(labels, task_idx)
                optimizer.zero_grad()
                outputs = model(images, task_idx)
                loss = F.nll_loss(outputs, labels)
                
                if task_idx > 0:
                    lp_loss = 0
                    for name, param in model.named_parameters():
                        if name in hessian_diag:
                            lp_loss += (hessian_diag[name] * (param - prev_params[name]) ** 2).sum()
                    loss += gamma_lp * lp_loss
                
                loss.backward()
                optimizer.step()

            if task_idx > 0:
                hessian_diag = compute_hessian_diag(model, train_loader, device, task_idx)
            prev_params = {name: param.clone().detach() for name, param in model.named_parameters()}

            summary_writer.add_scalar(f'Task_{task_idx}/Train_Loss', loss.item(), epoch)
        
        test_accuracy = test_model(model, DataLoader(test_tasks[task_idx], batch_size=batch_size), device, task_idx)  # Test model utility adapted for CIFAR100
        print(f"Test accuracy on task {task_idx}: {test_accuracy}%")
        summary_writer.add_scalar(f'Task_{task_idx}/Test_Accuracy', test_accuracy, epoch)
        accuracies[f"TASK {task_idx}"] = test_accuracy

        # Evaluate the model on all previous tasks' test sets to measure forgetting
        for previous_task_idx in range(task_idx + 1):
            prev_task_accuracy = test_model(model, DataLoader(test_tasks[previous_task_idx], batch_size=batch_size), device, task_idx)
            print(f"Test accuracy on previous task {previous_task_idx}: {prev_task_accuracy}%")
            summary_writer.add_scalar(f"Accuracy/Previous_Task_{previous_task_idx}_after_learning_{task_idx}", prev_task_accuracy, epoch)
            accuracies[f"Task_{previous_task_idx}_after_{task_idx}"] = prev_task_accuracy

    # Save accuracies to file
    accuracies_file = os.path.join(experiment_path, "final_accuracies.json")
    with open(accuracies_file, 'w') as f:
        json.dump(accuracies, f)
    print(f"Accuracies saved to {accuracies_file}")
    
    summary_writer.close()

In [27]:
epochs = 100
batch_size = 256 
gamma_lp = 0.1  # Coefficient for the LP regularization term
log_name = f"lp_conv_CIFAR100_gamma_lp_{gamma_lp}" 

model = Cifar10CNN(in_channels=3, num_tasks=10, num_classes_per_task=10).to(device)

run_task_lp(
    model=model,
    log_name=log_name,
    train_tasks=train_tasks,
    test_tasks=test_tasks,
    device=device,
    epochs=epochs,
    batch_size=batch_size,
    gamma_lp=gamma_lp
)

Training on task 0


Task 0 Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 0: 69.8%
Test accuracy on previous task 0: 69.8%
Training on task 1


Task 1 Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 1: 65.7%
Test accuracy on previous task 0: 11.6%
Test accuracy on previous task 1: 65.7%
Training on task 2


Task 2 Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 2: 72.8%
Test accuracy on previous task 0: 12.0%
Test accuracy on previous task 1: 9.0%
Test accuracy on previous task 2: 72.8%
Training on task 3


Task 3 Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 3: 68.6%
Test accuracy on previous task 0: 5.8%
Test accuracy on previous task 1: 15.9%
Test accuracy on previous task 2: 12.9%
Test accuracy on previous task 3: 68.6%
Training on task 4


Task 4 Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 4: 68.5%
Test accuracy on previous task 0: 14.2%
Test accuracy on previous task 1: 8.1%
Test accuracy on previous task 2: 10.6%
Test accuracy on previous task 3: 11.8%
Test accuracy on previous task 4: 68.5%
Training on task 5


Task 5 Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 5: 69.9%
Test accuracy on previous task 0: 7.9%
Test accuracy on previous task 1: 8.6%
Test accuracy on previous task 2: 8.1%
Test accuracy on previous task 3: 5.1%
Test accuracy on previous task 4: 13.7%
Test accuracy on previous task 5: 69.9%
Training on task 6


Task 6 Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 6: 71.8%
Test accuracy on previous task 0: 15.9%
Test accuracy on previous task 1: 8.5%
Test accuracy on previous task 2: 6.6%
Test accuracy on previous task 3: 8.1%
Test accuracy on previous task 4: 9.8%
Test accuracy on previous task 5: 8.9%
Test accuracy on previous task 6: 71.8%
Training on task 7


Task 7 Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 7: 68.4%
Test accuracy on previous task 0: 20.4%
Test accuracy on previous task 1: 11.0%
Test accuracy on previous task 2: 11.1%
Test accuracy on previous task 3: 8.3%
Test accuracy on previous task 4: 11.5%
Test accuracy on previous task 5: 11.1%
Test accuracy on previous task 6: 4.8%
Test accuracy on previous task 7: 68.4%
Training on task 8


Task 8 Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 8: 64.6%
Test accuracy on previous task 0: 5.3%
Test accuracy on previous task 1: 11.0%
Test accuracy on previous task 2: 5.7%
Test accuracy on previous task 3: 9.5%
Test accuracy on previous task 4: 8.8%
Test accuracy on previous task 5: 15.4%
Test accuracy on previous task 6: 8.4%
Test accuracy on previous task 7: 10.7%
Test accuracy on previous task 8: 64.6%
Training on task 9


Task 9 Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 9: 72.9%
Test accuracy on previous task 0: 9.1%
Test accuracy on previous task 1: 4.5%
Test accuracy on previous task 2: 6.7%
Test accuracy on previous task 3: 8.9%
Test accuracy on previous task 4: 3.9%
Test accuracy on previous task 5: 8.3%
Test accuracy on previous task 6: 19.8%
Test accuracy on previous task 7: 13.1%
Test accuracy on previous task 8: 12.2%
Test accuracy on previous task 9: 72.9%
Accuracies saved to out/experiments/lp_conv_CIFAR100_gamma_lp_0.1/final_accuracies.json


In [38]:
epochs = 100
batch_size = 256 
gamma_lp = 0.1  # Coefficient for the LP regularization term
log_name = f"lp_resnet_CIFAR100_gamma_lp_{gamma_lp}" 

model = ResNetCIFAR100(in_channels=3, num_tasks=10, num_classes_per_task=10).to(device)

run_task_lp(
    model=model,
    log_name=log_name,
    train_tasks=train_tasks,
    test_tasks=test_tasks,
    device=device,
    epochs=epochs,
    batch_size=batch_size,
    gamma_lp=gamma_lp
)

Training on task 0


Task 0 Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 0: 77.1%
Test accuracy on previous task 0: 77.1%
Training on task 1


Task 1 Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

KeyboardInterrupt: 