In [None]:
import glob
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import pickle
# From arm
import re
import sys

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.optim as optim
from scipy.signal import savgol_filter
from sklearn.linear_model import Lasso, LassoCV, LinearRegression, Ridge, RidgeCV
from sklearn.model_selection import KFold, train_test_split
from torch.utils.data import DataLoader
from tqdm import tqdm

from lifting_transformer.lifting_transformer import (
    criterion,
    data,
    helper,
    inference,
    model,
    training,
)
from mausspaun.data_processing.dlc import DLC_TO_MUJOCO_MAPPING
from mausspaun.visualization.plot_3D_video import plot_split_3d_video

#sys.path.append('../../../mouse-arm')
#sys.path.append('../../../DataJoint_mathis')

plt.style.use('cyhsm')

%load_ext autoreload
%autoreload 2

In [None]:
def extract_info(path):
    pattern = r"mouse-(?P<mouse_name>\w+)_day-(?P<day>\d+)_attempt-(?P<attempt>\d+)_camera-(?P<camera>\d+)_part-(?P<part>\d+)_"
    match = re.search(pattern, path)
    if match:
        return match.group("mouse_name"), int(match.group("day")), int(match.group("attempt")), int(
            match.group("camera")), int(match.group("part"))
    else:
        return None


def prep_ground_truth(full_session_names, gt_dataloader):
    # Prep GT so that we can batch process all labeled frames at the same time
    all_in, all_gt, count = [], [], 0
    all_names = []
    for idx, (session_name, (camera1, camera2, labels, gt_labels)) in enumerate(zip(full_session_names, gt_dataloader)):
        if torch.nansum(~gt_labels.isnan()) == 0:
            continue
        # check for zeros
        tmp_gt = gt_labels.detach().clone()
        tmp_gt[tmp_gt == 0] = float('nan')
        if torch.nansum(~tmp_gt.isnan()) == 0:
            #print('Found zeros')
            continue

        #print(torch.nansum(gt_labels.isnan()))
        if count == 0:
            all_in = camera1
            all_gt = gt_labels
            all_names = [session_name]
        else:
            all_in = torch.concatenate((all_in, camera1), axis=0)
            all_gt = torch.concatenate((all_gt, gt_labels), axis=0)
            all_names = all_names + [session_name]
        count += 1
    print('Finished processing GT')
    gt_dataloader.all_in = all_in
    gt_dataloader.all_gt = all_gt
    gt_dataloader.all_names = all_names
    return gt_dataloader


def save_video(epoch, test_preds, cam_positions, test_loss, cutoff, loss_weights, seq_length):
    all_pred_positions = {key: test_preds[:, -1, i, :] for i, key in enumerate(cam_positions.keys())}
    # if smoothing_window > 0:
    #     for key, item in all_pred_positions.items():
    #         all_pred_positions[key] = savgol_filter(item, smoothing_window, 3, axis=0)
    info = {
        "epoch":
        epoch,
        "loss":
        test_loss,
        "seq_length":
        seq_length,
        "cutoff":
        cutoff,
        "loss_weights":
        str([
            loss_weights['mse'], loss_weights['continuity'], loss_weights['connectivity'], loss_weights['ground_truth']
        ])
    }
    run_name = "{epoch}_loss{loss:.3f}_seq{seq_length}_cutoff{cutoff}_lossweights{loss_weights}".format(**info)

    labeled_2d_video = '/data/mausspaun/videos/videos_dlc2/rigVideo_mouse-Jaguar_day-19_attempt-1_camera-1_part-0_doe-20180813_rig-5.mp4'
    #labeled_2d_video = '/data/markus/mausspaun/labeled_videos/rigVideo_mouse-HoneyBee_day-77_attempt-1_camera-1_part-6_doe-20180803_rig-5DLC_resnet50_MackenzieJan21shuffle1_700000.mp4'
    plot_split_3d_video(
        labeled_2d_video,
        all_pred_positions,
        cam_positions=cam_positions,
        dpi=150,
        frames=np.arange(0, 500),  #np.concatenate([np.arange(0, 50), np.arange(400, 450)]),
        fn_save="/data/markus/mausspaun/3D/withcut0_{}".format(run_name))


