# Experiment with CNN via Various Non-VCL Approaches (e.g. MAP, LP, EWC, SI) on Split-CIFAR10 Task

The models are configured in almost the same way (in terms of widths and depths) as in the VCL experiments, namely CNN-4 and ResNet-4.

## Model Definition and Data Preparation

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import copy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
# Define non-vcl CNN & ResNet with task heads 
class CNN(nn.Module):
    def __init__(self, in_channels, num_tasks=5, num_classes_per_task=2):
        super(CNN, self).__init__()
        self.shared_conv_layers = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        
        self.fc_input_dim = 256 * 4 * 4  
        
        self.task_heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(self.fc_input_dim, 128),
                nn.ReLU(),
                nn.Linear(128, num_classes_per_task)
            ) for _ in range(num_tasks)
        ])
        
    def forward(self, x, task_idx):
        x = self.shared_conv_layers(x)
        x = x.view(-1, self.fc_input_dim)  # Flatten
        task_output = self.task_heads[task_idx](x)
        return F.log_softmax(task_output, dim=1)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNetCIFAR10(nn.Module):
    def __init__(self, block=BasicBlock, num_blocks=[2, 2], num_tasks=5, num_classes_per_task=100, in_channels=3):
        super(ResNetCIFAR10, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)

        self.task_heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(8192 * block.expansion, num_classes_per_task)
            ) for _ in range(num_tasks)
        ])

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, task_idx):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = F.avg_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        task_output = self.task_heads[task_idx](out)
        return F.log_softmax(task_output, dim=1)

In [3]:
# Get split CIFAR10 dataset
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize

from util.transforms import Flatten, Scale

# Normalization for CIFAR10
normalize = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

# permutation used for each task (add flatten for mlp)
transform = Compose([
    ToTensor(),
    normalize
])
# transform = Compose([Flatten(), Scale()])

cifar_train = CIFAR10(root=f"/scratch-ssd/oatml/data", train=True, download=False, transform=transform)
cifar_test = CIFAR10(root=f"/scratch-ssd/oatml/data", train=False, download=False, transform=transform)

label_to_task_mapping = {
    0: 0, 1: 0,
    2: 1, 3: 1,
    4: 2, 5: 2,
    6: 3, 7: 3,
    8: 4, 9: 4,
}

if isinstance(cifar_train[0][1], int):
    train_task_ids = torch.Tensor([label_to_task_mapping[y] for _, y in cifar_train])
    test_task_ids = torch.Tensor([label_to_task_mapping[y] for _, y in cifar_test])
elif isinstance(cifar_train[0][1], torch.Tensor):
    train_task_ids = torch.Tensor([label_to_task_mapping[y.item()] for _, y in cifar_train])
    test_task_ids = torch.Tensor([label_to_task_mapping[y.item()] for _, y in cifar_test])

## Joint Accuracy (Upperbounds)

In [6]:
# Train on Split CIFAR10 with new model instantiated at each run
import os
import json
from datetime import datetime
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torch.utils.tensorboard import SummaryWriter

from util.operations import task_subset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
binarize_y = lambda y, task: (y == (2 * task + 1)).long()

def test_model(model, dataloader, device, task_idx):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in dataloader:
            images, labels = data[0].to(device), binarize_y(data[1], task_idx).to(device)
            outputs = model(images, 0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy


def train_cifar_split(cls, log_name, cifar_train, cifar_test, train_task_ids, test_task_ids, device):
    """
    Trains separate models on a split CIFAR10 dataset (five binary tasks).

    Args:
    - cls: The model constructor to train with.
    - log_name: Name for the TensorBoard logs.
    - cifar_train: Training dataset.
    - cifar_test: Test dataset.
    - train_task_ids: Task IDs for the training data.
    - test_task_ids: Task IDs for the test data.
    - device: Torch device to use for training.
    """
    # Setup TensorBoard writer
    summary_logdir = os.path.join("logs", log_name, datetime.now().strftime('%b%d_%H-%M-%S'))
    summary_writer = SummaryWriter(summary_logdir)
    experiment_path = f"out/experiments/{log_name}"
    os.makedirs(experiment_path, exist_ok=True)  # Ensure output directory exists
    accuracies = {}

    num_tasks = 5
    for task_idx in range(num_tasks):
        print(f"Training on task {task_idx}")
        model = cls(in_channels=3, num_tasks=1, num_classes_per_task=2).to(device)
        
        task_dataset = task_subset(cifar_train, train_task_ids, task_idx)
        task_dataloader = DataLoader(task_dataset, batch_size=256, shuffle=True)

        test_dataset = task_subset(cifar_test, test_task_ids, task_idx)
        test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=False)

        optimizer = optim.Adam(model.parameters(), lr=0.001)

        model.train()
        for epoch in tqdm(range(100), desc=f"Epoch: "):
            for images, labels in task_dataloader:
                images, labels = images.to(device), binarize_y(labels, task_idx).to(device)
                optimizer.zero_grad()
                outputs = model(images, 0)
                loss = F.nll_loss(outputs, labels)
                loss.backward()
                optimizer.step()

            summary_writer.add_scalar(f'Task_{task_idx}/Train_Loss', loss.item(), epoch)

        task_accuracy = test_model(model, test_dataloader, device, task_idx)
        print(f"Test accuracy on task {task_idx}: {task_accuracy}%")
        summary_writer.add_scalar(f"Accuracy/task_{task_idx}_max", task_accuracy, global_step=task_idx)
        accuracies[f"TASK {task_idx}"] = task_accuracy

    # Save accuracies to file
    accuracies_file = os.path.join(experiment_path, "final_accuracies.json")
    with open(accuracies_file, 'w') as f:
        json.dump(accuracies, f)
    print(f"Accuracies saved to {accuracies_file}")

    summary_writer.close()

