In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds

In [2]:
(mnist_train, mnist_test), ds_info = tfds.load('mnist', split=['train', 'test'], as_supervised=True, with_info=True)

[1mDownloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to C:\Users\Sasha\tensorflow_datasets\mnist\3.0.1...[0m


Dl Completed...: 0 url [00:00, ? url/s]
Dl Completed...:   0%|          | 0/1 [00:00<?, ? url/s]
Dl Completed...:   0%|          | 0/2 [00:00<?, ? url/s]
Dl Completed...:   0%|          | 0/3 [00:00<?, ? url/s]
Dl Completed...:   0%|          | 0/4 [00:00<?, ? url/s]
Dl Completed...:   0%|          | 0/4 [00:00<?, ? url/s]
Dl Completed...:  25%|██▌       | 1/4 [00:00<00:00,  5.15 url/s]
Dl Completed...:  25%|██▌       | 1/4 [00:00<00:00,  5.15 url/s]
Dl Completed...:  25%|██▌       | 1/4 [00:00<00:00,  5.15 url/s]
Dl Completed...:  25%|██▌       | 1/4 [00:00<00:00,  5.15 url/s]
Dl Completed...:  25%|██▌       | 1/4 [00:00<00:00,  5.15 url/s]
Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  5.15 url/s]
[A
Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  5.15 url/s]
Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  5.15 url/s]
Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  5.15 url/s]
Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  5.15 url/s]
Dl Completed...:  50%

[1mDataset mnist downloaded and prepared to C:\Users\Sasha\tensorflow_datasets\mnist\3.0.1. Subsequent calls will reuse this data.[0m




In [3]:
def normalize_img(image, label):
  return tf.cast(image, tf.float32) / 255., label

def transform_labels(image, label):
  return image, tf.math.floor(label / 2)

def prepare(ds, shuffle=True, batch_size=32, prefetch=True):
  ds = ds.map(normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
  ds = ds.map(transform_labels, num_parallel_calls=tf.data.experimental.AUTOTUNE)
  ds = ds.shuffle(ds_info.splits['train'].num_examples) if shuffle else ds
  ds = ds.cache()
  ds = ds.batch(batch_size)
  ds = ds.prefetch(tf.data.experimental.AUTOTUNE) if prefetch else ds
  return ds

def split_tasks(ds, predicate):
  return ds.filter(predicate), ds.filter(lambda img, label: not predicate(img, label))

multi_task_train, multi_task_test = prepare(mnist_train), prepare(mnist_test)
task_A_train, task_B_train = split_tasks(mnist_train, lambda img, label: label % 2 == 0)
task_A_train, task_B_train = prepare(task_A_train), prepare(task_B_train)
task_A_test, task_B_test = split_tasks(mnist_test, lambda img, label: label % 2 == 0)
task_A_test, task_B_test = prepare(task_A_test), prepare(task_B_test)

In [5]:
def evaluate(model, test_set):
  acc = tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')
  for i, (imgs, labels) in enumerate(test_set):
    preds = model.predict_on_batch(imgs)
    acc.update_state(labels, preds)
  return acc.result().numpy()

In [6]:
multi_task_model = tf.keras.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(5)
])

multi_task_model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics='accuracy')

multi_task_model.fit(multi_task_train, epochs=6)

Epoch 1/6
Epoch 2/6
Epoch 3/6
Epoch 4/6
Epoch 5/6
Epoch 6/6


<keras.callbacks.History at 0x256c51aba90>

In [7]:
print("Task A accuracy after training on Multi-Task Problem: {}".format(evaluate(multi_task_model, task_A_test)))
print("Task B accuracy after training on Multi-Task Problem: {}".format(evaluate(multi_task_model, task_B_test)))

Task A accuracy after training on Multi-Task Problem: 0.9748274683952332
Task B accuracy after training on Multi-Task Problem: 0.9755616784095764


In [None]:
basic_cl_model = tf.keras.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(5)
])

basic_cl_model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics='accuracy')

basic_cl_model.fit(task_A_train, epochs=6)

Epoch 1/6
Epoch 2/6
Epoch 3/6
Epoch 4/6
Epoch 5/6
Epoch 6/6


