In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers.experimental import preprocessing

np.set_printoptions(precision=3, suppress=True)

In [None]:
dataset = pd.read_csv('measurements.csv', sep=';', decimal=',')
dataset = dataset.dropna()
dataset.drop(dataset.columns[[0, 1, 2]], axis = 1, inplace = True) 
dataset.tail()

In [None]:
train_dataset = dataset.sample(frac=0.8, random_state=42)
test_dataset = dataset.drop(train_dataset.index)
train_dataset.describe().transpose()

In [None]:
def plot_loss(history, name):
    plt.plot(history.history['loss'], label='mae')
    plt.plot(history.history['val_loss'], label='val_mae')
    #plt.ylim([0, 10])
    plt.xlabel('Epoch')
    plt.ylabel('Loss [' + name + ']')
    plt.legend()
    plt.grid(True)
    
def plot_rmse(history, name):
    plt.plot(history.history['root_mean_squared_error'], label='rmse')
    plt.plot(history.history['val_root_mean_squared_error'], label='val_rmse')
    #plt.ylim([0, 10])
    plt.xlabel('Epoch')
    plt.ylabel('RMSE [' + name + ']')
    plt.legend()
    plt.grid(True)

In [None]:
train_features = train_dataset.copy()
train_ds_glottic = train_features.drop('AP cricoïde', axis=1)
train_target_glottic = train_ds_glottic.pop('antéro-post CV')

test_features = test_dataset.copy()
test_ds_glottic = test_features.drop('AP cricoïde', axis=1)
test_target_glottic = test_ds_glottic.pop('antéro-post CV')

In [None]:
normalizer_glottic = preprocessing.Normalization()
normalizer_glottic.adapt(np.array(train_ds_glottic))

glottic_model = tf.keras.Sequential([
    normalizer_glottic,
    layers.Dense(units=64, activation='relu'),
    layers.Dense(units=1)
])

# glottic_model.summary()

glottic_model.compile(
    optimizer=tf.optimizers.Adam(learning_rate=0.001),
    loss='mse',
    metrics=[tf.keras.metrics.RootMeanSquaredError(),
            tf.keras.metrics.MeanAbsoluteError()])

history = glottic_model.fit(x=train_ds_glottic, y=train_target_glottic,
                            validation_data=(test_ds_glottic, test_target_glottic),
                            epochs=300)

In [None]:
print('Glottic MAE_val: {}'.format(min(history.history['val_mean_absolute_error'])))
print('Glottic RMSE_val: {}'.format(min(history.history['val_root_mean_squared_error'])))

plot_loss(history, 'Glottic')
plt.show()
plot_rmse(history, 'Glottic')

In [None]:
train_features = train_dataset.copy()
train_ds_cricoid = train_features.drop('antéro-post CV', axis=1)
train_target_cricoid = train_ds_cricoid.pop('AP cricoïde')
# train_ds_cricoid = train_ds_cricoid['age en mois']

test_features = test_dataset.copy()
test_ds_cricoid = test_features.drop('antéro-post CV', axis=1)
test_target_cricoid = test_ds_cricoid.pop('AP cricoïde')
# test_ds_cricoid = test_ds_cricoid['age en mois']

In [None]:
normalizer_cricoid = preprocessing.Normalization()
normalizer_cricoid.adapt(np.array(train_ds_cricoid))

cricoid_model = tf.keras.Sequential([
    normalizer_cricoid,
    layers.Dense(units=64, activation='relu'),
    layers.Dense(units=1)
])

# cricoid_model.summary()

cricoid_model.compile(
    optimizer=tf.optimizers.Adam(learning_rate=0.1),
    loss='mse',
    metrics=[tf.keras.metrics.RootMeanSquaredError(),
            tf.keras.metrics.MeanAbsoluteError()])

checkpoint_filepath = '/tmp/checkpoint'
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_accuracy',
    mode='max',
    save_best_only=True)


history = cricoid_model.fit(x=train_ds_cricoid, y=train_target_cricoid,
                            validation_data=(test_ds_cricoid, test_target_cricoid),
                            epochs=300)

In [None]:
print('Cricoid MAE_val: {}'.format(min(history.history['val_mean_absolute_error'])))
print('Cricoid RMSE_val: {}'.format(min(history.history['val_root_mean_squared_error'])))

plot_loss(history, 'Cricoid')
plt.show()
plot_rmse(history, 'Cricoid')