<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Continual_Learning_with_Elastic_Weight_Consolidation_(EWC).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Define a simple model
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class EWC:
    def __init__(self, model, dataloader, importance):
        self.model = model
        self.dataloader = dataloader
        self.importance = importance
        self.params = {n: p for n, p in self.model.named_parameters() if p.requires_grad}
        self._means = {}
        self._fisher = {}
        self._compute_fisher()

    def _compute_fisher(self):
        self.model.eval()
        fisher = {n: torch.zeros_like(p) for n, p in self.params.items()}
        for inputs, labels in self.dataloader:
            self.model.zero_grad()
            outputs = self.model(inputs)
            loss = nn.CrossEntropyLoss()(outputs, labels)
            loss.backward()
            for n, p in self.params.items():
                fisher[n] += p.grad.data.pow(2) / len(self.dataloader)
        for n, p in fisher.items():
            self._fisher[n] = fisher[n]
            self._means[n] = self.params[n].data.clone()

    def penalty(self, model):
        loss = 0
        for n, p in model.named_parameters():
            if n in self._fisher:
                _loss = self._fisher[n] * (p - self._means[n]).pow(2)
                loss += _loss.sum()
        return self.importance * loss

    def update_model(self, new_data, optimizer):
        self.model.train()
        dataloader = DataLoader(TensorDataset(*new_data), batch_size=32, shuffle=True)
        for inputs, labels in dataloader:
            optimizer.zero_grad()
            outputs = self.model(inputs)
            loss = nn.CrossEntropyLoss()(outputs, labels)
            ewc_loss = loss + self.penalty(self.model)
            ewc_loss.backward()
            optimizer.step()

# Initialize your model, dataloader, and optimizer
model = MyModel()
X_train = torch.randn(1000, 784)  # example training data
y_train = torch.randint(0, 10, (1000,))  # example training labels
dataloader = DataLoader(TensorDataset(X_train, y_train), batch_size=32, shuffle=True)
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Initialize EWC
ewc = EWC(model, dataloader, importance=1000)

# Example new data stream for continual learning
new_X1 = torch.randn(200, 784)  # example new data
new_y1 = torch.randint(0, 10, (200,))  # example new labels
data_stream = [(new_X1, new_y1)]

# Applying EWC in the learning loop
for new_data in data_stream:
    ewc.update_model(new_data, optimizer)