This notebook enables to train AcousticPhaseNet model.

In [None]:
import random
import csv
import os
import pickle

import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt

from sklearn.utils import shuffle

from src.utils.training.metrics import accuracy_for_segmenter, AUC_for_segmenter
from src.utils.training.data_loading import get_line_to_dataset_waveform
from src.utils.training.keras_models import AcousticPhaseNet

## Parameters

In [None]:
ROOT_DIR = "/path/to/the/dataset"  # path where we expect to find directories named "postives", "negatives" and a csv file
SEED = 0  # Seed for RNG
BATCH_SIZE = 32
EPOCHS = 50
CHECKPOINTS_DIR = "../../../../data/model_saves/AcousticPhaseNet"  # directory where the model will save its history and checkpoints

FOLDS = 5  # number of folds for the cross-validation
SIZE = int(2**(np.ceil(np.log2(100*240+1))))  # number of points in each file rounded to the next pow of 2
DURATION_S = 100  # duration of the files in s
OBJECTIVE_CURVE_WIDTH = 10  # defines width of objective function in s

data_loader = get_line_to_dataset_waveform(size=SIZE, duration_s=DURATION_S, objective_curve_width=OBJECTIVE_CURVE_WIDTH)
model = AcousticPhaseNet

## Load data

In [None]:
# open the csv listing data, shuffling the lines
with open(ROOT_DIR + "/dataset.csv", "r") as f:
    csv_reader = csv.reader(f, delimiter=",")
    lines = list(csv_reader)
random.Random(SEED).shuffle(lines)
print(len(lines), "files found")

# load data
pos = [l for l in lines if l[1]=="positive"]
xpos, ypos = data_loader(pos)
neg = [l for l in lines if l[1]=="negative"]
xneg, yneg = data_loader(neg)
print(f"{len(xpos)} positive files found and {len(xneg)} negative files found")

# merge and shuffle positives and negatives
xd = np.concatenate((xpos, xneg[:len(xpos)]))
extra_x = xneg[len(xpos):]
yd = np.concatenate((ypos, yneg[:len(ypos)]))
extra_y = yneg[len(ypos):]
xd, yd = shuffle(xd, yd)

## Plot the data

In [None]:
cols = 8
lines = 4
batch_number = 1  # number of the batch we want to inspect

to_show = cols * lines
plt.figure(figsize=(cols*2.5, lines*5))
shown=0
for i in range(batch_number*BATCH_SIZE, batch_number*BATCH_SIZE+to_show):
    x, y = xd[i], yd[i]
    
    ax1 = plt.subplot(lines*2, cols, 1 + shown%cols + cols*2*(shown//cols))
    plt.xlabel("time")
    plt.ylabel("pressure")
    plt.plot(x)
    ax1.set_xlim([0, SIZE])

    ax2 = plt.subplot(lines*2, cols, 1 + shown%cols + cols*2*(shown//cols) + cols)

    ax2.plot(y, label='ground truth')
    ax2.legend(loc="upper left")
    ax2.set_xlim([0, SIZE])
    ax2.set_ylim([0, 1])
    ax2.set_xlabel("time")
    ax2.set_ylabel("probability")

    shown += 1
plt.show()

## Cross-validation training

In [None]:
for i in range(FOLDS):
    path_prefix = f'{CHECKPOINTS_DIR}/FOLD-{i}'
    history_file = f'{path_prefix}/history.pkl'
    
    if os.path.isfile(history_file):
        print(f"fold {i} already has an history file, skipping it")
        continue
        
    print(f"starting training of fold {i}")
    m = model()
    m.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=0.0005),
            loss=tf.losses.binary_crossentropy,
            metrics=[accuracy_for_segmenter, AUC_for_segmenter()])
    m.build((BATCH_SIZE, SIZE))

    if i==0:
        m.summary()

    cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=f'{path_prefix}/cp-{{epoch:04d}}.ckpt', save_weights_only=True, verbose=1)

    # we make the folds right before using them to save memory
    start_valid_idx = int(len(xd) * i / FOLDS)
    end_valid_idx = int(len(xd) * (i + 1) / FOLDS)
    
    x_train = np.concatenate((xd[:start_valid_idx], xd[end_valid_idx:]))
    y_train = np.concatenate((yd[:start_valid_idx], yd[end_valid_idx:]))
    
    x_valid = xd[start_valid_idx:end_valid_idx]
    y_valid = yd[start_valid_idx:end_valid_idx]
    
    x_train, y_train = shuffle(x_train, y_train)
    x_valid, y_valid = shuffle(x_valid, y_valid)
    y_train = np.reshape(y_train, (-1, SIZE))
    y_valid = np.reshape(y_valid, (-1, SIZE))

    history = m.fit(x_train, y_train,
            batch_size=BATCH_SIZE,
            validation_data=(x_valid,y_valid),
            epochs=EPOCHS,
            callbacks=[cp_callback]
        )
    
    with open(history_file, 'wb') as f:
        pickle.dump(history.history, f)

## Training on all the data

In [None]:
all_train_x, all_train_y = np.concatenate((xd, extra_x)), np.concatenate((yd, extra_y))
all_train_x, all_train_y = shuffle(all_train_x, all_train_y)

m = model()

m.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.0005),
        loss=tf.losses.binary_crossentropy,
        metrics=[accuracy_for_segmenter, AUC_for_segmenter()])

m.build((BATCH_SIZE, SIZE))

m.summary()

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=f"{CHECKPOINTS_DIR}/cp-{{epoch:04d}}.ckpt",
                                                     save_weights_only=True,
                                                     verbose=1)



m.fit(
        all_train_x, all_train_y,
        batch_size=BATCH_SIZE,
        epochs=EPOCHS,
        callbacks=[cp_callback]
    )

## Plot some examples of outputs of the network

In [None]:
m = model()
epoch = 22  # epoch checkpoint that we want to load
m.load_weights(f"{CHECKPOINTS_DIR}/checkpoints/cp-{{epoch:04d}}.ckpt")

In [None]:
cols = 4
lines = 2
to_skip = 2

to_show = cols * lines
to_skip *= to_show
plt.figure(figsize=(cols*5, lines*10))
shown=0
for i in range(to_skip, to_skip+to_show):
    x, y = xd[i], yd[i]
    
    ax1 = plt.subplot(lines*2, cols, 1 + shown%cols + cols*2*(shown//cols))
    plt.xlabel("time")
    plt.ylabel("normalized pressure")
    plt.plot(x)
    ax1.set_xlim([0, SIZE])

    ax2 = plt.subplot(lines*2, cols, 1 + shown%cols + cols*2*(shown//cols) + cols)
    
    predicted = m.predict(np.reshape(x, (1, SIZE)), verbose=False)[0]
    ax2.plot(predicted, label='predicted')
    ax2.plot(y, label='ground truth')
    ax2.legend(loc="upper right")
    ax2.set_xlim([0, SIZE])
    ax2.set_ylim([0, 1])
    ax2.set_xlabel("time")
    ax2.set_ylabel("probability")

    shown += 1
plt.show()