<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStardust/blob/main/Continual_Learning_for_Long_Term_Adaptability.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn

class EWC:
    def __init__(self, model, old_data, importance=1e4):
        self.model = model
        self.importance = importance
        self.old_params = {n: p.clone() for n, p in model.named_parameters()}
        self.old_gradients = self.compute_gradients(old_data)

    def compute_gradients(self, data):
        self.model.eval()
        data = data.to(next(self.model.parameters()).device)  # Ensure data is on the same device as the model
        loss = self.model(data)
        loss.backward()
        return {n: p.grad.clone().detach() for n, p in self.model.named_parameters()}

    def penalty(self):
        self.model.train()
        loss = 0
        for n, p in self.model.named_parameters():
            _loss = self.importance * (p - self.old_params[n]) ** 2 * self.old_gradients[n] ** 2
            loss += _loss.sum()
        return loss

# Example usage
class DummyModel(nn.Module):
    def __init__(self):
        super(DummyModel, self).__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        return self.fc(x).sum()

model = DummyModel()
old_data = torch.randn(10, 10)

ewc = EWC(model, old_data)
penalty = ewc.penalty()
print(f"EWC Penalty: {penalty.item()}")