# Load packages and libraries

In [1]:
# to avoid OOM errors in the form of "UnimplementedError: Failed to get convolution algorithm. This is probably because cuDNN failed to initialize, so try looking to see if a warning log message was printed above. [[node model/conv2d/Conv2D (defined at :6) ]] [Op:__inference_distributed_function_7653]"
import tensorflow as tf
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True # https://stackoverflow.com/a/61786189
sess = tf.compat.v1.Session(config=config)

In [2]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from matplotlib import pyplot as plt
import myloginpath
import numpy as np
import pandas as pd
import json
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix 
import tensorflow as tf
import time
import warnings

# Configure access to GPU

In [3]:
print('Num GPUs Available:', len(tf.config.list_physical_devices('GPU')))
physical_devices = tf.config.list_physical_devices('GPU') # list of physical devices visible to the host runtime
# specifies which PhysicalDevice objects are visible to the runtime. TF will only allocate memory and place operations on visible physical devices
tf.config.set_visible_devices(physical_devices[0], 'GPU')

try: 
    for gpu in physical_devices:
        # If memory growth is enabled for a PhysicalDevice, the runtime initialization will not allocate all memory on the device. 
        # Memory growth cannot be configured on a PhysicalDevice with virtual devices configured.
        tf.config.experimental.set_memory_growth(gpu, True)
        # Set the virtual device configuration for a PhysicalDevice. memory_limit in MB
        tf.config.experimental.set_virtual_device_configuration(gpu, [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=20_000)]) 
except: 
    print('Invalid device or cannot modify virtual devices once initialized.', flush=True) 

Num GPUs Available: 2


In [4]:
# Using this strategy will place any variables created in its scope on the specified device. 
# Input distributed through this strategy will be prefetched to the specified device. 
# Moreover, any functions called via strategy.run will also be placed on the specified device as well.

#TODO why are we doing this?
strategy = tf.distribute.OneDeviceStrategy(device='/gpu:0')

# Prepare datasets

## Initialization

Parameters to connect to the database

In [5]:
mysqlSettings = myloginpath.parse('client', path='/tf/.mylogin.cnf')
mysqlSettings['database'] = 'sensor_data_schema'

Global parameters (Make sure to choose a unique name for each run!)

In [6]:
modelName = 'test_12Mparam'
path = '/tf/home'
tensorboardLogFolder = f'{path}/logs'
checkpointFolder = f'{path}/checkpoints/{modelName}/'
#confMatFolder = f'{path}/confusion_matrix/{modelName}/'