<keras.callbacks.History at 0x25455f00d60>

In [9]:
print("Task A accuracy after training model on only Task A: {}".format(evaluate(basic_cl_model, task_A_test)))

Task A accuracy after training model on only Task A: 0.984774649143219


In [10]:
basic_cl_model.fit(task_B_train, epochs=6)

Epoch 1/6
Epoch 2/6
Epoch 3/6
Epoch 4/6
Epoch 5/6
Epoch 6/6


<keras.callbacks.History at 0x253cd4f8fd0>

In [12]:
print("Task B accuracy after training trained model on Task B: {}".format(evaluate(basic_cl_model, task_B_test)))
print("Task A accuracy after training trained model on Task B: {}".format(evaluate(basic_cl_model, task_A_test)))

Task B accuracy after training trained model on Task B: 0.9848245978355408
Task A accuracy after training trained model on Task B: 0.22107186913490295


In [13]:
def l2_penalty(theta, theta_A):
  penalty = 0
  for i, theta_i in enumerate(theta):
    _penalty = tf.math.reduce_sum((theta_i - theta_A[i]) ** 2)
    penalty += _penalty
  return 0.5*penalty

def train_with_l2(model, task_A_train, task_B_train, task_A_test, task_B_test, epochs=6):
  # First we're going to fit to task A and retain a copy of parameters trained on Task A
  model.fit(task_A_train, epochs=epochs)
  theta_A = {n: p.value() for n, p in enumerate(model.trainable_variables.copy())}

  print("Task A accuracy after training on Task A: {}".format(evaluate(model, task_A_test)))

  accuracy = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
  loss = tf.keras.metrics.SparseCategoricalCrossentropy('loss')

  for epoch in range(epochs):
    accuracy.reset_states()
    loss.reset_states()
    for batch, (imgs, labels) in enumerate(task_B_train):
      with tf.GradientTape() as tape:
        preds = model(imgs)
        total_loss = model.loss(labels, preds) + l2_penalty(model.trainable_variables, theta_A)
      grads = tape.gradient(total_loss, model.trainable_variables)
      model.optimizer.apply_gradients(zip(grads, model.trainable_variables))
      
      accuracy.update_state(labels, preds)
      loss.update_state(labels, preds)
      print("\rEpoch: {}, Batch: {}, Loss: {:.3f}, Accuracy: {:.3f}".format(
          epoch+1, batch+1, loss.result().numpy(), accuracy.result().numpy()), flush=True, end=''
         )
    print("")
  
  print("Task B accuracy after training trained model on Task B: {}".format(evaluate(model, task_B_test)))
  print("Task A accuracy after training trained model on Task B: {}".format(evaluate(model, task_A_test)))

In [14]:
l2_model = tf.keras.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(5)
])

l2_model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics='accuracy')

train_with_l2(l2_model, task_A_train, task_B_train, task_A_test, task_B_test)

Epoch 1/6
Epoch 2/6
Epoch 3/6
Epoch 4/6
Epoch 5/6
Epoch 6/6
Task A accuracy after training on Task A: 0.9872106909751892
Epoch: 1, Batch: 954, Loss: 4.584, Accuracy: 0.530
Epoch: 2, Batch: 954, Loss: 4.460, Accuracy: 0.537
Epoch: 3, Batch: 954, Loss: 4.453, Accuracy: 0.537
Epoch: 4, Batch: 954, Loss: 4.453, Accuracy: 0.537
Epoch: 5, Batch: 954, Loss: 4.453, Accuracy: 0.537
Epoch: 6, Batch: 954, Loss: 4.450, Accuracy: 0.537
Task B accuracy after training trained model on Task B: 0.536657452583313
Task A accuracy after training trained model on Task B: 0.9332115054130554


