In [7]:
import tensorflow as tf
import tensorflow_datasets as tfds
import datetime
import tqdm

%load_ext tensorboard

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [8]:
def get_cifar10(batch_size):
    """
    Load and prepare CIFAR-10 as a tensorflow dataset.
    Returns a train and a validation dataset.
    Args:
    batch_size (int)
    """
    train_ds, val_ds = tfds.load('cifar10', split=['train', 'test'], shuffle_files=True)

    one_hot = lambda x: tf.one_hot(x, 10)

    map_func = lambda x,y: (tf.cast(x, dtype=tf.float32)/255.,
                            tf.cast(one_hot(y),tf.float32))

    map_func_2 = lambda x: (x["image"],x["label"])

    train_ds = train_ds.map(map_func_2).map(map_func).cache()
    val_ds   = val_ds.map(map_func_2).map(map_func).cache()
    
    train_ds = train_ds.shuffle(4096).batch(batch_size)
    val_ds   = val_ds.shuffle(4096).batch(batch_size)

    return (train_ds.prefetch(tf.data.AUTOTUNE), val_ds.prefetch(tf.data.AUTOTUNE))

train_ds, val_ds = get_cifar10(128)

In [12]:
class BasicCNNBlock(tf.keras.layers.Layer):
    def __init__(self, depth, layers):
        super(BasicCNNBlock, self).__init__()
        self.layers = [tf.keras.layers.Conv2D(filters=depth, kernel_size=3, padding='same', activation='relu') for _ in range(layers)]

    @tf.function
    def call(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

class BasicCNN(tf.keras.Model):
    def __init__(self):
        super().__init__()

        self.metrics_list = [tf.keras.metrics.Mean(name="loss"),
                            tf.keras.metrics.CategoricalAccuracy(name="acc")]
        
        self.optimizer = tf.keras.optimizers.Adam()
        
        self.loss_function = tf.keras.losses.CategoricalCrossentropy()

        self.layer_list = [BasicCNNBlock(24,6),
                           tf.keras.layers.MaxPooling2D(pool_size=2, strides=2),
                           BasicCNNBlock(48,4),
                           tf.keras.layers.MaxPooling2D(pool_size=2, strides=2),
                           BasicCNNBlock(96,2),
                           tf.keras.layers.GlobalAveragePooling2D(),
                           tf.keras.layers.Dense(10, activation='softmax')]
                           

    @tf.function
    def call(self, x):
        for item in self.layer_list:
            x = item(x)
        return x
     
    @property
    def metrics(self):
        return self.metrics_list
        # return a list with all metrics in the model

    # 4. reset all metrics objects
    def reset_metrics(self):
        for metric in self.metrics:
            metric.reset_states()

    # 5. train step method
    @tf.function
    def train_step(self, data):
        img, target = data
        
        with tf.GradientTape() as tape:
            output = self(img, training=True)
            loss = self.loss_function(target, output)
            
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        # update loss metric
        self.metrics[0].update_state(loss)
        
        # for all metrics except loss, update states (accuracy etc.)
        for metric in self.metrics[1:]:
            metric.update_state(target, output)

        # Return a dictionary mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}
    
    @tf.function
    def test_step(self, data):
        img, target = data

        output = self(img, training=False)
        loss = self.loss_function(target, output)

        self.metrics[0].update_state(loss)
        # for accuracy metrics:
        for metric in self.metrics[1:]:
            metric.update_state(target, output)

        return {m.name: m.result() for m in self.metrics}

In [13]:
# Define where to save the log
config_name= "BasicCNNavg"
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

train_log_path = f"logs/{config_name}/{current_time}/train"
val_log_path = f"logs/{config_name}/{current_time}/val"

# log writer for training metrics
train_summary_writer = tf.summary.create_file_writer(train_log_path)

# log writer for validation metrics
val_summary_writer = tf.summary.create_file_writer(val_log_path)

def training_loop(model, train_ds, val_ds, epochs, train_summary_writer, val_summary_writer):
    for epoch in range(epochs):
        print(f"Epoch {epoch}:")
        
        # Training:
        
        for data in tqdm.tqdm(train_ds, position=0, leave=True):
            metrics = model.train_step(data)
            
            # logging the validation metrics to the log file which is used by tensorboard
            with train_summary_writer.as_default():
                for metric in model.metrics:
                    tf.summary.scalar(f"{metric.name}", metric.result(), step=epoch)

        # print the metrics
        print([f"{key}: {value.numpy()}" for (key, value) in metrics.items()])

        # reset all metrics (requires a reset_metrics method in the model)
        model.reset_metrics()    
        
        # Validation:
        for data in val_ds:
            metrics = model.test_step(data)
        
            # logging the validation metrics to the log file which is used by tensorboard
            with val_summary_writer.as_default():
                for metric in model.metrics:
                    tf.summary.scalar(f"{metric.name}", metric.result(), step=epoch)
                    
        print([f"val_{key}: {value.numpy()}" for (key, value) in metrics.items()])

        # reset all metrics
        model.reset_metrics()
        print("\n")

In [5]:
%tensorboard --logdir logs/

Reusing TensorBoard on port 6006 (pid 6152), started 5 days, 20:55:31 ago. (Use '!kill 6152' to kill it.)

In [14]:
model = BasicCNN()

training_loop(model=model,
                train_ds=train_ds, 
                val_ds=val_ds, 
                epochs=20, 
                train_summary_writer=train_summary_writer, 
                val_summary_writer=val_summary_writer)

Epoch 0:


100%|██████████| 391/391 [00:15<00:00, 24.45it/s]


['loss: 1.9912642240524292', 'acc: 0.24305999279022217']
['val_loss: 1.6718600988388062', 'val_acc: 0.37529999017715454']


Epoch 1:


100%|██████████| 391/391 [00:13<00:00, 28.93it/s]


['loss: 1.5835299491882324', 'acc: 0.4131399989128113']
['val_loss: 1.4795026779174805', 'val_acc: 0.4575999975204468']


Epoch 2:


100%|██████████| 391/391 [00:13<00:00, 29.33it/s]


['loss: 1.401829719543457', 'acc: 0.4859200119972229']
['val_loss: 1.2616636753082275', 'val_acc: 0.5394999980926514']


Epoch 3:


100%|██████████| 391/391 [00:13<00:00, 29.48it/s]


['loss: 1.2197741270065308', 'acc: 0.5588399767875671']
['val_loss: 1.1795605421066284', 'val_acc: 0.5799999833106995']


Epoch 4:


100%|██████████| 391/391 [00:13<00:00, 29.65it/s]


['loss: 1.0611144304275513', 'acc: 0.6220200061798096']
['val_loss: 1.004021406173706', 'val_acc: 0.6438999772071838']


Epoch 5:


100%|██████████| 391/391 [00:13<00:00, 29.75it/s]


['loss: 0.9419018626213074', 'acc: 0.6665599942207336']
['val_loss: 0.9486222863197327', 'val_acc: 0.6671000123023987']


Epoch 6:


100%|██████████| 391/391 [00:13<00:00, 29.64it/s]


['loss: 0.8550794124603271', 'acc: 0.6966599822044373']
['val_loss: 0.8902435302734375', 'val_acc: 0.6909999847412109']


Epoch 7:


100%|██████████| 391/391 [00:13<00:00, 29.57it/s]


['loss: 0.7704941630363464', 'acc: 0.7279999852180481']
['val_loss: 0.799218475818634', 'val_acc: 0.7242000102996826']


Epoch 8:


100%|██████████| 391/391 [00:13<00:00, 29.53it/s]


['loss: 0.7119725942611694', 'acc: 0.7500600218772888']
['val_loss: 0.8744446039199829', 'val_acc: 0.7057999968528748']


Epoch 9:


100%|██████████| 391/391 [00:13<00:00, 29.34it/s]


['loss: 0.6547601819038391', 'acc: 0.7714400291442871']
['val_loss: 0.7512117028236389', 'val_acc: 0.7419999837875366']


Epoch 10:


100%|██████████| 391/391 [00:13<00:00, 29.52it/s]


['loss: 0.6072371602058411', 'acc: 0.7864800095558167']
['val_loss: 0.7488740682601929', 'val_acc: 0.7462999820709229']


Epoch 11:


100%|██████████| 391/391 [00:13<00:00, 29.52it/s]


['loss: 0.5590533018112183', 'acc: 0.8045399785041809']
['val_loss: 0.7691168189048767', 'val_acc: 0.7422999739646912']


Epoch 12:


100%|██████████| 391/391 [00:13<00:00, 29.15it/s]


['loss: 0.520820140838623', 'acc: 0.8163999915122986']
['val_loss: 0.7388847470283508', 'val_acc: 0.7491000294685364']


Epoch 13:


100%|██████████| 391/391 [00:13<00:00, 28.68it/s]


['loss: 0.48020705580711365', 'acc: 0.8306199908256531']
['val_loss: 0.759713888168335', 'val_acc: 0.7578999996185303']


Epoch 14:


100%|██████████| 391/391 [00:14<00:00, 27.85it/s]


['loss: 0.44595029950141907', 'acc: 0.8433200120925903']
['val_loss: 0.7562196850776672', 'val_acc: 0.7599999904632568']


Epoch 15:


100%|██████████| 391/391 [00:15<00:00, 25.74it/s]


['loss: 0.4209563136100769', 'acc: 0.8502399921417236']
['val_loss: 0.7578818202018738', 'val_acc: 0.756600022315979']


Epoch 16:


100%|██████████| 391/391 [00:14<00:00, 27.42it/s]


['loss: 0.37382957339286804', 'acc: 0.867680013179779']
['val_loss: 0.7817038297653198', 'val_acc: 0.7659000158309937']


Epoch 17:


100%|██████████| 391/391 [00:14<00:00, 27.74it/s]


['loss: 0.35040482878685', 'acc: 0.8743399977684021']
['val_loss: 0.7576773166656494', 'val_acc: 0.7630000114440918']


Epoch 18:


100%|██████████| 391/391 [00:13<00:00, 28.33it/s]


['loss: 0.31902873516082764', 'acc: 0.8873199820518494']
['val_loss: 0.8372607827186584', 'val_acc: 0.7645999789237976']


Epoch 19:


100%|██████████| 391/391 [00:13<00:00, 28.40it/s]


['loss: 0.2906223237514496', 'acc: 0.8957200050354004']
['val_loss: 0.8093714118003845', 'val_acc: 0.7731999754905701']