def run_model(train_dataloader,
              train_gt_dataloader,
              test_dataloader,
              test_gt_dataloader,
              relative_displacements,
              all_aggregated_tensor,
              loss_weights=None,
              num_epochs=25,
              save_weights=False,
              save_video_every=False,
              seq_length=7):
    # Get optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    transformer = model.SimpleTransformer(num_joints=num_joints).to(device)
    train_criterion = criterion.masked_loss_nan
    eval_criterion = criterion.masked_loss_nan
    optimizer = optim.Adam(transformer.parameters(), lr=0.0002)

    # Train the model
    if loss_weights is None:
        loss_weights = {
            "mse": 1,
            "continuity": 25,
            "connectivity": 1,
            "ground_truth": 0.00001,
        }
    loss_weights_str = str(
        [loss_weights['mse'], loss_weights['continuity'], loss_weights['connectivity'], loss_weights['ground_truth']])
    losses, predictions = [], []
    for epoch in range(num_epochs):
        train_loss = training.train(
            transformer,
            train_dataloader,
            train_gt_dataloader,
            device,
            train_criterion,
            optimizer,
            loss_weights,
            relative_displacements,
            all_aggregated_tensor,
        )
        print(f"Epoch {epoch+1} training loss: {train_loss:.5f}")
        test_loss, test_gt_loss, test_gt_loss_norel, test_preds, test_gt_preds = inference.evaluate(
            transformer, test_dataloader, test_gt_dataloader, device, eval_criterion, relative_displacements)
        print(
            f"Epoch {epoch+1} test loss: {test_loss:.5f}, ground truth loss (sum): {torch.sum(test_gt_loss):.5f}, ground truth loss (mean): {torch.mean(test_gt_loss):.5f}, TS: {torch.mean(test_gt_loss[:,-1,...]):.5f}, ground truth loss (norel): {torch.mean(test_gt_loss_norel[:,-1,...]):.5f}"
        )

        # Save losses
        test_loss_norel = torch.mean(test_gt_loss_norel[:, -1, ...]).cpu().detach().numpy()
        losses.append((train_loss, torch.mean(test_gt_loss[:, -1, ...]).cpu().detach().numpy(), test_loss_norel))
        predictions.append((test_gt_preds, test_preds))

        # Save video and weights
        if save_video_every > 0:
            if not ((epoch + 1) % save_video_every):
                save_video(epoch, test_preds, cam_positions, test_loss, cutoff, loss_weights, seq_length)

                if save_weights:
                    torch.save(
                        transformer.state_dict(),
                        f"/data/markus/mausspaun/3D/weights_withFixNewGT_{epoch}_loss{test_loss_norel}_seq{seq_length}_cutoff0.999_lossweights{loss_weights_str}.pt"
                    )

    return np.array(losses), predictions


def get_dataloaders(c1_train, c1_test, c2_train, c2_test, y_train, y_test, gt_train, gt_test, seq_length):
    # Create train set
    train_dataset = data.PositionDatasetGT(c1_train, c2_train, y_train, gt_train, seq_length=seq_length)
    train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    train_gt_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=False)

    # Create test set
    test_dataset = data.PositionDatasetGT(c1_test, c2_test, y_test, gt_test, seq_length=seq_length)
    test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=False)
    test_gt_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

    # Prep ground truth so we can batch process
    train_gt_dataloader = prep_ground_truth(full_session_names, train_gt_dataloader)
    test_gt_dataloader = prep_ground_truth(full_session_names, test_gt_dataloader)

    return (train_dataset, train_dataloader, train_gt_dataloader), (test_dataset, test_dataloader, test_gt_dataloader)


