In [None]:
!pip install loguru



In [None]:
import torch
import torch.nn as nn


import torch
import torch.nn as nn


class BaseModel(nn.Module):
    def __init__(self):
        super(BaseModel, self).__init__()

    def compute_fisher_information(self, dataloader, device):
        """Compute Fisher Information matrix for EWC"""
        fisher_info = {}

        self.train()
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            output = self(data)
            loss = nn.CrossEntropyLoss()(output, target)

            # Zero gradients before backward pass
            # self.zero_grad()
            loss.backward()

        for name, param in self.named_parameters():
            fisher_info[name] = param.grad.data.clone().pow(2)

        # Normalize by the number of batches
        # fisher_info = {k: v / len(dataloader) for k, v in fisher_info.items()}
        return fisher_info

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from loguru import logger


class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()

        # The first split convolution layer
        self.conv1_0 = nn.Conv2d(
            in_channels,
            out_channels // 2,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=False,
        )
        self.conv1_1 = nn.Conv2d(
            in_channels,
            out_channels // 2,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=False,
        )
        self.bn1 = nn.BatchNorm2d(out_channels)

        # The second split convolution layer
        self.conv2_0 = nn.Conv2d(
            out_channels,
            out_channels // 2,
            kernel_size=3,
            stride=1,
            padding=1,
            bias=False,
        )
        self.conv2_1 = nn.Conv2d(
            out_channels,
            out_channels // 2,
            kernel_size=3,
            stride=1,
            padding=1,
            bias=False,
        )
        self.bn2 = nn.BatchNorm2d(out_channels)

        # The identity shortcut connection
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels),
            )

    def forward(self, x):
        # First split convolution
        out1_0 = self.conv1_0(x)
        out1_1 = self.conv1_1(x)
        out1 = torch.cat((out1_0, out1_1), dim=1)  # Concatenate along channel dimension
        out1 = F.relu(self.bn1(out1))

        # Second split convolution
        out2_0 = self.conv2_0(out1)
        out2_1 = self.conv2_1(out1)
        out2 = torch.cat((out2_0, out2_1), dim=1)  # Concatenate along channel dimension
        out2 = self.bn2(out2)

        # Adding the shortcut (skip connection)
        out = out2 + self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet18(BaseModel):
    def __init__(self, speed=False, num_classes=10):
        super(ResNet18, self).__init__()
        self.speed = speed
        if self.speed:
            logger.info("Speed ResNet18 version with 256 channels")

        # Initial split convolution layer
        self.conv1_0 = nn.Conv2d(1, 32, kernel_size=7, stride=2, padding=3, bias=False)
        self.conv1_1 = nn.Conv2d(1, 32, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # Layer blocks
        self.layer1 = self._make_layer(64, 64, stride=1)
        self.layer2 = self._make_layer(64, 128, stride=2)
        self.layer3 = self._make_layer(128, 256, stride=2) if not self.speed else None
        self.layer4 = self._make_layer(256, 512, stride=2) if not self.speed else None

        # Fully connected layer
        self.fc = nn.Linear(128 if self.speed else 512, num_classes)

    def _make_layer(self, in_channels, out_channels, stride):
        layers = []
        layers.append(BasicBlock(in_channels, out_channels, stride))
        layers.append(BasicBlock(out_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        # Initial part with split convolutions
        out1_0 = self.conv1_0(x)
        out1_1 = self.conv1_1(x)
        x = torch.cat((out1_0, out1_1), dim=1)  # Concatenate along channel dimension
        x = F.relu(self.bn1(x))
        x = self.maxpool(x)

        # Layers
        x = self.layer1(x)
        x = self.layer2(x)
        if not self.speed:
            x = self.layer3(x)
            x = self.layer4(x)

        # Global Average Pooling
        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = torch.flatten(x, 1)

        # Fully connected layer
        x = self.fc(x)

        return x

    def freeze_0_filters(self):
        for name, param in self.named_parameters():
            if name.endswith("_0.weight"):
                param.requires_grad = False
        logger.info("Frozen _0 filters")

    def unfreeze_0_filters(self):
        for name, param in self.named_parameters():
            if name.endswith("_0.weight"):
                param.requires_grad = True
        logger.info("Unfrozen _0 filters")

    def freeze_1_filters(self):
        for name, param in self.named_parameters():
            if name.endswith("_1.weight"):
                param.requires_grad = False
        logger.info("Frozen _1 filters")

    def unfreeze_1_filters(self):
        for name, param in self.named_parameters():
            if name.endswith("_1.weight"):
                param.requires_grad = True
        logger.info("Unfrozen _1 filters")


In [None]:
import torch
import torch.nn as nn
from tqdm import tqdm
import wandb



class BaseStrategy:
    def __init__(self, model, optimizer, device, nr_epochs=2):
        self.model = model
        self.optimizer: torch.optim.Optimizer = optimizer
        self.device = device
        self.current_task = 0
        self.epoch = 0
        self.nr_epochs = nr_epochs

    def train_epoch(self, dataloader, **kwargs):
        self.epoch += 1
        self.model.train()
        total_loss = 0
        correct = 0
        total = 0

        for data, target in dataloader:
            data, target = data.to(self.device), target.to(self.device)

            self.optimizer.zero_grad()
            output = self.model(data)
            loss = self.compute_loss(output, target)
            if isinstance(loss, tuple):
                loss = loss[0]

            loss.backward()
            self.optimizer.step()

            total_loss += loss.item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)

        return total_loss / len(dataloader), 100.0 * correct / total

    def evaluate(self, dataloader):
        self.model.eval()
        total_loss = 0
        correct = 0
        total = 0

        with torch.no_grad():
            for data, target in dataloader:
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                loss = nn.CrossEntropyLoss()(output, target)
                total_loss += loss.item()
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()
                total += target.size(0)

        return total_loss / len(dataloader), 100.0 * correct / total

    def compute_loss(self, output, target):
        return nn.CrossEntropyLoss()(output, target)

    def on_task_complete(self, dataloader):
        self.current_task += 1
        self.epoch = 0


In [None]:
from loguru import logger
import torch
import torchvision
import numpy as np
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import yaml


class PermutedMNIST(Dataset):
    def __init__(self, task_id, train=True, transform=None):
        super(PermutedMNIST, self).__init__()
        self.task_id = task_id
        self.train = train
        self.transform = transform

        # Load MNIST dataset
        mnist = torchvision.datasets.MNIST(
            root="./data", train=train, download=True, transform=transforms.ToTensor()
        )

        self.data = mnist.data
        self.targets = mnist.targets

        # Generate permutation for this task
        if task_id > 0:  # First task uses original MNIST
            np.random.seed(task_id)
            self.permutation = np.random.permutation(784)
            # Apply permutation
            self.data = self.data.reshape(-1, 784)
            self.data = self.data[:, self.permutation].reshape(-1, 28, 28)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img, target = self.data[idx], self.targets[idx]
        # Convert to float and normalize
        img = img.float() / 255.0

        img = img.unsqueeze(0)
        if self.transform:
            img = self.transform(img)

        return img, target  # Add channel dimension


class SequentialMNIST(Dataset):
    def __init__(self, task_id, train=True, transform=None):
        super(SequentialMNIST, self).__init__()
        self.task_id = task_id % 10
        self.train = train
        self.transform = transform

        # Load MNIST dataset
        mnist = torchvision.datasets.MNIST(
            root="./data", train=train, download=True, transform=transforms.ToTensor()
        )

        self.data = mnist.data
        self.targets = mnist.targets

        # Filter data for the current task and the randomly chosen different task
        self.data = self.data[(self.targets == self.task_id)]
        self.targets = self.targets[(self.targets == self.task_id)]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img, target = self.data[idx], self.targets[idx]
        # Convert to float and normalize
        img = img.float() / 255.0

        img = img.unsqueeze(0)
        if self.transform:
            img = self.transform(img)

        return img, target

def get_datasets(task_id):
    transform = transforms.Compose(
        [transforms.Normalize((0.1307,), (0.3081,))]  # MNIST normalization
    )
    dataset_name = "permuted_mnist"

    if dataset_name == "sequential_mnist":
        train_dataset = SequentialMNIST(task_id=task_id, train=True, transform=transform)
        test_dataset = SequentialMNIST(task_id=task_id, train=False, transform=transform)
        logger.info(f"Using Sequential MNIST for task {task_id}")
    else:
        # Default to permuted MNIST
        train_dataset = PermutedMNIST(task_id=task_id, train=True, transform=transform)
        test_dataset = PermutedMNIST(task_id=task_id, train=False, transform=transform)
        logger.info(f"Using Permuted MNIST for task {task_id}")

    return train_dataset, test_dataset

def get_dataloaders(task_id, batch_size=256):
    train_dataset, test_dataset = get_datasets(task_id)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader

In [None]:
import torch
import wandb
from tqdm.notebook import tqdm

# SETUP
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NR_TASKS = 6
BATCH_SIZE = 128
NR_EPOCHS = 3
# OPTIMIZER = lambda model : torch.optim.Adam(model.parameters(), lr=0.001)
OPTIMIZER = lambda model : torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
SPEED = False # True

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, Subset


class RehearsalStrategy(BaseStrategy):
    def __init__(self, model, optimizer, device, memory_size=10_000, **kwargs):
        super().__init__(model, optimizer, device, **kwargs)
        self.memory_size = memory_size
        self.memory_data = None  # TensorDataset for memory
        self.current_memory_size = 0  # Track size of memory

    def update_memory(self, dataloader):
        print("Updating memory with new samples")
        # Collect all data and targets from the dataloader
        data_list, target_list = [], []
        for data, target in dataloader:
            data_list.append(data)
            target_list.append(target)

        data = torch.cat(data_list)
        targets = torch.cat(target_list)

        # Randomly sample `memory_size` examples
        indices = torch.randperm(len(data))[: self.memory_size]

        sampled_data = data[indices]
        sampled_targets = targets[indices]

        if self.memory_data is None:
            # Initialize memory if empty
            self.memory_data = TensorDataset(sampled_data, sampled_targets)
        else:
            # Combine existing memory with new samples
            existing_data, existing_targets = self.memory_data[:]
            all_data = torch.cat([existing_data, sampled_data])
            all_targets = torch.cat([existing_targets, sampled_targets])

            # Ensure memory does not exceed capacity
            final_indices = torch.randperm(len(all_data))[: self.memory_size]
            final_data = all_data[final_indices]
            final_targets = all_targets[final_indices]

            self.memory_data = TensorDataset(final_data, final_targets)

        self.current_memory_size = len(self.memory_data)

    def get_memory_loader(self, batch_size):
        # Return DataLoader for the memory
        if self.memory_data is not None:
            return DataLoader(self.memory_data, batch_size=batch_size, shuffle=True)
        return None


In [None]:
model = ResNet18(speed=SPEED).to(device)
optimizer = OPTIMIZER(model)
memory_size = 1000
strategy_interleaved = RehearsalStrategy(model, optimizer, device, memory_size)
rehearsal_interleaved_accs = []
if len(rehearsal_interleaved_accs) == 0:
    # Training loop
    for task_id in tqdm(range(NR_TASKS), desc="Tasks", unit="task"):
        print(f"\nTraining on task {task_id}")
        train_loader, test_loader = get_dataloaders(task_id, batch_size=BATCH_SIZE)
        # Train on current task
        epoch_progress_bar = tqdm(range(NR_EPOCHS),desc=f"Task {task_id} Epochs",unit="epoch")

        for epoch in epoch_progress_bar:
            # Combine memory data with current task data
            memory_loader = strategy_interleaved.get_memory_loader(batch_size=BATCH_SIZE)
            combined_loader = train_loader
            if memory_loader:
                combined_loader = torch.utils.data.DataLoader(
                    train_loader.dataset + memory_loader.dataset,  # Combine datasets
                    batch_size=BATCH_SIZE,
                    shuffle=True,
                )

            # Train for one epoch
            train_loss, train_acc = strategy_interleaved.train_epoch(combined_loader)
            epoch_progress_bar.set_postfix({"Train Loss": train_loss, "Train Acc": train_acc})
            # Evaluate on current task
            test_loss, test_acc = strategy_interleaved.evaluate(test_loader)
            epoch_progress_bar.set_postfix({"Test Loss": test_loss, "Test Acc": test_acc})

            print(f"Strategy : Rehearsal | Task : {task_id} | Epoch : {epoch} | Train Loss : {train_loss} | Train Acc : {train_acc} | Test Loss : {test_loss} | Test Acc : {test_acc}")

        # Evaluate on all seen tasks
        avg_acc = 0
        for eval_task_id in range(NR_TASKS):
            _, test_loader = get_dataloaders(eval_task_id, batch_size=BATCH_SIZE)
            _, test_acc = strategy_interleaved.evaluate(test_loader)
            avg_acc += test_acc
        avg_acc /= (NR_TASKS)
        rehearsal_interleaved_accs.append(avg_acc)
        print(f"Average total accuracy after training on {task_id}: {avg_acc}")

        # Update memory with current task data
        strategy_interleaved.update_memory(train_loader)

Tasks:   0%|          | 0/6 [00:00<?, ?task/s]

[32m2024-12-21 09:31:29.075[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 0[0m



Training on task 0


Task 0 Epochs:   0%|          | 0/3 [00:00<?, ?epoch/s]

Strategy : Rehearsal | Task : 0 | Epoch : 0 | Train Loss : 0.11849934698692136 | Train Acc : 96.325 | Test Loss : 0.055186214354301835 | Test Acc : 98.35
Strategy : Rehearsal | Task : 0 | Epoch : 1 | Train Loss : 0.038990021607320284 | Train Acc : 98.83166666666666 | Test Loss : 0.03372054585753763 | Test Acc : 98.96


[32m2024-12-21 09:32:40.403[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 0[0m


Strategy : Rehearsal | Task : 0 | Epoch : 2 | Train Loss : 0.028390054698975913 | Train Acc : 99.15666666666667 | Test Loss : 0.031189459640504364 | Test Acc : 99.0


[32m2024-12-21 09:32:42.700[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 1[0m
[32m2024-12-21 09:32:44.925[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 2[0m
[32m2024-12-21 09:32:46.752[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 3[0m
[32m2024-12-21 09:32:48.631[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 4[0m
[32m2024-12-21 09:32:50.364[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 5[0m


Average total accuracy after training on 0: 25.88166666666667
Updating memory with new samples

Training on task 1


[32m2024-12-21 09:32:58.101[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 1[0m


Task 1 Epochs:   0%|          | 0/3 [00:00<?, ?epoch/s]

Strategy : Rehearsal | Task : 1 | Epoch : 0 | Train Loss : 0.2940416549502304 | Train Acc : 91.5672131147541 | Test Loss : 0.2046888233457185 | Test Acc : 93.34
Strategy : Rehearsal | Task : 1 | Epoch : 1 | Train Loss : 0.12161532237313578 | Train Acc : 96.24098360655738 | Test Loss : 0.14062785565310829 | Test Acc : 95.52


[32m2024-12-21 09:34:05.803[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 0[0m


Strategy : Rehearsal | Task : 1 | Epoch : 2 | Train Loss : 0.08419440261457327 | Train Acc : 97.32622950819672 | Test Loss : 0.19372096134326125 | Test Acc : 93.79


[32m2024-12-21 09:34:07.622[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 1[0m
[32m2024-12-21 09:34:10.207[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 2[0m
[32m2024-12-21 09:34:12.209[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 3[0m
[32m2024-12-21 09:34:14.085[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 4[0m
[32m2024-12-21 09:34:15.951[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 5[0m


Average total accuracy after training on 1: 38.27333333333334
Updating memory with new samples

Training on task 2


[32m2024-12-21 09:34:24.125[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 2[0m


Task 2 Epochs:   0%|          | 0/3 [00:00<?, ?epoch/s]

Strategy : Rehearsal | Task : 2 | Epoch : 0 | Train Loss : 0.26624497208955156 | Train Acc : 92.04426229508196 | Test Loss : 0.15499910570185962 | Test Acc : 94.88
Strategy : Rehearsal | Task : 2 | Epoch : 1 | Train Loss : 0.10798393278548678 | Train Acc : 96.55737704918033 | Test Loss : 0.1116158143540585 | Test Acc : 96.58


[32m2024-12-21 09:35:31.203[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 0[0m


Strategy : Rehearsal | Task : 2 | Epoch : 2 | Train Loss : 0.07558670444355446 | Train Acc : 97.6344262295082 | Test Loss : 0.12110553460219238 | Test Acc : 96.13


[32m2024-12-21 09:35:32.955[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 1[0m
[32m2024-12-21 09:35:35.129[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 2[0m
[32m2024-12-21 09:35:37.381[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 3[0m
[32m2024-12-21 09:35:39.286[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 4[0m
[32m2024-12-21 09:35:41.053[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 5[0m


Average total accuracy after training on 2: 50.47333333333333
Updating memory with new samples

Training on task 3


[32m2024-12-21 09:35:48.739[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 3[0m


Task 3 Epochs:   0%|          | 0/3 [00:00<?, ?epoch/s]

Strategy : Rehearsal | Task : 3 | Epoch : 0 | Train Loss : 0.25413939696531623 | Train Acc : 92.34754098360656 | Test Loss : 0.1276201983395068 | Test Acc : 95.94
Strategy : Rehearsal | Task : 3 | Epoch : 1 | Train Loss : 0.10097948289173704 | Train Acc : 96.76065573770492 | Test Loss : 0.10503485362794061 | Test Acc : 96.55


[32m2024-12-21 09:36:55.973[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 0[0m


Strategy : Rehearsal | Task : 3 | Epoch : 2 | Train Loss : 0.06689163722464936 | Train Acc : 97.84098360655737 | Test Loss : 0.0988243220556716 | Test Acc : 96.92


[32m2024-12-21 09:36:57.743[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 1[0m
[32m2024-12-21 09:36:59.645[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 2[0m
[32m2024-12-21 09:37:01.480[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 3[0m
[32m2024-12-21 09:37:03.859[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 4[0m
[32m2024-12-21 09:37:05.980[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 5[0m


Average total accuracy after training on 3: 61.88333333333333
Updating memory with new samples

Training on task 4


[32m2024-12-21 09:37:13.060[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 4[0m


Task 4 Epochs:   0%|          | 0/3 [00:00<?, ?epoch/s]

Strategy : Rehearsal | Task : 4 | Epoch : 0 | Train Loss : 0.25047870119350263 | Train Acc : 92.59508196721312 | Test Loss : 0.12435109010614644 | Test Acc : 96.15
Strategy : Rehearsal | Task : 4 | Epoch : 1 | Train Loss : 0.0973145296676104 | Train Acc : 96.90819672131147 | Test Loss : 0.10118462476228023 | Test Acc : 96.8


[32m2024-12-21 09:38:19.962[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 0[0m


Strategy : Rehearsal | Task : 4 | Epoch : 2 | Train Loss : 0.06519171892932814 | Train Acc : 97.90163934426229 | Test Loss : 0.10015462992986947 | Test Acc : 96.83


[32m2024-12-21 09:38:21.745[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 1[0m
[32m2024-12-21 09:38:23.576[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 2[0m
[32m2024-12-21 09:38:25.313[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 3[0m
[32m2024-12-21 09:38:27.102[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 4[0m
[32m2024-12-21 09:38:29.292[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 5[0m


Average total accuracy after training on 4: 69.43166666666667
Updating memory with new samples

Training on task 5


[32m2024-12-21 09:38:36.683[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 5[0m


Task 5 Epochs:   0%|          | 0/3 [00:00<?, ?epoch/s]

Strategy : Rehearsal | Task : 5 | Epoch : 0 | Train Loss : 0.24369042152753667 | Train Acc : 92.6344262295082 | Test Loss : 0.13260704555395472 | Test Acc : 95.69
Strategy : Rehearsal | Task : 5 | Epoch : 1 | Train Loss : 0.10058990408875344 | Train Acc : 96.76229508196721 | Test Loss : 0.12073434032853457 | Test Acc : 96.03


[32m2024-12-21 09:39:44.187[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 0[0m


Strategy : Rehearsal | Task : 5 | Epoch : 2 | Train Loss : 0.0657000040432411 | Train Acc : 97.8672131147541 | Test Loss : 0.10458613900007986 | Test Acc : 96.5


[32m2024-12-21 09:39:46.241[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 1[0m
[32m2024-12-21 09:39:48.106[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 2[0m
[32m2024-12-21 09:39:49.986[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 3[0m
[32m2024-12-21 09:39:51.794[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 4[0m
[32m2024-12-21 09:39:53.600[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 5[0m


Average total accuracy after training on 5: 72.77999999999999
Updating memory with new samples


In [None]:
############################## REHEARSAL ##############################

model = ResNet18(speed=SPEED).to(device)
optimizer = OPTIMIZER(model)
memory_size = 1000
strategy = RehearsalStrategy(model, optimizer, device, memory_size)
rehearsal_accs = []
if len(rehearsal_accs) == 0:
    # Training loop
    for task_id in tqdm(range(NR_TASKS), desc="Tasks", unit="task"):
        print(f"\nTraining on task {task_id}")
        train_loader, test_loader = get_dataloaders(task_id, batch_size=BATCH_SIZE)
        # Train on current task
        epoch_progress_bar = tqdm(range(NR_EPOCHS),desc=f"Task {task_id} Epochs",unit="epoch")

        for epoch in epoch_progress_bar:
            # Combine memory data with current task data
            memory_loader = strategy.get_memory_loader(batch_size=BATCH_SIZE)
            combined_loader = train_loader
            if memory_loader:
                train_loss, train_acc = strategy.train_epoch(memory_loader)
                print(f"Strategy : Rehearsal (Replay) | Task : {task_id} | Epoch : {epoch} | Train Loss : {train_loss} | Train Acc : {train_acc}")
            # Train for one epoch
            train_loss, train_acc = strategy.train_epoch(train_loader)
            epoch_progress_bar.set_postfix({"Train Loss": train_loss, "Train Acc": train_acc})
            # Evaluate on current task
            test_loss, test_acc = strategy.evaluate(test_loader)
            epoch_progress_bar.set_postfix({"Test Loss": test_loss, "Test Acc": test_acc})

            print(f"Strategy : Rehearsal | Task : {task_id} | Epoch : {epoch} | Train Loss : {train_loss} | Train Acc : {train_acc} | Test Loss : {test_loss} | Test Acc : {test_acc}")

        # Evaluate on all seen tasks
        avg_acc = 0
        for eval_task_id in range(NR_TASKS):
            _, test_loader = get_dataloaders(eval_task_id, batch_size=BATCH_SIZE)
            _, test_acc = strategy.evaluate(test_loader)
            avg_acc += test_acc
        avg_acc /= (NR_TASKS)
        rehearsal_accs.append(avg_acc)
        print(f"Average total accuracy after training on {task_id}: {avg_acc}")

        # Update memory with current task data
        strategy.update_memory(train_loader)

Tasks:   0%|          | 0/6 [00:00<?, ?task/s]

[32m2024-12-21 09:40:01.142[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 0[0m



Training on task 0


Task 0 Epochs:   0%|          | 0/3 [00:00<?, ?epoch/s]

Strategy : Rehearsal | Task : 0 | Epoch : 0 | Train Loss : 0.119804160769727 | Train Acc : 96.355 | Test Loss : 0.05702507843652481 | Test Acc : 98.14
Strategy : Rehearsal | Task : 0 | Epoch : 1 | Train Loss : 0.042729227466564344 | Train Acc : 98.62 | Test Loss : 0.03882096827349219 | Test Acc : 98.68


[32m2024-12-21 09:41:11.707[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 0[0m


Strategy : Rehearsal | Task : 0 | Epoch : 2 | Train Loss : 0.028684640054537386 | Train Acc : 99.10833333333333 | Test Loss : 0.03248680669998918 | Test Acc : 98.85


[32m2024-12-21 09:41:13.502[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 1[0m
[32m2024-12-21 09:41:15.296[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 2[0m
[32m2024-12-21 09:41:17.038[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 3[0m
[32m2024-12-21 09:41:18.780[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 4[0m
[32m2024-12-21 09:41:20.532[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 5[0m


Average total accuracy after training on 0: 26.464999999999993
Updating memory with new samples

Training on task 1


[32m2024-12-21 09:41:28.233[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 1[0m


Task 1 Epochs:   0%|          | 0/3 [00:00<?, ?epoch/s]

Strategy : Rehearsal (Replay) | Task : 1 | Epoch : 0 | Train Loss : 0.035302827367559075 | Train Acc : 99.1
Strategy : Rehearsal | Task : 1 | Epoch : 0 | Train Loss : 0.2931406260140415 | Train Acc : 91.575 | Test Loss : 0.21667071936439863 | Test Acc : 92.88
Strategy : Rehearsal (Replay) | Task : 1 | Epoch : 1 | Train Loss : 0.6433165855705738 | Train Acc : 82.7
Strategy : Rehearsal | Task : 1 | Epoch : 1 | Train Loss : 0.1369765876158913 | Train Acc : 95.71333333333334 | Test Loss : 0.14843807354145036 | Test Acc : 95.27
Strategy : Rehearsal (Replay) | Task : 1 | Epoch : 2 | Train Loss : 0.12200849549844861 | Train Acc : 95.3


[32m2024-12-21 09:42:34.986[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 0[0m


Strategy : Rehearsal | Task : 1 | Epoch : 2 | Train Loss : 0.08260949739594577 | Train Acc : 97.335 | Test Loss : 0.13062839270598856 | Test Acc : 95.81


[32m2024-12-21 09:42:37.195[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 1[0m
[32m2024-12-21 09:42:38.969[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 2[0m
[32m2024-12-21 09:42:40.731[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 3[0m
[32m2024-12-21 09:42:42.578[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 4[0m
[32m2024-12-21 09:42:44.378[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 5[0m


Average total accuracy after training on 1: 36.58166666666666
Updating memory with new samples

Training on task 2


[32m2024-12-21 09:42:52.067[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 2[0m


Task 2 Epochs:   0%|          | 0/3 [00:00<?, ?epoch/s]

Strategy : Rehearsal (Replay) | Task : 2 | Epoch : 0 | Train Loss : 0.09424901893362403 | Train Acc : 96.9
Strategy : Rehearsal | Task : 2 | Epoch : 0 | Train Loss : 0.25666806147074395 | Train Acc : 92.42 | Test Loss : 0.13905348216737562 | Test Acc : 95.47
Strategy : Rehearsal (Replay) | Task : 2 | Epoch : 1 | Train Loss : 1.0197686441242695 | Train Acc : 69.6
Strategy : Rehearsal | Task : 2 | Epoch : 1 | Train Loss : 0.12746969443251457 | Train Acc : 95.94166666666666 | Test Loss : 0.11168256185100048 | Test Acc : 96.37
Strategy : Rehearsal (Replay) | Task : 2 | Epoch : 2 | Train Loss : 0.34582050517201424 | Train Acc : 88.1


[32m2024-12-21 09:43:57.904[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 0[0m


Strategy : Rehearsal | Task : 2 | Epoch : 2 | Train Loss : 0.07743737899037058 | Train Acc : 97.42666666666666 | Test Loss : 0.10440365986418144 | Test Acc : 96.69


[32m2024-12-21 09:43:59.981[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 1[0m
[32m2024-12-21 09:44:02.192[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 2[0m
[32m2024-12-21 09:44:03.936[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 3[0m
[32m2024-12-21 09:44:05.736[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 4[0m
[32m2024-12-21 09:44:07.520[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 5[0m


Average total accuracy after training on 2: 45.705000000000005
Updating memory with new samples

Training on task 3


[32m2024-12-21 09:44:15.087[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 3[0m


Task 3 Epochs:   0%|          | 0/3 [00:00<?, ?epoch/s]

Strategy : Rehearsal (Replay) | Task : 3 | Epoch : 0 | Train Loss : 0.2067897692322731 | Train Acc : 93.4
Strategy : Rehearsal | Task : 3 | Epoch : 0 | Train Loss : 0.23674081754423917 | Train Acc : 93.01333333333334 | Test Loss : 0.14073704510832888 | Test Acc : 95.4
Strategy : Rehearsal (Replay) | Task : 3 | Epoch : 1 | Train Loss : 0.8351709023118019 | Train Acc : 73.4
Strategy : Rehearsal | Task : 3 | Epoch : 1 | Train Loss : 0.1122301912773202 | Train Acc : 96.425 | Test Loss : 0.10757244215641595 | Test Acc : 96.58
Strategy : Rehearsal (Replay) | Task : 3 | Epoch : 2 | Train Loss : 0.291321599856019 | Train Acc : 89.6


[32m2024-12-21 09:45:21.489[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 0[0m


Strategy : Rehearsal | Task : 3 | Epoch : 2 | Train Loss : 0.07315888147411952 | Train Acc : 97.62166666666667 | Test Loss : 0.10451273761300629 | Test Acc : 96.71


[32m2024-12-21 09:45:23.276[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 1[0m
[32m2024-12-21 09:45:25.089[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 2[0m
[32m2024-12-21 09:45:27.575[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 3[0m
[32m2024-12-21 09:45:29.465[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 4[0m
[32m2024-12-21 09:45:31.208[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 5[0m


Average total accuracy after training on 3: 56.72333333333333
Updating memory with new samples

Training on task 4


[32m2024-12-21 09:45:38.541[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 4[0m


Task 4 Epochs:   0%|          | 0/3 [00:00<?, ?epoch/s]

Strategy : Rehearsal (Replay) | Task : 4 | Epoch : 0 | Train Loss : 0.1496552093885839 | Train Acc : 95.8
Strategy : Rehearsal | Task : 4 | Epoch : 0 | Train Loss : 0.23389319758981403 | Train Acc : 92.94 | Test Loss : 0.11555357783726311 | Test Acc : 96.16
Strategy : Rehearsal (Replay) | Task : 4 | Epoch : 1 | Train Loss : 0.9893712624907494 | Train Acc : 71.0
Strategy : Rehearsal | Task : 4 | Epoch : 1 | Train Loss : 0.10647865657263728 | Train Acc : 96.46333333333334 | Test Loss : 0.10036590901613117 | Test Acc : 96.66
Strategy : Rehearsal (Replay) | Task : 4 | Epoch : 2 | Train Loss : 0.2311564888805151 | Train Acc : 92.2


[32m2024-12-21 09:46:44.197[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 0[0m


Strategy : Rehearsal | Task : 4 | Epoch : 2 | Train Loss : 0.062471061982691034 | Train Acc : 97.97166666666666 | Test Loss : 0.10244664951239395 | Test Acc : 96.62


[32m2024-12-21 09:46:45.914[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 1[0m
[32m2024-12-21 09:46:47.750[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 2[0m
[32m2024-12-21 09:46:49.464[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 3[0m
[32m2024-12-21 09:46:51.889[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 4[0m
[32m2024-12-21 09:46:53.882[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 5[0m


Average total accuracy after training on 4: 60.22666666666667
Updating memory with new samples

Training on task 5


[32m2024-12-21 09:47:00.703[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 5[0m


Task 5 Epochs:   0%|          | 0/3 [00:00<?, ?epoch/s]

Strategy : Rehearsal (Replay) | Task : 5 | Epoch : 0 | Train Loss : 0.10279931081458926 | Train Acc : 96.5
Strategy : Rehearsal | Task : 5 | Epoch : 0 | Train Loss : 0.23468547356503605 | Train Acc : 92.97 | Test Loss : 0.12181167850636324 | Test Acc : 96.0
Strategy : Rehearsal (Replay) | Task : 5 | Epoch : 1 | Train Loss : 0.8545159287750721 | Train Acc : 73.4
Strategy : Rehearsal | Task : 5 | Epoch : 1 | Train Loss : 0.10078330851122261 | Train Acc : 96.835 | Test Loss : 0.11655685901523957 | Test Acc : 96.23
Strategy : Rehearsal (Replay) | Task : 5 | Epoch : 2 | Train Loss : 0.2780282683670521 | Train Acc : 90.8


[32m2024-12-21 09:48:06.947[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 0[0m


Strategy : Rehearsal | Task : 5 | Epoch : 2 | Train Loss : 0.06348933807131388 | Train Acc : 97.94833333333334 | Test Loss : 0.09168427292614162 | Test Acc : 97.06


[32m2024-12-21 09:48:08.667[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 1[0m
[32m2024-12-21 09:48:10.443[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 2[0m
[32m2024-12-21 09:48:12.198[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 3[0m
[32m2024-12-21 09:48:14.027[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 4[0m
[32m2024-12-21 09:48:16.301[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_datasets[0m:[36m95[0m - [1mUsing Permuted MNIST for task 5[0m


Average total accuracy after training on 5: 63.43833333333333
Updating memory with new samples


In [None]:

model = ResNet18(speed=SPEED).to(device)
optimizer = OPTIMIZER(model)
memory_size = 1000
strategy_lstw = RehearsalStrategy(model, optimizer, device, memory_size)
rehearsal_lstw_accs = []
if len(rehearsal_lstw_accs) == 0:
    # Training loop
    for task_id in tqdm(range(NR_TASKS), desc="Tasks", unit="task"):
        print(f"\nTraining on task {task_id}")
        train_loader, test_loader = get_dataloaders(task_id, batch_size=BATCH_SIZE)
        # Train on current task
        epoch_progress_bar = tqdm(range(NR_EPOCHS),desc=f"Task {task_id} Epochs",unit="epoch")

        for epoch in epoch_progress_bar:
            # Combine memory data with current task data
            memory_loader = strategy_lstw.get_memory_loader(batch_size=BATCH_SIZE)
            if memory_loader:
                train_loss, train_acc = strategy_lstw.train_epoch(memory_loader)
                print(f"Strategy : LSTW (Replay) | Task : {task_id} | Epoch : {epoch} | Train Loss : {train_loss} | Train Acc : {train_acc}")
            # Train for one epoch
            # freeze
            model.freeze_0_filters() # FREEZE
            train_loss, train_acc = strategy_lstw.train_epoch(train_loader)
            model.unfreeze_0_filters()
            epoch_progress_bar.set_postfix({"Train Loss": train_loss, "Train Acc": train_acc})
            # Evaluate on current task
            test_loss, test_acc = strategy_lstw.evaluate(test_loader)
            epoch_progress_bar.set_postfix({"Test Loss": test_loss, "Test Acc": test_acc})

            print(f"Strategy : LSTW | Task : {task_id} | Epoch : {epoch} | Train Loss : {train_loss} | Train Acc : {train_acc} | Test Loss : {test_loss} | Test Acc : {test_acc}")

        # Evaluate on all seen tasks
        avg_acc = 0
        for eval_task_id in range(NR_TASKS):
            _, test_loader = get_dataloaders(eval_task_id, batch_size=BATCH_SIZE)
            _, test_acc = strategy_lstw.evaluate(test_loader)
            avg_acc += test_acc
        avg_acc /= (NR_TASKS)
        rehearsal_lstw_accs.append(avg_acc)
        print(f"Average total accuracy after training on {task_id}: {avg_acc}")

        # Update memory with current task data
        strategy_lstw.update_memory(train_loader)

In [None]:
import matplotlib.pyplot as plt

# Plot average accuracy for each strategy
plt.figure(figsize=(10, 6))
plt.plot(range(1, NR_TASKS + 1), rehearsal_accs, marker='o', label='Rehearsal')
plt.plot(range(1, NR_TASKS + 1), rehearsal_lstw_accs, marker='o', label='LSTW')
plt.plot(range(1, NR_TASKS + 1), rehearsal_interleaved_accs, marker='o', label='Rehearsal (Interleaved)')

plt.xlabel('Task')
plt.ylabel('Average Accuracy')
plt.title('Comparison of Continual Learning Strategies')
plt.legend()
plt.grid(True)
plt.show()