In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import time
import random
import numpy as np
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import torchvision
from torchvision import transforms

In [31]:
# Define Class Incremental MNIST Dataset
class ClassIncrementalMNIST(Dataset):
    def __init__(self, root, train=True, transform=None, classes=None):
        self.mnist_dataset = torchvision.datasets.MNIST(root=root, train=train, transform=transforms.ToTensor(), download=True)
        self.transform = transform
        self.classes = classes
        self.train = train
        # Filter data to include only the specified classes
        self.data = []
        self.targets = []
        for image, label in self.mnist_dataset:
            if label in self.classes:
                self.data.append(image)
                self.targets.append(label)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image, label = self.data[idx], self.targets[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

# Setup Class Incremental MNIST Tasks
num_tasks = 5
classes_per_task = 2

# Divide classes into tasks
class_splits = [list(range(i * classes_per_task, (i + 1) * classes_per_task)) for i in range(num_tasks)]

# Load datasets for each task
train_tasks = [ClassIncrementalMNIST(root="./data", train=True, classes=class_splits[i]) for i in range(num_tasks)]
test_tasks = [ClassIncrementalMNIST(root="./data", train=False, classes=class_splits[i]) for i in range(num_tasks)]

# Function to create DataLoaders
def get_task_data(task_idx, batch_size=64):
    train_loader = DataLoader(train_tasks[task_idx], batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_tasks[task_idx], batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

In [34]:
# Define SimpleNN model
class SimpleNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleNN, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# Train function with naive rehearsal
def train_task_with_replay(model, task_idx, criterion, optimizer, replay_data, replay_labels, epochs=5):
    train_loader, _ = get_task_data(task_idx)
    
    # Combine current task data with replay data
    combined_data = torch.utils.data.TensorDataset(
        torch.cat([torch.stack([x for x, _ in train_loader.dataset]), replay_data]),
        torch.cat([torch.tensor([y for _, y in train_loader.dataset], dtype=torch.long), replay_labels])
    )
    combined_loader = DataLoader(combined_data, batch_size=64, shuffle=True)
    
    # For collecting metrics
    task_train_loss = []
    task_train_acc = []
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for inputs, labels in combined_loader:
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
            # Calculate accuracy
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        epoch_loss = running_loss / len(combined_loader)
        epoch_acc = 100 * correct / total
        
        task_train_loss.append(epoch_loss)
        task_train_acc.append(epoch_acc)
        
        print(f'Task {task_idx+1}, Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%')
    
    return task_train_loss, task_train_acc

# Evaluate function for cumulative classes
def evaluate_cumulative_classes(model, num_tasks):
    combined_test_set = ConcatDataset(test_tasks[:num_tasks])
    test_loader = DataLoader(combined_test_set, batch_size=64, shuffle=False)
    
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in test_loader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    return accuracy

# Naive Rehearsal continual learning demonstration
def demonstrate_naive_rehearsal():
    input_size = 28 * 28
    hidden_size = 256
    learning_rate = 0.01
    epochs_per_task = 5
    
    # Start with output size for first task
    output_size = 10  # All MNIST digits (0-9)
    
    model = SimpleNN(input_size, hidden_size, output_size)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)
    
    accuracies = []
    replay_data = torch.tensor([])
    replay_labels = torch.tensor([], dtype=torch.long)  # Ensure replay_labels is of type torch.long
    
    for task_idx in range(len(class_splits)):
        print(f"\n{'='*50}")
        print(f"Training on Task {task_idx+1}: Classes {class_splits[task_idx]}")
        print(f"{'='*50}")
        
        # Train on current task with replay
        train_loss, train_acc = train_task_with_replay(model, task_idx, criterion, optimizer, replay_data, replay_labels, epochs_per_task)
        
        # Evaluate on cumulative classes
        accuracy = evaluate_cumulative_classes(model, task_idx + 1)
        accuracies.append(accuracy)
        
        print(f"\nModel Accuracy after Task {task_idx + 1}: {accuracy:.2f}%")
        
        # Store a subset of current task data for replay
        current_task_data = torch.stack([x for x, _ in train_tasks[task_idx]])
        current_task_labels = torch.tensor([y for _, y in train_tasks[task_idx]], dtype=torch.long)
        
        # Randomly select a subset of data for replay
        replay_size = 100  # Number of samples to store for replay
        indices = torch.randperm(len(current_task_data))[:replay_size]
        replay_data = torch.cat([replay_data, current_task_data[indices]])
        replay_labels = torch.cat([replay_labels, current_task_labels[indices]])
    
    return accuracies

In [35]:
# Run the demonstration
if __name__ == "__main__":
    naive_rehearsal_accuracies = demonstrate_naive_rehearsal()


Training on Task 1: Classes [0, 1]
Task 1, Epoch 1/5, Loss: 0.4516, Accuracy: 96.98%
Task 1, Epoch 2/5, Loss: 0.0373, Accuracy: 99.67%
Task 1, Epoch 3/5, Loss: 0.0217, Accuracy: 99.75%
Task 1, Epoch 4/5, Loss: 0.0162, Accuracy: 99.76%
Task 1, Epoch 5/5, Loss: 0.0134, Accuracy: 99.77%

Model Accuracy after Task 1: 99.91%

Training on Task 2: Classes [2, 3]
Task 2, Epoch 1/5, Loss: 0.7270, Accuracy: 83.19%
Task 2, Epoch 2/5, Loss: 0.1863, Accuracy: 94.59%
Task 2, Epoch 3/5, Loss: 0.1485, Accuracy: 95.36%
Task 2, Epoch 4/5, Loss: 0.1327, Accuracy: 95.68%
Task 2, Epoch 5/5, Loss: 0.1231, Accuracy: 95.93%

Model Accuracy after Task 2: 52.49%

Training on Task 3: Classes [4, 5]
Task 3, Epoch 1/5, Loss: 0.7729, Accuracy: 83.04%
Task 3, Epoch 2/5, Loss: 0.2067, Accuracy: 95.73%
Task 3, Epoch 3/5, Loss: 0.1639, Accuracy: 96.52%
Task 3, Epoch 4/5, Loss: 0.1411, Accuracy: 96.91%
Task 3, Epoch 5/5, Loss: 0.1283, Accuracy: 97.16%

Model Accuracy after Task 3: 43.72%

Training on Task 4: Classes [6