In [None]:
import os
import sys

import torch

from IPython.display import Video

from dannce.cli import get_parser, build_clarg_params 
from dannce.interface import dannce_predict, sdannce_predict 
from dannce.engine.utils.vis import visualize_pose_predictions
from dannce.engine.data.io import load_sync

In [None]:
cwd = '/home/bezver/sdannce/demo'

my_os = sys.platform

In [None]:
RUNNING_MODE = "predict"
MODEL_TYPE = "sdannce"

EXPROOT = '/mnt/d/Project/SDANNCE-Models/555-5CAM/SD-20250605B'
MODEL_CHECKPOINT = "/mnt/d/Project/SDANNCE-Models/4CAM-3D-2ETUP/Weights/checkpoint-SDANNCE-best.pth"

if my_os == "win32":
    cwd = cwd.replace("/home/bezver","D:/Repository")
    EXPROOT = EXPROOT.replace("/mnt/d","D:")
    MODEL_CHECKPOINT = MODEL_CHECKPOINT.replace("/mnt/d","D:")
CONFIG = '../configs/custom_sdannce_config.yaml'

N_ANIMALS = 1
INFERENCE_ALL = True

STARTFRAME = 25000
ENDFRAME = 26000

In [None]:
def launch_inference(
    exproot=EXPROOT,
    n_animals=N_ANIMALS,
    config=CONFIG,
    model_checkpoint=MODEL_CHECKPOINT,
    com_file="./COM/predict00/com3d.mat",
    dannce_predict_dir="./DANNCE/predict00",
    max_num_samples=26000,
    start_sample=25000,
):
    exproot = EXPROOT
    config = os.path.join(cwd, config)
    model_checkpoint = os.path.join(cwd, model_checkpoint)
    print(model_checkpoint)

    # parameter arguments
    arguments = {
        "dannce-predict-model": model_checkpoint,
        "dannce-predict-dir": dannce_predict_dir,
        "com-file": com_file,
        "start-sample": start_sample,
        "max-num-samples": max_num_samples,
        "batch-size": 1,
        "n-instances": n_animals,
    }

    # DANNCE must run within the experiment directory
    os.chdir(exproot)

    # compose the DANNCE command
    cmds = ['dannce', RUNNING_MODE, MODEL_TYPE, config]
    # override default arguments if specified
    for k, v in arguments.items():
        cmds += [f"--{k}", str(v)]

    # set arguments and launch command
    sys.argv = cmds
    args = get_parser()
    params = build_clarg_params(
        args,
        dannce_net=(args.mode == "dannce") | (args.mode == "sdannce"),
        prediction=(args.command == "predict"),
    )
    sdannce_predict(params)

    # please manually clear the CUDA cache to avoid OOM
    torch.cuda.empty_cache()

    # visualize predictions
    video_path = visualize_pose_predictions(
        exproot=exproot,
        expfolder=args.dannce_predict_dir,
        datafile=f"save_data_AVG{start_sample}.mat",        
        n_frames=max_num_samples - start_sample,
        start_frame=start_sample,
        cameras="1,2,3,4",
        animal="rat16",
        n_animals=n_animals,
        zoom_in=True,
        zoom_window_size=80
    )
    return video_path

In [None]:
def segmented_and_call_inference(
    inference_all = INFERENCE_ALL,
    sync_path = os.path.join(EXPROOT,"sync_dannce.mat"),
    step_size = 2000,
):
    if not inference_all: # If inference_all is False, perform a single inference run without segmentation.
        video_path = launch_inference()
        return video_path
    else: # If inference_all is True, perform segmented inference.
        data_sync = load_sync(sync_path)

        # Get the total number of frames from the 'data_frame' field in the loaded sync data.
        total_frames = len(data_sync[0]['data_frame'][0])
        print(f"Found {total_frames} frames in {EXPROOT} videos")

        # Calculate segments
        num_segment = total_frames // step_size
        final_segment = total_frames % step_size
        if final_segment > 0:
            num_segment += 1
        else:
            final_segment = step_size
        
        video_path = []

        # Loop through each segment to perform inference
        for i in range(num_segment):
            start_frame = i * step_size
            end_frame = (i+1) * step_size if i != num_segment-1 else i * step_size + final_segment

            # Progress bar
            percent = (i + 1) / num_segment * 100
            bar_length = 50
            filled_length = int(bar_length * (i + 1) // num_segment)
            bar = '█' * filled_length + '-' * (bar_length - filled_length)

            sys.stdout.write(f'\rProgress: |{bar}| {percent:.2f}% ({i + 1}/{num_segment} segments) -  ')
            sys.stdout.flush()
            
            # Call the inference function for the current segment
            video_path.append(launch_inference(max_num_samples=end_frame,start_sample=start_frame))

        sys.stdout.write('\nDone!\n')
        sys.stdout.flush()
        return video_path

In [None]:
video_path = segmented_and_call_inference()

if isinstance(video_path, list):
    print("\nMultiple video files generated. Please choose a video to display:")
    # Display the list of videos with numbers
    for i, vid in enumerate(video_path):
        print(f"{i + 1}. {os.path.basename(vid)}")
    print("0. Exit without displaying/opening anything.")

    # Get user input
    while True:
        try:
            choice = input(f"Enter your choice (1-{len(video_path)}, 0 to exit): ")
            choice_int = int(choice)

            if choice_int == 0:
                print("Exiting.")
                break
            elif 1 <= choice_int <= len(video_path):
                selected_vid = video_path[choice_int - 1] # Adjust for 0-based indexing
                print(f"\nAttempting to display/open: {os.path.basename(selected_vid)}")
                Video(selected_vid, embed=True, width=600, height=400)
                break # Exit loop after displaying/opening
            else:
                print("Invalid choice. Please enter a number within the valid range.")
        except ValueError:
            print("Invalid input. Please enter a number.")
        except Exception as e:
            print(f"An unexpected error occurred: {e}")
            break

else:
    Video(video_path, embed=True, width=600, height=400)