In [14]:
train_cifar_split(
    CNN, 
    "joint_disc_conv_s_cifar10_upperbound", 
    cifar_train, 
    cifar_test, 
    train_task_ids, 
    test_task_ids, 
    device
)

Training on task 0


Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 0: 96.15%
Training on task 1


Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 1: 86.4%
Training on task 2


Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 2: 91.8%
Training on task 3


Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 3: 96.8%
Training on task 4


Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 4: 96.0%
Accuracies saved to out/experiments/joint_disc_conv_s_cifar10_upperbound/final_accuracies.json


NameError: name 'writer' is not defined

In [18]:
train_cifar_split(
    ResNetCIFAR10, 
    "joint_disc_resnet_s_cifar10_upperbound", 
    cifar_train, 
    cifar_test, 
    train_task_ids, 
    test_task_ids, 
    device
)

Training on task 0


Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 0: 97.95%
Training on task 1


Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 1: 88.35%
Training on task 2


Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 2: 93.55%
Training on task 3


Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 3: 97.7%
Training on task 4


Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 4: 91.75%
Accuracies saved to out/experiments/joint_disc_resnet_s_cifar10_upperbound/final_accuracies.json


NameError: name 'writer' is not defined

## MLE / MAP for Split CIFAR10 with CNN/ResNet

In [24]:
# Train on Split CIFAR10 without any CL strategies.
def test_model(model, dataloader, device, task_idx):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in dataloader:
            images, labels = data[0].to(device), binarize_y(data[1], task_idx).to(device)
            outputs = model(images, task_idx)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy


def train_cifar_split(model, log_name, cifar_train, cifar_test, train_task_ids, test_task_ids, device):
    """
    Trains a given model on a split CIFAR10 dataset without continual learning strategies.

    Args:
    - model: The model to train.
    - log_name: Name for the TensorBoard logs.
    - cifar_train: Training dataset.
    - cifar_test: Test dataset.
    - train_task_ids: Task IDs for the training data.
    - test_task_ids: Task IDs for the test data.
    - device: Torch device to use for training.
    """
    # Setup TensorBoard writer
    summary_logdir = os.path.join("logs", log_name, datetime.now().strftime('%b%d_%H-%M-%S'))
    summary_writer = SummaryWriter(summary_logdir)
    os.makedirs("out/models/", exist_ok=True)  # Ensure output directory exists
    experiment_path = f"out/experiments/{log_name}"
    os.makedirs(experiment_path, exist_ok=True)  # Ensure output directory exists
    accuracies = {}

    num_tasks = 5
    for task_idx in range(num_tasks):
        print(f"Training on task {task_idx}")
        task_dataset = task_subset(cifar_train, train_task_ids, task_idx)
        task_dataloader = DataLoader(task_dataset, batch_size=256, shuffle=True)

        test_dataset = task_subset(cifar_test, test_task_ids, task_idx)
        test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=False)

        optimizer = optim.Adam(model.parameters(), lr=0.001)

        model.train()
        for epoch in tqdm(range(100), desc=f"Epoch: "):
            for images, labels in task_dataloader:
                images, labels = images.to(device), binarize_y(labels, task_idx).to(device)
                optimizer.zero_grad()
                outputs = model(images, task_idx)
                loss = F.nll_loss(outputs, labels)
                loss.backward()
                optimizer.step()

            summary_writer.add_scalar(f'Task_{task_idx}/Train_Loss', loss.item(), epoch)

        task_accuracy = test_model(model, test_dataloader, device, task_idx)
        print(f"Test accuracy on task {task_idx}: {task_accuracy}%")
        summary_writer.add_scalar(f"Accuracy/task_{task_idx}", task_accuracy, global_step=task_idx)
        accuracies[f"TASK {task_idx}"] = task_accuracy

        for previous_task_idx in range(task_idx + 1):
            test_dataset = task_subset(cifar_test, test_task_ids, previous_task_idx)
            test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=False)
            
            accuracy = test_model(model, test_dataloader, device, previous_task_idx)
            print(f"Test accuracy on previous task {previous_task_idx}: {accuracy}%")
            summary_writer.add_scalar(f"Cross_Task_Accuracy/task_{task_idx}_on_{previous_task_idx}", accuracy, global_step=task_idx)
            accuracies[f"TASK {previous_task_idx}"] = accuracy
            
    # Save model state
    model_save_path = os.path.join("out/models/", f"{log_name}_model_final.pth")
    torch.save(model.state_dict(), model_save_path)
    print(f"Model saved to {model_save_path}")

    # Save accuracies to file
    accuracies_file = os.path.join(experiment_path, "final_accuracies.json")
    with open(accuracies_file, 'w') as f:
        json.dump(accuracies, f)
    print(f"Accuracies saved to {accuracies_file}")

    summary_writer.close()

In [25]:
model = ResNetCIFAR10(in_channels=3, num_tasks=5, num_classes_per_task=2).to(device)
train_cifar_split(
    model, 
    "mle_disc_resnet_s_cifar10", 
    cifar_train, 
    cifar_test, 
    train_task_ids, 
    test_task_ids, 
    device
)

