In [1]:
import os
import sys
import yaml

import torch

from dannce.cli import get_parser, build_clarg_params 
from dannce.interface import dannce_predict, sdannce_predict, com_predict
from dannce.engine.utils.vis import visualize_pose_predictions
from dannce.engine.data.io import load_sync

In [2]:
RUNNING_MODE = "predict"
MODEL_TYPE = "sdannce"

EXPROOT = '/mnt/d/Project/SDANNCE-Models/555-5CAM/SD-20250605B'

MODEL_CHECKPOINT = "/mnt/d/Project/SDANNCE-Models/4CAM-3D-2ETUP/Weights/checkpoint-SDANNCE-best.pth"
COM_CHECKPOINT = "/mnt/d/Project/SDANNCE-Models/4CAM-3D-2ETUP/Weights/checkpoint-COM-best-epoch1000.pth"

CONFIG = '/home/bezver/sdannce/configs/custom_sdannce_config.yaml'
COM_CONFIG = '/home/bezver/sdannce/configs/custom_com_config.yaml'

N_ANIMALS = 1
INFERENCE_ALL = False
STEPSIZE = 2000

# Only used when INFERENCE_ALL is FALSE
STARTFRAME = 24000
ENDFRAME = 28000

In [3]:
def generate_io_yaml(
    exproot=EXPROOT,
    n_animals=N_ANIMALS
):
    os.makedirs(exproot, exist_ok=True)

    config = {
        'com_train_dir': './COM/train00',
        'com_predict_dir': './COM/predict00',
        'com_exp': None,
        'exp': None,
        'dannce_train_dir': './DANNCE/train00/',
        'dannce_predict_dir': './DANNCE/predict00/',
        'com_file': './COM/predict00/com3d.mat',
        'use_npy': True,
        'rand_view_replace': True,
        'n_rand_views': 4,
        'n_instances': n_animals,
        'mirror_augmentation': False,
        'augment_hue': False,
        'augment_brightness': False
    }
    
    output_file = os.path.join(exproot, 'io.yaml')
    with open(output_file, 'w') as f:
        yaml.dump(config, f, default_flow_style=False, sort_keys=False)
    
    print(f"YAML config saved to: {output_file}")

In [4]:
def predict_com(
    exproot=EXPROOT,
    n_animals=N_ANIMALS,
    com_checkpoint=COM_CHECKPOINT,
    config=COM_CONFIG
):
    # parameter arguments
    arguments = {
            "com-predict-weights": com_checkpoint,
            "com-predict-dir": "./COM/predict00",
            "batch-size": 8,
            "n-instances": n_animals,
    }
    
    os.chdir(exproot)

    cmds = ['com',RUNNING_MODE, "com", config]

    for k, v in arguments.items():
        cmds += [f"--{k}", str(v)]

    sys.argv = cmds
    args = get_parser()
    params = build_clarg_params(
        args,
        dannce_net=(args.mode == "dannce") | (args.mode == "sdannce"),
        prediction=(args.command == "predict"),
    )
    com_predict(params)

    torch.cuda.empty_cache()

