In [None]:
import tensorflow_datasets as tfds
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Dense
import datetime
import pprint
import tqdm

%load_ext tensorboard
%tensorboard --logdir logs

(train_ds, test_ds), ds_info = tfds.load('cifar10', split=['train', 'test'], as_supervised=True, with_info=True)

In [None]:
def prepare_cifar_data(cifar):
  #convert data from uint8 to float32
  cifar = cifar.map(lambda img, target: (tf.cast(img, tf.float32), target))
  #sloppy input normalization, just bringing image values from range [0, 255] to [-1, 1]
  cifar = cifar.map(lambda img, target: ((img/128.)-1., target))
  #create one-hot targets
  cifar = cifar.map(lambda img, target: (img, tf.one_hot(target, depth=10)))
  #cache this progress in memory, as there is no need to redo it; it is deterministic after all
  cifar = cifar.cache()
  #shuffle, batch, prefetch
  cifar = cifar.shuffle(1000)
  cifar = cifar.batch(32)
  cifar = cifar.prefetch(20)
  #return preprocessed dataset
  return cifar

train_ds = train_ds.apply(prepare_cifar_data)
test_ds = test_ds.apply(prepare_cifar_data)


def try_model(model, ds):
  for x, t in ds.take(5):
    y = model(x)

In [None]:
class CifarConv(tf.keras.Model):
    def __init__(self):
        super(CifarConv, self).__init__()

        self.convlayer1 = tf.keras.layers.Conv2D(filters=24, kernel_size=3, padding='same', activation='relu')
        self.convlayer2 = tf.keras.layers.Conv2D(filters=24, kernel_size=3, padding='same', activation='relu')
        self.pooling = tf.keras.layers.MaxPooling2D(pool_size=2, strides=2)

        self.convlayer3 = tf.keras.layers.Conv2D(filters=48, kernel_size=3, padding='same', activation='relu')
        self.convlayer4 = tf.keras.layers.Conv2D(filters=48, kernel_size=3, padding='same', activation='relu')
        self.global_pool = tf.keras.layers.GlobalAvgPool2D()

        self.out = tf.keras.layers.Dense(10, activation='softmax')

        self.optimizer = tf.keras.optimizers.Adam()
        self.loss_function = tf.keras.losses.CategoricalCrossentropy()
        self.metrics_list = [tf.keras.metrics.Mean(name="loss"),
        tf.keras.metrics.Accuracy(name="accuracy")]

    def call(self, x):
        x = self.convlayer1(x)
        x = self.convlayer2(x)
        x = self.pooling(x)
        x = self.convlayer3(x)
        x = self.convlayer4(x)
        x = self.global_pool(x)
        x = self.out(x)
        return x

    # 3. metrics property
    @property 
    def metrics(self):
      # return a list with all metrics in the model
      return self.metrics_list
    
    # 4. reset all metrics objects
    def reset_metrics(self):
      for metric in self.metrics:
        metric.reset_states()

    def train_step(self, data): 
      img, label = data

      with tf.GradientTape() as tape:
        output = self(img, training=True)
        loss = self.loss_function(label, output)

      gradients = tape.gradient(loss, self.trainable_variables)
    
      self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

      # update loss metric
      self.metrics[0].update_state(loss)
      # for all metrics except loss, update states (accuracy etc.)
      for metric in self.metrics[1:]:
        metric.update_state(label, output)
      return {m.name: m.result() for m in self.metrics}
    
      # 6. test step
    def test_step(self, data):
      img, label = data

      output = self((img), training=False)
      loss = self.loss_function(label, output)
    
      # update loss metric
      self.metrics[0].update_state(loss)
      # for all metrics except loss, update states (accuracy etc.)
      for metric in self.metrics[1:]:
        metric.update_state(label, output)
      return {m.name: m.result() for m in self.metrics}


cifar_model = CifarConv()
try_model(cifar_model, train_ds)    

In [None]:

# Define where to save the log
config_name= "config_name"
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

train_log_path = f"logs/{config_name}/{current_time}/train"
test_log_path = f"logs/{config_name}/{current_time}/test"

# log writer for training metrics
train_summary_writer = tf.summary.create_file_writer(train_log_path)

# log writer for validation metrics
test_summary_writer = tf.summary.create_file_writer(test_log_path)

