<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Online_Learning_with_Elastic_Weight_Consolidation_(EWC).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn

# Define a simple model for demonstration
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 2)

    def forward(self, x):
        return self.fc(x)

# Function to calculate Fisher Information
def calculate_fisher_information(model, data):
    model.eval()
    fisher_information = {name: torch.zeros_like(param) for name, param in model.named_parameters()}

    for inputs, labels in data:
        model.zero_grad()
        outputs = model(inputs)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        loss.backward()

        for name, param in model.named_parameters():
            fisher_information[name] += param.grad ** 2

    for name in fisher_information:
        fisher_information[name] /= len(data)

    return fisher_information

# Function to calculate EWC loss
def ewc_loss(model, model_previous, fisher_information, importance_factor):
    penalty = 0
    for name, param in model.named_parameters():
        fisher = fisher_information[name]
        prev_param = model_previous[name]
        penalty += (fisher * (param - prev_param) ** 2).sum()
    return importance_factor * penalty

# Example data for Fisher Information calculation
data = [(torch.randn(1, 10), torch.tensor([0])) for _ in range(10)]
model = SimpleModel()
model_previous = {name: param.clone() for name, param in model.named_parameters()}

# Calculate Fisher Information
fisher_information = calculate_fisher_information(model, data)

# Calculate EWC loss
importance_factor = 1000
loss = ewc_loss(model, model_previous, fisher_information, importance_factor)

print("EWC Loss:", loss.item())