In [None]:
import csv
import random

import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt

from utils.training import data_loading
from utils.training.data_loading import list_to_generator_seg, load_spectro_for_seg
from utils.training.models import time_segmenter_model

In [None]:
ROOT_DIR = "/media/plerolland/LaBoite/PublicData/training/classification/spectrograms"
SEED = 0
BATCH_SIZE = 64
EPOCHS = 50
CHECKPOINTS_DIR = "checkpoints_seg"

In [None]:
with open(f"{ROOT_DIR}/dataset.csv", "r") as f:
    csv_reader = csv.reader(f, delimiter=",")
    lines = list(csv_reader)
pos = [l for l in lines if l[1]=="positive"]
neg = [l for l in lines if l[1]=="negative"]
files = pos + neg[:len(pos)]
print(len(files), "files found")

random.Random(SEED).shuffle(files)
train_files = files[int(0.2 * len(files)):]
valid_files = files[:int(0.2 * len(files))]
valid_generator = list_to_generator_seg(valid_files)
train_generator = list_to_generator_seg(train_files)
train_dataset = tf.data.Dataset.from_generator(lambda: map(tuple, train_generator), output_signature=tf.TensorSpec(shape=[None], dtype=tf.string))
valid_dataset = tf.data.Dataset.from_generator(lambda: map(tuple, valid_generator), output_signature=tf.TensorSpec(shape=[None], dtype=tf.string))
# enable cache usage and prefetching for performance
AUTOTUNE = tf.data.experimental.AUTOTUNE
train_dataset = train_dataset.map(load_spectro_for_seg).batch(batch_size=BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)
valid_dataset = valid_dataset.map(load_spectro_for_seg).batch(batch_size=BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)

In [None]:
cols = 8
lines = 4
batch_number = 2

to_show = cols * lines
plt.figure(figsize=(20, lines*5))
shown=0
for images, y in valid_dataset.take(batch_number+to_show//BATCH_SIZE+1):
    if batch_number:
            batch_number -= 1
            continue
    for i in range(min(BATCH_SIZE, to_show-shown)):
        ax1 = plt.subplot(lines*2, cols, 1 + shown%cols + cols*2*(shown//cols))
        plt.xlabel("time (s)")
        plt.ylabel("frequency (Hz)")
        plt.imshow(images[i].numpy(), cmap='inferno')

        ax2 = plt.subplot(lines*2, cols, 1 + shown%cols + cols*2*(shown//cols) + cols)

        ax2.plot(y[i], label='ground truth')
        ax2.legend(loc="upper left")
        ax2.set_xlim([0, 128])
        ax2.set_ylim([0, 1])
        ax2.set_xlabel("time (s)")
        ax2.set_ylabel("probability")

        shown += 1
plt.show()

In [None]:
model = time_segmenter_model()

model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
        loss=tf.losses.binary_crossentropy,
        metrics='MeanAbsoluteError')

model.build((BATCH_SIZE, data_loading.SPECTRO_SIZE[0], data_loading.SPECTRO_SIZE[1], 1))

model.summary()

In [None]:
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath="{CHECKPOINTS_DIR}/cp-{epoch:04d}.ckpt",
                                                     save_weights_only=True,
                                                     verbose=1)

In [None]:
model.fit(
        train_dataset,
        batch_size=BATCH_SIZE,
        validation_data=valid_dataset,
        steps_per_epoch=len(train_files) // BATCH_SIZE,
        validation_steps=len(valid_files) // BATCH_SIZE,
        epochs=EPOCHS,
        callbacks=[cp_callback]
    )

In [None]:
cols = 8
lines = 8
batch_number = 15

to_show = cols * lines
plt.figure(figsize=(20, lines*5))
shown=0
for images, y in valid_dataset.take(batch_number+to_show//BATCH_SIZE+1):
    if batch_number:
            batch_number -= 1
            continue
    for i in range(min(BATCH_SIZE, to_show-shown)):
        ax1 = plt.subplot(lines*2, cols, 1 + i%cols + cols*2*(i//cols))
        plt.xlabel("time (s)")
        plt.ylabel("frequency (Hz)")
        plt.imshow(images[i].numpy(), cmap='inferno')

        ax2 = plt.subplot(lines*2, cols, 1 + i%cols + cols*2*(i//cols) + cols)
        
        ax2.plot(y[i], label='ground truth')
        predicted = model.predict(np.reshape(images[i], (1, 128, 128, 1)), verbose=False)[0]
        ax2.plot(predicted, label='predicted')
        ax2.legend(loc="upper left")
        ax2.set_xlim([0, 128])
        ax2.set_ylim([0, 1])
        ax2.set_xlabel("time (s)")
        ax2.set_ylabel("probability")

        shown += 1
plt.show()