<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Continual_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

class SimpleNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

def compute_fisher(model, data_loader, optimizer):
    fisher = {}
    for name, param in model.named_parameters():
        fisher[name] = torch.zeros_like(param)

    model.train()
    for data, target in data_loader:
        optimizer.zero_grad()
        output = model(data)
        loss = ((output - target) ** 2).mean()
        loss.backward()
        for name, param in model.named_parameters():
            fisher[name] += param.grad.data ** 2 / len(data_loader)

    return fisher

def ewc_loss(model, fisher, old_params, lambda_ewc=0.4):
    loss = 0
    for name, param in model.named_parameters():
        loss += (fisher[name] * (param - old_params[name]) ** 2).sum()
    return loss * lambda_ewc

model = SimpleNet(input_size=2, hidden_size=10, output_size=1)
optimizer = optim.SGD(model.parameters(), lr=0.01)

# First task training
data = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
target = torch.tensor([[1], [0]], dtype=torch.float32)
dataset = TensorDataset(data, target)
loader = DataLoader(dataset, batch_size=2)

model.train()
for epoch in range(10):
    for batch_idx, (data, target) in enumerate(loader):
        optimizer.zero_grad()
        output = model(data)
        loss = ((output - target) ** 2).mean()
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch + 1}, Loss: {loss.item()}')

# Save model parameters and compute Fisher information
old_params = {name: param.clone() for name, param in model.named_parameters()}
fisher = compute_fisher(model, loader, optimizer)

# Second task training with EWC
data = torch.tensor([[5, 6], [7, 8]], dtype=torch.float32)
target = torch.tensor([[0], [1]], dtype=torch.float32)
dataset = TensorDataset(data, target)
loader = DataLoader(dataset, batch_size=2)

model.train()
for epoch in range(10):
    for batch_idx, (data, target) in enumerate(loader):
        optimizer.zero_grad()
        output = model(data)
        loss = ((output - target) ** 2).mean()
        loss_ewc = ewc_loss(model, fisher, old_params)
        total_loss = loss + loss_ewc
        total_loss.backward()
        optimizer.step()
    print(f'Epoch {epoch + 1}, Loss: {total_loss.item()}')