In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds

In [2]:
(mnist_train, mnist_test), ds_info = tfds.load('mnist', split=['train', 'test'], as_supervised=True, with_info=True)

In [3]:
def normalize_img(image, label):
  return tf.cast(image, tf.float32) / 255., label

def transform_labels(image, label):
  return image, tf.math.floor(label / 2)

def prepare(ds, shuffle=True, batch_size=32, prefetch=True):
  ds = ds.map(normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
  ds = ds.map(transform_labels, num_parallel_calls=tf.data.experimental.AUTOTUNE)
  ds = ds.shuffle(ds_info.splits['train'].num_examples) if shuffle else ds
  ds = ds.cache()
  ds = ds.batch(batch_size)
  ds = ds.prefetch(tf.data.experimental.AUTOTUNE) if prefetch else ds
  return ds

def split_tasks(ds, predicate):
  return ds.filter(predicate), ds.filter(lambda img, label: not predicate(img, label))

multi_task_train, multi_task_test = prepare(mnist_train), prepare(mnist_test)
task_A_train, task_B_train = split_tasks(mnist_train, lambda img, label: label % 2 == 0)
task_A_train, task_B_train = prepare(task_A_train), prepare(task_B_train)
task_A_test, task_B_test = split_tasks(mnist_test, lambda img, label: label % 2 == 0)
task_A_test, task_B_test = prepare(task_A_test), prepare(task_B_test)

In [4]:
def evaluate(model, test_set):
  acc = tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')
  for i, (imgs, labels) in enumerate(test_set):
    preds = model.predict_on_batch(imgs)
    acc.update_state(labels, preds)
  return acc.result().numpy()

In [5]:
multi_task_model = tf.keras.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(5)
])

multi_task_model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics='accuracy')

multi_task_model.fit(multi_task_train, epochs=6)

Epoch 1/6
Epoch 2/6
Epoch 3/6
Epoch 4/6
Epoch 5/6
Epoch 6/6


<keras.callbacks.History at 0x2a721960be0>

In [6]:
print("Task A accuracy after training on Multi-Task Problem: {}".format(evaluate(multi_task_model, task_A_test)))
print("Task B accuracy after training on Multi-Task Problem: {}".format(evaluate(multi_task_model, task_B_test)))

Task A accuracy after training on Multi-Task Problem: 0.9758424758911133
Task B accuracy after training on Multi-Task Problem: 0.9737879633903503


In [7]:
basic_cl_model = tf.keras.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(5)
])

basic_cl_model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics='accuracy')

basic_cl_model.fit(task_A_train, epochs=6)

Epoch 1/6
Epoch 2/6
Epoch 3/6
Epoch 4/6
Epoch 5/6
Epoch 6/6


<keras.callbacks.History at 0x2a7267ce820>

In [8]:
print("Task A accuracy after training model on only Task A: {}".format(evaluate(basic_cl_model, task_A_test)))

Task A accuracy after training model on only Task A: 0.9861956834793091


In [9]:
basic_cl_model.fit(task_B_train, epochs=6)

Epoch 1/6
Epoch 2/6
Epoch 3/6
Epoch 4/6
Epoch 5/6
Epoch 6/6


<keras.callbacks.History at 0x2a61389d880>

In [10]:
print("Task B accuracy after training trained model on Task B: {}".format(evaluate(basic_cl_model, task_B_test)))
print("Task A accuracy after training trained model on Task B: {}".format(evaluate(basic_cl_model, task_A_test)))

Task B accuracy after training trained model on Task B: 0.9838391542434692
Task A accuracy after training trained model on Task B: 0.2771010994911194


In [11]:
def l2_penalty(theta, theta_A):
  penalty = 0
  for i, theta_i in enumerate(theta):
    _penalty = tf.math.reduce_sum((theta_i - theta_A[i]) ** 2)
    penalty += _penalty
  return 0.5*penalty

def train_with_l2(model, task_A_train, task_B_train, task_A_test, task_B_test, epochs=6):
  # First we're going to fit to task A and retain a copy of parameters trained on Task A
  model.fit(task_A_train, epochs=epochs)
  theta_A = {n: p.value() for n, p in enumerate(model.trainable_variables.copy())}

  print("Task A accuracy after training on Task A: {}".format(evaluate(model, task_A_test)))

  accuracy = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
  loss = tf.keras.metrics.SparseCategoricalCrossentropy('loss')

  for epoch in range(epochs):
    accuracy.reset_states()
    loss.reset_states()
    for batch, (imgs, labels) in enumerate(task_B_train):
      with tf.GradientTape() as tape:
        preds = model(imgs)
        total_loss = model.loss(labels, preds) + l2_penalty(model.trainable_variables, theta_A)
      grads = tape.gradient(total_loss, model.trainable_variables)
      model.optimizer.apply_gradients(zip(grads, model.trainable_variables))
      
      accuracy.update_state(labels, preds)
      loss.update_state(labels, preds)
      print("\rEpoch: {}, Batch: {}, Loss: {:.3f}, Accuracy: {:.3f}".format(
          epoch+1, batch+1, loss.result().numpy(), accuracy.result().numpy()), flush=True, end=''
         )
    print("")
  
  print("Task B accuracy after training trained model on Task B: {}".format(evaluate(model, task_B_test)))
  print("Task A accuracy after training trained model on Task B: {}".format(evaluate(model, task_A_test)))

In [12]:
l2_model = tf.keras.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(5)
])

l2_model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics='accuracy')

