# DeepART Drafting

## Load Modules

In [11]:
import tensorflow as tf
import tensorflow_datasets as tfds
print(tf.__version__)

2.6.0


## Load the data

In [12]:
# (mnist_train, mnist_test), ds_info = tfds.load('mnist', split=['train', 'test'], as_supervised=True, with_info=True)
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

def normalize_img(image, label):
    """Normalizes images: `uint8` -> `float32`."""
    return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

## Train the model

In [13]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)

Epoch 1/6
Epoch 2/6
Epoch 3/6
Epoch 4/6
Epoch 5/6
Epoch 6/6


<keras.callbacks.History at 0x27ba8357670>

## EWC + ART training

In [25]:
dim = 28
model_bu = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(dim, dim)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])

model_td = tf.keras.models.Sequential([
  # tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Flatten(input_shape=(10,)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(dim**2)
])

model_bu.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model_td.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

In [21]:
# def compute_precision_matrices(model, task_set, num_batches=1, batch_size=32):
#   task_set = task_set.repeat()
#   precision_matrices = {n: tf.zeros_like(p.value()) for n, p in enumerate(model.trainable_variables)}

#   for i, (imgs, labels) in enumerate(task_set.take(num_batches)):
#     # We need gradients of model params
#     with tf.GradientTape() as tape:
#       # Get model predictions for each image
#       preds = model(imgs)
#       # Get the log likelihoods of the predictions
#       ll = tf.nn.log_softmax(preds)
#     # Attach gradients of ll to ll_grads
#     ll_grads = tape.gradient(ll, model.trainable_variables)
#     # Compute F_i as mean of gradients squared
#     for i, g in enumerate(ll_grads):
#       precision_matrices[i] += tf.math.reduce_mean(g ** 2, axis=0) / num_batches

#   return precision_matrices

# def compute_elastic_penalty(F, theta, theta_A, alpha=25):
#   penalty = 0
#   for i, theta_i in enumerate(theta):
#     _penalty = tf.math.reduce_sum(F[i] * (theta_i - theta_A[i]) ** 2)
#     penalty += _penalty
#   return 0.5*alpha*penalty

# def ewc_loss(labels, preds, model, F, theta_A):
#   loss_b = model.loss(labels, preds)
#   penalty = compute_elastic_penalty(F, model.trainable_variables, theta_A)
#   return loss_b + penalty

def l2_penalty(theta, theta_A):
    penalty = 0
    for i, theta_i in enumerate(theta):
        penalty += tf.math.reduce_sum(tf.math.square(theta_i - theta_A[i]))
    return 0.5*penalty

def train_art(model_bu, model_td, train, test, epochs=6):
    # We'll only compute Fisher once, you can do it whenever
    # F = compute_precision_matrices(model, task_A_set, num_batches=1000)

    theta_bu = {n: p.value() for n, p in enumerate(model_bu.trainable_variables.copy())}
    theta_td = {n: p.value() for n, p in enumerate(model_td.trainable_variables.copy())}
    
    accuracy_bu = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
    loss_bu = tf.keras.metrics.SparseCategoricalCrossentropy('loss')

    accuracy_td = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
    loss_td = tf.keras.metrics.SparseCategoricalCrossentropy('loss')

    for epoch in range(epochs):
        accuracy_bu.reset_states()
        loss_bu.reset_states()
        accuracy_td.reset_states()
        loss_td.reset_states()

        # for batch, (imgs, labels) in enumerate(train):
        #     for img in imgs:
        #         with tf.GradientTape() as tape:
        #             pred = model_bu(img)
        #             total_loss = model.loss(labels, preds) + l2_penalty(model.trainable_variables, theta_A)
                # preds = model(imgs)
    #         total_loss = model.loss(labels, preds) + l2_penalty(model.trainable_variables, theta_A)
    #     grads = tape.gradient(total_loss, model.trainable_variables)
    #     model.optimizer.apply_gradients(zip(grads, model.trainable_variables))
        
    #     accuracy.update_state(labels, preds)
    #     loss.update_state(labels, preds)
    #     print("\rEpoch: {}, Batch: {}, Loss: {:.3f}, Accuracy: {:.3f}".format(
    #         epoch+1, batch+1, loss.result().numpy(), accuracy.result().numpy()), flush=True, end=''
    #         )
    #     print("")

train_art(model_bu, model_td, ds_train, ds_test, epochs=1)


In [14]:

# def l2_penalty(theta, theta_A):
#     penalty = 0
#     for i, theta_i in enumerate(theta):
#         _penalty = tf.math.reduce_sum((theta_i - theta_A[i]) ** 2)
#         penalty += _penalty
#     return 0.5*penalty

# def train_art(model, train, test, epochs=6):
#     #     theta_A = {n: p.value() for n, p in enumerate(model.trainable_variables.copy())}
#     accuracy = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
#     loss = tf.keras.metrics.SparseCategoricalCrossentropy('loss')
#     for epoch in range(epochs):
#         accuracy.reset_states()
#         loss.reset_states()
#         for batch, (imgs, labels) in enumerate(task_B_train):
#         with tf.GradientTape() as tape:
#             preds = model(imgs)
#             total_loss = model.loss(labels, preds) + l2_penalty(model.trainable_variables, theta_A)
#         grads = tape.gradient(total_loss, model.trainable_variables)
#         model.optimizer.apply_gradients(zip(grads, model.trainable_variables))
        
#         accuracy.update_state(labels, preds)
#         loss.update_state(labels, preds)
#         print("\rEpoch: {}, Batch: {}, Loss: {:.3f}, Accuracy: {:.3f}".format(
#             epoch+1, batch+1, loss.result().numpy(), accuracy.result().numpy()), flush=True, end=''
#             )
#         print("")
  

# def train_with_l2(model, task_A_train, task_B_train, task_A_test, task_B_test, epochs=6):
#     # First we're going to fit to task A and retain a copy of parameters trained on Task A
#     model.fit(task_A_train, epochs=epochs)
#     theta_A = {n: p.value() for n, p in enumerate(model.trainable_variables.copy())}

#     print("Task A accuracy after training on Task A: {}".format(evaluate(model, task_A_test)))

#     accuracy = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
#     loss = tf.keras.metrics.SparseCategoricalCrossentropy('loss')

#     for epoch in range(epochs):
#         accuracy.reset_states()
#         loss.reset_states()
#         for batch, (imgs, labels) in enumerate(task_B_train):
#         with tf.GradientTape() as tape:
#             preds = model(imgs)
#             total_loss = model.loss(labels, preds) + l2_penalty(model.trainable_variables, theta_A)
#         grads = tape.gradient(total_loss, model.trainable_variables)
#         model.optimizer.apply_gradients(zip(grads, model.trainable_variables))
        
#         accuracy.update_state(labels, preds)
#         loss.update_state(labels, preds)
#         print("\rEpoch: {}, Batch: {}, Loss: {:.3f}, Accuracy: {:.3f}".format(
#             epoch+1, batch+1, loss.result().numpy(), accuracy.result().numpy()), flush=True, end=''
#             )
#         print("")
  
#     print("Task B accuracy after training trained model on Task B: {}".format(evaluate(model, task_B_test)))
#     print("Task A accuracy after training trained model on Task B: {}".format(evaluate(model, task_A_test)))