def get_groundtruth(reload=False):
    save_path = '/data/markus/mausspaun/nn_training_data/ground_truth_data.pkl'
    if os.path.exists(save_path) and not reload:
        (full_gt, full_session_names, labeled_per_session, session_with_gt) = pickle.load(open(save_path, 'rb'))
        return (full_gt, full_session_names, labeled_per_session, session_with_gt)

    cam1_paths, cam2_paths = helper.get_paths()
    session_with_gt, labeled_per_session = [], []
    for idx, (cam1_path, cam2_path) in enumerate(zip(cam1_paths, cam2_paths)):
        mouse_name, day, attempt, camera, part = extract_info(cam1_path)
        gt = helper.load_and_process_ground_truth(mouse_name, day, attempt, part, to_mujoco=True)
        print('Running: {}'.format(cam1_path))
        try:
            X_3d_train, X_2d_c1_train, X_2d_c2_train, cam_positions, likelihood_c1, likelihood_c2 = helper.get_training_data(
                cam1_path, cam2_path, likelihood_cutoff=cutoff)
        except ValueError as e:
            print(e)
            print('Could not load {}'.format(cam1_path))

        X_gt = X_3d_train.copy() * np.NaN
        if gt is not None:  # If not None then we have labeled 3D ground truth
            for frame in gt['frame'].unique():
                for marker in DLC_TO_MUJOCO_MAPPING.keys():
                    if DLC_TO_MUJOCO_MAPPING[marker] not in helper.mausspaun_keys:
                        continue
                    marker_index = helper.mausspaun_keys.index(DLC_TO_MUJOCO_MAPPING[marker])
                    gt_values = gt[(gt['frame'] == frame) & (gt['bodypart'] == marker)][['x', 'y', 'z']].values
                    X_gt[frame, marker_index, :] = gt_values

        labeled_frames = np.nansum(X_gt, axis=(1, 2)) != 0
        num_labeled = np.sum(labeled_frames)
        print(f"Found {num_labeled} frames")

        session_name = {'mouse_name': mouse_name, 'day': day, 'attempt': attempt, 'camera': camera, 'part': part}
        #f"{mouse_name}-day{day}-attempt{attempt}-camera{camera}-part{part}"
        session_names = [session_name] * X_gt.shape[0]
        if idx == 0:
            full_gt = X_gt
            full_session_names = session_names
        else:
            full_gt = np.vstack((full_gt, X_gt))
            full_session_names = full_session_names + session_names
        if not (np.nansum(X_gt) == 0.0):
            session_with_gt.append(session_names[0])
            labeled_per_session.append(num_labeled)
    pickle.dump((full_gt, full_session_names, labeled_per_session, session_with_gt), open(save_path, 'wb'))
    return (full_gt, full_session_names, labeled_per_session, session_with_gt)

In [None]:
# Preprocess data for model training
cutoff = 0.0
full_X_c1, full_X_c2, full_y = helper.load_data(
    reload=False, path='/data/markus/mausspaun/nn_training_data/data_with_left_{}.pkl'.format(cutoff), cutoff=cutoff)
print('X Camera 1: {}, X Camera 2: {}, Y: {}'.format(full_X_c1.shape, full_X_c2.shape, full_y.shape))

---
# Get groundtruth

In [None]:
(full_gt, full_session_names, labeled_per_session, session_with_gt) = get_groundtruth(reload=False)

In [None]:
seq_length = 13
num_joints = 32

# Get relative displacements
(_, _, _, cam_positions, likelihood_c1, likelihood_c2) = helper.load_test_data()
all_aggregated_tensor, relative_displacements = helper.get_relative_displacements(cam_positions)

In [None]:
#session_with_gt = [{'mouse_name': 'HoneyBee', 'day': 77, 'attempt': 1, 'camera': 1, 'part': 6}]
#session_with_gt = [{'mouse_name': 'HoneyBee', 'day': 77, 'attempt': 1, 'camera': 1, 'part': 0}]
session_with_gt = [{'mouse_name': 'Jaguar', 'day': 19, 'attempt': 1, 'camera': 1, 'part': 0}]

In [None]:
loss_weights = {"mse": 1, "continuity": 25, "connectivity": 1, "ground_truth": 1}
gt_weights = [0.0001]  #[0.0001] #[1, 0.1, 0.01, 0.001, 0.0001, 0.00001]
num_epochs = 10
use_full_training = False