In [7]:
datasetList = [
    #POLLEN DATASETS --------------------------------------------------------
    ('alnus', 3990, '11ea8493-7107-8db4-9bf7-ae7b87f820b4'),                 # 0alnus 'alnus_20200220_p5_1_benoit' 3990
    ('alnus', 4966, '11ea847a-f995-790c-830f-ae7b87f820b4'),                 # 0alnus 'alnus_20200218_p2_1_benoit' 4966
    ('alnus', 3474, '11ea8475-957e-347c-985a-ae7b87f820b4'),                 # 0alnus 'alnus_20200214_p4_1_benoit' 3474 TOTAL ALNUS=12'430
    ('betula', 5770, '11ea8897-f50e-66a2-9876-ae7b87f820b4'),                # 1betula 'betula_20200406_p2_1_benoit' 5770
    ('betula', 6533, '11ea8632-18ed-7210-985a-ae7b87f820b4'),                # 1betula 'betula_20200407_p4_2_benoit' 6533
    ('betula', 2173, '11ea8632-1eb2-2452-bc84-ae7b87f820b4'),                # 1betula 'betula_20200406_p4_1_benoit' 2173 TOTAL BETULA=14'476
    ('carpinus', 643, '11ea8f77-4ee3-aef4-b330-ae7b87f820b4'),               # 2carpinus 'carpinus_20200319_p5_2_fiona' 643
    ('carpinus', 664, '11ea8f6d-3e75-9fe6-b46e-ae7b87f820b4'),               # 2carpinus 'carpinus_20200319_p2_2_fiona' 664
    ('carpinus', 545, '11ea8f6d-1562-211a-8192-ae7b87f820b4'),               # 2carpinus 'carpinus_20200319_p2_3_fiona' 545
    ('carpinus', 395, '11ea8f6c-b78c-d076-a542-ae7b87f820b4'),               # 2carpinus 'carpinus_20200319_p4_2_fiona' 395 TOTAL CARPINUS=2'247
    ('corylus', 3736, '11ea8498-b729-d4e6-bc84-ae7b87f820b4'),               # 3corylus 'corylus_20200225_p2_2_benoit' 3736
    ('corylus', 500, '11ea8498-b083-cb92-a1a5-ae7b87f820b4'),                # 3corylus 'corylus_20200225_p2_1_benoit' 500
    ('corylus', 3578, '11ea8498-afa9-cec4-a877-ae7b87f820b4'),               # 3corylus 'corylus_20200225_p5_1_benoit' 3578 TOTAL CORYLUS=7'814
    ('cupressus', 421, '11ea8fa9-6c12-723a-b3dd-ae7b87f820b4'),              # 4cupressus 'cupressus_20200317_p5_1_fiona' 421
    ('cupressus', 2340, '11ea8fa8-fafa-aeb4-ac46-ae7b87f820b4'),             # 4cupressus 'cupressus_20200317_p2_1_fiona' 2340
    ('cupressus', 583, '11ea8fa8-d163-dce2-b1cb-ae7b87f820b4'),              # 4cupressus 'cupressus_20200317_p4_1_fiona' 583 TOTAL CUPRESSUS=3'344
    ('fagus', 2759, '11ea8636-313b-a6e4-a69e-ae7b87f820b4'),                 # 5fagus 'fagus_20200413_p4_1_benoit' 2759
    ('fagus', 3410, '11ea8635-ef91-6ab2-a877-ae7b87f820b4'),                 # 5fagus 'fagus_20200407_p2_1_benoit' 3410
    ('fagus', 4143, '11ea8635-eb18-6ee0-9876-ae7b87f820b4'),                 # 5fagus 'fagus_20200413_p5_1_benoit' 4143 TOTAL FAGUS=10'312
    ('fraxinus', 5703, '11ea857e-7bc5-60a0-842e-ae7b87f820b4'),              # 6fraxinus 'fraxinus_20200402_p5_2_benoit' 5703
    ('fraxinus', 2621, '11ea857b-3d52-9034-830f-ae7b87f820b4'),              # 6fraxinus 'fraxinus_20200330_p4_1_benoit' 2621
    ('fraxinus', 1712, '11ea857b-150e-c372-bc84-ae7b87f820b4'),              # 6fraxinus 'fraxinus_20200330_p2_1_benoit' 1712 TOTAL FRAXINUS=10'036
    ('pinaceae', 1826, '11ea8af3-c533-f39e-8b25-ae7b87f820b4'),              # 7pinaceae 'picea_20200423_p2_1_fiona' 1826
    ('pinaceae', 2375, '11ea8af1-91fc-9a46-8b25-ae7b87f820b4'),              # 7pinaceae 'picea_20200423_p4_1_fiona' 2375
    ('pinaceae', 1969, '11ea8af0-83dc-6d66-b06c-ae7b87f820b4'),              # 7pinaceae 'picea_20200423_p5_1_fiona' 1969
    ('pinaceae', 3403, '11ea863d-acf6-0ade-985a-ae7b87f820b4'),              # 7pinaceae 'pinus_20200421_p5_1_benoit' 3403
    ('pinaceae', 8582, '11ea863c-2449-be52-8814-ae7b87f820b4'),              # 7pinaceae 'pinus_20200421_p2_1_benoit' 8582 TOTAL PINACEAE=18'155
    ('platanus', 5603, '11ea8b83-25c9-8194-90d1-ae7b87f820b4'),              # 8platanus 'platanus_20200417_p4_1_benoit' 5603
    ('platanus', 5544, '11ea8881-3721-9aa8-a907-ae7b87f820b4'),              # 8platanus 'platanus_20200417_p2_1_benoit' 5544 TOTAL PLATANUS=11'147
    ('poaceae', 1229, '11ea990f-ee01-8334-b3dd-ae7b87f820b4'),               # 9poaceae 'gram_20200518_p2_1_benoit' 1229
    ('poaceae', 1508, '11ea990c-b2bc-fe96-b46e-ae7b87f820b4'),               # 9poaceae 'gram_20200518_p5_1_benoit' 1508  TOTAL POACEAE inital=4'909
    ('poaceae', 5895, '11eb5fd9-961a-313e-ac56-ae7b87f820b4'),               # 9poaceae 'POCclean_cynosurus_20200520_p4_1_fiona' 5895
    ('poaceae', 6248, '11eb5fd9-dd36-0a20-88f3-ae7b87f820b4'),               # 9poaceae 'POCclean_cynosurus_20200520_p2_1_' 6248
    ('poaceae', 3110, '11ebe542-660e-0206-80be-ae7b87f820b4'),               # 9poaceae 'poaceae_dactylis_fresh_p19_2021_tri_Nina' 3110
    ('poaceae', 1127, '11eb5fc3-03fa-6da2-8b42-ae7b87f820b4'),               # 9poaceae 'POCclean_dactylis_20200518_p4_1' 1127
    ('poaceae', 1377, '11ebe540-187e-9a0c-b0e2-ae7b87f820b4'),               # 9poaceae 'poaceae_trisetum_fresh_p19_2021_tri_Nina' 1377
    ('populus', 657, '11ea8893-edfb-ca84-a877-ae7b87f820b4'),                # 10populus 'populus_20200327_p5_1_benoit' 657
    ('populus', 508, '11ea84a0-e89b-43b8-a69e-ae7b87f820b4'),                # 10populus 'populus_20200327_p2_benoit' 508
    ('populus', 2913, '11ea84a0-a2f0-ab8c-a877-ae7b87f820b4'),               # 10populus 'populus_20200327_p4_benoit' 2913 TOTAL POPULUS=4'078
    ('quercus', 3824, '11ea863e-1fea-0f7c-a1a5-ae7b87f820b4'),               # 11quercus 'quercus_20200421_p4_1_benoit' 3824
    ('quercus', 4768, '11ea863e-1b86-8226-a1a5-ae7b87f820b4'),               # 11quercus 'quercus_20200421_p2_1_benoit' 4768
    ('quercus', 2519, '11ea863d-f388-a038-a1a5-ae7b87f820b4'),               # 11quercus 'quercus_20200421_p5_1_benoit' 2519 TOTAL QUERCUS=11'111
    ('taxus', 4872, '11ea8477-cede-e7dc-897d-ae7b87f820b4'),                 # 12taxus 'taxus_20200218_p4_1_benoit' 4872
    ('taxus', 5593, '11ea8477-b584-b690-830f-ae7b87f820b4'),                 # 12taxus 'taxus_20200218_p2_1_benoit' 5593
    ('taxus', 3411, '11ea8494-33a5-2e4e-bc84-ae7b87f820b4'),                 # 12taxus 'taxus_20200220_p5_1_benoit' 3411 TOTAL TAXUS=13'876
    ('ulmus', 3289, '11ea849c-df8f-d95e-897d-ae7b87f820b4'),                 # 13ulmus 'ulmus_20200311_p4_2_benoit' 3289
    ('ulmus', 2392, '11ea849c-db7b-2170-8b0f-ae7b87f820b4'),                 # 13ulmus 'ulmus_20200311_p2_2_benoit' 2392
    ('ulmus', 4844, '11ea849a-0e25-4018-8814-ae7b87f820b4'),                 # 13ulmus 'ulmus_20200304_p5_1_benoit' 4844 TOTAL ULMUS=10'525
    ## SPORES DATASETS ------------------------------------------------------
    ('alternaria solani', 2767, '11ebf9db-f2e9-98cc-bc67-ae7b87f820b4'),     # 14Alternaria solani 'alternaria_solani_sophie_clean' event counts 2767
    ('fusarium graminearum', 25054, '11ec01b9-d571-ea8e-b7e1-ae7b87f820b4'), # 15Fusarium graminearum 'fusarium_graminearum_p1' event count 25054
    ## WATER DROPPLETS (WD) DATASETS ----------------------------------------
    ('rain', 389, '11ebe542-f782-c172-bf10-ae7b87f820b4'),                   # 16Rain 'P5_Payerne_Rain_28_04' event counts 389
    ('rain', 2786, '11ebeabd-e224-d5c4-8b63-ae7b87f820b4'),                  # 16Rain 'P5_Payerne_Rain_30_04_AM' event counts 2786
    ('rain', 3179, '11ebedec-0da5-47ac-8066-ae7b87f820b4'),                  # 16Rain 'P5_Payerne_Rain_30_04_PM' event counts 3179
    ('rain', 7691, '11ebee15-1fea-4c68-9cd6-ae7b87f820b4'),                  # 16Rain 'P16_Locarno_Rain_29_04' event counts 7691 TOTAL PLUIE = 14045
    ## IBERULITES DATASETS --------------------------------------------------
    ('iberulite', 190, '11ec6179-fde0-042c-adac-ae7b87f820b4'),              # 17Iberulite 'P4_iberulites_06022021_clean' event counts 190
    ('iberulite', 183, '11ec5821-b371-c3dc-8359-ae7b87f820b4'),              # 17Iberulite 'P5_iberulites_06022021_clean' event counts 183
    ('iberulite', 66, '11ec617a-fb8b-e3f2-80fb-ae7b87f820b4'),               # 17Iberulite 'P4_saharan_dust_april2021_clean' event counts 66
    ('iberulite', 556, '11ec5832-4fee-03b4-8561-ae7b87f820b4'),              # 17Iberulite 'P5_saharan_dust_april2021_clean' event counts 117 TOTAL IBERULITE = 556
]

