In [1]:
import _init_paths
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np

from utils.cnn_2d_utils import get_2d_cnn_model
from utils.base_utils.data_prep_utils import get_cifar_10_data
from configs.exp_configs import tf_exp_cfg as exp_cfg

# Get Data, Model, and Train it.

In [2]:
train_x, train_y, test_x, test_y = get_cifar_10_data()
img_rows, img_cols = train_x[0].shape[0], train_x[0].shape[1]
model, _ = get_2d_cnn_model((img_rows, img_cols, 3), exp_cfg)
model.compile(
    loss=tfa.losses.sigmoid_focal_crossentropy,
    optimizer=tf.keras.optimizers.Adam(lr=exp_cfg["lr"]),
    metrics=["accuracy"])
print("*"*90)
print(model.summary())
print("*"*90)
model.fit(train_x, train_y, batch_size=exp_cfg["batch_size"], 
          epochs=exp_cfg["epochs"])

******************************************************************************************
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 32, 32, 3)]       0         
_________________________________________________________________
conv2d (Conv2D)              (None, 32, 32, 32)        896       
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 32, 32, 64)        18496     
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 16, 16, 64)        36928     
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 8, 8, 96)          55392     
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 4, 4, 128)         110720    
____________________________________

<tensorflow.python.keras.callbacks.History at 0x2ba467ed6350>

# Evaluate and Save the model weights

In [3]:
eval_results = model.evaluate(test_x, test_y)
model.save_weights(exp_cfg["tf_wts_otpt_dir"]+"/weights")

