In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import datasets, transforms
import numpy as np
from ewc import EWC
import generate_datasets as ds

  _torch_pytree._register_pytree_node(


In [2]:
# Helper function to evaluate the model
def evaluate_model(model, dataloader, device='cuda'):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

    return correct / total


In [3]:
def train_model(model, train_dataloader, test_dataloaders, criterion, optimizer, ewc=None, lambda_ewc=0.0, epochs=20, device='cuda'):
    model.train()
    accuracies = np.zeros((epochs, len(test_dataloaders)))
    for epoch in range(epochs):
        total_loss = 0
        
        for inputs, targets in train_dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            task_loss = criterion(outputs, targets)

            # Add regularization loss if applicable
            ewc_loss = ewc.compute_ewc_loss(model, lambda_ewc) if ewc else 0.0
            loss = task_loss + ewc_loss

            loss.backward()
            optimizer.step()
            total_loss += task_loss.item()

        # Evaluate the model after each epoch on each test set
        for i, test_dataloader in enumerate(test_dataloaders):
            accuracy = evaluate_model(model, test_dataloader, device)
            accuracies[epoch, i] = accuracy
            print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(train_dataloader):.4f}, Accuracy on task {i}: {accuracy:.4f}")

    return accuracies


In [4]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load permuted datasets
permuted_train_loaders, permuted_test_loaders, _, _ = ds.load_datasets()

In [None]:
from MNIST_functions import CustomNN, EarlyStopping, set_experiment_params

# Set experiment parameters
params = set_experiment_params('2A')
learning_rate = params['learning_rate']
dropout_input = params['dropout_input']
dropout_hidden = params['dropout_hidden']
early_stopping_enabled = params['early_stopping_enabled']
num_hidden_layers = params['num_hidden_layers']
width_hidden_layers = params['width_hidden_layers']
epochs = params['epochs']

# Initialize the model, criterion, optimizer, and early stopping
model_ewc = CustomNN(num_hidden_layers=num_hidden_layers, hidden_size=width_hidden_layers, dropout_input=dropout_input, dropout_hidden=dropout_hidden).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model_ewc.parameters(), lr=learning_rate)
early_stopping = EarlyStopping(patience=5) if early_stopping_enabled else None

# Train on first task with EWC
ewc = EWC(model_ewc)
accuracies_a = train_model(model_ewc, permuted_train_loaders[0], permuted_test_loaders[0:1], criterion, optimizer, ewc=ewc, lambda_ewc=1000, epochs=epochs, device=device)

ewc.compute_fisher(permuted_train_loaders[0])
ewc.update_params()

# Train on second task with EWC
accuracies_b = train_model(model_ewc, permuted_train_loaders[1], permuted_test_loaders[0:2], criterion, optimizer, ewc=ewc, lambda_ewc=1000, epochs=epochs, device=device)

ewc.compute_fisher(permuted_train_loaders[1])
ewc.update_params()

# Train on third task with EWC
accuracies_c = train_model(model_ewc, permuted_train_loaders[2], permuted_test_loaders[0:3], criterion, optimizer, ewc=ewc, lambda_ewc=1000, epochs=epochs, device=device)


Epoch 1/20, Loss: 0.4210, Accuracy on task 0: 0.9431
Epoch 2/20, Loss: 0.1395, Accuracy on task 0: 0.9554
Epoch 3/20, Loss: 0.0932, Accuracy on task 0: 0.9584
Epoch 4/20, Loss: 0.0679, Accuracy on task 0: 0.9616
Epoch 5/20, Loss: 0.0508, Accuracy on task 0: 0.9617
Epoch 6/20, Loss: 0.0394, Accuracy on task 0: 0.9666
Epoch 7/20, Loss: 0.0306, Accuracy on task 0: 0.9673
Epoch 8/20, Loss: 0.0243, Accuracy on task 0: 0.9689
Epoch 9/20, Loss: 0.0193, Accuracy on task 0: 0.9683
Epoch 10/20, Loss: 0.0159, Accuracy on task 0: 0.9702
Epoch 11/20, Loss: 0.0130, Accuracy on task 0: 0.9697
Epoch 12/20, Loss: 0.0110, Accuracy on task 0: 0.9712
Epoch 13/20, Loss: 0.0092, Accuracy on task 0: 0.9719
Epoch 14/20, Loss: 0.0079, Accuracy on task 0: 0.9723
Epoch 15/20, Loss: 0.0069, Accuracy on task 0: 0.9719
Epoch 16/20, Loss: 0.0061, Accuracy on task 0: 0.9717
Epoch 17/20, Loss: 0.0054, Accuracy on task 0: 0.9738
Epoch 18/20, Loss: 0.0048, Accuracy on task 0: 0.9719
Epoch 19/20, Loss: 0.0044, Accuracy o

KeyboardInterrupt: 

In [None]:
# evaluate on all tasks
acc_a_ewc = evaluate_model(model_ewc, permuted_test_loaders[0], device=device)
acc_b_ewc = evaluate_model(model_ewc, permuted_test_loaders[1], device=device)
acc_c_ewc = evaluate_model(model_ewc, permuted_test_loaders[2], device=device)

print(f"EWC - Accuracy on Task A: {acc_a_ewc:.2f}, Task B: {acc_b_ewc:.2f}, Task C: {acc_c_ewc:.2f}")

EWC - Accuracy on Task A: 0.86, Task B: 0.95, Task C: 0.97