In [8]:
chunksize = 256 # How many events should be used per dataset. TF will train on them for x epochs before going to the next chunk of data. Choose size according to your hardware (ram, gpu, gpu-memory)
n_prefetch = 2 # How many batches should be cached in the background.
batchsize_per_replica = 64 
batchsize = batchsize_per_replica * strategy.num_replicas_in_sync # number of samples processed before the model is updated
epochs = 2 # number of complete passes through the training dataset

In [9]:
target_names = set(map(lambda i: i[0], datasetList)) # set of class names
target_ids_mapping = {id: list(target_names).index(label) for id, (label, _, _) in enumerate(datasetList)} # mapping of dataset_id: class name
num_classes = len(target_names) # number of classes

## Loading

In [10]:
from preprocessingFunc import load_blob, process_waves, filter_blur, filter_crop
%load_ext autoreload
%autoreload 2
img_preprocessing = lambda img0,img1: process_waves(
                                        *filter_crop(
                                            *filter_blur(
                                                *load_blob(img0,img1))))

In [11]:
# swisens way
from swisensDataFunc import init_sets as swisens_init_sets, get_train as swisens_get_train
itList, test_set = swisens_init_sets(
    datasetList, 
    batchsize, 
    chunksize, 
    n_prefetch, 
    mysqlSettings, 
    target_ids_mapping, 
    num_classes, 
    img_preprocessing
)

