In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision
from torchvision import transforms
import time
import numpy as np

In [2]:
# Define Permuted MNIST Dataset
class PermutedMNIST(Dataset):
    def __init__(self, root, train=True, transform=None, permutations=None):
        self.mnist_dataset = torchvision.datasets.MNIST(root=root, train=train, transform=transforms.ToTensor(), download=True)
        self.transform = transform
        self.permutations = permutations
        self.train = train

    def __len__(self):
        return len(self.mnist_dataset)

    def __getitem__(self, idx):
        image, label = self.mnist_dataset[idx]
        if self.permutations is not None:
            image = image.view(-1)[self.permutations].view(image.shape)
        if self.transform:
            image = self.transform(image)
        return image, label

# Setup Permuted MNIST Tasks
num_tasks = 5
input_size = 28 * 28  # Flattened MNIST image
permutations = [torch.randperm(input_size) for _ in range(num_tasks)]

# Load Permuted MNIST Datasets for each task
train_tasks = [PermutedMNIST(root="./data", train=True, permutations=permutations[i]) for i in range(num_tasks)]
test_tasks = [PermutedMNIST(root="./data", train=False, permutations=permutations[i]) for i in range(num_tasks)]

# Function to create DataLoaders for each task
def get_task_data(task_idx, batch_size=64):
    train_loader = DataLoader(train_tasks[task_idx], batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_tasks[task_idx], batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

In [4]:
# Define a simple neural network
class SimpleNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleNN, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

class PackNet:
    def __init__(self, model, prune_percent=0.1):
        self.model = model
        self.prune_percent = prune_percent
        self.masks = {}  # To store task-specific masks
        self.task_masks = {}  # To store masks for each task
        self.original_weights = {}  # To store original weights for resetting

    def prune_weights(self, task_idx):
        """Prune a percentage of weights to free up capacity for the next task."""
        for name, param in self.model.named_parameters():
            if 'weight' in name:  # Only prune weights, not biases
                # Store original weights for resetting later
                self.original_weights[name] = param.data.clone()
                
                # Prune based on weight magnitude
                weights = param.data.cpu().numpy()
                flat_weights = np.abs(weights.flatten())
                threshold = np.percentile(flat_weights, self.prune_percent * 100)
                mask = (np.abs(weights) > threshold).astype(float)
                self.masks[name] = torch.tensor(mask, dtype=torch.float32).to(param.device)
                self.task_masks[f"task_{task_idx}_{name}"] = self.masks[name].clone()

    def apply_masks(self, task_idx):
        """Apply masks to freeze pruned weights for the current task."""
        for name, param in self.model.named_parameters():
            if name in self.masks:
                param.data *= self.masks[name]

    def freeze_weights(self):
        """Freeze the remaining weights for the current task."""
        for name, param in self.model.named_parameters():
            if name in self.masks:
                param.requires_grad = False

    def load_task_masks(self, task_idx):
        """Load task-specific masks for inference."""
        for name, param in self.model.named_parameters():
            if f"task_{task_idx}_{name}" in self.task_masks:
                # Reset weights to original state before applying the mask
                param.data = self.original_weights[name].clone()
                # Apply task-specific mask
                param.data *= self.task_masks[f"task_{task_idx}_{name}"]

# Function to train the model on a specific task with PackNet
def train_task_packnet(model, packnet, task_idx, criterion, optimizer, epochs=5):
    train_loader, _ = get_task_data(task_idx)
    
    # For collecting metrics
    task_train_loss = []
    task_train_acc = []
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for inputs, labels in train_loader:
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
            # Calculate accuracy
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = 100 * correct / total
        
        task_train_loss.append(epoch_loss)
        task_train_acc.append(epoch_acc)
        
        print(f'Task {task_idx+1}, Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%')
    
    # Prune and freeze weights after training the task
    packnet.prune_weights(task_idx)
    packnet.apply_masks(task_idx)
    packnet.freeze_weights()
    
    return task_train_loss, task_train_acc

# Function to evaluate the model on all seen tasks with PackNet
def evaluate_all_tasks_packnet(model, packnet, num_tasks):
    accuracies = []
    
    for i in range(num_tasks):
        _, test_loader = get_task_data(i)
        
        # Load task-specific masks for inference
        packnet.load_task_masks(i)
        
        model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for inputs, labels in test_loader:
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        accuracy = 100 * correct / total
        accuracies.append(accuracy)
        print(f'Task {i+1} Accuracy: {accuracy:.2f}%')
    
    return accuracies

# Main function to demonstrate PackNet
def demonstrate_packnet():
    # Hyperparameters
    input_size = 28 * 28  # Flattened MNIST image
    hidden_size = 256
    output_size = 10  # 10 classes for Permuted MNIST
    learning_rate = 0.01
    epochs_per_task = 5
    prune_percent = 0.05  # Reduced pruning percentage
    
    # Initialize model and PackNet
    model = SimpleNN(input_size, hidden_size, output_size)
    packnet = PackNet(model, prune_percent=prune_percent)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)
    
    # To store metrics
    training_history = {
        "task_accuracies": [],  # Performance on each task after sequential training
        "training_time": [],    # Time taken to train each task
        "learning_curves": {    # Loss and accuracy during training
            "loss": [],
            "accuracy": []
        }
    }
    
    # To compute forgetting metrics
    initial_accuracies = []  # Accuracy on each task right after learning it
    
    # Train on each task sequentially
    for task_idx in range(len(train_tasks)):
        print(f"\n{'='*50}")
        print(f"Training on Task {task_idx+1}")
        print(f"{'='*50}")
        
        # Measure training time
        start_time = time.time()
        
        # Train on current task with PackNet
        task_loss, task_acc = train_task_packnet(model, packnet, task_idx, criterion, optimizer, epochs=epochs_per_task)
        
        # Record training time
        end_time = time.time()
        training_time = end_time - start_time
        training_history["training_time"].append(training_time)
        
        # Save learning curves
        training_history["learning_curves"]["loss"].extend(task_loss)
        training_history["learning_curves"]["accuracy"].extend(task_acc)
        
        # Evaluate on all tasks seen so far
        print("\nEvaluating on all tasks seen so far:")
        task_accuracies = evaluate_all_tasks_packnet(model, packnet, task_idx + 1)
        
        # Store the accuracy on the current task after learning it
        if task_idx == 0:
            initial_accuracies.append(task_accuracies[0])
        else:
            training_history["task_accuracies"].append(task_accuracies.copy())
            initial_accuracies.append(task_accuracies[task_idx])
    
    # Calculate forgetting metrics
    forgetting_rate = calculate_forgetting_metrics(training_history, initial_accuracies)
    
    return training_history, forgetting_rate, initial_accuracies

