In [3]:
import glob
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import pickle
# From arm
import re
import sys
from copy import deepcopy

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.optim as optim
from scipy.signal import savgol_filter
from sklearn.linear_model import Lasso, LassoCV, LinearRegression, Ridge, RidgeCV
from sklearn.model_selection import KFold, train_test_split
from torch.utils.data import DataLoader
from tqdm import tqdm

import datajoint as dj
from lifting_transformer.lifting_transformer import (
    criterion,
    helper,
    inference,
    model,
    training,
)
from mausspaun.data_processing.dlc import DLC_TO_MUJOCO_MAPPING
from mausspaun.visualization.gui import utils
from mausspaun.visualization.plot_3D_video import plot_split_3d_video

#sys.path.append('../../../mouse-arm')
sys.path.append('/home/markus/projects/arm')
sys.path.append('/home/markus/projects/DataJoint_mathis')
import analysis
import data

data.s1_connect()
from schema import an, dlc, exp, meso, mice, mousearm, shared, traj

plt.style.use('cyhsm')

%load_ext autoreload
%autoreload 2

DataJoint verified plugin `datajoint_connection_hub` detected.
Connecting markus@128.178.51.6:3306


In [None]:
def extract_info(path):
    pattern = r"mouse-(?P<mouse_name>\w+)_day-(?P<day>\d+)_attempt-(?P<attempt>\d+)_camera-(?P<camera>\d+)_part-(?P<part>\d+)_"
    match = re.search(pattern, path)
    if match:
        return match.group("mouse_name"), int(match.group("day")), int(match.group("attempt")), int(
            match.group("camera")), int(match.group("part"))
    else:
        return None


def _get_new_coordinates(data_2d, bodyparts_dlc):
    dlc_camera_1_coordinates = {}
    for source in bodyparts_dlc:
        if source not in DLC_TO_MUJOCO_MAPPING:
            continue
        target = DLC_TO_MUJOCO_MAPPING[source]
        idx = bodyparts_dlc.index(source)
        dlc_camera_1_coordinates[target] = data_2d[..., idx]
    return dlc_camera_1_coordinates


def load_from_file(scan_key_list, save_path='/data/markus/mausspaun/GLM/data/steschema_Xy_base_thesis.npy'):
    if os.path.exists(save_path):
        all_data = np.load(save_path, allow_pickle=True)[()]
    else:
        all_data = {}
        for scan_key in scan_key_list:
            try:
                align = (mousearm.Alignment.Annotations * mousearm.Alignment.Joystick * mousearm.Alignment.Mousearm *
                         mousearm.Alignment.DeepLabCut2D * mousearm.Alignment.Traces)
                print('Running {}'.format(scan_key))
                dj_data = (align & scan_key).fetch1()
                key_string = str(scan_key)  # Convert the scan_key dictionary to a string
            except:
                print('No data in DJ for {}'.format(scan_key))
                continue
            try:
                (joint_angles, joint_torques) = load_angles_and_torques(scan_key)
                assert joint_angles.shape[0] == joint_torques.shape[0]
                all_data[
                    key_string] = dj_data  # This is set here to make sure that only keys are stored that have the full set of data, including joint angles and torques
            except:
                print('Did not find joint angles and torques for {}'.format(scan_key))
                continue
            all_data[key_string]['joint_angles'] = joint_angles
            all_data[key_string]['joint_torques'] = joint_torques
        np.save(save_path, all_data)
    return all_data


def generate_filepaths(base_videos, base_h5, mouse_name, day, attempt, part):
    common_file_part = f"rigVideo_mouse-{mouse_name}_day-{day}_attempt-{attempt}_camera-"
    h5_file_pattern = f"_part-{part}_doe-*_rig-5DLC_resnet50_MackenzieJan21shuffle1_700000.h5"
    mp4_file_pattern = f"_part-{part}_doe-*_rig-5.mp4"

    filepaths = utils.find_paths_on_server(base_videos, base_h5, common_file_part, h5_file_pattern, mp4_file_pattern)

    return filepaths


def save_video(all_pred_positions, labeled_2d_video, eval_key, seq_length, smoothing_window=3):
    all_pred_positions_smooth = deepcopy(all_pred_positions)
    if smoothing_window > 0:
        for key, item in all_pred_positions_smooth.items():
            all_pred_positions_smooth[key] = savgol_filter(item, smoothing_window, 3, axis=0)
    run_name = f"{eval_key['mouse_name']}_{eval_key['day']}_{eval_key['attempt']}_sequence_length{seq_length}"

    plot_split_3d_video(
        labeled_2d_video,
        all_pred_positions,
        cam_positions=all_pred_positions_smooth,
        dpi=150,
        frames=np.arange(0, 50),  #None,  #np.concatenate([np.arange(0, 50), np.arange(400, 450)]),
        fn_save="/data/mausspaun/videos/videos_3D_averageweights/{}".format(run_name))

