This notebook enables to train TiSSNet model.

In [1]:
import random
import csv

import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt

from utils.training.metrics import accuracy_for_segmenter, AUC_for_segmenter
from utils.training.data_loading import lines_to_line_generator, get_line_to_spectro_seg
from utils.training.keras_models import TiSSNet

In [None]:
print(tf.config.list_physical_devices())

## Parameters

In [3]:
ROOT_DIRS = ["/media/plerolland/LaBoite/PublicData/training/spectrograms",
             "/media/plerolland/LaBoite/PublicData/test/OHASISBIO_2020/STFT",
             "/media/plerolland/LaBoite/PublicData/test/HYDROMOMAR_2013/STFT"]  # paths where we expect to find directories named "positives", "negatives" and a csv file
SEED = 0  # seed used for RND (shuffling)
BATCH_SIZE = 64
EPOCHS = 50
CHECKPOINTS_DIR = "../../../../data/model_saves/TiSSNet_final"  # directory where the model will save its history and checkpoints

SIZE = (128, 186)  # number of pixels in the spectrograms
DURATION_S = 100  # duration of the spectrograms in s
OBJECTIVE_CURVE_WIDTH = 10  # defines width of objective function in s

data_loader = get_line_to_spectro_seg(size=SIZE, duration_s=DURATION_S, channels=1, objective_curve_width=OBJECTIVE_CURVE_WIDTH)
model = TiSSNet

## Load data

In [4]:
# list the data samples
lines = []
for root in ROOT_DIRS:
    with open(root + "/dataset.csv", "r") as f:
        csv_reader = csv.reader(f, delimiter=",")
        lines.extend(list(csv_reader))
print(len(lines), "files found")

random.Random(SEED).shuffle(lines)
train_files = lines[int(0.2*len(lines)):]
valid_files = lines[:int(0.2*len(lines))]
train_generator = lines_to_line_generator(train_files)
train_dataset = tf.data.Dataset.from_generator(lambda: map(tuple, train_generator), output_signature=tf.TensorSpec(shape=[None], dtype=tf.string))
train_dataset = train_dataset.map(data_loader).batch(batch_size=BATCH_SIZE)
valid_generator = lines_to_line_generator(valid_files)
valid_dataset = tf.data.Dataset.from_generator(lambda: map(tuple, valid_generator), output_signature=tf.TensorSpec(shape=[None], dtype=tf.string))
valid_dataset = valid_dataset.map(data_loader).batch(batch_size=BATCH_SIZE)

## Plot of the data

In [5]:
cols = 8
lines = 2
batch_number = 1  # number of the batch we want to inspect

to_show = cols * lines
plt.figure(figsize=(cols*2.5, lines*5))
shown=0
for images, y in train_dataset.take(batch_number+to_show//BATCH_SIZE+1):
    if batch_number:
            batch_number -= 1
            continue
    for i in range(min(BATCH_SIZE, to_show-shown)):
        ax1 = plt.subplot(lines*2, cols, 1 + shown%cols + cols*2*(shown//cols))
        plt.xlabel("time")
        plt.ylabel("frequency")
        plt.imshow(images[i].numpy(), cmap='inferno')

        ax2 = plt.subplot(lines*2, cols, 1 + shown%cols + cols*2*(shown//cols) + cols)

        ax2.plot(y[i], label='ground truth')
        ax2.legend(loc="upper left")
        ax2.set_xlim([0, SIZE[1]])
        ax2.set_ylim([0, 1])
        ax2.set_xlabel("time")
        ax2.set_ylabel("probability")

        shown += 1
plt.show()

## Training on all the data

In [6]:
m = model()

m.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
        loss=tf.losses.binary_crossentropy,
        metrics=[accuracy_for_segmenter, AUC_for_segmenter()])

m.build((BATCH_SIZE, SIZE[0], SIZE[1], 1))

m.summary()

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=f"{CHECKPOINTS_DIR}/checkpoints/cp-{{epoch:04d}}.ckpt",
                                                     save_weights_only=True,
                                                     verbose=1)

m.fit(
        train_dataset,
        batch_size=BATCH_SIZE,
        validation_data=valid_dataset,
        steps_per_epoch=(len(train_files))// BATCH_SIZE,
        validation_steps=(len(valid_files))// BATCH_SIZE,
        epochs=EPOCHS,
        callbacks=[cp_callback]
    )

## Plot some examples of outputs of the network

In [7]:
m.load_weights(f"{CHECKPOINTS_DIR}/checkpoints/cp-0022.ckpt")
m.compile()

In [9]:
cols = 5
lines = 10
batch_number = 0

to_show = cols * lines
plt.figure(figsize=(cols*5, lines*10))
shown=0
for images, y in valid_dataset.take(batch_number+to_show//BATCH_SIZE+1):
    if batch_number:
            batch_number -= 1
            continue
    for i in range(min(BATCH_SIZE, to_show-shown)):
        ax1 = plt.subplot(lines*2, cols, 1 + i%cols + cols*2*(i//cols))
        plt.xlabel("time (s)")
        plt.ylabel("frequency (Hz)")
        plt.imshow(images[i].numpy(), cmap='inferno', extent=(0, 100, 0, 120))

        ax2 = plt.subplot(lines*2, cols, 1 + i%cols + cols*2*(i//cols) + cols)
        
        ax2.plot(y[i], label='ground truth')
        predicted = m.predict(np.reshape(images[i], (1, SIZE[0], SIZE[1], 1)), verbose=False)[0]
        ax2.plot(predicted, label='predicted')
        ax2.legend(loc="upper left")
        ax2.set_xlim([0, SIZE[1]])
        ax2.set_ylim([0, 1])
        ax2.set_xlabel("time (s)")
        ax2.set_ylabel("probability")

        shown += 1
plt.show()