In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np

# Simulated dataset
class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = torch.tensor(data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

# Neural network model
class SimpleNN(nn.Module):
    def __init__(self, input_size, num_classes):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(input_size, 64)
        self.fc2 = nn.Linear(64, num_classes)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Compute Fisher Information Matrix
def compute_fisher_matrix(model, dataloader, criterion):
    model.eval()
    fisher_matrix = {n: torch.zeros_like(p) for n, p in model.named_parameters() if p.requires_grad}
    
    for inputs, labels in dataloader:
        model.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        
        for n, p in model.named_parameters():
            if p.grad is not None:
                fisher_matrix[n] += p.grad ** 2 / len(dataloader)
    
    return fisher_matrix

# EWC regularization term
def ewc_loss(model, fisher_matrix, old_params, lambda_ewc):
    reg_loss = 0.0
    for n, p in model.named_parameters():
        if p.requires_grad:
            reg_loss += torch.sum(fisher_matrix[n] * (p - old_params[n]) ** 2)
    return lambda_ewc * reg_loss

# Updated EWC regularization term
def ewc_loss(model, fisher_matrix, old_params, lambda_ewc):
    reg_loss = 0.0
    for n, p in model.named_parameters():
        if n in fisher_matrix and n in old_params and p.shape == old_params[n].shape:
            reg_loss += torch.sum(fisher_matrix[n] * (p - old_params[n]) ** 2)
    return lambda_ewc * reg_loss

# Train the model
def train(model, dataloader, criterion, optimizer, fisher_matrix=None, old_params=None, lambda_ewc=0.0):
    model.train()
    running_loss = 0.0

    for inputs, labels in dataloader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        if fisher_matrix and old_params:
            loss += ewc_loss(model, fisher_matrix, old_params, lambda_ewc)
        
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    return running_loss / len(dataloader)

# Simulate Task A and Task B datasets
np.random.seed(42)
task_a_data = np.random.rand(200, 4)  # 200 samples, 4 features
task_a_labels = np.random.randint(0, 2, 200)  # Binary classification

task_b_data = np.random.rand(200, 4)
task_b_labels = np.random.randint(0, 4, 200)  # Multi-class classification (4 classes)

task_a_loader = DataLoader(CustomDataset(task_a_data, task_a_labels), batch_size=32, shuffle=True)
task_b_loader = DataLoader(CustomDataset(task_b_data, task_b_labels), batch_size=32, shuffle=True)

# Initialize model, loss, and optimizer
model = SimpleNN(input_size=4, num_classes=2)  # Start with Task A's 2 classes
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train on Task A
print("Training Task A...")
for epoch in range(10):
    loss = train(model, task_a_loader, criterion, optimizer)
    print(f"Epoch {epoch + 1}, Loss: {loss:.4f}")

# Save Task A parameters and Fisher Information Matrix
old_params = {n: p.clone() for n, p in model.named_parameters() if p.requires_grad}
fisher_matrix = compute_fisher_matrix(model, task_a_loader, criterion)

# Modify model for Task B
# Modify model for Task B (4 classes)
model.fc2 = nn.Linear(64, 4)  # Update the output layer for Task B
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train on Task B with EWC
print("\nTraining Task B with EWC...")
lambda_ewc = 100.0  # Regularization strength
for epoch in range(10):
    loss = train(model, task_b_loader, criterion, optimizer, fisher_matrix, old_params, lambda_ewc)
    print(f"Epoch {epoch + 1}, Loss: {loss:.4f}")


Training Task A...
Epoch 1, Loss: 0.7002
Epoch 2, Loss: 0.6950
Epoch 3, Loss: 0.6888
Epoch 4, Loss: 0.6906
Epoch 5, Loss: 0.6946
Epoch 6, Loss: 0.6949
Epoch 7, Loss: 0.6949
Epoch 8, Loss: 0.6865
Epoch 9, Loss: 0.6893
Epoch 10, Loss: 0.6888

Training Task B with EWC...
Epoch 1, Loss: 1.3849
Epoch 2, Loss: 1.3806
Epoch 3, Loss: 1.3794
Epoch 4, Loss: 1.3656
Epoch 5, Loss: 1.3814
Epoch 6, Loss: 1.3700
Epoch 7, Loss: 1.3644
Epoch 8, Loss: 1.3652
Epoch 9, Loss: 1.3647
Epoch 10, Loss: 1.3608