Training on task 0


Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 0: 97.4%
Test accuracy on previous task 0: 97.4%
Training on task 1


Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 1: 88.5%
Test accuracy on previous task 0: 84.9%
Test accuracy on previous task 1: 88.5%
Training on task 2


Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 2: 93.8%
Test accuracy on previous task 0: 77.7%
Test accuracy on previous task 1: 80.8%
Test accuracy on previous task 2: 93.8%
Training on task 3


Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 3: 98.2%
Test accuracy on previous task 0: 58.5%
Test accuracy on previous task 1: 71.9%
Test accuracy on previous task 2: 86.95%
Test accuracy on previous task 3: 98.2%
Training on task 4


Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 4: 97.55%
Test accuracy on previous task 0: 64.85%
Test accuracy on previous task 1: 65.55%
Test accuracy on previous task 2: 72.2%
Test accuracy on previous task 3: 83.95%
Test accuracy on previous task 4: 97.55%
Model saved to out/models/mle_disc_resnet_s_cifar10_model_final.pth
Accuracies saved to out/experiments/mle_disc_resnet_s_cifar10/final_accuracies.json


In [48]:
model = CNN(in_channels=3, num_tasks=5, num_classes_per_task=2).to(device)
train_cifar_split(
    model, 
    "mle_disc_conv_s_cifar10", 
    cifar_train, 
    cifar_test, 
    train_task_ids, 
    test_task_ids, 
    device
)

Training on task 0


Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Testing Task 0:   0%|          | 0/8 [00:00<?, ?it/s]

Test accuracy on task 0: 96.2%
Training on task 1


Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Testing Task 1:   0%|          | 0/8 [00:00<?, ?it/s]

Test accuracy on task 1: 86.6%


Testing Task 0:   0%|          | 0/8 [00:00<?, ?it/s]

Test accuracy on previous task 0: 89.75%
Training on task 2


Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Testing Task 2:   0%|          | 0/8 [00:00<?, ?it/s]

Test accuracy on task 2: 91.25%


Testing Task 0:   0%|          | 0/8 [00:00<?, ?it/s]

Test accuracy on previous task 0: 67.2%


Testing Task 1:   0%|          | 0/8 [00:00<?, ?it/s]

Test accuracy on previous task 1: 52.0%
Training on task 3


Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Testing Task 3:   0%|          | 0/8 [00:00<?, ?it/s]

Test accuracy on task 3: 97.45%


Testing Task 0:   0%|          | 0/8 [00:00<?, ?it/s]

Test accuracy on previous task 0: 65.85%


Testing Task 1:   0%|          | 0/8 [00:00<?, ?it/s]

Test accuracy on previous task 1: 71.55%


Testing Task 2:   0%|          | 0/8 [00:00<?, ?it/s]

Test accuracy on previous task 2: 81.7%
Training on task 4


Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Testing Task 4:   0%|          | 0/8 [00:00<?, ?it/s]

Test accuracy on task 4: 95.65%


Testing Task 0:   0%|          | 0/8 [00:00<?, ?it/s]

Test accuracy on previous task 0: 83.55%


Testing Task 1:   0%|          | 0/8 [00:00<?, ?it/s]

Test accuracy on previous task 1: 78.55%


Testing Task 2:   0%|          | 0/8 [00:00<?, ?it/s]

Test accuracy on previous task 2: 75.45%


Testing Task 3:   0%|          | 0/8 [00:00<?, ?it/s]

Test accuracy on previous task 3: 86.9%
Model saved to out/models/comp_disc_conv_s_cifar10_model_final.pth


## EWC Method for Split CIFAR10 with CNN/ResNet

We may see significant improvements across tasks from above.

In [27]:
# Test models with EWC method on split cifar10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def compute_fisher_information(model, dataloader, task_idx, device):
    model.eval()
    fisher_information = {}
    with torch.no_grad():
        for name, param in model.named_parameters():
            fisher_information[name] = torch.zeros_like(param.data)

        for data, target in dataloader:
            data, target = data.to(device), binarize_y(target, task_idx).to(device)
            with torch.enable_grad():  # Enables gradient calculation in this block
                model.zero_grad()
                output = model(data, task_idx)
                loss = F.nll_loss(output, target)
                loss.backward()
                for name, param in model.named_parameters():
                    if param.grad is not None:
                        fisher_information[name] += param.grad.data ** 2 / len(dataloader.dataset)
    return fisher_information

def modify_loss_function(original_loss, model, lambda_ewc, fisher_matrices, optimal_params):
    ewc_loss = 0
    for name, param in model.named_parameters():
        if name in fisher_matrices:
            fisher_matrix = fisher_matrices[name]
            optimal_param = optimal_params[name].to(device)
            ewc_loss += (fisher_matrix * (param - optimal_param) ** 2).sum()
    return original_loss + lambda_ewc / 2 * ewc_loss

