# DewanLab DeepLabCut Model Training
## 1. Import Dependencies
#### *You can ignore any errors from tensorflow about oneDNN, cuBLAS, libnvifer, or TensorRT*

In [None]:
import os
from pathlib import Path
os.environ['DLClight']="True"
os.environ['PYTHONPYCACHEPREFIX'] = './tmp'
%matplotlib ipympl
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

import torch

if not torch.cuda.is_available():
    raise Exception("GPU not found!")

try:
    import deeplabcut
except Exception as e:
    print("Error importing deeplabcut!")
    print(e)
    
print("Dependencies successfully imported!")

## 2. Set User Configurables

In [None]:
## Training Dataset Creation
new_shuffle_num = 1

## Training Parameters
display_iters = 10
save_iters = 5000
training_shuffle = 1

## Video Analysis Parameters
video_dir = './videos/Evaluation_Vid'
video_file_extensions = ['avi', 'mp4', 'mkv']
recursive_video_search = False

novel_video_shuffle = 1
save_as_csv = True
trailpoints = 0

output_threshold = 0.8

epochs = [] 

## 3. Get Config Path

In [None]:
current_dir = Path.cwd()
config_path = current_dir.joinpath('config.yaml')

## 4. Create New Training Dataset (If Needed)
#### If continuing a previous round of training, this is not needed!

In [None]:
deeplabcut.create_training_dataset(config_path, num_shuffles=new_shuffle_num)

## 5. Train Model

In [None]:
deeplabcut.train_network(config_path, shuffle=training_shuffle)

## 6: Analyze Videos
### 6a: Get New Videos

In [None]:
# The notebook should be in the root directory of the project
new_videos_dir = Path(video_dir)

if not new_videos_dir.exists():
    raise FileNotFoundError(f'The path {{{new_videos_dir}}} does not exist!')

# Get a list of the videos in "New Video Files"
video_paths = []

for file_extension in video_file_extensions:
    search_string = f'*.{file_extension}'
    if recursive_video_search:
        new_vids = list(new_videos_dir.rglob(search_string))
    else:
        new_vids = list(new_videos_dir.glob(search_string))

    if len(new_vids) > 0:
        video_paths.extend(new_vids)
        
video_paths = [video for video in video_paths if 'labeled' not in video.name]
# Filter out any analyzed videos

video_strings = [str(video) for video in video_paths]

print(f'Found the following video(s): {video_strings}')

### 6b: Process Videos

In [None]:
for video in video_paths:
    try:
        video_type = video.suffix
        deeplabcut.analyze_videos(str(config_path), str(video), shuffle=novel_video_shuffle, save_as_csv=save_as_csv, videotype=video_type)
        # deeplabcut.create_labeled_video(str(config_path), str(video), shuffle=novel_video_shuffle, videotype=video_type, trailpoints=_trailpoints)
    except Exception as e:
        print(f"An error has occurred while processing video {{{video}}}")
        print(e)

## 7: (Optional) Analyze Analysis Output

In [None]:
# Get H5 Files
print(f"Percentile of frames at or above threshold ({output_threshold}) for each video:\n")

_epochs = [str(epoch) for epoch in epochs]

sorted_h5_files = {epoch:[] for epoch in _epochs}

_h5_files = list(Path(video_dir).glob(f'*.h5'))

percentiles = []

for epoch in _epochs:
    for file in _h5_files:
        if epoch in file.stem:
            sorted_h5_files[epoch].append(file)

for epoch in _epochs:
    print(f"Analyzing epoch {epoch}\n=======================\n")    
    epoch_h5s = sorted_h5_files[epoch]
    _percentiles = []
    for h5_file in epoch_h5s:
        df = pd.read_hdf(h5_file)
        # Get likelihood for each component, sum the values above the threshold, and calculate percentile
        nose_like_mask = (df[df.columns[2]] >= output_threshold)
        nose_like_percentile = (nose_like_mask.sum() * 100) / len(nose_like_mask)
        _percentiles.append(nose_like_percentile)
        result = f'Video: {h5_file.stem}\nNose Percentile: {round(nose_like_percentile, 2)}%\n'
        print(result)
    if _percentiles:
        percentiles.append(np.mean(_percentiles))
    else:
        percentiles.append(0)

In [None]:
%matplotlib inline

# We average the percentile for each video and plot it over the epochs
# This gives us a rough approximation of the models performance at accurately predicting the position of the nose

fig, ax = plt.subplots()
ax.plot(epochs, percentiles)
_ = ax.set_xticks(epochs)
ax.set_xlabel('Epochs')
ax.set_ylabel(f'Average Percent of labels >= {output_threshold}')
plt.suptitle('Performance of Model per Epoch')