This notebook enables to train ResNet-50 model.

In [None]:
import csv
import os
import pickle
import random

import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt

from utils.training.data_loading import get_load_spectro_for_class
from utils.training.keras_models import resnet

## Parameters

In [None]:
ROOT_DIR = "/path/to/the/dataset"  # path where we expect to find directories named "postives", "negatives" and a csv file
SEED = 0
BATCH_SIZE = 64
EPOCHS = 50
CHECKPOINTS_DIR = "../../../data/model_saves/ResNet-50"  # directory where the model will save its history and checkpoints

FOLDS = 5   # number of folds for the cross-validation

data_loader = get_load_spectro_for_class(size=224, channels=3)
model = resnet

## Load data

In [None]:
with open(f"{ROOT_DIR}/dataset.csv", "r") as f:
    csv_reader = csv.reader(f, delimiter=",")
    lines = list(csv_reader)
    
pos = [l[0] for l in lines if l[1]=="positive"]
neg = [l[0] for l in lines if l[1]=="negative"]
random.Random(SEED).shuffle(pos)
random.Random(SEED).shuffle(neg)
print(f"{len(pos)} positive files found and {len(neg)} negative files found")

# prepare lists for cross-validation
train_datasets, valid_datasets = [], []
for i in range(FOLDS):
    start_valid_idx = int(len(pos) * i / FOLDS)
    end_valid_idx = int(len(pos) * (i + 1) / FOLDS)
    train_files = np.concatenate((pos[:start_valid_idx], pos[end_valid_idx:], neg[:start_valid_idx], neg[end_valid_idx:])) # unbalanced training set
    valid_files = np.concatenate((pos[start_valid_idx:end_valid_idx], neg[start_valid_idx:end_valid_idx])) # balanced validation set
    random.Random(SEED).shuffle(train_files)
    random.Random(SEED).shuffle(valid_files)
    train_datasets.append(tf.data.Dataset.from_tensor_slices(train_files))
    valid_datasets.append(tf.data.Dataset.from_tensor_slices(valid_files))
    
    train_datasets[-1] = train_datasets[-1].map(data_loader).batch(batch_size=BATCH_SIZE).prefetch(buffer_size=tf.data.experimental.AUTOTUNE).cache()
    valid_datasets[-1] = valid_datasets[-1].map(data_loader).batch(batch_size=BATCH_SIZE).prefetch(buffer_size=tf.data.experimental.AUTOTUNE).cache()

# prepare a dataset with all data to train the model at the end
all_train_files = pos + neg
random.Random(SEED).shuffle(all_train_files)
all_train_dataset = tf.data.Dataset.from_tensor_slices(all_train_files)
all_train_dataset = all_train_dataset.map(data_loader).batch(batch_size=BATCH_SIZE).prefetch(buffer_size=tf.data.experimental.AUTOTUNE).cache()

## Plot the data

In [None]:
cols = 8
lines = 4
batch_number = 1  # number of the batch we want to inspect

to_show = cols * lines
plt.figure(figsize=(20, lines*3))
shown=0
for images, y in valid_datasets[0].take(batch_number+to_show//BATCH_SIZE+1):
    if batch_number:
            batch_number -= 1
            continue
    for i in range(min(BATCH_SIZE, to_show-shown)):
        ax1 = plt.subplot(lines, cols, 1 + shown)
        plt.xlabel("time (s)")
        plt.ylabel("frequency (Hz)")
        plt.imshow(images[i].numpy()[:,:,0], cmap='inferno')
        label = "positive" if y[i]==1 else "negative"
        plt.title(label)

        shown += 1
plt.show()

## Cross-validation training

In [None]:
for i in range(FOLDS):
    path_prefix = f'{CHECKPOINTS_DIR}/FOLD-{i}'
    history_file = f'{path_prefix}/history.pkl'
    
    if os.path.isfile(history_file):
        print(f"fold {i} already has an history file, skipping it")
        continue
        
    print(f"starting training of fold {i}")
    m = model()
    m.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
        loss=tf.losses.binary_crossentropy,
        metrics=['Accuracy','AUC'])

    m.build((BATCH_SIZE, 224, 224, 3))

    if i==0:
        m.summary()

    cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=f'{path_prefix}/cp-{{epoch:04d}}.ckpt', save_weights_only=True, verbose=1)

    history = m.fit(
            train_datasets[i],
            batch_size=BATCH_SIZE,
            validation_data=valid_datasets[i],
            epochs=EPOCHS,
            callbacks=[cp_callback]
        )
    
    with open(history_file, 'wb') as f:
        pickle.dump(history.history, f)

## Training on all data

In [None]:
path_prefix = f'{CHECKPOINTS_DIR}/all'
history_file = f'{path_prefix}/history.pkl'

print(f"starting training with all data")
m = model()
m.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
    loss=tf.losses.binary_crossentropy,
    metrics='Accuracy')

m.build((BATCH_SIZE, 224, 224, 3))

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=f'{path_prefix}/cp-{{epoch:04d}}.ckpt', save_weights_only=True, verbose=1)

history = m.fit(
        all_train_dataset,
        batch_size=BATCH_SIZE,
        epochs=EPOCHS,
        callbacks=[cp_callback]
    )

with open(history_file, 'wb') as f:
        pickle.dump(history.history, f)

## Plot some examples of outputs of the network

In [None]:
m = model()
epoch = 31  # epoch checkpoint that we want to load
m.load_weights(f"{CHECKPOINTS_DIR}/checkpoints/cp-{{epoch:04d}}.ckpt")

In [None]:
cols = 8
lines = 4
batch_number = 1

to_show = cols * lines
plt.figure(figsize=(20, lines*3))
shown=0
for images, y in valid_datasets[0].take(batch_number+to_show//BATCH_SIZE+1):
    if batch_number:
            batch_number -= 1
            continue
    for i in range(min(BATCH_SIZE, to_show-shown)):
        ax1 = plt.subplot(lines, cols, 1 + shown)
        plt.xlabel("time (s)")
        plt.ylabel("frequency (Hz)")
        plt.imshow(images[i].numpy()[:,:,0], cmap='inferno')
        label = "positive" if y[i]==1 else "negative"
        predicted = model.predict(np.reshape(images[i], (1, 224, 224, 3)), verbose=False)[0]
        predicted_label = "positive" if predicted>=0.5 else "negative"
        plt.title(f"{predicted_label}/{label}")

        shown += 1
plt.show()