In [None]:
# Visualization after Lightning Pose training

# Reference: litpose_training_demo.ipynb

In [1]:
import hydra
from omegaconf import DictConfig, OmegaConf
import os
import lightning.pytorch as pl

In [2]:
litpose_dir = r"/home/yiting/Documents/GitHub/lightning-pose"
config_path = r"/home/yiting/Documents/LP_projects/LP_240726"
config_name= "config_hand-1cam.yaml"
output_dir = r"/home/yiting/Documents/GitHub/lightning-pose/outputs"
model_dir = r"2024-07-30/13-41-23"
# Load hydra configuration file
cfg = OmegaConf.load(os.path.join(config_path, config_name))


## Predictions/diagnostics for labeled data (FiftyOne)

### Creating FiftyOne.Dataset for predictions

In [None]:
# Override the default configs here:
cfg.eval.hydra_paths=[os.path.join(output_dir, model_dir)] # you can add multiple output_directory2, output_directory3 to compare 
cfg.eval.fiftyone.dataset_name="LP_240726"
cfg.eval.fiftyone.model_display_names=["single-view"]

In [None]:
import fiftyone as fo
from lightning_pose.utils.fiftyone import check_dataset, FiftyOneImagePlotter

# initializes everything
fo_plotting_instance = FiftyOneImagePlotter(cfg=cfg)

# internally loops over models
dataset = fo_plotting_instance.create_dataset()

# create metadata and print if there are problems
check_dataset(dataset)
fo_plotting_instance.dataset_info_print() 

In [None]:
# Launch the FiftyOne UI
fo.launch_app()

### Launch previously created FiftyOne.Dataset objects

In [None]:
import fiftyone as fo
dataset = fo.load_dataset("LP_240726")
session = fo.launch_app(dataset)

In [None]:
# List dataset names 
fo.list_datasets()

## Plot video predictions and unsupervised losses

### Load data

In [3]:
from collections import defaultdict
import pandas as pd
from pathlib import Path

from lightning_pose.apps.utils import build_precomputed_metrics_df, get_col_names, concat_dfs
from lightning_pose.apps.utils import update_vid_metric_files_list
from lightning_pose.apps.utils import get_model_folders, get_model_folders_vis
from lightning_pose.apps.plots import plot_precomputed_traces



In [6]:
# select which model(s) to use
model_folders = get_model_folders(output_dir)

# get the last two levels of each path to be presented to user
model_names = get_model_folders_vis(model_folders)

# get prediction files for each model
prediction_files = update_vid_metric_files_list(video="2023-11-21T10-29-36_camTo", model_preds_folders=model_folders)

# load data
dframes_metrics = defaultdict(dict)
dframes_traces = {}
for p, model_pred_files in enumerate(prediction_files):
    model_name = model_names[p]
    model_folder = model_folders[p]
    for model_pred_file in model_pred_files:
        model_pred_file_path = os.path.join(model_folder, "video_preds", model_pred_file)
        if not isinstance(model_pred_file, Path):
            model_pred_file.seek(0)  # reset buffer after reading
        if "pca" in str(model_pred_file) or "temporal" in str(model_pred_file) or "pixel" in str(model_pred_file):
            dframe = pd.read_csv(model_pred_file_path, index_col=None)
            dframes_metrics[model_name][str(model_pred_file)] = dframe
        else:
            dframe = pd.read_csv(model_pred_file_path, header=[1, 2], index_col=0)
            dframes_traces[model_name] = dframe
            dframes_metrics[model_name]["confidence"] = dframe
        data_types = dframe.iloc[:, -1].unique()

# compute metrics
# concat dataframes, collapsing hierarchy and making df fatter.
df_concat, keypoint_names = concat_dfs(dframes_traces)
df_metrics = build_precomputed_metrics_df(
    dframes=dframes_metrics, keypoint_names=keypoint_names)
metric_options = list(df_metrics.keys())

# print keypoint names; select one of these to plot below
print(keypoint_names)

# NOTE: you can ignore all errors and warnings of the type:
#    No runtime found, using MemoryCacheStorageManager

['Index_Tip', 'Index_DIP', 'Index_PIP', 'Middle_Tip', 'Middle_DIP', 'Middle_PIP', 'Ring_Tip', 'Ring_DIP', 'Ring_PIP', 'Small_Tip', 'Small_DIP', 'Small_PIP']


### Plot video traces

In [None]:
# rerun this cell each time you want to update the keypoint

from IPython.display import display, clear_output
import ipywidgets as widgets

def on_change(change):
    if change["type"] == "change" and change["name"] == "value":
        clear_output()
        cols = get_col_names(change["new"], "x", dframes_metrics.keys())
        fig_traces = plot_precomputed_traces(df_metrics, df_concat, cols)
        fig_traces.show()

# create a Dropdown widget
dropdown = widgets.Dropdown(
    options=keypoint_names,
    value=None,  # Set the default selected value
    description="Select keypoint:",
)

# update plot upon change
dropdown.observe(on_change)

# display widget
display(dropdown)