# Train CNN on ABG data wihtout prior XC training
- You will need to provide training and validation data to run this script. 
- Uses env ABG-cnn_tf230

In [None]:
import tensorflow as tf

%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=1

import IPython.display as display
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os
import pathlib

# option to not use GPUs 
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

AUTOTUNE = tf.data.experimental.AUTOTUNE
tf.__version__

In [None]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

In [None]:
# MODEL OUTPUT PATH
# Let's checkpoint the model here when needed
checkpoint_path = '../results/ABGQI-CNN/cp.ckpt'
print("does this checkpoint exist?")
print(checkpoint_path)
os.path.isfile(checkpoint_path)

In [None]:
# INPUT DATA
# Set up training data 
tr_pth = '../data/splits/training' 
data_dir = pathlib.Path(tr_pth)

# Set up the validation data
val_pth = '../data/splits/validation' 
val_data_dir = pathlib.Path(val_pth)

In [None]:
# PRETRAINED MODEL INPUT
# MobNet pretrained on imagenet 
model_path = '../data/IMGNET_mobileNet_S2L_finetune/my_model/'
new_model = tf.keras.models.load_model(model_path)
new_model.summary()

### Functions

In [None]:
def get_label(file_path):
  # convert the path to a list of path components
  parts = tf.strings.split(file_path, os.path.sep)
  # The second to last is the class-directory
  return parts[-2] == CLASS_NAMES

def decode_img(img):
  # convert the compressed string to a 3D uint8 tensor
  img = tf.image.decode_jpeg(img, channels=3)
  # Use `convert_image_dtype` to convert to floats in the [0,1] range.
  img = tf.image.convert_image_dtype(img, tf.float32)
  # resize the image to the desired size.
  return tf.image.resize(img, [IMG_HEIGHT, IMG_WIDTH])

def process_path(file_path):
    label = get_label(file_path)
    # load the raw data from the file as a string
    img = tf.io.read_file(file_path)
    img = decode_img(img)
    return img, label

In [None]:
def prepare_for_training(ds, cache=True, shuffle_buffer_size=1000, repeat=1):
  # This is a small dataset, only load it once, and keep it in memory.
  # use `.cache(filename)` to cache preprocessing work for datasets that don't
  # fit in memory.
  if cache:
    if isinstance(cache, str):
        ds = ds.cache(cache)
    else:
        ds = ds.cache()

  ds = ds.shuffle(buffer_size=shuffle_buffer_size)

  # Repeat forever
  ds = ds.repeat(repeat)     # repeat has arg 'count' = A tf.int64 scalar tf.Tensor, representing the number of times the dataset should be repeated. The default behavior (if count is None or -1) is for the dataset be repeated indefinitely.

  ds = ds.batch(BATCH_SIZE)

  # `prefetch` lets the dataset fetch batches in the background while the model
  # is training.
  ds = ds.prefetch(buffer_size=AUTOTUNE)

  return ds

def show_batch(image_batch, label_batch):
  plt.figure(figsize=(10,10))
  for n in range(25):
      ax = plt.subplot(5,5,n+1)
      plt.imshow(image_batch[n])
      plt.title(CLASS_NAMES[label_batch[n]==1][0].title())
      plt.axis('off')

### Image analysis

In [None]:
image_count = len(list(data_dir.glob('*/*.png')))
image_count

In [None]:
CLASS_NAMES = np.array([item.name for item in data_dir.glob('*')])
CLASS_NAMES

In [None]:
# The 1./255 is to convert from uint8 to float32 in range [0,1].
image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)

In [None]:
# training parameters
BATCH_SIZE = 64
IMG_HEIGHT = 224
IMG_WIDTH = 224
STEPS_PER_EPOCH = np.ceil(image_count/BATCH_SIZE)

In [None]:
# example of 5 pngs
list_ds = tf.data.Dataset.list_files(str(data_dir/'*/*'))
for f in list_ds.take(5):
    print(f.numpy())

In [None]:
# Set `num_parallel_calls` so multiple images are loaded/processed in parallel.
labeled_ds = list_ds.map(process_path, num_parallel_calls=AUTOTUNE)

In [None]:
# what are the dimensions of a png and what do the labels look like
for image, label in labeled_ds.take(1):
    print("Image shape: ", image.numpy().shape)
    print("Label: ", label.numpy())

In [None]:
# Prep dataset iterations
train_ds = prepare_for_training(labeled_ds, repeat = None)
image_batch, label_batch = next(iter(train_ds))

In [None]:
# display some pngs
show_batch(image_batch.numpy(), label_batch.numpy())

#### Validation data 
- follows the same preparation as training data

In [None]:
val_image_count = len(list(val_data_dir.glob('*/*.png')))
val_image_count

In [None]:
val_data_dir

In [None]:
CLASS_NAMES = np.array([item.name for item in val_data_dir.glob('*') if item.name != "LICENSE.txt"])
CLASS_NAMES

In [None]:
# The 1./255 is to convert from uint8 to float32 in range [0,1].
image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)

In [None]:
STEPS_PER_EPOCH = np.ceil(val_image_count/BATCH_SIZE)

In [None]:
list_ds = tf.data.Dataset.list_files(str(val_data_dir/'*/*'))

for f in list_ds.take(5):
    print(f.numpy())

