In [None]:
import os
import sys
import yaml
from scipy.io import loadmat, savemat

import numpy as np
import torch

from dannce.cli import get_parser, build_clarg_params, load_params
from dannce.interface import dannce_predict, sdannce_predict, com_predict
from dannce.engine.utils.vis import visualize_pose_predictions
from dannce.engine.data.io import load_sync

In [None]:
RUNNING_MODE = "predict"
MODEL_TYPE = "sdannce"

EXPROOT = '/mnt/d/Project/SDANNCE-Models/666-6CAM/SD-20250910-c55toe1-5cam/'

MODEL_CHECKPOINT = "/mnt/d/Project/SDANNCE-Models/555-5CAM/Weights/checkpoint-SDANNCE-5cam-best.pth"
COM_CHECKPOINT = "/mnt/d/Project/SDANNCE-Models/555-5CAM/Weights/checkpoint-COM-5cam-best-epoch1000.pth"

CONFIG_FOLDER = '/home/bezver/sdannce/configs/'

N_ANIMALS = 1
N_VIEWS = 5
INFERENCE_ALL = False
STEPSIZE = 10000 # Decrease this if encountered OOM
SKIP_VIDEO = True

# Only used when INFERENCE_ALL is FALSE
STARTFRAME = 64000
ENDFRAME = -1

In [None]:
if N_VIEWS == 4:
    CONFIG = os.path.join(CONFIG_FOLDER, "custom_sdannce_config.yaml")
    COM_CONFIG = os.path.join(CONFIG_FOLDER, "custom_com_config.yaml")
else:
    CONFIG = os.path.join(CONFIG_FOLDER, f"custom{N_VIEWS}_sdannce_config.yaml")
    COM_CONFIG = os.path.join(CONFIG_FOLDER, f"custom{N_VIEWS}_com_config.yaml")

In [None]:
def generate_io_yaml(
    exproot=EXPROOT,
):
    os.makedirs(exproot, exist_ok=True)

    config = {
        'com_train_dir': './COM/train00',
        'com_predict_dir': './COM/predict00',
        'com_exp': None,
        'exp': None,
        'dannce_train_dir': './DANNCE/train00/',
        'dannce_predict_dir': './DANNCE/predict00/',
        'com_file': './COM/predict00/com3d.mat',
        'use_npy': True,
        'rand_view_replace': True,
        'mirror_augmentation': False,
        'augment_hue': False,
        'augment_brightness': False
    }
    
    output_file = os.path.join(exproot, 'io.yaml')
    with open(output_file, 'w') as f:
        yaml.dump(config, f, default_flow_style=False, sort_keys=False)
    
    print(f"YAML config saved to: {output_file}")

In [None]:
def predict_com(
    exproot=EXPROOT,
    n_animals=N_ANIMALS,
    com_checkpoint=COM_CHECKPOINT,
    config=COM_CONFIG
):
    # parameter arguments
    arguments = {
            "com-predict-weights": com_checkpoint,
            "com-predict-dir": "./COM/predict00",
            "batch-size": 16,
            "n-instances": n_animals,
    }
    
    os.chdir(exproot)

    cmds = ['com',RUNNING_MODE, "com", config]

    for k, v in arguments.items():
        cmds += [f"--{k}", str(v)]

    sys.argv = cmds
    args = get_parser()
    params = build_clarg_params(
        args,
        dannce_net=(args.mode == "dannce") | (args.mode == "sdannce"),
        prediction=(args.command == "predict"),
    )
    com_predict(params)

    torch.cuda.empty_cache()

In [None]:
def launch_inference(
    exproot=EXPROOT,
    n_animals=N_ANIMALS,
    config=CONFIG,
    model_checkpoint=MODEL_CHECKPOINT,
    com_file="./COM/predict00/com3d.mat",
    dannce_predict_dir="./DANNCE/predict00",
    max_num_samples=ENDFRAME,
    start_sample=STARTFRAME,
    skip_video=SKIP_VIDEO
): # Adapted from SDANNCE Repo
    # parameter arguments
    arguments = {
        "dannce-predict-model": model_checkpoint,
        "dannce-predict-dir": dannce_predict_dir,
        "com-file": com_file,
        "start-sample": start_sample,
        "max-num-samples": max_num_samples,
        "batch-size": 1,
        "n-instances": n_animals,
    }

    # DANNCE must run within the experiment directory
    os.chdir(exproot)

    # compose the DANNCE command
    cmds = ['dannce', RUNNING_MODE, MODEL_TYPE, config]
    # override default arguments if specified
    for k, v in arguments.items():
        cmds += [f"--{k}", str(v)]

    # set arguments and launch command
    sys.argv = cmds
    args = get_parser()
    params = build_clarg_params(
        args,
        dannce_net=(args.mode == "dannce") | (args.mode == "sdannce"),
        prediction=(args.command == "predict"),
    )
    sdannce_predict(params)

    # please manually clear the CUDA cache to avoid OOM
    torch.cuda.empty_cache()

    if skip_video:
        return
    
    # visualize predictions
    video_path = visualize_pose_predictions(
        exproot=exproot,
        expfolder=args.dannce_predict_dir,
        datafile=f"save_data_AVG{start_sample}.mat",        
        n_frames=max_num_samples - start_sample,
        start_frame=start_sample,
        cameras=','.join(map(str, range(1, N_VIEWS+1))),
        animal="rat16",
        n_animals=n_animals,
        zoom_in=True,
        zoom_window_size=80
    )
    return video_path