Building testset


# Model initialization

In [12]:
def get_model(nClasses, with_fluorescence=False, n_fl_configs=1,strategy=strategy):

    in_img0 = tf.keras.layers.Input((200,200,1))
    in_img1 = tf.keras.layers.Input((200,200,1))

    # If you want to train a model including fluorescence, you need to include these inputs in your model
    if with_fluorescence:
        in_fl_avg = tf.keras.layers.Input((n_fl_configs*6, 1))
        in_fl_pha = tf.keras.layers.Input((n_fl_configs*6, 1))
        in_fl_corrMag = tf.keras.layers.Input((n_fl_configs*6, 1))


    #Image Processing
    path1 = tf.keras.layers.Conv2D(64, (5,5), padding='same', activation='relu')(in_img0)
    path1 = tf.keras.layers.Conv2D(64, (5,5), padding='same', activation='relu')(path1)
    path1 = tf.keras.layers.MaxPool2D(2, strides=(2,2),padding='same')(path1)
    path1 = tf.keras.layers.Dropout(0.1)(path1)
    path1 = tf.keras.layers.Conv2D(64, (3,3), padding='same', activation='relu')(path1)
    path1 = tf.keras.layers.Conv2D(64, (3,3), padding='same', activation='relu')(path1)
    path1 = tf.keras.layers.MaxPool2D(2, strides=(2,2),padding='same')(path1)
    path1 = tf.keras.layers.Dropout(0.1)(path1)
    path1 = tf.keras.layers.Conv2D(128, (3,3), padding='same', activation='relu')(path1)
    path1 = tf.keras.layers.Conv2D(128, (3,3), padding='same', activation='relu')(path1)
    path1 = tf.keras.layers.Conv2D(128, (3,3), padding='same', activation='relu')(path1)
    path1 = tf.keras.layers.MaxPool2D((2,2), strides=(2,2),padding='same')(path1)
    path1 = tf.keras.layers.Dropout(0.1)(path1)
    path1 = tf.keras.layers.Conv2D(256, (3,3), padding='same', activation='relu')(path1)
    path1 = tf.keras.layers.Conv2D(256, (3,3), padding='same', activation='relu')(path1)
    path1 = tf.keras.layers.Conv2D(256, (3,3), padding='same', activation='relu')(path1)
    path1 = tf.keras.layers.MaxPool2D((2,2), strides=(2,2),padding='valid')(path1)
    path1 = tf.keras.layers.Dropout(0.1)(path1)
    path1 = tf.keras.layers.Conv2D(256, (3,3), padding='same', activation='relu')(path1)
    path1 = tf.keras.layers.Conv2D(256, (3,3), padding='same', activation='relu')(path1)
    path1 = tf.keras.layers.Conv2D(256, (3,3), padding='same', activation='relu')(path1)
    path1 = tf.keras.layers.MaxPool2D((2,2), strides=(2,2),padding='same')(path1)
    path1 = tf.keras.layers.Dropout(0.1)(path1)

    #path1 = tf.keras.layers.MaxPool2D((2,2), strides=(2,2),padding='same')(path1)
    #path1 = tf.keras.layers.Dropout(0.3)(path1)
    #path1 = tf.keras.layers.Dropout(0.4)(path1)
    path2 = tf.keras.layers.Conv2D(64, (5,5), padding='same', activation='relu')(in_img1)
    path2 = tf.keras.layers.Conv2D(64, (5,5), padding='same', activation='relu')(path2)
    path2 = tf.keras.layers.MaxPool2D(2, strides=(2,2),padding='same')(path2)
    path2 = tf.keras.layers.Dropout(0.1)(path2)
    path2 = tf.keras.layers.Conv2D(64, (3,3), padding='same', activation='relu')(path2)
    path2 = tf.keras.layers.Conv2D(64, (3,3), padding='same', activation='relu')(path2)
    path2 = tf.keras.layers.MaxPool2D(2, strides=(2,2),padding='same')(path2)
    path2 = tf.keras.layers.Dropout(0.1)(path2)
    path2 = tf.keras.layers.Conv2D(128, (3,3), padding='same', activation='relu')(path2)
    path2 = tf.keras.layers.Conv2D(128, (3,3), padding='same', activation='relu')(path2)
    path2 = tf.keras.layers.Conv2D(128, (3,3), padding='same', activation='relu')(path2)
    path2 = tf.keras.layers.MaxPool2D((2,2), strides=(2,2),padding='same')(path2)
    path2 = tf.keras.layers.Dropout(0.1)(path2)
    path2 = tf.keras.layers.Conv2D(256, (3,3), padding='same', activation='relu')(path2)
    path2 = tf.keras.layers.Conv2D(256, (3,3), padding='same', activation='relu')(path2)
    path2 = tf.keras.layers.Conv2D(256, (3,3), padding='same', activation='relu')(path2)
    path2 = tf.keras.layers.MaxPool2D((2,2), strides=(2,2),padding='valid')(path2)
    path2 = tf.keras.layers.Dropout(0.1)(path2)
    path2 = tf.keras.layers.Conv2D(256, (3,3), padding='same', activation='relu')(path2)
    path2 = tf.keras.layers.Conv2D(256, (3,3), padding='same', activation='relu')(path2)
    path2 = tf.keras.layers.Conv2D(256, (3,3), padding='same', activation='relu')(path2)
    path2 = tf.keras.layers.MaxPool2D((2,2), strides=(2,2),padding='same')(path2)
    path2 = tf.keras.layers.Dropout(0.1)(path2)

    #path2 = tf.keras.layers.MaxPool2D((2,2), strides=(2,2),padding='same')(path2)
    #path2 = tf.keras.layers.Dropout(0.3)(path2)

    path1Flat = tf.keras.layers.Flatten()(path1)
    path2Flat = tf.keras.layers.Flatten()(path2)

    # FL Processing
    if with_fluorescence:
        fl_avg_path = processFlInput(in_fl_avg)
        fl_pha_path = processFlInput(in_fl_pha)
        fl_corrMag_path = processFlInput(in_fl_corrMag)
        path = tf.keras.layers.Concatenate()(
            [path1Flat, path2Flat, fl_avg_path, fl_pha_path, fl_corrMag_path]
        )
    else:
        path = tf.keras.layers.Concatenate()([path1Flat, path2Flat])

    #Densely(fully)-connected layer
    path = tf.keras.layers.Dense(256)(path)
    path = tf.keras.layers.Dropout(0.2)(path)
    path = tf.keras.layers.Dense(128)(path)
    path = tf.keras.layers.Dropout(0.2)(path)
    #Densely(fully)-connected layer
    path = tf.keras.layers.Dense(nClasses)(path)
    #Softmax activation fct
    output = tf.keras.layers.Softmax()(path)

    # If we work with fluorescence, we need to add all the inputs to the final model
    if with_fluorescence:
        model = tf.keras.Model(
            inputs=[in_img0, in_img1, in_fl_avg, in_fl_pha, in_fl_corrMag],
            outputs=output
        )
    else:
        model = tf.keras.Model(inputs=[in_img0, in_img1], outputs=output)
        
    return model

