# Postoperative movement classification of mice

In [None]:
import tensorflow as tf
import keras
import numpy as np
import importlib
import keras
from keras import layers
from keras import backend as K
import pickle as pkl
import matplotlib.pyplot as plt

from mice_wellbeing import data_handler
from mice_wellbeing import model_builder
from mice_wellbeing import helper_func

In [None]:
# reload any libraries after change to them without restarting notebook kernel
importlib.reload(data_handler)
importlib.reload(model_builder)
importlib.reload(helper_func)

## Load and prepare data

In [None]:
# lookup for label of each video
label_lookup = data_handler.create_label_lookup(path="./video/Homecage_Observation")

In [None]:
# load .csv files with joint coordinates
X_pos, X_rel, y = data_handler.load_data_and_labels(datapath="/path/to/DLC_project_folder/videos")

In [None]:
X_pos[0].shape

In [None]:
# slice data into more but shorter windows
X_pos_split, X_rel_split, y_split = data_handler.window_data(X_pos, X_rel, y, num_frames=250, overlap=150, start_buffer=0, end_buffer=0)

In [None]:
# look at index 0 for total sample count after slicing
X_pos_split.shape

In [None]:
# shuffle slices
permutation = np.random.permutation(X_pos_split.shape[0])
np.take(X_pos_split, permutation, axis=0, out=X_pos_split)
np.take(X_rel_split, permutation, axis=0, out=X_rel_split)
np.take(y_split, permutation, axis=0, out=y_split)

In [None]:
# split data into train, val and test
train_pct = 0.7
val_pct = 0.15
# test_pct is the remainder to 1

X_pos_train, X_pos_val, X_pos_test = data_handler.np_train_val_test_split(X_pos_split, train_pct=train_pct, val_pct=val_pct, ds_length=X_pos_split.shape[0])
X_rel_train, X_rel_val, X_rel_test = data_handler.np_train_val_test_split(X_rel_split, train_pct=train_pct, val_pct=val_pct, ds_length=X_pos_split.shape[0])
y_train, y_val, y_test = data_handler.np_train_val_test_split(y_split, train_pct=train_pct, val_pct=val_pct, ds_length=X_pos_split.shape[0])

In [None]:
# check for approx. equal distribution of classes in train
np.unique(y_train, return_counts=True)

In [None]:
# tensors for training
pos_ds_train =  tf.data.Dataset.from_tensor_slices((X_pos_train.copy(), y_train.copy()), name="pos_train")
pos_ds_val =  tf.data.Dataset.from_tensor_slices((X_pos_val.copy(), y_val.copy()), name="pos_val")
pos_ds_test =  tf.data.Dataset.from_tensor_slices((X_pos_test.copy(), y_test.copy()), name="pos_test")

rel_ds_train =  tf.data.Dataset.from_tensor_slices((X_rel_train.copy(), y_train.copy()), name="rel_train")
rel_ds_val =  tf.data.Dataset.from_tensor_slices((X_rel_val.copy(), y_val.copy()), name="rel_val")
rel_ds_test =  tf.data.Dataset.from_tensor_slices((X_rel_test.copy(), y_test.copy()), name="rel_test")

comb_ds_train = tf.data.Dataset.from_tensor_slices(((X_pos_train.copy(), X_rel_train.copy()), y_train.copy()), name="comb_train")
comb_ds_val = tf.data.Dataset.from_tensor_slices(((X_pos_val.copy(), X_rel_val.copy()), y_val.copy()), name="comb_val")
comb_ds_test = tf.data.Dataset.from_tensor_slices(((X_pos_test.copy(), X_rel_test.copy()), y_test.copy()), name="comb_test")

In [None]:
# shuffle and batch datasets 
batch_size = 128
pos_ds_train = pos_ds_train.shuffle(50000, reshuffle_each_iteration=True).batch(batch_size).shuffle(50000, reshuffle_each_iteration=True)
rel_ds_train = rel_ds_train.shuffle(50000, reshuffle_each_iteration=True).batch(batch_size).shuffle(50000, reshuffle_each_iteration=True)
comb_ds_train = comb_ds_train.shuffle(50000, reshuffle_each_iteration=True).batch(batch_size).shuffle(50000, reshuffle_each_iteration=True)