In [16]:
def compute_precision_matrices(model, task_set, num_batches=1, batch_size=32):
  task_set = task_set.repeat()
  precision_matrices = {n: tf.zeros_like(p.value()) for n, p in enumerate(model.trainable_variables)}

  for i, (imgs, labels) in enumerate(task_set.take(num_batches)):
    # We need gradients of model params
    with tf.GradientTape() as tape:
      # Get model predictions for each image
      preds = model(imgs)
      # Get the log likelihoods of the predictions
      ll = tf.nn.log_softmax(preds)
    # Attach gradients of ll to ll_grads
    ll_grads = tape.gradient(ll, model.trainable_variables)
    # Compute F_i as mean of gradients squared
    for i, g in enumerate(ll_grads):
      precision_matrices[i] += tf.math.reduce_mean(g ** 2, axis=0) / num_batches

  return precision_matrices

def compute_elastic_penalty(F, theta, theta_A, alpha=25):
  penalty = 0
  for i, theta_i in enumerate(theta):
    _penalty = tf.math.reduce_sum(F[i] * (theta_i - theta_A[i]) ** 2)
    penalty += _penalty
  return 0.5*alpha*penalty

def ewc_loss(labels, preds, model, F, theta_A):
  loss_b = model.loss(labels, preds)
  penalty = compute_elastic_penalty(F, model.trainable_variables, theta_A)
  return loss_b + penalty

def train_with_ewc(model, task_A_set, task_B_set, task_A_test, task_B_test, epochs=3):
  # First we're going to fit to task A and retain a copy of parameters trained on Task A
  model.fit(task_A_set, epochs=epochs)
  theta_A = {n: p.value() for n, p in enumerate(model.trainable_variables.copy())}
  # We'll only compute Fisher once, you can do it whenever
  F = compute_precision_matrices(model, task_A_set, num_batches=1000)

  print("Task A accuracy after training on Task A: {}".format(evaluate(model, task_A_test)))

  # Now we set up the training loop for task B with EWC
  accuracy = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
  loss = tf.keras.metrics.SparseCategoricalCrossentropy('loss')

  for epoch in range(epochs*3):
    accuracy.reset_states()
    loss.reset_states()

    for batch, (imgs, labels) in enumerate(task_B_set):
      with tf.GradientTape() as tape:
        # Make the predictions
        preds = model(imgs)
        # Compute EWC loss
        total_loss = ewc_loss(labels, preds, model, F, theta_A)
      # Compute the gradients of model's trainable parameters wrt total loss
      grads = tape.gradient(total_loss, model.trainable_variables)
      # Update the model with gradients
      model.optimizer.apply_gradients(zip(grads, model.trainable_variables))
      # Report updated loss and accuracy
      accuracy.update_state(labels, preds)
      loss.update_state(labels, preds)
      print("\rEpoch: {}, Batch: {}, Loss: {:.3f}, Accuracy: {:.3f}".format(
          epoch+1, batch+1, loss.result().numpy(), accuracy.result().numpy()), flush=True, end=''
         )
    print("")

  print("Task B accuracy after training trained model on Task B: {}".format(evaluate(model, task_B_test)))
  print("Task A accuracy after training trained model on Task B: {}".format(evaluate(model, task_A_test)))

In [17]:
ewc_model = tf.keras.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.5),
  tf.keras.layers.Dense(5)
])

ewc_model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics='accuracy')

train_with_ewc(ewc_model, task_A_train, task_B_train, task_A_test, task_B_test)

Epoch 1/3
Epoch 2/3
Epoch 3/3
Task A accuracy after training on Task A: 0.9825416207313538
Epoch: 1, Batch: 954, Loss: 9.363, Accuracy: 0.249
Epoch: 2, Batch: 954, Loss: 8.059, Accuracy: 0.391
Epoch: 3, Batch: 954, Loss: 7.881, Accuracy: 0.427
Epoch: 4, Batch: 954, Loss: 7.817, Accuracy: 0.443
Epoch: 5, Batch: 954, Loss: 7.780, Accuracy: 0.453
Epoch: 6, Batch: 954, Loss: 7.749, Accuracy: 0.460
Epoch: 7, Batch: 954, Loss: 7.733, Accuracy: 0.465
Epoch: 8, Batch: 954, Loss: 7.723, Accuracy: 0.469
Epoch: 9, Batch: 954, Loss: 7.712, Accuracy: 0.472
Task B accuracy after training trained model on Task B: 0.47516751289367676
Task A accuracy after training trained model on Task B: 0.5783597230911255