losses = []
for gt in gt_weights:
    loss_weights["ground_truth"] = gt
    print(loss_weights)

    cv_losses = []
    for cv_fold, session in enumerate(session_with_gt):
        test_index = [i for i, x in enumerate(full_session_names) if x == session]
        train_index = [i for i, x in enumerate(full_session_names) if x != session]

        # Splitting the data for this fold
        c1_train, c1_test = full_X_c1[train_index], full_X_c1[test_index]
        c2_train, c2_test = full_X_c2[train_index], full_X_c2[test_index]
        y_train, y_test = full_y[train_index], full_y[test_index]
        gt_train, gt_test = full_gt[train_index], full_gt[test_index]

        print('Size of training: {}, Size of Test: {}'.format(y_train.shape, y_test.shape))

        # Ensure test set has ground truth data
        # assert np.nansum(gt_test) != 0, f"No GT in test set for fold {fold}"
        #         if np.nansum(gt_test) == 0:
        #             print('No test data in fold: {}'.format(cv_fold))
        #             continue

        if use_full_training:
            (train_dataset, train_dataloader,
             train_gt_dataloader), (test_dataset, test_dataloader,
                                    test_gt_dataloader) = get_dataloaders(full_X_c1, c1_test, full_X_c2, c2_test,
                                                                          full_y, y_test, full_gt, gt_test, seq_length)
        else:
            (train_dataset, train_dataloader,
             train_gt_dataloader), (test_dataset, test_dataloader,
                                    test_gt_dataloader) = get_dataloaders(c1_train, c1_test, c2_train, c2_test, y_train,
                                                                          y_test, gt_train, gt_test, seq_length)

        model_losses, predictions = run_model(train_dataloader,
                                              train_gt_dataloader,
                                              test_dataloader,
                                              test_gt_dataloader,
                                              relative_displacements,
                                              all_aggregated_tensor,
                                              loss_weights,
                                              save_weights=True,
                                              seq_length=seq_length,
                                              num_epochs=num_epochs,
                                              save_video_every=3)
        cv_losses.append(model_losses)
    losses.append(cv_losses)
losses = np.array(losses)
pickle.dump((gt_weights, losses, labeled_per_session, session_with_gt),
            open('./losses_full_cut0_cv_to{}_sessions{}.p'.format(num_epochs, len(session_with_gt)), 'wb'))

In [None]:
STOP

In [None]:
%debug

In [None]:
STOP

In [None]:
(gt_weights, losses, labeled_per_session,
 session_with_gt) = pickle.load(open('./losses_full2_cv_to50_sessions1.p', 'rb'))

In [None]:
epochs = np.arange(0, 50)

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 9))
axes = axes.flatten()

# Plotting Train and Test Sum in the first subplot
parameters = gt_weights
colors = plt.cm.tab10(np.linspace(0, 1, len(parameters) + 1))
for i, (color, param) in enumerate(zip(colors, parameters)):
    means = np.mean(losses[i, :, :, 2], axis=0)  # Test Mean
    stds = np.std(losses[i, :, :, 2], axis=0)
    axes[1].errorbar(epochs, means, yerr=stds / 100, label=f"Param {param}", color=color, capsize=4)

for i, (color, param) in enumerate(zip(colors, parameters)):
    means = np.mean(losses[i, :, :, 1], axis=0)  # Test Mean
    stds = np.std(losses[i, :, :, 1], axis=0)
    axes[0].errorbar(epochs, means, yerr=stds / 100, label=f"Param {param}", color=color, capsize=4)

best_weight = np.argmin(np.mean(losses, axis=(1, 2, 3)))
print('Best Weight: {}'.format(gt_weights[best_weight]))

legend_labels = [
    f"{sgt['mouse_name']}-Day{sgt['day']} | #Labeled: {labeled}"
    for sgt, labeled in zip(session_with_gt, labeled_per_session)
]
axes[2].plot(losses[best_weight, :, :, 1].T)
axes[2].legend(legend_labels, fontsize=7)
axes[3].plot(losses[best_weight, :, :, 2].T)
#axes[3].legend(legend_labels, fontsize=7)

# Setting titles, labels and legends
axes[0].set_title("Test Last Timestep")
axes[0].legend(fontsize=9)

axes[1].set_title("Test Last Timestep No Rel. Displacement")
axes[1].legend(fontsize=9)

for ax in axes:
    ax.set_xlabel("Epochs")
    ax.set_ylabel("Loss")

plt.tight_layout()
plt.show()

In [None]:
STOP