In [None]:
# Set `num_parallel_calls` so multiple images are loaded/processed in parallel.
labeled_ds = list_ds.map(process_path, num_parallel_calls=AUTOTUNE)

In [None]:
for image, label in labeled_ds.take(1):
    print("Image shape: ", image.numpy().shape)
    print("Label: ", label.numpy())

In [None]:
validation_ds = prepare_for_training(labeled_ds)
image_batch, label_batch = next(iter(validation_ds))

In [None]:
show_batch(image_batch.numpy(), label_batch.numpy())

In [None]:
# NOW WE HAVE:
print(validation_ds)
print(train_ds)

## MODEL TRAINING

In [None]:
# Image size, here 224 is default MobileNet x, y with 3 bands (RGB)
IMG_SIZE = 224
IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3)

In [None]:
# number of target class : ABGIQ
n_classes = len(CLASS_NAMES)
print(n_classes)
print(CLASS_NAMES)

In [None]:
# Remove FC and Global pooling layers to allow for ABGQI fine tuning
base_model_output = new_model.layers[-3]#.output
print(base_model_output)
feature_batch = base_model_output(image_batch)

base_model_output.trainable = False

In [None]:
# Add pooling layer
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
feature_batch_average = global_average_layer(feature_batch)
print(feature_batch_average.shape)

In [None]:
# Add FC/ Dense layer
prediction_layer = tf.keras.layers.Dense(n_classes, activation = None)

In [None]:
# compile the new model with S2L-mobilenet weights and new pooling + FC layers
model = tf.keras.Sequential([
  base_model_output,
  global_average_layer,  
  prediction_layer
])

In [None]:
# Let's take a look to see how many layers are in the base model (i.e. S2L pre-trained mobileNet)
print("Number of layers in the base model: ", len(base_model_output.layers))

In [None]:
# Fine tune FC layers
base_learning_rate = 0.0001 #the initial learning rate. This will be reduced by a factor of 10 in the Finetuning stage

# specify what loss function, optimizer, and accuracy metric to use
model.compile(optimizer = tf.keras.optimizers.Adam(lr=base_learning_rate),
              metrics=tf.keras.metrics.CategoricalAccuracy(),
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=True)) #Whether to interpret y_pred as a tensor of logit values. By default, we assume that y_pred contains probabilities (i.e., values in [0, 1]). **Note - Using from_logits=True may be more numerically stable.

In [None]:
model.summary() # trainable params = 8,965 here

In [None]:
len(model.trainable_variables) # pooling and dense layers

In [None]:
# NOW USE THE validation_ds and train_ds THAT WE BUILT BEFORE
loss0,accuracy0 = model.evaluate(validation_ds, steps= val_image_count // BATCH_SIZE)

In [None]:
print("initial loss: {:.2f}".format(loss0))
print("initial accuracy: {:.2f}".format(accuracy0))

In [None]:
# train with our prepared data
initial_epochs = 10 # short training period
history = model.fit(train_ds, 
                    epochs=initial_epochs,
                    validation_data=validation_ds, 
                    steps_per_epoch = np.ceil(image_count/BATCH_SIZE))

In [None]:
# visualize accuracy and loss
acc = history.history['categorical_accuracy']
val_acc = history.history['val_categorical_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0,1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

### MODEL TRAINING: fine tuning the base model

In [None]:
# update the ability to train th mobilenet base
base_model_output.trainable = True

In [None]:
# Let's take a look to see how many layers are in the base model
print("Number of layers in the base model: ", len(base_model_output.layers))

In [None]:
# Train CNN features here
# Fine-tune from this layer onwards
fine_tune_at = 50

# Freeze all the layers before the `fine_tune_at` layer
for layer in base_model_output.layers[:fine_tune_at]:
    layer.trainable = False

In [None]:
# reduce learning rate by factor of ten
second_tr_lr = base_learning_rate/10

In [None]:
# set up model but with second learning rate
model.compile(optimizer = tf.keras.optimizers.Adam(lr=second_tr_lr),     # reduce lr by a factor of 10! LR is 0.00001 here then
              metrics=tf.keras.metrics.CategoricalAccuracy(),
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=True))

In [None]:
model.summary()

In [None]:
len(model.trainable_variables) # more trainable parameters because we are tuning the base mobilenet now

In [None]:
fine_tune_epochs = 10 # short training period
total_epochs =  initial_epochs + fine_tune_epochs # total training 

In [None]:
# Create a callback that saves the model's weights as a checkpoint
# Checkpoints use less memory and speed up training - can compile model after training
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True, # checkpoints not full model
                                                 save_best_only=True,  # save the best model based on what's being monitored
                                                 monitor='val_categorical_accuracy',
                                                 verbose=1)

In [None]:
# second full fine-tune learning
history_fine = model.fit(train_ds,
                         epochs=total_epochs,
                         initial_epoch =  history.epoch[-1],
                         validation_data = validation_ds,
                         steps_per_epoch = np.ceil(image_count/BATCH_SIZE),
                         callbacks=[cp_callback]) # added this callback for checkpointing

In [None]:
acc = history_fine.history['categorical_accuracy']
val_acc = history_fine.history['val_categorical_accuracy']

loss = history_fine.history['loss']
val_loss = history_fine.history['val_loss']

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0,1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()