In [5]:
def launch_inference(
    exproot=EXPROOT,
    n_animals=N_ANIMALS,
    config=CONFIG,
    model_checkpoint=MODEL_CHECKPOINT,
    com_file="./COM/predict00/com3d.mat",
    dannce_predict_dir="./DANNCE/predict00",
    max_num_samples=ENDFRAME,
    start_sample=STARTFRAME,
):
    # parameter arguments
    arguments = {
        "dannce-predict-model": model_checkpoint,
        "dannce-predict-dir": dannce_predict_dir,
        "com-file": com_file,
        "start-sample": start_sample,
        "max-num-samples": max_num_samples,
        "batch-size": 1,
        "n-instances": n_animals,
    }

    # DANNCE must run within the experiment directory
    os.chdir(exproot)

    # compose the DANNCE command
    cmds = ['dannce', RUNNING_MODE, MODEL_TYPE, config]
    # override default arguments if specified
    for k, v in arguments.items():
        cmds += [f"--{k}", str(v)]

    # set arguments and launch command
    sys.argv = cmds
    args = get_parser()
    params = build_clarg_params(
        args,
        dannce_net=(args.mode == "dannce") | (args.mode == "sdannce"),
        prediction=(args.command == "predict"),
    )
    sdannce_predict(params)

    # please manually clear the CUDA cache to avoid OOM
    torch.cuda.empty_cache()

    # visualize predictions
    video_path = visualize_pose_predictions(
        exproot=exproot,
        expfolder=args.dannce_predict_dir,
        datafile=f"save_data_AVG{start_sample}.mat",        
        n_frames=max_num_samples - start_sample,
        start_frame=start_sample,
        cameras="1,2,3,4",
        animal="rat16",
        n_animals=n_animals,
        zoom_in=True,
        zoom_window_size=80
    )
    return video_path