def run_task_ewc(model, log_name, cifar_train, cifar_test, train_task_ids, test_task_ids, device, epochs, batch_size, lambda_ewc):
    """
    Trains a given model on CIFAR10 split tasks using the Elastic Weight Consolidation (EWC) method.

    Args:
    - model: The model to train.
    - log_name: Name for the TensorBoard logs.
    - cifar_train: Training dataset.
    - cifar_test: Test dataset.
    - train_task_ids: Task IDs for the training data.
    - test_task_ids: Task IDs for the test data.
    - device: Torch device to use for training.
    - epochs: Number of epochs to train each task.
    - batch_size: Batch size for training and testing.
    - lambda_ewc: The EWC penalty term lambda.
    """    
    summary_writer = SummaryWriter(log_dir=os.path.join("logs", log_name, datetime.now().strftime('%Y%m%d_%H%M%S')))
    experiment_path = f"out/experiments/{log_name}"
    os.makedirs(experiment_path, exist_ok=True)  # Ensure output directory exists
    accuracies = {}
    
    previous_fisher_matrices = {}
    previous_optimal_params = {}

    for task_idx in range(5):
        print(f"Training on task {task_idx}")
        task_dataset = task_subset(cifar_train, train_task_ids, task_idx)
        train_loader = DataLoader(task_dataset, batch_size=batch_size, shuffle=True)
    
        test_dataset = task_subset(cifar_test, test_task_ids, task_idx)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

        model.train()
        for epoch in tqdm(range(epochs), desc=f"Training Task {task_idx}"):
            for data, target in train_loader:
                data, target = data.to(device), binarize_y(target, task_idx).to(device)
                optimizer.zero_grad()
                output = model(data, task_idx)
                loss = F.nll_loss(output, target)
                if task_idx > 0:
                    ewc_loss = modify_loss_function(loss, model, lambda_ewc, previous_fisher_matrices, previous_optimal_params)
                    # print(f"Task {task_idx}'s ewc_loss: ", ewc_loss)
                    loss = ewc_loss
                loss.backward()
                optimizer.step()
        
            summary_writer.add_scalar(f'Task_{task_idx}/Train_Loss', loss.item(), epoch)

        test_accuracy = test_model(model, test_loader, device, task_idx)
        print(f"Test accuracy on task {task_idx}: {test_accuracy}%")
        summary_writer.add_scalar(f'Task_{task_idx}/Test_Accuracy', test_accuracy, epoch)
        accuracies[f"TASK {task_idx}"] = test_accuracy

        for previous_task_idx in range(task_idx + 1):
            test_dataset = task_subset(cifar_test, test_task_ids, previous_task_idx)
            test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=False)
            
            accuracy = test_model(model, test_dataloader, device, previous_task_idx)
            print(f"Test accuracy on previous task {previous_task_idx}: {accuracy}%")
            summary_writer.add_scalar(f"Cross_Task_Accuracy/task_{task_idx}_on_{previous_task_idx}", accuracy, global_step=task_idx)
            accuracies[f"TASK {previous_task_idx}"] = accuracy

        # Update for EWC (after model.eval())
        model.eval()
        fisher_information = compute_fisher_information(model, train_loader, task_idx, device)
        optimal_params = copy.deepcopy(model.state_dict())

        if task_idx == 0:
            previous_fisher_matrices = fisher_information
            previous_optimal_params = optimal_params
        else:
            for name in fisher_information:
                previous_fisher_matrices[name] += fisher_information[name]

    # Save accuracies to file
    import json
    accuracies_file = os.path.join(experiment_path, "final_accuracies.json")
    with open(accuracies_file, 'w') as f:
        json.dump(accuracies, f)
    print(f"Accuracies saved to {accuracies_file}")

    summary_writer.close()

In [84]:
lambda_ewc = 5000

model = CNN(in_channels=3, num_tasks=5, num_classes_per_task=2).to(device)
run_task_ewc(model, f"ewc_conv_s_cifar10_lambda_ewc_{lambda_ewc}", cifar_train, cifar_test, train_task_ids, test_task_ids, device, epochs=100, batch_size=256, lambda_ewc=lambda_ewc)

Training on task 0


Training Task 0:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 0: 96.25%
Training on task 1


Training Task 1:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 1: 87.3%
Test accuracy on previous task 0: 90.2%
Training on task 2


Training Task 2:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 2: 91.6%
Test accuracy on previous task 0: 88.65%
Test accuracy on previous task 1: 70.05%
Training on task 3


Training Task 3:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 3: 97.25%
Test accuracy on previous task 0: 89.3%
Test accuracy on previous task 1: 70.8%
Test accuracy on previous task 2: 86.1%
Training on task 4


Training Task 4:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 4: 95.85%
Test accuracy on previous task 0: 90.0%
Test accuracy on previous task 1: 73.9%
Test accuracy on previous task 2: 80.5%
Test accuracy on previous task 3: 91.45%


In [30]:
lambda_ewc = 1
model = ResNetCIFAR10(in_channels=3, num_tasks=5, num_classes_per_task=2).to(device)

run_task_ewc(
    model, 
    f"ewc_resnet_s_cifar10_lambda_ewc_{lambda_ewc}", 
    cifar_train, 
    cifar_test, 
    train_task_ids, 
    test_task_ids, 
    device,
    epochs=100, 
    batch_size=256, 
    lambda_ewc=lambda_ewc
)

Training on task 0


Training Task 0:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 0: 98.1%
Test accuracy on previous task 0: 98.1%
Training on task 1


Training Task 1:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 1: 89.65%
Test accuracy on previous task 0: 82.35%
Test accuracy on previous task 1: 89.65%
Training on task 2


Training Task 2:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 2: 94.25%
Test accuracy on previous task 0: 75.5%
Test accuracy on previous task 1: 80.85%
Test accuracy on previous task 2: 94.25%
Training on task 3


Training Task 3:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 3: 98.25%
Test accuracy on previous task 0: 66.3%
Test accuracy on previous task 1: 74.25%
Test accuracy on previous task 2: 83.9%
Test accuracy on previous task 3: 98.25%
Training on task 4