pos_ds_val = pos_ds_val.batch(batch_size)
pos_ds_test = pos_ds_test.batch(batch_size)
rel_ds_val = rel_ds_val.batch(batch_size)
rel_ds_test = rel_ds_test.batch(batch_size)
comb_ds_val = comb_ds_val.batch(batch_size)
comb_ds_test = comb_ds_test.batch(batch_size)

## Create and train model

In [None]:
# create a 2-stream model to use with comb_ds
# !overwrites previous models with same var name!
# 1-stream model definition below
pos_input = keras.Input(X_pos_split.shape[1:], name="pos_input")
rel_input = keras.Input(X_rel_split.shape[1:], name="rel_input")

# for shared skeleton transformer
#-----
skeleton_transformer = model_builder.SkeletonTransformerLayerV2((X_pos_split.shape[1:]), name="comb_transformer")

pos_x = skeleton_transformer(pos_input)
rel_x = skeleton_transformer(rel_input)
#-----

# for separat skeleton transformers
#-----
# pos_x = model_builder.SkeletonTransformerLayerV2((X_pos_split.shape[1:]), name="pos_transformer")(pos_input)
# rel_x = model_builder.SkeletonTransformerLayerV2((X_pos_split.shape[1:]), name="rel_transformer")(rel_input)
#-----

pos_x = layers.Conv2D(8, (10, 100), data_format="channels_first", padding="same", name="pos_conv")(pos_x)
rel_x = layers.Conv2D(8, (10, 100), data_format="channels_first", padding="same", name="rel_conv")(rel_x)

# pooling can be removed here
#-----
pos_x = layers.MaxPool2D((1, 2), data_format="channels_first", name="pos_pooling")(pos_x)
rel_x = layers.MaxPool2D((1, 2), data_format="channels_first", name="rel_pooling")(rel_x)
#-----

pos_x = keras.Model(inputs=pos_input, outputs=pos_x, name="pos_model")
rel_x = keras.Model(inputs=rel_input, outputs=rel_x, name="rel_model")

x = layers.concatenate([pos_x.output, rel_x.output], name="concat_layer", axis=1)

x = layers.Conv2D(4, (5, 25), data_format="channels_first", padding="same", name="comb_conv")(x)

x = layers.Flatten(name="flatten_layer")(x)
x = layers.Dense(4, activation=layers.activation.Softmax(), name="dense_output")(x)

model = keras.Model(inputs=[pos_input, rel_input], outputs=x, name="comb_model")

In [None]:
# create a 1-stream model to use with rel_ds or pos_ds
# !overwrites previous models with same var name!
model = keras.Sequential()

model.add(model_builder.SkeletonTransformerLayerV2((X_pos_split.shape[1:])))

model.add(layers.Conv2D(2, (5, 50), data_format="channels_first", padding="same"))

model.add(layers.MaxPool2D((1,2)))

model.add(layers.Flatten())

model.add(layers.Dense(4, activation=layers.activation.Softmax()))

In [None]:
model.summary()

In [None]:
keras.utils.plot_model(model, rankdir="LR")

In [None]:
# compile model and set optimizer, loss function and additional metrics
model.compile(optimizer="adam",
              loss=[keras.losses.sparse_categorical_crossentropy],
              metrics=["accuracy"])

In [None]:
# define early stopping
early_stopping_cb =  keras.callbacks.EarlyStopping(monitor='val_loss', patience=100)

# define checkpoints
checkpoint_filepath = '/home/thomas/bachelorarbeit/models/ckpt/<model_name>_{epoch:02d}-{val_loss:.2f}.keras'
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    monitor='val_loss',
    mode='auto',
    save_best_only=True,
    initial_value_threshold=0.8
    )

In [None]:
# set learning rate
K.set_value(model.optimizer.learning_rate, 0.000001) # 0.000001

In [None]:
# train model (change train and val variables depending on model)
# comb_ds_train and comb_ds_val for 2-stream
# pos_ds_train and pos_ds_val or rel_ds_train and rel_ds_val for 1-stream
history = model.fit(pos_ds_train, epochs=10000, callbacks=[early_stopping_cb, model_checkpoint_callback], validation_data=pos_ds_val)

