<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Continual_Learning_with_Elastic_Weight_Consolidation_(EWC).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn

class EWC:
    def __init__(self, model, data, importance=1000):
        self.model = model
        self.importance = importance
        self.fisher_information = self.compute_fisher(data)
        self.params = {name: param.clone() for name, param in model.named_parameters()}

    def compute_fisher(self, data):
        fisher_information = {}
        self.model.eval()

        for x, y in data:
            self.model.zero_grad()
            output = self.model(x)
            loss = nn.CrossEntropyLoss()(output, y)
            loss.backward()

            for name, param in self.model.named_parameters():
                if param.grad is not None:
                    if name not in fisher_information:
                        fisher_information[name] = param.grad.clone() ** 2
                    else:
                        fisher_information[name] += param.grad.clone() ** 2

        # Average the Fisher information over all data points
        for name in fisher_information:
            fisher_information[name] /= len(data)

        return fisher_information

    def penalty(self, model):
        loss = 0
        for name, param in model.named_parameters():
            fisher = self.fisher_information[name]
            prev_param = self.params[name]
            loss += (fisher * (param - prev_param) ** 2).sum()
        return self.importance * loss

# Example usage with a dummy model and data
class DummyModel(nn.Module):
    def __init__(self):
        super(DummyModel, self).__init__()
        self.fc = nn.Linear(10, 2)

    def forward(self, x):
        return self.fc(x)

# Create dummy data
dummy_data = [(torch.randn(1, 10), torch.tensor([0])) for _ in range(100)]

# Initialize model and EWC
model = DummyModel()
ewc = EWC(model, dummy_data)

# Compute penalty
penalty_loss = ewc.penalty(model)
print("EWC Penalty Loss:", penalty_loss.item())