Training Task 4:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 4: 97.0%
Test accuracy on previous task 0: 75.9%
Test accuracy on previous task 1: 68.05%
Test accuracy on previous task 2: 73.65%
Test accuracy on previous task 3: 88.4%
Test accuracy on previous task 4: 97.0%
Accuracies saved to out/experiments/ewc_resnet_s_cifar10_lambda_ewc_1/final_accuracies.json


In [31]:
lambda_ewc = 1

model = CNN(in_channels=3, num_tasks=5, num_classes_per_task=2).to(device)
run_task_ewc(model, f"ewc_conv_CIFAR10_lambda_ewc_{lambda_ewc}", cifar_train, cifar_test, train_task_ids, test_task_ids, device, epochs=100, batch_size=256, lambda_ewc=lambda_ewc)

Training on task 0


Training Task 0:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 0: 96.2%
Test accuracy on previous task 0: 96.2%
Training on task 1


Training Task 1:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 1: 87.3%
Test accuracy on previous task 0: 92.6%
Test accuracy on previous task 1: 87.3%
Training on task 2


Training Task 2:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 2: 92.25%
Test accuracy on previous task 0: 83.35%
Test accuracy on previous task 1: 69.25%
Test accuracy on previous task 2: 92.25%
Training on task 3


Training Task 3:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 3: 97.5%
Test accuracy on previous task 0: 84.25%
Test accuracy on previous task 1: 71.15%
Test accuracy on previous task 2: 88.15%
Test accuracy on previous task 3: 97.5%
Training on task 4


Training Task 4:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 4: 96.2%
Test accuracy on previous task 0: 87.05%
Test accuracy on previous task 1: 70.8%
Test accuracy on previous task 2: 77.7%
Test accuracy on previous task 3: 95.05%
Test accuracy on previous task 4: 96.2%
Accuracies saved to out/experiments/ewc_conv_CIFAR10_lambda_ewc_1/final_accuracies.json


## Synaptic Intelligence (SI) for Split CIFAR10 with CNN

In [41]:
def run_task_si(model, log_name, cifar_train, cifar_test, train_task_ids, test_task_ids, device, epochs, batch_size, c_si):
    """
    Trains a given model on CIFAR10 split tasks using the Synaptic Intelligence (SI) method.
    
    Args:
    - model: The model to train.
    - log_name: Name for the TensorBoard logs.
    - cifar_train: Training dataset.
    - cifar_test: Test dataset.
    - train_task_ids: Task IDs for the training data.
    - test_task_ids: Task IDs for the test data.
    - device: Torch device to use for training.
    - epochs: Number of epochs to train each task.
    - batch_size: Batch size for training and testing.
    - c_si: The SI regularization term coefficient.
    """
    summary_writer = SummaryWriter(log_dir=os.path.join("logs", log_name, datetime.now().strftime('%Y%m%d_%H%M%S')))
    experiment_path = f"out/experiments/{log_name}"
    os.makedirs(experiment_path, exist_ok=True)  # Ensure output directory exists
    accuracies = {}
    
    # Initialize importance and previous parameters dictionaries
    importance = {}
    prev_params = {}
    
    for name, param in model.named_parameters():
        if param.requires_grad:
            importance[name] = torch.zeros_like(param, device=device)
            prev_params[name] = param.clone().detach()

    for task_idx in range(5):
        print(f"Training on task {task_idx}")
        # Assume task_subset and binarize_y are defined elsewhere to handle task-specific data preparation
        task_dataset = task_subset(cifar_train, train_task_ids, task_idx)
        train_loader = DataLoader(task_dataset, batch_size=batch_size, shuffle=True)

        test_dataset = task_subset(cifar_test, test_task_ids, task_idx)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

        for epoch in tqdm(range(epochs), desc=f"Training Task {task_idx}"):
            model.train()
            for data, target in train_loader:
                data, target = data.to(device), binarize_y(target, task_idx).to(device)
                optimizer.zero_grad()
                output = model(data, task_idx)
                loss = F.nll_loss(output, target)
                
                # Backpropagate to compute gradients
                loss.backward()

                # Calculate SI regularization term and add it to the loss before optimizer step
                si_reg_loss = sum((importance[name] * (param - prev_params[name]) ** 2).sum() for name, param in model.named_parameters() if param.requires_grad)
                (c_si * si_reg_loss).backward() 
                
                optimizer.step()

                # Update the importance weights after optimizer step
                for name, param in model.named_parameters():
                    if param.requires_grad:
                        delta_param = param.detach() - prev_params[name]
                        # Assuming a zero-initialized importance, this needs accumulation of gradient information over tasks
                        importance[name] += (param.grad.detach() ** 2) * delta_param.abs()
            summary_writer.add_scalar(f'Task_{task_idx}/Train_Loss', loss.item(), epoch)
        
        # Evaluate the model on the current task's test set
        test_accuracy = test_model(model, test_loader, device, task_idx)
        print(f"Test accuracy on task {task_idx}: {test_accuracy}%")
        summary_writer.add_scalar(f'Task_{task_idx}/Test_Accuracy', test_accuracy, epoch)
        accuracies[f"TASK {task_idx}"] = test_accuracy
        
        # Evaluate the model on all previous tasks' test sets to measure forgetting
        for previous_task_idx in range(task_idx + 1):
            test_dataset = task_subset(cifar_test, test_task_ids, previous_task_idx)
            test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=False)
            accuracy = test_model(model, test_dataloader, device, previous_task_idx)
            print(f"Test accuracy on previous task {previous_task_idx}: {accuracy}%")
            summary_writer.add_scalar(f"Cross_Task_Accuracy/task_{task_idx}_on_{previous_task_idx}", accuracy, global_step=task_idx)
            accuracies[f"TASK {previous_task_idx}"] = accuracy
        
        # After task training, update previous parameters for the next task
        prev_params = {name: param.clone().detach() for name, param in model.named_parameters() if param.requires_grad}


    # Save accuracies to file
    accuracies_file = os.path.join(experiment_path, "final_accuracies.json")
    with open(accuracies_file, 'w') as f:
        json.dump(accuracies, f)
    print(f"Accuracies saved to {accuracies_file}")

    summary_writer.close()

