In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import time
import random
import numpy as np
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import torchvision
from torchvision import transforms

In [2]:
# Define Class Incremental MNIST Dataset
class ClassIncrementalMNIST(Dataset):
    def __init__(self, root, train=True, transform=None, classes=None):
        self.mnist_dataset = torchvision.datasets.MNIST(root=root, train=train, transform=transforms.ToTensor(), download=True)
        self.transform = transform
        self.classes = classes
        self.train = train
        # Filter data to include only the specified classes
        self.data = []
        self.targets = []
        for image, label in self.mnist_dataset:
            if label in self.classes:
                self.data.append(image)
                self.targets.append(label)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image, label = self.data[idx], self.targets[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

# Setup Class Incremental MNIST Tasks
num_tasks = 5
classes_per_task = 2

# Divide classes into tasks
class_splits = [list(range(i * classes_per_task, (i + 1) * classes_per_task)) for i in range(num_tasks)]

# Load datasets for each task
train_tasks = [ClassIncrementalMNIST(root="./data", train=True, classes=class_splits[i]) for i in range(num_tasks)]
test_tasks = [ClassIncrementalMNIST(root="./data", train=False, classes=class_splits[i]) for i in range(num_tasks)]

# Function to create DataLoaders
def get_task_data(task_idx, batch_size=64):
    train_loader = DataLoader(train_tasks[task_idx], batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_tasks[task_idx], batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

In [5]:
# Define Experience Replay Buffer
class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.data = []
        self.labels = []

    def add(self, images, labels):
        self.data.extend(images)
        self.labels.extend(labels)
        
        # Keep buffer size within capacity
        if len(self.data) > self.capacity:
            self.data = self.data[-self.capacity:]
            self.labels = self.labels[-self.capacity:]

    def sample(self, batch_size):
        indices = random.sample(range(len(self.data)), min(batch_size, len(self.data)))
        return torch.stack([self.data[i] for i in indices]), torch.tensor([self.labels[i] for i in indices], dtype=torch.long)

# Define SimpleNN model
class SimpleNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleNN, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# Train function with experience replay
def train_with_experience_replay(model, task_idx, criterion, optimizer, buffer, epochs=5, batch_size=64):
    train_loader, _ = get_task_data(task_idx)
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct, total = 0, 0

        for inputs, labels in train_loader:
            # Sample from the buffer
            if len(buffer.data) > 0:
                replay_inputs, replay_labels = buffer.sample(batch_size // 2)
                inputs = torch.cat((inputs, replay_inputs))
                labels = torch.cat((labels, replay_labels))

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = 100 * correct / total
        print(f'Task {task_idx+1}, Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%')

# Evaluate function for cumulative classes
def evaluate_cumulative_classes(model, num_tasks):
    combined_test_set = ConcatDataset(test_tasks[:num_tasks])
    test_loader = DataLoader(combined_test_set, batch_size=64, shuffle=False)
    
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in test_loader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    return accuracy

# Full continual learning setup with experience replay
def demonstrate_experience_replay():
    input_size = 28 * 28
    hidden_size = 256
    learning_rate = 0.01
    epochs_per_task = 5
    replay_capacity = 500  # Buffer capacity

    # Start with output size for all 10 MNIST digits
    output_size = 10

    model = SimpleNN(input_size, hidden_size, output_size)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)
    
    buffer = ReplayBuffer(replay_capacity)
    accuracies = []
    
    for task_idx in range(len(class_splits)):
        print(f"\n{'='*50}")
        print(f"Training on Task {task_idx+1}: Classes {class_splits[task_idx]}")
        print(f"{'='*50}")

        # Train with experience replay
        train_with_experience_replay(model, task_idx, criterion, optimizer, buffer, epochs_per_task)
        
        # Evaluate on cumulative tasks
        accuracy = evaluate_cumulative_classes(model, task_idx + 1)
        accuracies.append(accuracy)
        print(f"\nModel Accuracy after Task {task_idx + 1}: {accuracy:.2f}%")
        
        # Add current task data to buffer
        current_task_data = torch.stack([x for x, _ in train_tasks[task_idx]])
        current_task_labels = torch.tensor([y for _, y in train_tasks[task_idx]], dtype=torch.long)
        buffer.add(current_task_data, current_task_labels)

    return accuracies

In [6]:
# Run the demonstration
if __name__ == "__main__":
    experience_replay_accuracies = demonstrate_experience_replay()


Training on Task 1: Classes [0, 1]
Task 1, Epoch 1/5, Loss: 0.4762, Accuracy: 97.37%
Task 1, Epoch 2/5, Loss: 0.0380, Accuracy: 99.67%
Task 1, Epoch 3/5, Loss: 0.0221, Accuracy: 99.72%
Task 1, Epoch 4/5, Loss: 0.0165, Accuracy: 99.76%
Task 1, Epoch 5/5, Loss: 0.0136, Accuracy: 99.76%

Model Accuracy after Task 1: 99.91%

Training on Task 2: Classes [2, 3]
Task 2, Epoch 1/5, Loss: 0.7913, Accuracy: 82.96%
Task 2, Epoch 2/5, Loss: 0.2591, Accuracy: 94.27%
Task 2, Epoch 3/5, Loss: 0.1911, Accuracy: 95.20%
Task 2, Epoch 4/5, Loss: 0.1617, Accuracy: 95.75%
Task 2, Epoch 5/5, Loss: 0.1458, Accuracy: 95.96%

Model Accuracy after Task 2: 96.75%

Training on Task 3: Classes [4, 5]
Task 3, Epoch 1/5, Loss: 0.8979, Accuracy: 79.11%
Task 3, Epoch 2/5, Loss: 0.3081, Accuracy: 93.87%
Task 3, Epoch 3/5, Loss: 0.2132, Accuracy: 95.32%
Task 3, Epoch 4/5, Loss: 0.1787, Accuracy: 95.61%
Task 3, Epoch 5/5, Loss: 0.1514, Accuracy: 96.11%

Model Accuracy after Task 3: 67.82%

Training on Task 4: Classes [6