train_with_l2(l2_model, task_A_train, task_B_train, task_A_test, task_B_test)

Epoch 1/6
Epoch 2/6
Epoch 3/6
Epoch 4/6
Epoch 5/6
Epoch 6/6
Task A accuracy after training on Task A: 0.986601710319519
Epoch: 1, Batch: 954, Loss: 4.513, Accuracy: 0.538
Epoch: 2, Batch: 954, Loss: 4.404, Accuracy: 0.546
Epoch: 3, Batch: 954, Loss: 4.398, Accuracy: 0.546
Epoch: 4, Batch: 954, Loss: 4.396, Accuracy: 0.545
Epoch: 5, Batch: 954, Loss: 4.396, Accuracy: 0.546
Epoch: 6, Batch: 954, Loss: 4.395, Accuracy: 0.546
Task B accuracy after training trained model on Task B: 0.5650374293327332
Task A accuracy after training trained model on Task B: 0.8828664422035217


In [7]:
def compute_precision_matrices(model, task_set, num_batches=1, batch_size=32):
  task_set = task_set.repeat()
  precision_matrices = {n: tf.zeros_like(p.value()) for n, p in enumerate(model.trainable_variables)}

  for i, (imgs, labels) in enumerate(task_set.take(num_batches)):
    # We need gradients of model params
    with tf.GradientTape() as tape:
      # Get model predictions for each image
      preds = model(imgs)
      # Get the log likelihoods of the predictions
      ll = tf.nn.log_softmax(preds)
    # Attach gradients of ll to ll_grads
    ll_grads = tape.gradient(ll, model.trainable_variables)
    # Compute F_i as mean of gradients squared
    for i, g in enumerate(ll_grads):
      precision_matrices[i] += tf.math.reduce_mean(g ** 2, axis=0) / num_batches

  return precision_matrices

def compute_elastic_penalty(F, theta, theta_A, alpha=25):
  penalty = 0
  for i, theta_i in enumerate(theta):
    _penalty = tf.math.reduce_sum(F[i] * (theta_i - theta_A[i]) ** 2)
    penalty += _penalty
  return 0.5*alpha*penalty

def ewc_loss(labels, preds, model, F, theta_A):
  loss_b = model.loss(labels, preds)
  penalty = compute_elastic_penalty(F, model.trainable_variables, theta_A)
  return loss_b + penalty

def train_with_ewc(model, task_A_set, task_B_set, task_A_test, task_B_test, epochs=3):
  # First we're going to fit to task A and retain a copy of parameters trained on Task A
  model.fit(task_A_set, epochs=epochs)
  theta_A = {n: p.value() for n, p in enumerate(model.trainable_variables.copy())}
  # We'll only compute Fisher once, you can do it whenever
  F = compute_precision_matrices(model, task_A_set, num_batches=1000)

  print("Task A accuracy after training on Task A: {}".format(evaluate(model, task_A_test)))

  # Now we set up the training loop for task B with EWC
  accuracy = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
  loss = tf.keras.metrics.SparseCategoricalCrossentropy('loss')

  for epoch in range(epochs*3):
    accuracy.reset_states()
    loss.reset_states()

    for batch, (imgs, labels) in enumerate(task_B_set):
      with tf.GradientTape() as tape:
        # Make the predictions
        preds = model(imgs)
        # Compute EWC loss
        total_loss = ewc_loss(labels, preds, model, F, theta_A)
      # Compute the gradients of model's trainable parameters wrt total loss
      grads = tape.gradient(total_loss, model.trainable_variables)
      # Update the model with gradients
      model.optimizer.apply_gradients(zip(grads, model.trainable_variables))
      # Report updated loss and accuracy
      accuracy.update_state(labels, preds)
      loss.update_state(labels, preds)
      print("\rEpoch: {}, Batch: {}, Loss: {:.3f}, Accuracy: {:.3f}".format(
          epoch+1, batch+1, loss.result().numpy(), accuracy.result().numpy()), flush=True, end=''
         )
    print("")

  print("Task B accuracy after training trained model on Task B: {}".format(evaluate(model, task_B_test)))
  print("Task A accuracy after training trained model on Task B: {}".format(evaluate(model, task_A_test)))

In [8]:
ewc_model = tf.keras.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.5),
  tf.keras.layers.Dense(5)
])

ewc_model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics='accuracy')

train_with_ewc(ewc_model, task_A_train, task_B_train, task_A_test, task_B_test)

Epoch 1/3
Epoch 2/3
Epoch 3/3
Task A accuracy after training on Task A: 0.9815266132354736
Epoch: 1, Batch: 954, Loss: 10.589, Accuracy: 0.062
Epoch: 2, Batch: 954, Loss: 10.059, Accuracy: 0.131
Epoch: 3, Batch: 954, Loss: 10.034, Accuracy: 0.165
Epoch: 4, Batch: 954, Loss: 10.029, Accuracy: 0.177
Epoch: 5, Batch: 954, Loss: 10.022, Accuracy: 0.185
Epoch: 6, Batch: 954, Loss: 10.019, Accuracy: 0.189
Epoch: 7, Batch: 954, Loss: 10.016, Accuracy: 0.191
Epoch: 8, Batch: 954, Loss: 10.016, Accuracy: 0.193
Epoch: 9, Batch: 954, Loss: 10.012, Accuracy: 0.194
Task B accuracy after training trained model on Task B: 0.20555774867534637
Task A accuracy after training trained model on Task B: 0.9593991041183472
