<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Elastic_Weight_Consolidation_(EWC)_Example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

class EWC:
    def __init__(self, model, importance=1000):
        self.model = model
        self.importance = importance
        self.old_params = {name: param.clone() for name, param in model.named_parameters()}
        self.importance_matrix = None

    def calculate_importance(self, dataset):
        self.model.eval()
        importance_matrix = {name: torch.zeros_like(param) for name, param in self.model.named_parameters()}
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)
        for data in dataloader:
            self.model.zero_grad()
            outputs = self.model(data[0])
            loss = outputs.mean()
            loss.backward()
            for name, param in self.model.named_parameters():
                importance_matrix[name] += param.grad.abs()

        # Normalize importance values
        self.importance_matrix = {name: imp / len(dataset) for name, imp in importance_matrix.items()}

    def ewc_loss(self, new_task_loss):
        # Penalize deviation from old task parameters
        penalty = 0
        for name, param in self.model.named_parameters():
            old_param = self.old_params[name]
            importance = self.importance_matrix[name]
            penalty += (importance * (param - old_param) ** 2).sum()
        return new_task_loss + self.importance * penalty

# Example usage
model = nn.Linear(10, 2)  # Example model
ewc = EWC(model, importance=1000)

# Simulate a dataset
dataset = torch.utils.data.TensorDataset(torch.randn(100, 10), torch.randint(0, 2, (100,)))

# Calculate importance matrix
ewc.calculate_importance(dataset)

# Define a new task loss
criterion = nn.CrossEntropyLoss()
new_task_outputs = model(torch.randn(10, 10))  # Example new task outputs
new_task_labels = torch.randint(0, 2, (10,))
new_task_loss = criterion(new_task_outputs, new_task_labels)

# Calculate EWC loss
total_loss = ewc.ewc_loss(new_task_loss)
print(total_loss.item())