# Meta Continual Learning

In [None]:
import torch
import torchvision
from torch import nn
from torchvision.datasets import FashionMNIST
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor, Normalize
import numpy as np

## Data Preprocessing

In [None]:
# A function to split the FashionMNIST dataset into tasks
def create_fashion_mnist_tasks(root_dir, num_tasks=5, transform=None):
    # Load the entire FashionMNIST dataset
    full_dataset = FashionMNIST(root=root_dir, train=False, download=True, transform=transform)

    # Determine the number of classes per task
    classes_per_task = len(full_dataset.classes) // num_tasks

    # Create tasks by splitting the dataset
    tasks = []
    for task_idx in range(num_tasks):
        # Calculate class indices for the current task
        class_start = task_idx * classes_per_task
        class_end = class_start + classes_per_task
        task_classes = list(range(class_start, class_end))

        # Find indices of images belonging to the current task's classes
        task_indices = [i for i, (_, label) in enumerate(full_dataset) if label in task_classes]

        # Create a Subset for the current task
        task_dataset = Subset(full_dataset, task_indices)
        tasks.append(task_dataset)

    return tasks

# Create tasks
root_dir = './data'
tasks_evaluation = create_fashion_mnist_tasks(root_dir, num_tasks=5, transform=transform)

In [None]:
device = ('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using {device} device.")

Using cpu device.


## NN Architecture

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = nn.functional.relu(nn.functional.max_pool2d(self.conv1(x), 2))
        x = nn.functional.relu(nn.functional.max_pool2d(self.conv2(x), 2))
        x = x.view(x.size(0), -1)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

## Training

In [None]:
def maml_update(model, optimizer, loss_fn, data_loader, steps=3, alpha=0.0001):
    # Create a copy of the model's initial state
    initial_state = {name: param.clone() for name, param in model.named_parameters()}

    # Task-specific update
    for step in range(steps):
        for inputs, labels in data_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            # labels = torch.tensor(labels)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()

            # Apply the update to the model (simulating gradient descent)
            with torch.no_grad():
                for name, param in model.named_parameters():
                    param -= alpha * param.grad

    # Update the meta-model parameters
    meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    meta_loss = 0
    for inputs, labels in data_loader:
        outputs = model(inputs)
        meta_loss += loss_fn(outputs, labels)

    meta_optimizer.zero_grad()
    meta_loss.backward()
    meta_optimizer.step()

    # Restore the model to its initial state
    with torch.no_grad():
        for name, param in model.named_parameters():
            param.copy_(initial_state[name])

In [None]:
task_dataloaders = [DataLoader(task, batch_size=64, shuffle=True) for task in tasks]
model = SimpleCNN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
for task_id, task_loader in enumerate(task_dataloaders):
    print(f"Training on task {task_id}")
    maml_update(model, optimizer, nn.CrossEntropyLoss(), task_loader)

Training on task 0
Training on task 1
Training on task 2
Training on task 3
Training on task 4


In [None]:
def evaluate_model(model, task_loader):
    model.eval()
    correct = 0
    total = 0
    for inputs, labels in task_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    accuracy = correct / total
    print(f"Task Accuracy: {accuracy}")
    return accuracy


In [None]:
eval_task_dataloaders = [DataLoader(task, batch_size=64, shuffle=True) for task in tasks_evaluation]

In [None]:
# Evaluate model on each task and compute average accuracy
total_accuracy = 0
for task_id, task_loader in enumerate(eval_task_dataloaders):
    print(f"Evaluation on task {task_id}")
    accuracy = evaluate_model(model, task_loader)
    total_accuracy += accuracy

num_tasks = len(eval_task_dataloaders)
print(f"Average Accuracy: {total_accuracy / num_tasks}")

Evaluation on task 0
Task Accuracy: 0.303
Evaluation on task 1
Task Accuracy: 0.315
Evaluation on task 2
Task Accuracy: 0.0
Evaluation on task 3
Task Accuracy: 0.0
Evaluation on task 4
Task Accuracy: 0.0
Average Accuracy: 0.1236