# Function to calculate forgetting metrics
def calculate_forgetting_metrics(training_history, initial_accuracies):
    forgetting_rate = {}
    
    # For each task (except the last one since we don't have measurements after it)
    for task_idx in range(len(initial_accuracies) - 1):
        forgetting = []
        
        # Calculate forgetting for the task at each subsequent evaluation point
        for eval_idx, accuracies in enumerate(training_history["task_accuracies"]):
            if task_idx <= eval_idx:  # We only have measurements for tasks we've seen
                forgetting.append(initial_accuracies[task_idx] - accuracies[task_idx])
        
        forgetting_rate[f"Task {task_idx+1}"] = forgetting
    
    return forgetting_rate

In [5]:
# Run the PackNet demonstration
if __name__ == "__main__":
    training_history, forgetting_rate, initial_accuracies = demonstrate_packnet()
    print("Training History:", training_history)
    print("Forgetting Rate:", forgetting_rate)
    print("Initial Accuracies:", initial_accuracies)


Training on Task 1
Task 1, Epoch 1/5, Loss: 1.2487, Accuracy: 73.45%
Task 1, Epoch 2/5, Loss: 0.4920, Accuracy: 87.64%
Task 1, Epoch 3/5, Loss: 0.3910, Accuracy: 89.40%
Task 1, Epoch 4/5, Loss: 0.3506, Accuracy: 90.19%
Task 1, Epoch 5/5, Loss: 0.3255, Accuracy: 90.84%

Evaluating on all tasks seen so far:
Task 1 Accuracy: 91.59%

Training on Task 2
Task 2, Epoch 1/5, Loss: 2.2612, Accuracy: 16.34%
Task 2, Epoch 2/5, Loss: 2.1643, Accuracy: 19.66%
Task 2, Epoch 3/5, Loss: 2.1377, Accuracy: 21.68%
Task 2, Epoch 4/5, Loss: 2.1136, Accuracy: 23.09%
Task 2, Epoch 5/5, Loss: 2.0912, Accuracy: 24.48%

Evaluating on all tasks seen so far:
Task 1 Accuracy: 88.74%
Task 2 Accuracy: 24.17%

Training on Task 3
Task 3, Epoch 1/5, Loss: 2.3802, Accuracy: 11.74%
Task 3, Epoch 2/5, Loss: 2.2895, Accuracy: 15.23%
Task 3, Epoch 3/5, Loss: 2.2424, Accuracy: 17.21%
Task 3, Epoch 4/5, Loss: 2.2109, Accuracy: 18.70%
Task 3, Epoch 5/5, Loss: 2.1875, Accuracy: 19.89%

Evaluating on all tasks seen so far:
Task