In [None]:
def segmented_and_call_inference(
    inference_all = INFERENCE_ALL,
    sync_path = os.path.join(EXPROOT,"sync_dannce.mat"),
    duration = ENDFRAME - STARTFRAME,
    step_size = STEPSIZE,
    start_frame = STARTFRAME,
):
    if not inference_all and duration <= step_size:
        video_path = launch_inference()
        return video_path
    else: # If inference_all is True or partial duration is larger than step size, perform segmented inference.
        if inference_all:
            start_frame = 0

        data_sync = load_sync(sync_path)

        # Get the total number of frames from the 'data_frame' field in the loaded sync data.
        total_frames = len(data_sync[0]['data_frame'][0]) if inference_all else duration
        print(f"Found {total_frames} frames to inference.")

        # Calculate segments
        num_segment = total_frames // step_size
        final_segment = total_frames % step_size
        if final_segment > 0:
            num_segment += 1
        else:
            final_segment = step_size
        
        video_path = []

        # Loop through each segment to perform inference
        for i in range(num_segment):
            start_frame_seg = i * step_size + start_frame
            end_frame_seg = (i+1) * step_size + start_frame if i != num_segment-1 else i * step_size + final_segment + start_frame

            # Progress bar
            percent = (i + 1) / num_segment * 100
            bar_length = 50
            filled_length = int(bar_length * (i + 1) // num_segment)
            bar = '█' * filled_length + '-' * (bar_length - filled_length)

            sys.stdout.write(f'\rProgress: |{bar}| {percent:.2f}% ({i + 1}/{num_segment} segments) -  ')
            sys.stdout.flush()
            
            # Call the inference function for the current segment
            video_path.append(launch_inference(max_num_samples=end_frame_seg,start_sample=start_frame_seg))

        sys.stdout.write('\nDone!\n')
        sys.stdout.flush()
        return video_path

In [7]:
if "io.yaml" not in os.listdir(EXPROOT):
    generate_io_yaml()

COM_dir = os.path.join(EXPROOT,"com","predict00")
os.makedirs(COM_dir,exist_ok=True)

if "com3d.mat" not in os.listdir(COM_dir):
    print(f"COM file not found in {COM_dir}, predicting COM now...")
    predict_com()

if "sync_dannce.mat" not in os.listdir(EXPROOT):
    sync_file = [f for f in os.listdir(EXPROOT) if f.endswith("_dannce.mat")]
    if sync_file:
        sync_path = os.path.join(EXPROOT, sync_file[0])
        video_path = segmented_and_call_inference(sync_path=sync_path)
        print(f"Inferenced video saved to {video_path}")
    else:
        print("Error: No sync file found in project folder.")
else:
    video_path = segmented_and_call_inference()
    print(f"Inferenced video saved to {video_path}")

Found 4000 frames to inference.
Progress: |█████████████████████████-------------------------| 50.00% (1/2 segments) -  

[32m2025-09-14 08:36:40.513[0m | [1mINFO    [0m | [36mdannce.cli[0m:[36mcombine[0m:[36m768[0m - [1mSetting com_train_dir to: ../weights[0m
[32m2025-09-14 08:36:40.513[0m | [1mINFO    [0m | [36mdannce.cli[0m:[36mcombine[0m:[36m768[0m - [1mSetting com_predict_dir to: ./COM/predict00[0m
[32m2025-09-14 08:36:40.513[0m | [1mINFO    [0m | [36mdannce.cli[0m:[36mcombine[0m:[36m768[0m - [1mSetting com_exp to: None[0m
[32m2025-09-14 08:36:40.513[0m | [1mINFO    [0m | [36mdannce.cli[0m:[36mcombine[0m:[36m768[0m - [1mSetting exp to: None[0m
[32m2025-09-14 08:36:40.514[0m | [1mINFO    [0m | [36mdannce.cli[0m:[36mcombine[0m:[36m768[0m - [1mSetting dannce_train_dir to: ./DANNCE/train00/[0m
[32m2025-09-14 08:36:40.514[0m | [1mINFO    [0m | [36mdannce.cli[0m:[36mcombine[0m:[36m768[0m - [1mSetting dannce_predict_dir to: ./DANNCE/predict00[0m
[32m2025-09-14 08:36:40.515[0m | [1mINFO    [0m | [36mdannce.cli[0m:[36mcombine[0m:

Saving checkpoint at 1000th batch


100%|██████████| 2000/2000 [01:55<00:00, 17.30it/s]
100%|██████████| 2000/2000 [04:31<00:00,  7.36it/s]


Visualization of n=2000 took 272.8024709224701 sec.
Progress: |██████████████████████████████████████████████████| 100.00% (2/2 segments) -  

[32m2025-09-14 08:43:27.468[0m | [1mINFO    [0m | [36mdannce.cli[0m:[36mcombine[0m:[36m768[0m - [1mSetting com_train_dir to: ../weights[0m
[32m2025-09-14 08:43:27.473[0m | [1mINFO    [0m | [36mdannce.cli[0m:[36mcombine[0m:[36m768[0m - [1mSetting com_predict_dir to: ./COM/predict00[0m
[32m2025-09-14 08:43:27.478[0m | [1mINFO    [0m | [36mdannce.cli[0m:[36mcombine[0m:[36m768[0m - [1mSetting com_exp to: None[0m
[32m2025-09-14 08:43:27.482[0m | [1mINFO    [0m | [36mdannce.cli[0m:[36mcombine[0m:[36m768[0m - [1mSetting exp to: None[0m
[32m2025-09-14 08:43:27.487[0m | [1mINFO    [0m | [36mdannce.cli[0m:[36mcombine[0m:[36m768[0m - [1mSetting dannce_train_dir to: ./DANNCE/train00/[0m
[32m2025-09-14 08:43:27.493[0m | [1mINFO    [0m | [36mdannce.cli[0m:[36mcombine[0m:[36m768[0m - [1mSetting dannce_predict_dir to: ./DANNCE/predict00[0m
[32m2025-09-14 08:43:27.499[0m | [1mINFO    [0m | [36mdannce.cli[0m:[36mcombine[0m:

Saving checkpoint at 3000th batch


100%|██████████| 2000/2000 [01:55<00:00, 17.36it/s]
100%|██████████| 2000/2000 [04:25<00:00,  7.52it/s]


Visualization of n=2000 took 267.0325131416321 sec.

Done!
Inferenced video saved to ['/mnt/d/Project/SDANNCE-Models/555-5CAM/SD-20250605B/./DANNCE/predict00/vis/frame0-2000_Camera1,2,3,4.mp4', '/mnt/d/Project/SDANNCE-Models/555-5CAM/SD-20250605B/./DANNCE/predict00/vis/frame2000-4000_Camera1,2,3,4.mp4']