In [None]:
scan_key_list = (meso.Scan() & (meso.ManualGroup() & {'group_name': 'base_thesis'})).fetch(dj.key)
scan_key_list = scan_key_list[::-1]
for i, key in enumerate(scan_key_list):
    print('%s\t%s' % (str(i), str(key)))
all_data = analysis.load_from_file(scan_key_list, save_path='/data/markus/mausspaun/GLM/data/base_thesis_27Oct2023.npy')

In [None]:
# Rerun joint angles and joint torques
seq_length = 7
for key, animal_data in all_data.items():
    #     if not ((eval(key)['mouse_name'] == 'HoneyBee') and (eval(key)['day'] == 77)):
    #         continue
    print('Running: {}'.format(key))
    eval_key = eval(key)

    data_2d = animal_data['dlc_2d']
    data_2d_muj = _get_new_coordinates(data_2d, animal_data['dlc_bodyparts'])
    inference_preds = inference.run_inference(
        data_2d_muj, seq_length=seq_length, model_weights='averaged_weights_epoch_5_12.pt'
    )  #'weights_withleft_noelb_12_loss0.029_seq7_cutoff0.999_lossweights[1, 25, 1].pt')

    data_3d_muj = np.transpose(np.array([cp for key, cp in inference_preds.items()]), axes=(1, 2, 0))  # (40332, 3, 25)

    all_data[key]['mujoco_dlc_3d'] = data_3d_muj
    all_data[key]['mujoco_bodyparts'] = list(inference_preds.keys())

    # Plot video
    original_filepaths = generate_filepaths('/data/mausspaun/' + 'videos/videos_base/',
                                            '/data/mausspaun/' + 'emissions/', eval_key['mouse_name'], eval_key['day'],
                                            eval_key['attempt'], 0)
    save_video(inference_preds, original_filepaths[0][0], eval_key, seq_length=seq_length, smoothing_window=5)

#     joint_angles_data, joint_angles_dict, _ = run_inverse_kinematics(data_3d_muj, animal_data['mujoco_bodyparts'], scale=0.75)
#     joint_torques = run_jointspace(joint_angles_data.copy(), kp=600, kv=30)

#     all_data[key]['joint_angles'] = np.array(list(joint_angles_dict.values())).T
#     all_data[key]['joint_torques'] = np.array(list(joint_torques.values())).T

#     all_data[key]['joint_torque_names'] = list(joint_torques.keys())
#     all_data[key]['joint_angle_names'] = list(joint_angles_dict.keys())
#np.save('/data/markus/mausspaun/GLM/data/base_thesis_weights_seq7_13Feb2024.npy', all_data)

---
# Average weights

In [4]:
def average_model_weights(model_path1, model_path2, output_path):
    """
    Load two PyTorch models, average their weights, and save the averaged weights.
    
    Parameters:
    - model_path1: Path to the first .pt model file.
    - model_path2: Path to the second .pt model file.
    - output_path: Path where the averaged model weights will be saved.
    """
    # Load the state dicts from the two model files
    state_dict1 = torch.load(model_path1)
    state_dict2 = torch.load(model_path2)

    # Initialize a new state dict for storing averaged weights
    averaged_state_dict = {}

    # Iterate over the keys in the state dict to average the weights
    for key in state_dict1:
        averaged_state_dict[key] = (state_dict1[key] + state_dict2[key]) / 2.0

    # Save the averaged state dict to a new file
    torch.save(averaged_state_dict, output_path)
    print(f"Averaged model weights saved to {output_path}")

In [5]:
model_path1 = "/home/markus/projects/mouse-arm/lifting_transformer/lifting_transformer/weights/weights_epoch_5_loss0.11_seq2_cutoff0.999_lossweights[1, 25, 1, 1e-05].pt"
model_path2 = "/home/markus/projects/mouse-arm/lifting_transformer/lifting_transformer/weights/weights_withleft_noelb_12_loss0.029_seq7_cutoff0.999_lossweights[1, 25, 1].pt"
output_path = "/home/markus/projects/mouse-arm/lifting_transformer/lifting_transformer/weights/averaged_weights_epoch_5_12.pt"

In [6]:
average_model_weights(model_path1, model_path2, output_path)

Averaged model weights saved to /home/markus/projects/mouse-arm/lifting_transformer/lifting_transformer/weights/averaged_weights_epoch_5_12.pt
