In [5]:
import tensorflow as tf
import tensorflow_datasets as tfds
import copy

In [6]:
(mnist_train, mnist_test), ds_info = tfds.load('mnist', split=['train', 'test'], as_supervised=True, with_info=True)

def normalize_img(image, label):
  return tf.cast(image, tf.float32) / 255., label

def transform_labels(image, label):
  return image, tf.math.floor(label / 2)

def prepare(ds, shuffle=True, batch_size=32, prefetch=True):
  ds = ds.map(normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
  ds = ds.map(transform_labels, num_parallel_calls=tf.data.experimental.AUTOTUNE)
  ds = ds.shuffle(ds_info.splits['train'].num_examples) if shuffle else ds
  ds = ds.cache()
  ds = ds.batch(batch_size)
  ds = ds.prefetch(tf.data.experimental.AUTOTUNE) if prefetch else ds
  return ds

def split_tasks(ds, predicate):
  return ds.filter(predicate), ds.filter(lambda img, label: not predicate(img, label))

multi_task_train, multi_task_test = prepare(mnist_train), prepare(mnist_test)
task_A_train, task_B_train = split_tasks(mnist_train, lambda img, label: label % 2 == 0)
task_A_train, task_B_train = prepare(task_A_train), prepare(task_B_train)
task_A_test, task_B_test = split_tasks(mnist_test, lambda img, label: label % 2 == 0)
task_A_test, task_B_test = prepare(task_A_test), prepare(task_B_test)

2022-03-29 11:50:48.927309: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2022-03-29 11:50:48.927802: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


Metal device set to: Apple M1

systemMemory: 8.00 GB
maxCacheSize: 2.67 GB



In [7]:
def evaluate(model, test_set):
  acc = tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')
  for i, (imgs, labels) in enumerate(test_set):
    preds = model.predict_on_batch(imgs)
    acc.update_state(labels, preds)
  return acc.result().numpy()

In [64]:
class Model():

    def __init__(self):

        input= tf.keras.Input(shape=(28, 28, 1))
        l1=tf.keras.layers.Dense(128, activation='relu')(input)
        self.feature_extractor=tf.keras.Model(input,l1)

        self.fc=tf.keras.layers.Dense(1)

        self.model=tf.keras.Model(input,self.fc(self.feature_extractor(input)))

        self.model.build(input_shape=(None, 28, 28, 1))

        # self.model= tf.keras.Model(input,self.fc)


        self.n_classes=1
        self.classes=[]

    def forward(self,x):
        return self.model.predict(x)

    def fit_existing_model(self,dataset, epochs=10):
        self.model.compile(optimizer='adam',
                      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                      metrics=['accuracy'])

        self.model.fit(dataset, epochs=epochs)

    def update_model(self,dataset):
        new_classes=[]
        for i, (imgs, labels) in enumerate(dataset):
            for label in labels:
                if label not in new_classes:
                    new_classes.append(label.numpy())
        new_n=len(new_classes)
        new_size=self.n_classes+new_n


        prev_model=copy.deepcopy(self)
        weights_old=self.fc.get_weights()[0]

        new_layer=tf.keras.layers.Dense(new_size,kernel_initializer='random_normal',bias_initializer='zeros')
        new_layer.build(input_shape=self.feature_extractor.output_shape[1:])
        # print(new_layer.get_weights())
        weights_new=new_layer.get_weights()[0]
        print(weights_new.shape)
        print(weights_old.shape)
        weights_new[:,:self.n_classes]=weights_old
        weights_new_final=[weights_new,new_layer.get_weights()[1]]
        new_layer.set_weights(weights_new_final)

        accuracy = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
        loss_cls = tf.keras.metrics.SparseCategoricalCrossentropy('loss')
        loss_dist = tf.keras.metrics.SparseCategoricalCrossentropy('loss')

        self.model.compile(optimizer='adam',
                      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                      metrics=['accuracy'])

        for epoch in range(15):

            accuracy.reset_states()
            loss_cls.reset_states()
            loss_dist.reset_states()
            for i, (imgs, labels) in enumerate(dataset):

                with tf.GradientTape() as tape:
                    logits = self.model.predict(imgs)
                    cls_loss = self.model.loss(labels, logits)
                          
                    # dist_target= prev_model.forward(imgs)
                    # logits_dist=logits[:,:-len(new_classes)]
                    # dist_loss=loss_dist(dist_target,logits_dist)

                    loss_value = cls_loss #+ dist_loss
                grads=tape.gradient(loss_value, self.model.trainable_variables)
                self.model.optimizer.apply_gradients(zip(grads, self.model.trainable_variables))

        self.n_classes=new_size


In [65]:
model=Model()

model.update_model(task_A_train)

(128, 6)
(128, 1)


LookupError: No gradient defined for operation 'IteratorGetNext' (op type: IteratorGetNext)

In [23]:
for i, (imgs, labels) in enumerate(task_A_test):
    print((imgs.numpy()).shape)

(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28, 28, 1)
(32, 28,