In [None]:
import tensorflow as tf
import numpy as np
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import mnist

# 1. Load two different tasks (e.g., classify digits 0–4 and 5–9)
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype("float32") / 255.
x_test = x_test.astype("float32") / 255.
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)

# Task A: digits 0–4
task_a_idx = y_train < 5
X_a, y_a = x_train[task_a_idx], y_train[task_a_idx]

# Task B: digits 5–9
task_b_idx = y_train >= 5
X_b, y_b = x_train[task_b_idx], y_train[task_b_idx] - 5  # relabel to 0–4

# 2. Define a small CNN model
def build_model(num_classes):
    model = models.Sequential([
        layers.Input(shape=(28, 28, 1)),
        layers.Conv2D(16, (3,3), activation='relu'),
        layers.MaxPooling2D(),
        layers.Flatten(),
        layers.Dense(64, activation='relu'),
        layers.Dense(num_classes, activation='softmax')
    ])
    return model

model = build_model(num_classes=5)

# 3. Train on Task A
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(X_a, y_a, epochs=5, batch_size=64, verbose=0)


# === Prepare test splits ===
task_a_test_idx = y_test < 5
X_a_test, y_a_test = x_test[task_a_test_idx], y_test[task_a_test_idx]

task_b_test_idx = y_test >= 5
X_b_test, y_b_test = x_test[task_b_test_idx], y_test[task_b_test_idx] - 5  # relabel to 0–4

# === Evaluate after Task A training ===
print("\n✅ Evaluation after Task A training:")
acc_a = model.evaluate(X_a_test, y_a_test, verbose=0)[1]
acc_b = model.evaluate(X_b_test, y_b_test, verbose=0)[1]
print(f"Accuracy on Task A (0–4): {acc_a:.4f}")
print(f"Accuracy on Task B (5–9): {acc_b:.4f}  ← Expected to be low (not trained yet)")


# 4. Save weights and compute importance (approximated via gradients)
weights_task_a = model.get_weights()

# Compute importance (Fisher-like approximation using gradients on task A data)
importance = []
for var in model.trainable_variables:
    importance.append(tf.zeros_like(var))

batch = tf.data.Dataset.from_tensor_slices((X_a, y_a)).batch(64)

for x_batch, y_batch in batch:
    with tf.GradientTape() as tape:
        preds = model(x_batch)
        loss = tf.keras.losses.sparse_categorical_crossentropy(y_batch, preds)
    grads = tape.gradient(loss, model.trainable_variables)
    for i, grad in enumerate(grads):
        if grad is not None:
            importance[i] += tf.square(grad)

# Normalize importance
importance = [imp / len(batch) for imp in importance]

# 5. Fine-tune on Task B with EWC regularization
lambda_ewc = 1000.0  # importance regularization strength

optimizer = tf.keras.optimizers.Adam()
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()

# Custom training loop with EWC
for epoch in range(5):
    print(f"Epoch {epoch + 1}")
    for step, (x_batch, y_batch) in enumerate(tf.data.Dataset.from_tensor_slices((X_b, y_b)).batch(64)):
        with tf.GradientTape() as tape:
            preds = model(x_batch, training=True)
            loss = loss_fn(y_batch, preds)
            # Add EWC regularization
            for var, old_w, imp in zip(model.trainable_variables, weights_task_a, importance):
                loss += (lambda_ewc / 2) * tf.reduce_sum(imp * tf.square(var - old_w))
        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

# === Evaluate after Task B training with EWC ===
print("\n✅ Evaluation after Task B training (EWC fine-tuning):")
acc_a = model.evaluate(X_a_test, y_a_test, verbose=0)[1]
acc_b = model.evaluate(X_b_test, y_b_test, verbose=0)[1]
print(f"Accuracy on Task A (0–4): {acc_a:.4f}  ← Should remain high if EWC works")
print(f"Accuracy on Task B (5–9): {acc_b:.4f}")
