In [17]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import datasets, transforms
import numpy as np
from moja_ewc import EWC
import generate_datasets as ds

In [18]:
# Helper function to evaluate the model
def evaluate_model(model, dataloader, device='cpu'):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

    return correct / total


In [19]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load permuted datasets
permuted_train_loaders, permuted_test_loaders, rotated_train_loaders, rotated_test_loaders = ds.load_datasets()

In [20]:
def train_model(model, train_dataloader, test_dataloader, criterion, optimizer, ewc=None, lambda_ewc=0.0, epochs=5, device='cuda'):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        
        for inputs, targets in train_dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            task_loss = criterion(outputs, targets)

            # Add regularization loss if applicable
            ewc_loss = ewc.compute_ewc_loss(model, lambda_ewc) if ewc else 0.0
            loss = task_loss + ewc_loss

            loss.backward()
            optimizer.step()
            total_loss += task_loss.item()

        # Evaluate the model after each epoch
        accuracy = evaluate_model(model, test_dataloader, device=device)
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(train_dataloader):.4f}, Accuracy: {accuracy:.4f}")

In [21]:
from MNIST_functions import CustomNN, EarlyStopping, set_experiment_params

# Set experiment parameters
learning_rate, dropout_input, dropout_hidden, early_stopping_enabled, num_hidden_layers, width_hidden_layers, epochs = set_experiment_params('2A')

# Initialize the model, criterion, optimizer, and early stopping
model_ewc = CustomNN(num_hidden_layers=num_hidden_layers, hidden_size=width_hidden_layers, dropout_input=dropout_input, dropout_hidden=dropout_hidden).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model_ewc.parameters(), lr=learning_rate)
early_stopping = EarlyStopping(patience=5) if early_stopping_enabled else None

# Train on first task with EWC
ewc = EWC(model_ewc)
train_model(model_ewc, permuted_train_loaders[0], permuted_test_loaders[0], criterion, optimizer, ewc=ewc, lambda_ewc=500, epochs=epochs, device=device)

ewc.compute_fisher(permuted_train_loaders[0])
ewc.update_params()

# Train on second task with EWC
train_model(model_ewc, permuted_train_loaders[1], permuted_test_loaders[1], criterion, optimizer, ewc=ewc, lambda_ewc=500, epochs=epochs, device=device)

ewc.compute_fisher(permuted_train_loaders[1])
ewc.update_params()

# Train on third task with EWC
train_model(model_ewc, permuted_train_loaders[2], permuted_test_loaders[2], criterion, optimizer, ewc=ewc, lambda_ewc=500, epochs=epochs, device=device)


Epoch 1/3, Loss: 0.4264, Accuracy: 0.9390
Epoch 2/3, Loss: 0.1441, Accuracy: 0.9551
Epoch 3/3, Loss: 0.0959, Accuracy: 0.9613
Epoch 1/3, Loss: 0.3860, Accuracy: 0.9371
Epoch 2/3, Loss: 0.1663, Accuracy: 0.9515
Epoch 3/3, Loss: 0.1221, Accuracy: 0.9569
Epoch 1/3, Loss: 0.3731, Accuracy: 0.9407
Epoch 2/3, Loss: 0.1676, Accuracy: 0.9456
Epoch 3/3, Loss: 0.1258, Accuracy: 0.9535


In [22]:
# evaluate on all tasks
acc_a_ewc = evaluate_model(model_ewc, permuted_test_loaders[0], device=device)
acc_b_ewc = evaluate_model(model_ewc, permuted_test_loaders[1], device=device)
acc_c_ewc = evaluate_model(model_ewc, permuted_test_loaders[2], device=device)

print(f"EWC - Accuracy on Task A: {acc_a_ewc:.2f}, Task B: {acc_b_ewc:.2f}, Task C: {acc_c_ewc:.2f}")

EWC - Accuracy on Task A: 0.90, Task B: 0.95, Task C: 0.95