In [None]:
def training_loop(model, train_ds, test_ds, epochs, train_summary_writer, test_summary_writer):
    for epoch in range(epochs):
        print(f"Epoch {epoch}:")

        # Training: 
        for data in tqdm.tqdm(train_ds, position=0, leave=True):
            #print(data)
            metrics = model.train_step(data)

            with train_summary_writer.as_default():
                for metric in model.metrics:
                    tf.summary.scalar(f"{metric.name}", metric.result(), step=epoch)
        
        # print the metrics
        print([f"train_{key}: {value.numpy()}" for (key, value) in metrics.items()])

        # reset all metrics
        model.reset_metrics()

        # Validation:
        for data in test_ds:
            metrics = model.test_step(data)

            # logging the validation metrics to the log file which is used by tensorboard
            with test_summary_writer.as_default():
              for metric in model.metrics:
                    tf.summary.scalar(f"{metric.name}", metric.result(), step=epoch)
        
        print([f"test_{key}: {value.numpy()}" for (key, value) in metrics.items()])

        # reset all metrics
        model.reset_metrics()
        print("\n")

In [None]:
# save the model with a meaningful name
cifar_model.save_weights(f"saved_model_{config_name}", save_format="tf")

# load the model:
# instantiate a new model from our CNN class
#loaded_model = cifar_model(task=2)

# load the model weights to continue training. 
cifar_model.load_weights(f"saved_model_{config_name}");

# continue training (but: optimizer state is lost)orc barabraians

training_loop(model=cifar_model,
            train_ds= train_ds,
            test_ds= test_ds,
            epochs=10,
            train_summary_writer=train_summary_writer,
            test_summary_writer=test_summary_writer)

Epoch 0:


100%|██████████| 1563/1563 [00:37<00:00, 41.29it/s]


['train_loss: 1.6723462343215942', 'train_accuracy: 0.0']
['test_loss: 1.4104536771774292', 'test_accuracy: 0.0']


Epoch 1:


100%|██████████| 1563/1563 [00:30<00:00, 51.16it/s]


['train_loss: 1.3179869651794434', 'train_accuracy: 0.0']
['test_loss: 1.2165377140045166', 'test_accuracy: 0.0']


Epoch 2:


100%|██████████| 1563/1563 [00:40<00:00, 38.18it/s]


['train_loss: 1.1806749105453491', 'train_accuracy: 0.0']
['test_loss: 1.160409688949585', 'test_accuracy: 0.0']


Epoch 3:


100%|██████████| 1563/1563 [00:29<00:00, 52.56it/s]


['train_loss: 1.1001814603805542', 'train_accuracy: 0.0']
['test_loss: 1.0493645668029785', 'test_accuracy: 0.0']


Epoch 4:


100%|██████████| 1563/1563 [00:31<00:00, 49.62it/s]


['train_loss: 1.0393104553222656', 'train_accuracy: 0.0']
['test_loss: 1.018194317817688', 'test_accuracy: 0.0']


Epoch 5:


100%|██████████| 1563/1563 [00:30<00:00, 51.72it/s]


['train_loss: 0.9910697937011719', 'train_accuracy: 0.0']
['test_loss: 0.9671459197998047', 'test_accuracy: 0.0']


Epoch 6:


100%|██████████| 1563/1563 [00:40<00:00, 38.17it/s]


['train_loss: 0.9524663090705872', 'train_accuracy: 0.0']
['test_loss: 0.9557242393493652', 'test_accuracy: 0.0']


Epoch 7:


100%|██████████| 1563/1563 [00:29<00:00, 52.19it/s]


['train_loss: 0.9179278016090393', 'train_accuracy: 0.0']
['test_loss: 1.0109692811965942', 'test_accuracy: 0.0']


Epoch 8:


100%|██████████| 1563/1563 [00:31<00:00, 50.22it/s]


['train_loss: 0.8819910287857056', 'train_accuracy: 0.0']
['test_loss: 0.8788188695907593', 'test_accuracy: 0.0']


Epoch 9:


100%|██████████| 1563/1563 [00:29<00:00, 52.84it/s]


['train_loss: 0.8545975685119629', 'train_accuracy: 0.0']
['test_loss: 0.8962308764457703', 'test_accuracy: 0.0']