In [90]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 100
batch_size = 256
c_si = 1
log_name = f"si_conv_s_CIFAR10_c_si_{c_si}" 
model = CNN(in_channels=3, num_tasks=5, num_classes_per_task=2).to(device)

run_task_si(
    model=model,
    log_name=log_name,
    cifar_train=cifar_train,
    cifar_test=cifar_test,
    train_task_ids=train_task_ids,
    test_task_ids=test_task_ids,
    device=device,
    epochs=epochs,
    batch_size=batch_size,
    c_si=c_si
)

Training on task 0


Training Task 0:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 0: 87.05%
Test accuracy on task 0: 92.8%
Test accuracy on task 0: 94.3%
Test accuracy on task 0: 94.0%
Test accuracy on task 0: 93.9%
Test accuracy on task 0: 95.55%
Test accuracy on task 0: 95.15%
Test accuracy on task 0: 93.5%
Test accuracy on task 0: 95.75%
Test accuracy on task 0: 96.0%
Test accuracy on task 0: 95.85%
Test accuracy on task 0: 95.7%
Test accuracy on task 0: 95.7%
Test accuracy on task 0: 95.6%
Test accuracy on task 0: 96.3%
Test accuracy on task 0: 95.45%
Test accuracy on task 0: 96.3%
Test accuracy on task 0: 96.25%
Test accuracy on task 0: 96.15%
Test accuracy on task 0: 96.1%
Test accuracy on task 0: 96.1%
Test accuracy on task 0: 96.15%
Test accuracy on task 0: 96.05%
Test accuracy on task 0: 96.1%
Test accuracy on task 0: 96.15%
Test accuracy on task 0: 96.2%
Test accuracy on task 0: 96.15%
Test accuracy on task 0: 96.15%
Test accuracy on task 0: 96.2%
Test accuracy on task 0: 96.15%
Test accuracy on task 0: 96.15%
Test accuracy on task 0:

Training Task 1:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 1: 77.55%
Test accuracy on task 1: 81.6%
Test accuracy on task 1: 80.85%
Test accuracy on task 1: 83.65%
Test accuracy on task 1: 85.1%
Test accuracy on task 1: 84.6%
Test accuracy on task 1: 84.45%
Test accuracy on task 1: 78.6%
Test accuracy on task 1: 84.8%
Test accuracy on task 1: 85.75%
Test accuracy on task 1: 85.2%
Test accuracy on task 1: 85.9%
Test accuracy on task 1: 85.3%
Test accuracy on task 1: 86.2%
Test accuracy on task 1: 85.85%
Test accuracy on task 1: 86.35%
Test accuracy on task 1: 86.3%
Test accuracy on task 1: 86.35%
Test accuracy on task 1: 86.35%
Test accuracy on task 1: 86.35%
Test accuracy on task 1: 86.3%
Test accuracy on task 1: 86.25%
Test accuracy on task 1: 86.1%
Test accuracy on task 1: 86.15%
Test accuracy on task 1: 86.3%
Test accuracy on task 1: 86.15%
Test accuracy on task 1: 86.35%
Test accuracy on task 1: 86.25%
Test accuracy on task 1: 86.35%
Test accuracy on task 1: 86.25%
Test accuracy on task 1: 86.25%
Test accuracy on task

Training Task 2:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 2: 85.65%
Test accuracy on task 2: 88.8%
Test accuracy on task 2: 89.45%
Test accuracy on task 2: 84.25%
Test accuracy on task 2: 90.3%
Test accuracy on task 2: 90.05%
Test accuracy on task 2: 89.75%
Test accuracy on task 2: 89.65%
Test accuracy on task 2: 90.65%
Test accuracy on task 2: 89.75%
Test accuracy on task 2: 90.8%
Test accuracy on task 2: 91.1%
Test accuracy on task 2: 91.25%
Test accuracy on task 2: 90.85%
Test accuracy on task 2: 91.0%
Test accuracy on task 2: 91.15%
Test accuracy on task 2: 91.05%
Test accuracy on task 2: 91.0%
Test accuracy on task 2: 91.05%
Test accuracy on task 2: 91.05%
Test accuracy on task 2: 91.0%
Test accuracy on task 2: 91.05%
Test accuracy on task 2: 91.05%
Test accuracy on task 2: 91.05%
Test accuracy on task 2: 91.15%
Test accuracy on task 2: 91.0%
Test accuracy on task 2: 91.15%
Test accuracy on task 2: 91.0%
Test accuracy on task 2: 91.0%
Test accuracy on task 2: 91.0%
Test accuracy on task 2: 91.0%
Test accuracy on tas