In [None]:
def merge_pred(predict_path=None):  # Adapted from SDANNCE Repo
    # Get all of the paths
    if predict_path is None:
        # Try to get it from io.yaml
        params = load_params(os.path.join(EXPROOT, "io.yaml"))
        if params["dannce_predict_dir"] is None:
            raise ValueError(
                "Either predict_path (clarg) or dannce_predict_dir (in io.yaml) must be specified for merge"
            )
        else:
            param_dir = params["dannce_predict_dir"]
            predict_path = os.path.join(EXPROOT, param_dir.split("./")[1])

    print(f"predict_path:{predict_path}")
    pred_files = [
        f for f in os.listdir(predict_path) if f.startswith("save_data_AVG") and f != ("save_data_AVG.mat")
    ]
    pred_inds = [
        int(f.split("save_data_AVG")[-1].split(".")[0])
        for f in pred_files
    ]
    pred_files = [pred_files[i] for i in np.argsort(pred_inds)]
    if len(pred_files) == 0:
        raise FileNotFoundError("No prediction files were found.")

    # Load all of the data
    pred, data, p_max, sampleID = [], [], [], []
    for file in pred_files:
        M = loadmat(os.path.join(predict_path, file))
        pred.append(M["pred"])
        data.append(M["data"])
        p_max.append(M["p_max"])
        sampleID.append(M["sampleID"])
    pred = np.concatenate(pred, axis=0)
    data = np.concatenate(data, axis=0)
    p_max = np.concatenate(p_max, axis=0)
    sampleID = np.concatenate(sampleID, axis=1)

    # save to a single file.
    fn = os.path.join(
        predict_path, "save_data_AVG" + ".mat"
    )
    savemat(
        fn,
        {
            "pred": pred,
            "data": data,
            "p_max": p_max,
            "sampleID": sampleID,
        },
    )

In [None]:
def segmented_and_call_inference(
    inference_all = INFERENCE_ALL,
    sync_path = os.path.join(EXPROOT,"sync_dannce.mat"),
    step_size = STEPSIZE,
    start_frame = STARTFRAME,
    end_frame = ENDFRAME,
):
    data_sync = load_sync(sync_path)
    video_total_frames = len(data_sync[0]['data_frame'][0])

    if end_frame == -1:
        end_frame = video_total_frames
    duration = end_frame - start_frame

    if inference_all:
        start_frame = 0

    if not inference_all and duration <= step_size:
        video_path = launch_inference()
        return video_path

    # Get the total number of frames from the 'data_frame' field in the loaded sync data.
    total_frames = video_total_frames if inference_all else duration
    print(f"Found {total_frames} frames to inference.")

    # Calculate segments
    num_segment = total_frames // step_size
    final_segment = total_frames % step_size
    if final_segment > 0:
        num_segment += 1
    else:
        final_segment = step_size
    
    video_path = []

    # Loop through each segment to perform inference
    for i in range(num_segment):
        start_frame_seg = i * step_size + start_frame
        end_frame_seg = (i+1) * step_size + start_frame if i != num_segment-1 else i * step_size + final_segment + start_frame

        # Progress bar
        percent = (i + 1) / num_segment * 100
        bar_length = 50
        filled_length = int(bar_length * (i + 1) // num_segment)
        bar = '█' * filled_length + '-' * (bar_length - filled_length)

        sys.stdout.write(f'\rProgress: |{bar}| {percent:.2f}% ({i + 1}/{num_segment} segments) -  ')
        sys.stdout.flush()
        
        # Call the inference function for the current segment
        path = launch_inference(max_num_samples=end_frame_seg,start_sample=start_frame_seg)
        video_path.append(f"{path}\n")

    merge_pred()
    sys.stdout.write('\nDone!\n')
    sys.stdout.flush()
    return video_path

In [None]:
if "io.yaml" not in os.listdir(EXPROOT):
    generate_io_yaml()

COM_dir = os.path.join(EXPROOT,"com","predict00")
os.makedirs(COM_dir,exist_ok=True)

if "com3d.mat" not in os.listdir(COM_dir):
    print(f"COM file not found in {COM_dir}, predicting COM now...")
    predict_com()

if "sync_dannce.mat" not in os.listdir(EXPROOT):
    sync_file = [f for f in os.listdir(EXPROOT) if f.endswith("_dannce.mat")]
    if sync_file:
        sync_path = os.path.join(EXPROOT, sync_file[0])
        video_path = segmented_and_call_inference(sync_path=sync_path)
        print(f"Inferenced video saved to {video_path}")
    else:
        print("Error: No sync file found in project folder.")
else:
    video_path = segmented_and_call_inference()
    if not SKIP_VIDEO:
        print(f"Inferenced video saved to {video_path}")