## Overview

This notebook contains the replay+ewc approach for training coarse-grained and fine-grained datasets in a continual learning classification setting.

In [None]:
# seeds for reproducibility
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
import torch
import torch.backends.cudnn as cudnn
import random
import numpy as np

# seeds
seed = 88
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
cudnn.deterministic = True
cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
g = torch.Generator()
g.manual_seed(seed)

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

## Data Loading and Transforms

In [None]:
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.utils.data import ConcatDataset

# remove copyright banner
class RemoveCopyrightBanner(object):
    def __call__(self, img):
        width, height = img.size
        return img.crop((0, 0, width, height - 20))

transform = transforms.Compose([
    RemoveCopyrightBanner(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# granularity = 'variant'

# # Create the FGVC Aircraft dataset instance
# train_dataset = FGVCAircraft(
#     root='./data',
#     split='trainval',              # Options: 'train', 'val', 'trainval', 'test'
#     annotation_level=granularity,    # Options: 'variant', 'family', 'manufacturer'
#     transform=transform,
#     download=True
# )

# val_dataset = FGVCAircraft(
#     root='./data',
#     split='val',
#     annotation_level='variant',
#     transform=transform,
#     download=True
# )

# test_dataset = FGVCAircraft(
#     root='./data',
#     split='test',
#     annotation_level=granularity,
#     transform=transform,
#     download=True
# )

data_root = './data'

train_dataset = datasets.DTD(
    root=data_root,
    split='train',
    download=True,
    transform=transform
)

val_dataset = datasets.DTD(
    root=data_root,
    split='val',
    download=True,
    transform=transform
)
test_dataset = datasets.DTD(
    root=data_root,
    split='test',
    download=True,
    transform=transform
)

trainval_dataset = ConcatDataset([train_dataset, val_dataset])
train_dataset, val_dataset = torch.utils.data.random_split(trainval_dataset, [0.8, 0.2], generator=g)

In [None]:
# function to show images
def show_images(train_dataset, num_images=5):
  #shuffle the dataset
  train_dataset = torch.utils.data.Subset(train_dataset, torch.randperm(len(train_dataset)))
  fig, axes = plt.subplots(1, num_images, figsize=(15, 5))
  for i in range(num_images):
      image, label = train_dataset[i]
      image = image.permute(1, 2, 0)  # convert from CxHxW to HxWxC
      axes[i].imshow(image)
      axes[i].set_title(f'Label: {label}')
      axes[i].axis('off')
  plt.show()

show_images(train_dataset, num_images=5)

## Create the Dataset
In a continual learning setting, each task contains a new set of classes to train the model on. The validation and test datasets should be cumulative (to evalute the performance of the model on all classes seen by the model so far).

In [None]:
from collections import defaultdict
import torch
from tqdm import tqdm

def group_task_indices(dataset, cumulative=False, max_per_class=1000, classes_per_task=10):
    """
    Task 0: 0-9, Task 1: 10-19, ..., Task 9: 90-99 for train
    Output a dictionary where keys are task indices and values are lists of image indices.
    For example, task_dict[0] will contain indices of images with labels 0-9.
    """
    per_class_counts = defaultdict(int)
    task_dict = defaultdict(list)
    for idx, (_, label) in tqdm(enumerate(dataset), total=len(dataset)):
        # for test and val, should have cumulative indices (all classes seen so far)
        if per_class_counts[label] >= max_per_class:
          continue
        per_class_counts[label] += 1
        if cumulative:
          for i in range((label // classes_per_task), classes_per_task):
            task_dict[i].append(idx)
        else:
          task_dict[label // classes_per_task].append(idx)
    return task_dict

train_task_idxs = group_task_indices(train_dataset, cumulative=False, max_per_class=60)
val_task_idxs = group_task_indices(val_dataset, cumulative=True)
test_task_idxs = group_task_indices(test_dataset, cumulative=True)

## Training Functions

In [None]:
def val_net(net_to_val, val_loader):
    net_to_val.eval()
    loss = 0

    criterion = torch.nn.CrossEntropyLoss()

    with torch.no_grad():
        for img, label in tqdm(val_loader, desc="Validating"):

            # Get the input images and their corresponding labels
            img, label = img.cuda(), label.cuda()

            # Forward pass: Get predictions from the model
            outputs = net_to_val(img)
            loss += criterion(outputs, label)

        return loss / len(val_loader)

def train_net_ewc_replay(max_epochs, freeze_epochs, patience, ewc, net_to_train, opt, train_loader, val_loader, task, save_file=None, save_path=None):
    criterion = torch.nn.CrossEntropyLoss()

    train_losses = []
    val_losses = []
    net_to_train.cuda()

    initial_freeze = (task == 0)

    for name, param in net_to_train.named_parameters():
        if initial_freeze and 'fc' not in name:
            param.requires_grad = False
        else:
            param.requires_grad = True

    optimizer = opt

    best_val_loss = float('inf')
    epochs_no_improve = 0

    print(f"Starting training for Task {task}. Trainable parameters:")
    for name, param in net_to_train.named_parameters():
        if param.requires_grad:
             print(f"  - {name}")


    for epoch in range(max_epochs):
        net_to_train.train()
        running_loss = 0.0
        running_ewc = 0.0

        if epoch == freeze_epochs and task > 0:
            print(f"Unfreezing backbone at epoch {epoch} for task {task}")
            for name, param in net_to_train.named_parameters():
                if not param.requires_grad and 'fc' not in name:
                    # check if the parameter was frozen by EWC
                    # if it was frozen by EWC, we should not unfreeze it
                    was_frozen_by_importance = False
                    for n, imp in getattr(ewc, 'important_params', {}).items():
                        if n == name:
                            was_frozen_by_importance = True
                            break

                    if not was_frozen_by_importance:
                        param.requires_grad = True

            # adjust LR for the existing optimizer
            current_lr = optimizer.param_groups[0]['lr']
            new_lr = 1e-4
            if current_lr != new_lr:
                 print(f"Setting LR to {new_lr}")
                 for g in optimizer.param_groups:
                     g['lr'] = new_lr


        for imgs, labels in tqdm(train_loader, unit='batch', desc=f"Task {task} Epoch {epoch+1}"):
            imgs, labels = imgs.cuda(), labels.cuda()
            optimizer.zero_grad()
            outputs = net_to_train(imgs)
            loss = criterion(outputs, labels)
            ewc_loss = ewc.ewc_loss(task)
            running_ewc += ewc_loss.item()
            loss += ewc_loss
            loss.backward()
            torch.nn.utils.clip_grad_norm_(net_to_train.parameters(), max_norm=1.0) # grad clipping
            optimizer.step()
            running_loss += loss.item()

        avg_loss = running_loss / len(train_loader)
        avg_ewc_loss = running_ewc / len(train_loader)
        train_losses.append(avg_loss)

        # validation
        current_val_loss = val_net(net_to_train, val_loader)
        val_losses.append(current_val_loss)
        print(f"Task {task}, Epoch {epoch + 1}, EWC Loss: {avg_ewc_loss:.4f}, Total Loss: {avg_loss:.4f}, Val Loss (Cumulative): {current_val_loss:.4f}")

        # logging
        if save_file:
             with open(save_file, 'a') as f:
                  f.write(f"{task},{epoch + 1},{avg_loss},{current_val_loss}\n") 

        # early stopping
        if current_val_loss < best_val_loss:
            best_val_loss = current_val_loss
            epochs_no_improve = 0
            if save_path:
               torch.save(net_to_train.state_dict(), os.path.join(save_path, f"model_task{task}_best.pth"))
               print(f"  New best validation loss: {best_val_loss:.4f}. Saved best model.")
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= patience:
            print(f"Early stopping triggered at epoch {epoch + 1} for task {task}. Best Val Loss: {best_val_loss:.4f}")
            if save_path and os.path.exists(os.path.join(save_path, f"model_task{task}_best.pth")):
               print("Loading best model weights before exiting.")
               net_to_train.load_state_dict(torch.load(os.path.join(save_path, f"model_task{task}_best.pth")))
            break

    print(f"Finished training task {task}")
    return train_losses, val_losses

In [None]:
import torch.nn as nn
def modify_resnet_head(model, num_classes):
  """
  Modify the last fully connected layer of the ResNet model to match the number of classes.
  """

  old_fc = model.fc
  old_num_classes = old_fc.out_features
  num_ftrs = old_fc.in_features

  # Create the new head
  new_fc = nn.Linear(num_ftrs, num_classes).cuda()

  # Copy weights and biases from the old head
  if old_num_classes < num_classes:
    new_fc.weight.data[:old_num_classes, :] = old_fc.weight.data.clone().cuda()
    new_fc.bias.data[:old_num_classes] = old_fc.bias.data.clone().cuda()

  model.fc = new_fc
  return model

In [None]:
import torch

def get_test_accuracy(model, test_loader, num_classes):
    model.eval()
    correct_preds = 0
    total = 0
    correct_per_class = [0] * num_classes
    total_per_class = [0] * num_classes

    with torch.no_grad():
        for imgs, labels in tqdm(test_loader, desc="Testing", total=len(test_loader)):
            imgs, labels = imgs.cuda(), labels.cuda()
            output = model(imgs)
            preds = output.argmax(dim=1)

            correct_preds += (preds == labels).sum().item()
            total += labels.size(0)

            # Per-class stats
            for c in range(num_classes):
                correct_per_class[c] += ((preds == c) & (labels == c)).sum().item()
                total_per_class[c] += (labels == c).sum().item()

    overall_acc = correct_preds / total
    per_class_acc = [correct_per_class[c] / total_per_class[c] if total_per_class[c] > 0 else 0.0
                     for c in range(num_classes)]
    return overall_acc, per_class_acc


## EWC+Replay Classes and Helpers

In [None]:
def update_memory_buffer(buffer, max_size, new_samples):
    """Adds new samples to the buffer and trims it if it exceeds max_size."""
    buffer.extend(new_samples)
    # If buffer exceeds max size, remove samples randomly
    overflow = len(buffer) - max_size
    if overflow > 0:
        indices_to_remove = random.sample(range(len(buffer)), overflow)
        for index in sorted(indices_to_remove, reverse=True):
            del buffer[index]
    print(f"Memory buffer size: {len(buffer)} / {max_size}")

class MemoryDataset(Dataset):
    """Dataset class for the memory buffer."""
    def __init__(self, buffer_list):
        self.buffer = buffer_list

    def __len__(self):
        return len(self.buffer)

    def __getitem__(self, idx):
        # buffer contains (image_tensor, label)
        return self.buffer[idx]

In [None]:
import torch
import torch.nn.functional as F
from tqdm import tqdm
import random

class EWC:
    def __init__(self, model, device='cuda', lambda_ewc=5000):
        """
        Create EWC object to manage EWC regularization.
        lambda_ewc is the regularization strength for EWC penalty.
        """
        self.model = model
        self.device = device
        self.lambda_ewc = lambda_ewc

        # fisher information for each task
        self.fisher_dict = {}
        # optimal parameters for each task
        self.optpar_dict = {}
        # output layer sizes for each task
        self.output_sizes = {}
        # track important parameters
        self.important_params = {}

    def compute_fisher(self, data_loader, samples=500):
        """
        Compute the diagonal fisher matrix. Samples is the number of samples to use for Fisher computation.
        """
        fisher = {n: torch.zeros_like(p) for n, p in self.model.named_parameters() if p.requires_grad}
        self.model.eval()

        sample_loader = torch.utils.data.DataLoader(
            torch.utils.data.Subset(data_loader.dataset,
                                  torch.randperm(len(data_loader.dataset))[:samples].tolist()),
            batch_size=1, shuffle=True
        )

        for input_data, _ in sample_loader:
            input_data = input_data.to(self.device)
            output = self.model(input_data)
            probs = F.softmax(output, dim=1)

            num_classes = probs.size(1)
            for c in range(num_classes):
                self.model.zero_grad()
                class_prob = probs[0, c]
                log_class_prob = torch.log(class_prob)
                log_class_prob.backward(retain_graph=(c < num_classes-1))

                prob_value = class_prob.item()
                for n, p in self.model.named_parameters():
                    if p.grad is not None and p.requires_grad:
                        fisher[n] += prob_value * p.grad.data.pow(2) / samples

        return fisher

    def store_task_parameters(self, task_id, data_loader):
        """
        Store the optimal parameters and compute fisher after training on a task.
        """
        print(f"Storing parameters for task {task_id}...")

        # Store current parameter values
        self.optpar_dict[task_id] = {}
        for n, p in self.model.named_parameters():
            if p.requires_grad:
                self.optpar_dict[task_id][n] = p.data.clone()
                if 'fc' in n or 'layer4' in n:
                    print(f"Stored parameter {n}: min={p.min().item():.6f}, max={p.max().item():.6f}, mean={p.mean().item():.6f}")

        # compute fisher matrix
        self.fisher_dict[task_id] = self.compute_fisher(data_loader)

        # store output layer size
        if hasattr(self.model, 'fc'):
            self.output_sizes[task_id] = self.model.fc.weight.size(0)
            print(f"Stored output size for task {task_id}: {self.output_sizes[task_id]}")
        elif hasattr(self.model, 'classifier'):
            self.output_sizes[task_id] = self.model.classifier.weight.size(0)
            print(f"Stored output size for task {task_id}: {self.output_sizes[task_id]}")

    def ewc_loss(self, current_task_id):
        """
        Calculate the EWC penalty
        """
        if current_task_id == 0:
            return torch.tensor(0.0, device=self.device)
        total_loss = 0
        param_count = 0

        # calculate EWC loss for all previous tasks
        for task_id in range(current_task_id):
            task_loss = 0
            for n, p in self.model.named_parameters():
                if p.requires_grad and n in self.fisher_dict[task_id] and n in self.optpar_dict[task_id]:
                    if "fc.weight" in n:
                        prev_size = self.output_sizes[task_id]
                        fisher_term = self.fisher_dict[task_id][n][:prev_size, :]
                        param_diff = (p[:prev_size, :] - self.optpar_dict[task_id][n][:prev_size, :]).pow(2)
                        task_loss += (fisher_term * param_diff).sum()
                        param_count += fisher_term.numel()
                    elif "fc.bias" in n:
                        prev_size = self.output_sizes[task_id]
                        fisher_term = self.fisher_dict[task_id][n][:prev_size]
                        param_diff = (p[:prev_size] - self.optpar_dict[task_id][n][:prev_size]).pow(2)
                        task_loss += (fisher_term * param_diff).sum()
                        param_count += fisher_term.numel()
                    else:
                        fisher_term = self.fisher_dict[task_id][n]
                        param_diff = (p - self.optpar_dict[task_id][n]).pow(2)
                        task_loss += (fisher_term * param_diff).sum()
                        param_count += fisher_term.numel()

            total_loss += task_loss

        # lambda scaling
        ewc_penalty = self.lambda_ewc * total_loss / 2
        return ewc_penalty

## Training Loop

In [None]:
import random
from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset
from torchvision import models
import torch.optim as optim

LAMBDA_EWC = 10e5
save_dir = 'replay-ewc-coarse-grained-1M'

# init model and ewc object
model = models.resnet18(pretrained=True)
ewc = EWC(model, device='cuda', lambda_ewc=LAMBDA_EWC)

# freeze all parameters except the last fc layer
for name, param in model.named_parameters():
    if name != 'fc.weight' and name != 'fc.bias':
        param.requires_grad = False

# replay init
memory_buffer = []
memory_size = 1000 
samples_per_task_in_memory = 20


# init log files
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
open(os.path.join(save_dir, 'train_val_losses.txt'), 'w').close()
open(os.path.join(save_dir, 'accuracies.txt'), 'w').write("Task,Overall Accuracy,Per-Class Accuracy\n")


for task in range(5):
    print(f"Training on task {task}...")

    model = modify_resnet_head(model, (task+1) * 10)
    model = model.cuda()

    # get current task data
    current_task_train_subset = Subset(train_dataset, train_task_idxs[task])

    # combine with memory buffer if task > 0
    if task > 0 and len(memory_buffer) > 0:
        replay_dataset = MemoryDataset(memory_buffer)
        combined_train_dataset = ConcatDataset([current_task_train_subset, replay_dataset])
        print(f"Task {task}: Training with {len(current_task_train_subset)} current samples and {len(replay_dataset)} replay samples.")
    else:
        # task 0 or empty buffer: train only on current task data
        combined_train_dataset = current_task_train_subset
        print(f"Task {task}: Training only with {len(current_task_train_subset)} current samples.")

    # dataloader for current task
    train_loader_combined = DataLoader(
        combined_train_dataset,
        batch_size=128,
        shuffle=True,
        num_workers=4,
        worker_init_fn=seed_worker,
        generator=g
    )
    train_loader_not_combined = DataLoader(
        current_task_train_subset,
        batch_size=128,
        shuffle=True,
        num_workers=4,
        worker_init_fn=seed_worker,
        generator=g
    )

    # validation and test loaders for the current task
    val_loader = torch.utils.data.DataLoader(
        Subset(val_dataset, val_task_idxs[task]),
        batch_size=256,
        shuffle=False,
        num_workers=4,
        worker_init_fn=seed_worker,
        generator=g
    )
    test_loader = torch.utils.data.DataLoader(
        Subset(test_dataset, test_task_idxs[task]),
        batch_size=256,
        shuffle=False,
        num_workers=4,
        worker_init_fn=seed_worker,
        generator=g
    )

    # optimizer initialization
    if task == 0:
        optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.0015)
    else:
        optimizer = optim.Adam(model.parameters(), lr=1e-4) # Lower LR for full network


    # train the model
    train_losses, val_losses = train_net_ewc_replay(
        max_epochs=15,
        freeze_epochs=5,
        patience=5,
        net_to_train=model,
        opt=optimizer,
        train_loader=train_loader_combined,
        val_loader=val_loader,
        task=task,
        save_file=os.path.join(save_dir, 'train_val_losses.txt')
    )

    # update ewc after training
    ewc.store_task_parameters(task, train_loader_not_combined)

    # update memory buffer with samples from current task
    num_to_sample = min(samples_per_task_in_memory, len(current_task_train_subset))
    if num_to_sample > 0:
        indices_to_sample = random.sample(range(len(current_task_train_subset)), num_to_sample)
        new_memory_samples = []
        print(f"Sampling {num_to_sample} examples from task {task} for memory buffer...")
        for idx in indices_to_sample:
            img_tensor, label = current_task_train_subset[idx]
            new_memory_samples.append((img_tensor, label)) # Append as tuple
        update_memory_buffer(memory_buffer, memory_size, new_memory_samples)
    else:
        print(f"Not enough samples in task {task} subset to add to memory.")


    # evaluate and save the model
    overall_acc, per_class_acc = get_test_accuracy(model, test_loader, (task+1) * 10)
    print(f"Overall accuracy for task {task} (on classes 0-{(task+1)*10 - 1}): {overall_acc:.4f}")

    with open(os.path.join(save_dir, 'accuracies.txt'), 'a') as f:
        f.write(f"{task},{overall_acc:.4f},{per_class_acc}\n")

    torch.save(model.state_dict(), os.path.join(save_dir, f"model_task_{task}.pth"))
    print(f"Model for task {task} saved as model_task_{task}.pth")