In [13]:
logger = tf.keras.callbacks.TensorBoard(log_dir=f"{tensorboardLogFolder}/{modelName}")
saver = tf.keras.callbacks.ModelCheckpoint(filepath=checkpointFolder)

In [14]:
with strategy.scope():
    # Instantiate an optimizer to train the model.
    optimizer = tf.keras.optimizers.Adam(learning_rate=5e-5)
    # Instantiate a loss function.
    loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=False) # remove the from_logits=True if you simply call model.fit()

In [15]:
with strategy.scope():
    print('Building model...', flush=True)
    model = get_model(num_classes, strategy=strategy)
    model.compile(optimizer=optimizer, loss=loss_fn, metrics=['accuracy'])
    print('Model is built:', flush=True)
    model.summary()

Building model...
Model is built:
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 200, 200, 1  0           []                               
                                )]                                                                
                                                                                                  
 input_2 (InputLayer)           [(None, 200, 200, 1  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 200, 200, 64  1664        ['input_1[0][0]']                
                                )                           

# Model training

In [16]:
with strategy.scope():
    train_acc_metric = tf.keras.metrics.CategoricalAccuracy() # SparseCategoricalAccuracy if using old code
    val_acc_metric = tf.keras.metrics.CategoricalAccuracy()

In [None]:
epochs = 2
for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,), flush=True)
    
    if epoch > 0:
        print('Shuffle the dataset and re-create batches for training...', flush=True)
        #train_set.unbatch().shuffle().batch(batchsize) # TODO

    # Iterate over the batches of the dataset.
    print("1")
    train_set = swisens_get_train(itList, batchsize, num_classes) # TODO
    for step, (x_batch_train, y_batch_train) in enumerate(train_set):
        print("2")

        # Open a GradientTape to record the operations run during the forward pass, which enables auto-differentiation.
        with tf.GradientTape() as tape:

            # Run the forward pass of the layer.
            # The operations that the layer applies to its inputs are going to be recorded on the GradientTape.
            logits = model(x_batch_train, training=True)  # Logits for this minibatch

            # Compute the loss value for this minibatch.
            loss_value = loss_fn(y_batch_train, logits)

        # Use the gradient tape to automatically retrieve the gradients of the trainable variables with respect to the loss.
        grads = tape.gradient(loss_value, model.trainable_weights)

        # Run one step of gradient descent by updating the value of the variables to minimize the loss.
        optimizer.apply_gradients(zip(grads, model.trainable_weights))

        # Update training metric.
        train_acc_metric.update_state(y_batch_train, logits)

        # Log every 200 batches.
        if step % 200 == 0:
            print("Training loss (for one batch) at step %d: %.4f" % (step, float(loss_value)))
            print("Seen so far: %d samples" % ((step + 1) * batchsize))

    # Display metrics at the end of each epoch.
    train_acc = train_acc_metric.result()
    print("Training acc over epoch: %.4f" % (float(train_acc),))

    # Reset training metrics at the end of each epoch
    train_acc_metric.reset_states()

    # Run a validation loop at the end of each epoch.
    for x_batch_val, y_batch_val in val_set:
        val_logits = model(x_batch_val, training=False)
        # Update val metrics
        val_acc_metric.update_state(y_batch_val, val_logits)
    val_acc = val_acc_metric.result()
    val_acc_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))
    print("Time taken: %.2fs" % (time.time() - start_time))