Training Task 3:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 3: 95.4%
Test accuracy on task 3: 95.2%
Test accuracy on task 3: 96.6%
Test accuracy on task 3: 95.55%
Test accuracy on task 3: 96.4%
Test accuracy on task 3: 96.85%
Test accuracy on task 3: 96.8%
Test accuracy on task 3: 96.9%
Test accuracy on task 3: 95.4%
Test accuracy on task 3: 96.95%
Test accuracy on task 3: 96.4%
Test accuracy on task 3: 95.9%
Test accuracy on task 3: 96.85%
Test accuracy on task 3: 97.05%
Test accuracy on task 3: 96.95%
Test accuracy on task 3: 97.25%
Test accuracy on task 3: 97.2%
Test accuracy on task 3: 97.3%
Test accuracy on task 3: 97.15%
Test accuracy on task 3: 97.15%
Test accuracy on task 3: 97.15%
Test accuracy on task 3: 97.25%
Test accuracy on task 3: 97.25%
Test accuracy on task 3: 97.15%
Test accuracy on task 3: 97.25%
Test accuracy on task 3: 97.15%
Test accuracy on task 3: 97.15%
Test accuracy on task 3: 97.25%
Test accuracy on task 3: 97.25%
Test accuracy on task 3: 97.3%
Test accuracy on task 3: 97.2%
Test accuracy on task

Training Task 4:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 4: 90.65%
Test accuracy on task 4: 93.35%
Test accuracy on task 4: 94.8%
Test accuracy on task 4: 95.05%
Test accuracy on task 4: 93.5%
Test accuracy on task 4: 95.3%
Test accuracy on task 4: 95.35%
Test accuracy on task 4: 95.4%
Test accuracy on task 4: 95.7%
Test accuracy on task 4: 95.75%
Test accuracy on task 4: 95.95%
Test accuracy on task 4: 96.0%
Test accuracy on task 4: 95.95%
Test accuracy on task 4: 95.9%
Test accuracy on task 4: 95.9%
Test accuracy on task 4: 96.05%
Test accuracy on task 4: 95.95%
Test accuracy on task 4: 96.05%
Test accuracy on task 4: 96.0%
Test accuracy on task 4: 96.0%
Test accuracy on task 4: 96.1%
Test accuracy on task 4: 96.1%
Test accuracy on task 4: 95.9%
Test accuracy on task 4: 96.1%
Test accuracy on task 4: 95.9%
Test accuracy on task 4: 96.0%
Test accuracy on task 4: 95.95%
Test accuracy on task 4: 96.0%
Test accuracy on task 4: 96.0%
Test accuracy on task 4: 95.95%
Test accuracy on task 4: 96.0%
Test accuracy on task 4: 96

In [42]:
epochs = 100
batch_size = 256
c_si = 1
log_name = f"si_resnet_s_CIFAR10_c_si_{c_si}" 
model = ResNetCIFAR10(in_channels=3, num_tasks=5, num_classes_per_task=2).to(device)

run_task_si(
    model=model,
    log_name=log_name,
    cifar_train=cifar_train,
    cifar_test=cifar_test,
    train_task_ids=train_task_ids,
    test_task_ids=test_task_ids,
    device=device,
    epochs=epochs,
    batch_size=batch_size,
    c_si=c_si
)

Training on task 0


Training Task 0:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 0: 97.45%
Test accuracy on previous task 0: 97.45%
Training on task 1


Training Task 1:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 1: 90.15%
Test accuracy on previous task 0: 85.0%
Test accuracy on previous task 1: 90.15%
Training on task 2


Training Task 2:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 2: 94.15%
Test accuracy on previous task 0: 70.9%
Test accuracy on previous task 1: 77.5%
Test accuracy on previous task 2: 94.15%
Training on task 3


Training Task 3:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 3: 97.7%
Test accuracy on previous task 0: 59.85%
Test accuracy on previous task 1: 70.4%
Test accuracy on previous task 2: 86.85%
Test accuracy on previous task 3: 97.7%
Training on task 4


Training Task 4:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 4: 97.15%
Test accuracy on previous task 0: 68.75%
Test accuracy on previous task 1: 69.7%
Test accuracy on previous task 2: 83.35%
Test accuracy on previous task 3: 92.65%
Test accuracy on previous task 4: 97.15%
Accuracies saved to out/experiments/si_resnet_s_CIFAR10_c_si_1/final_accuracies.json


## Laplace Propagation Method for Split CIFAR10 Task with CNN

In [7]:
def compute_hessian_diag(model, dataloader, device, task_idx):
    model.eval()
    hessian_diag = {}
    for name, param in model.named_parameters():
        hessian_diag[name] = torch.zeros_like(param)

    for data, target in dataloader:
        data, target = data.to(device), binarize_y(target, task_idx).to(device)
        model.zero_grad()
        output = model(data, task_idx)
        loss = F.nll_loss(output, target)
        # Set allow_unused=True to handle parameters not used in the graph
        grad_params = torch.autograd.grad(loss, model.parameters(), create_graph=True, allow_unused=True)

        for grad, (name, param) in zip(grad_params, model.named_parameters()):
            if grad is not None:  # Only proceed if the gradient is not None
                grad2 = torch.autograd.grad(grad.sum(), param, retain_graph=True, allow_unused=True)[0]
                if grad2 is not None:  # Check if the second derivative is not None
                    hessian_diag[name] += grad2.data / len(dataloader.dataset)

    return hessian_diag


