<a href="https://colab.research.google.com/github/OneFineStarstuff/TheOneEverAfter/blob/main/Implementing_the_Elastic_Weight_Consolidation_(EWC)_loss.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Placeholder for task-specific loss computation
def compute_task_loss(agent, task_data):
    # Example task loss computation (replace with actual task-specific logic)
    inputs, targets = task_data
    outputs = agent(inputs)
    loss = nn.CrossEntropyLoss()(outputs, targets)
    return loss

# EWC loss computation
def ewc_loss(agent, task_data, importance_weights, param_old):
    # Standard loss for current task
    loss = compute_task_loss(agent, task_data)

    # Regularization for previous tasks
    for name, param in agent.named_parameters():
        if name in importance_weights:
            loss += (importance_weights[name] * (param - param_old[name]).pow(2)).sum()

    return loss

# Function to save old parameters
def save_old_params(agent):
    return {name: param.clone().detach() for name, param in agent.named_parameters()}

# Example model definition
class SimpleModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.fc(x)

# Initialize model, importance weights, and previous parameters
input_dim = 10
output_dim = 2
model = SimpleModel(input_dim, output_dim)

# Example task data (inputs and targets)
inputs = torch.randn(32, input_dim)
targets = torch.randint(0, output_dim, (32,))
task_data = (inputs, targets)

# Save old parameters
param_old = save_old_params(model)

# Example importance weights (replace with actual computation)
importance_weights = {name: torch.ones_like(param) for name, param in model.named_parameters()}

# Compute EWC loss
loss = ewc_loss(model, task_data, importance_weights, param_old)

print(f'EWC Loss: {loss.item()}')