# Implementation: Elastic Weight Consolidation (Conceptual)

**Goal**: Add a penalty loss for changing important weights.

In [None]:
import torch
import torch.nn as nn

# 1. Mock weights after Task A
old_weights = torch.tensor([1.0, 2.0, 3.0], requires_grad=False)

# 2. Importance of weights (Fisher Information Matrix approx)
# Suppose Weight 2 was VERY important for Task A -> High importance
importance = torch.tensor([0.1, 10.0, 0.1], requires_grad=False)

# 3. Current training for Task B
new_weights = torch.tensor([1.2, 2.5, 5.0], requires_grad=True)

def ewc_loss(new, old, importance, lambda_reg=100):
    # Standard EWC: sum( importance * (new - old)^2 )
    penalty = torch.sum(importance * (new - old)**2)
    return lambda_reg * penalty

loss = ewc_loss(new_weights, old_weights, importance)
print(f"EWC Penalty: {loss.item()}")

# Logic:
# Changing weight 1 from 1.0 -> 1.2 (diff 0.2) cost 0.1 * 0.04 = 0.004
# Changing weight 2 from 2.0 -> 2.5 (diff 0.5) cost 10.0 * 0.25 = 2.5 (Huge penalty!)
# Changing weight 3 from 3.0 -> 5.0 (diff 2.0) cost 0.1 * 4.0 = 0.4
# The model is forced to keep Weight 2 close to 2.0.

## Conclusion
EWC allows the model to shift "unimportant" weights (like w3) to learn Task B, while protecting "important" weights (w2) to preserve Task A.