<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Elastic_Weight_Consolidation_(EWC).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer

class EWC:
    def __init__(self, model, importance):
        self.model = model
        self.importance = importance
        self.fisher_information = {}
        self.params_old = {}

    def compute_fisher_information(self, data_loader, criterion):
        self.model.eval()

        # Initialize Fisher information
        for name, param in self.model.named_parameters():
            self.fisher_information[name] = torch.zeros_like(param)
            self.params_old[name] = param.clone().detach()

        # Compute Fisher Information for each parameter
        for inputs, labels in data_loader:
            self.model.zero_grad()
            outputs = self.model(inputs)
            loss = criterion(outputs.logits, labels)
            loss.backward()

            for name, param in self.model.named_parameters():
                self.fisher_information[name] += param.grad.data.pow(2) / len(data_loader)

        self.model.train()

    def penalty(self):
        penalty = 0
        for name, param in self.model.named_parameters():
            fisher_val = self.fisher_information[name]
            param_old = self.params_old[name]
            penalty += (self.importance * fisher_val * (param - param_old).pow(2)).sum()
        return penalty

# Example usage
# Create a dummy dataset
input_data = torch.randint(0, 1000, (100, 10))  # 100 samples, sequence length 10
labels = torch.randint(0, 2, (100,))  # 100 binary labels

dataset = TensorDataset(input_data, labels)
data_loader = DataLoader(dataset, batch_size=10)

# Initialize model and tokenizer
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

ewc = EWC(model, importance=1.0)

# Compute Fisher Information
ewc.compute_fisher_information(data_loader, criterion)

# Use the penalty in the training loop
for inputs, labels in data_loader:
    model.zero_grad()
    outputs = model(inputs)["logits"]
    loss = criterion(outputs, labels) + ewc.penalty()  # Add EWC penalty to the loss
    loss.backward()
    optimizer.step()
    print(f"Loss: {loss.item()}")