In [None]:
# loss graph
loss_delta = np.array(history.history['loss']) - np.array(history.history['val_loss'])
loss_delta = np.abs(loss_delta)

plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.plot(loss_delta)
plt.title("model loss")
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test', "delta"], loc='upper left')
plt.show()

In [None]:
# accuracy graph
accuracy_delta = np.array(history.history["accuracy"]) - np.array(history.history["val_accuracy"])
accuracy_delta = np.abs(accuracy_delta)

plt.plot(history.history["accuracy"])
plt.plot(history.history["val_accuracy"])
plt.plot(accuracy_delta)
plt.title("model accuracy")
plt.ylabel("accuracy")
plt.xlabel("epoch")
plt.legend(["train", "test", "delta"], loc="upper left")
plt.show()

In [None]:
# evaluation on test data (change var name according to what was used in training)
model.evaluate(comb_ds_test)

In [None]:
# save model
model.save("./models/<model_name>.keras")

In [None]:
# save history
with open("./models/<model_name>.pkl", "wb") as file:
    pkl.dump(history, file)

In [None]:
# load model from checkpoint
with keras.utils.CustomObjectScope({"SkeletonTransformerLayerV2": model_builder.SkeletonTransformerLayerV2}):
    loaded_model = keras.models.load_model("./models/ckpt/<checkpoint_name>.keras")

## Test model on excluded videos

In [None]:
importlib.reload(helper_func)

In [None]:
# load model from checkpoint
with keras.utils.CustomObjectScope({"SkeletonTransformerLayerV2": model_builder.SkeletonTransformerLayerV2}):
    loaded_model = keras.models.load_model("./models/ckpt_new/<model_name>.keras")

In [None]:
# apply loaded model on excluded videos
# set function arguments to reflect process used for data preparation for training
result_dict = helper_func.predict_videos(loaded_model, input_mode="comb", num_frames=250, overlap=150, batch_size=128, start_buffer=0, end_buffer=0)

In [None]:
# create additional metrics for excluded videos
count_true = 0
count_false = 0
count_true_weighted = 0
count_false_weighted = 0
count_true_weighted_p_cutoff = 0
count_false_weighted_p_cutoff = 0

for vid, pred_dict in result_dict.items():
    pred_dict["y_guess"] = pred_dict["y_count"][1].argmax()
    pred_dict["y_corr"] = pred_dict["y"] == pred_dict["y_guess"]

    pred_dict["y_guess_weighted"] = pred_dict["y_pred"].sum(axis=0) / pred_dict["y_pred"].shape[0]
    pred_dict["y_corr_weighted"] = pred_dict["y_guess_weighted"].argmax() == pred_dict["y"]

    if pred_dict["y_corr"]:
        count_true += 1
    else:
        count_false += 1
    
    if pred_dict["y_corr_weighted"]:
        count_true_weighted += 1
        if np.max(pred_dict["y_guess_weighted"]) >= 0.4:
            count_true_weighted_p_cutoff += 1
    else:
        count_false_weighted += 1
        if np.max(pred_dict["y_guess_weighted"]) >= 0.4:
            count_false_weighted_p_cutoff += 1

    # print(f'{vid=}\n{pred_dict["y"]=}\n{pred_dict["y_guess"]=}\n{pred_dict["y_corr"]=}\n{pred_dict["y_guess_weighted"]=}\nCorrect: {pred_dict["y_corr_weighted"]}\n')

print(f"Argmax accuracy: {count_true / (count_true + count_false)}\n({count_true=} {count_false=})")
print(f"Weighted accuracy: {count_true_weighted / (count_true_weighted + count_false_weighted)}\n({count_true_weighted=} {count_false_weighted=})")
# print(f"Weighted p cutoff accuracy: {count_true_weighted_p_cutoff / (count_true_weighted_p_cutoff + count_false_weighted_p_cutoff)} on {100 * (count_true_weighted_p_cutoff + count_false_weighted_p_cutoff) / (count_true_weighted + count_false_weighted)}% of Videos\n({count_true_weighted_p_cutoff=} {count_false_weighted_p_cutoff=})")