def run_task_lp(model, log_name, cifar_train, cifar_test, train_task_ids, test_task_ids, device, epochs, batch_size, gamma_lp):
    """
    Trains a given model on CIFAR10 split tasks using a simplified Laplace Propagation method (approximated with second-order Taylor Expansion).
    
    Args:
    - model: The model to train.
    - log_name: Name for the TensorBoard logs.
    - cifar_train: Training dataset.
    - cifar_test: Test dataset.
    - train_task_ids: Task IDs for the training data.
    - test_task_ids: Task IDs for the test data.
    - device: Torch device to use for training.
    - epochs: Number of epochs to train each task.
    - batch_size: Batch size for training and testing.
    - gamma_lp: Coefficient for the LP regularization term.
    """
    summary_writer = SummaryWriter(log_dir=os.path.join("logs", log_name, datetime.now().strftime('%Y%m%d_%H%M%S')))
    experiment_path = f"out/experiments/{log_name}"
    os.makedirs(experiment_path, exist_ok=True)  # Ensure output directory exists
    accuracies = {}
    
    # Initialize Hessian approximation (diagonal) and previous parameters
    hessian_diag = {}
    prev_params = {}

    for task_idx in range(5):
        print(f"Training on task {task_idx}")
        task_dataset = task_subset(cifar_train, train_task_ids, task_idx)
        train_loader = DataLoader(task_dataset, batch_size=batch_size, shuffle=True)

        test_dataset = task_subset(cifar_test, test_task_ids, task_idx)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

        for epoch in tqdm(range(epochs), desc=f"Training Task {task_idx}"):
            model.train()
            for data, target in train_loader:
                data, target = data.to(device), binarize_y(target, task_idx).to(device)
                optimizer.zero_grad()
                output = model(data, task_idx)
                loss = F.nll_loss(output, target)
                
                if task_idx > 0:
                    # Calculate LP regularization term
                    lp_loss = 0
                    for name, param in model.named_parameters():
                        if name in hessian_diag:
                            lp_loss += (hessian_diag[name] * (param - prev_params[name]) ** 2).sum()
                    loss += gamma_lp * lp_loss
                
                loss.backward()
                optimizer.step()
            
            # Update Hessian approximation and previous parameters after each epoch
            if task_idx > 0:
                hessian_diag = compute_hessian_diag(model, train_loader, device, task_idx)
            prev_params = {name: param.clone().detach() for name, param in model.named_parameters()}

            summary_writer.add_scalar(f'Task_{task_idx}/Train_Loss', loss.item(), epoch)
            
        # Evaluate the model on the current task's test set
        test_accuracy = test_model(model, test_loader, device, task_idx)
        print(f"Test accuracy on task {task_idx}: {test_accuracy}%")
        summary_writer.add_scalar(f'Task_{task_idx}/Test_Accuracy', test_accuracy, epoch)
        accuracies[f"TASK {task_idx}"] = test_accuracy

        # Evaluate the model on all previous tasks' test sets to measure forgetting
        for previous_task_idx in range(task_idx + 1):
            test_dataset = task_subset(cifar_test, test_task_ids, previous_task_idx)
            test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=False)
            accuracy = test_model(model, test_dataloader, device, previous_task_idx)
            print(f"Test accuracy on previous task {previous_task_idx}: {accuracy}%")
            summary_writer.add_scalar(f"Cross_Task_Accuracy/task_{task_idx}_on_{previous_task_idx}", accuracy, global_step=task_idx)
            accuracies[f"TASK {previous_task_idx}"] = accuracy

    # Save accuracies to file
    accuracies_file = os.path.join(experiment_path, "final_accuracies.json")
    with open(accuracies_file, 'w') as f:
        json.dump(accuracies, f)
    print(f"Accuracies saved to {accuracies_file}")
    
    summary_writer.close()

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 100 
batch_size = 256 
gamma_lp = 0.05
log_name = f"lp_conv_CIFAR10_gamma_lp_{gamma_lp}" 

model = CNN(in_channels=3, num_tasks=5, num_classes_per_task=2).to(device)

run_task_lp(
    model=model,
    log_name=log_name,
    cifar_train=cifar_train,
    cifar_test=cifar_test,
    train_task_ids=train_task_ids,
    test_task_ids=test_task_ids,
    device=device,
    epochs=epochs,
    batch_size=batch_size,
    gamma_lp=gamma_lp
)

Training on task 0


Training Task 0:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 0: 95.9%
Test accuracy on previous task 0: 95.9%
Training on task 1


Training Task 1:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 1: 64.4%
Test accuracy on previous task 0: 86.0%
Test accuracy on previous task 1: 64.4%
Training on task 2


Training Task 2:   0%|          | 0/100 [00:00<?, ?it/s]

Test accuracy on task 2: 59.2%
Test accuracy on previous task 0: 80.0%
Test accuracy on previous task 1: 56.0%
Test accuracy on previous task 2: 59.2%
Training on task 3


Training Task 3:   0%|          | 0/100 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 100 
batch_size = 256 
gamma_lp = 0.05
log_name = f"lp_resnet_CIFAR10_gamma_lp_{gamma_lp}" 

model = ResNetCIFAR10(in_channels=3, num_tasks=5, num_classes_per_task=2).to(device)

run_task_lp(
    model=model,
    log_name=log_name,
    cifar_train=cifar_train,
    cifar_test=cifar_test,
    train_task_ids=train_task_ids,
    test_task_ids=test_task_ids,
    device=device,
    epochs=epochs,
    batch_size=batch_size,
    gamma_lp=gamma_lp
)