#### create training dataset from aeon raw data

In [None]:
# exp 0.2
# solo BAA-1101818 (tail+ear painted)
# 2022-06-23 08:39:04.261089801	BAA-1101818	26.4	Enter 
# 2022-06-23 11:14:46.121759892	BAA-1101818	28.0	Exit
# 2022-06-24 09:32:37.183360100	BAA-1101818	26.9	Enter (ear repainted)
# 2022-06-24 12:29:54.365859985	BAA-1101818	27.8	Exit

# solo BAA-1101819
# 2022-06-21 13:28:10.593659878	BAA-1101819	25.4	Enter
# 2022-06-21 16:34:29.241280079	BAA-1101819	26.4	Exit

# multianimal BAA-1101818 and BAA-1101819
# 2022-06-22 10:40:00	        BAA-1101819	24.9	Enter
# 2022-06-22 13:29:04.050240040	BAA-1101818	28.4	Exit 
# 2022-06-23 11:24:23.876420021	BAA-1101819	25.6	Enter
# 2022-06-23 14:19:39.241819859	BAA-1101818	26.4	Exit

# exp 0.3
# multianimal BAA-1102505 and BAA-1102506
# 1904-01-03 22:03:16.696000000	BAA-1102505	20	Enter
# 1904-01-03 22:03:30.928000000	BAA-1102506	20	Enter
# 1904-01-03 23:57:31.952000000	BAA-1102505	20	Exit
# 1904-01-03 23:57:37.824000000	BAA-1102506	20	Exit



# AEON
#### create training dataset from aeon raw data

In [1]:
import pandas as pd
import aeon.io.api as aeon
import matplotlib.pyplot as plt
import numpy.typing as npt
import cv2

from aeon.schema.dataset import exp02
from aeon.analysis.utils import *
from aeon.io.video import frames, export
from dotmap import DotMap
from pathlib import Path
from scipy import stats
from matplotlib import path


def get_raw_tracking_data(
    root: str,
    subj_id: str,
    start: "pd.Timestamp",
    end: "pd.Timestamp",
    source_frame: "aeon.io.reader.Video" = exp02.CameraTop.Video,
):
    """
    Retrieves pos tracking and video data and assigns subject ID.

    :param root: The root path, or prioritised sequence of paths, where epoch data is stored.
    :param subj_id: The subject ID string to be assigned.
    :param start: The left bound of the time range to extract.
    :param end: The right bound of the time range to extract.
    :returns: A pandas data frame containing pos tracking and video data, and subject ID.
    """

    subj_video = aeon.load(root, source_frame, start=start, end=end)
    subj_pos = aeon.load(root, exp02.CameraTop.Position, start=start, end=end)
    subj_data = pd.merge_asof(
        subj_video,
        subj_pos,
        left_index=True,
        right_index=True,
        direction="nearest",
        tolerance=pd.Timedelta("1ms"),
    )[["x", "y", "id", "area", "_frame", "_path"]]
    subj_data.dropna(inplace=True)
    subj_data["id"] = subj_id
    return subj_data


def sample_n_from_bins(subj_data: "pd.DataFrame", n_samples: int = 1, n_bins: int = 50, range: "npt.ArrayLike" =[[0, 1440], [0, 1080]]):
    """
    Uniformly samples n number of data from x number of bins.

    :param subj_data: A pandas data frame containing pos tracking and video data, and subject ID.
    :param optional n_samples: The number of samples to take from each bin.
    :param optional n_bins: The number of bins to use for sampling.
    :param optional range: The leftmost and rightmost edges of the bins along each dimension (if not specified explicitly in the bins parameters): [[xmin, xmax], [ymin, ymax]]. All values outside of this range will be considered outliers and not tallied in the histogram.
    :returns: A pandas data frame containing uniformly-sampled pos tracking and video data, and subject ID.
    """

    hist_data = stats.binned_statistic_2d(
        subj_data.x,
        subj_data.y,
        values=subj_data,
        statistic="count",
        bins=n_bins,
        range=range,
    )
    subj_data["bin"] = hist_data.binnumber
    sampled_data = (
        subj_data.groupby(["bin"]).sample(n=n_samples, replace=True).drop_duplicates()
    )
    return sampled_data


def create_session_dataset(
    session: dict,
    subj_ids: list = None,
    plot_dist: bool = True,
):
    """
    Creates a dataset for a given session dict

    :param session: A dictionary containing the root path, subject IDs, and their start and end times.
    :param optional subj_ids: A list of subject ids to extract. If None, all subjects are extracted.
    :param optional plot_dist: Whether to plot the 1d and 2d histograms of x, y pos tracking for each subject.
    :returns: A pandas data frame containing uniformly-sampled pos tracking and video data, and subject ID.
    """
    all_subj_data = pd.DataFrame()
    if not subj_ids:
        subj_ids = session["subjects"].keys()
    for subj in subj_ids:
        subj_dict = {
            "id": subj,
            "root": session["root"],
            "start": session["subjects"][subj]["start"],
            "end": session["subjects"][subj]["end"],
        }
        subj_data = (
            create_subject_dataset(
                subj_dict,
                min_area=500,
                n_samples=4,
                n_bins=10,
            )  # sample fewer points for manual annotation
            if "multi_" in subj
            else create_subject_dataset(
                subj_dict,
            )
        )
        all_subj_data = pd.concat([all_subj_data, subj_data])
    if plot_dist:
        fig = plot_position_histograms(all_subj_data)
        fig.show()
    return all_subj_data


def plot_position_histograms(data: "pd.DataFrame", n_bins: int = 50):
    """
    Plots the 1d and 2d histograms of x, y pos tracking for each subject in a given data frame.

    :param data: A pandas data frame containing x, y pos tracking and subject ID(s).
    :returns: A plot containing 1d and 2d histograms of x, y pos tracking for each subject.
    """
    subj_ids = data["id"].unique()
    fig, ax = plt.subplots(2, len(subj_ids))
    n_bins = 50
    if len(subj_ids) == 1:
        data[["x", "y"]].plot.hist(bins=n_bins, alpha=0.5, ax=ax[0], title=subj_ids[0])
        ax[1].hist2d(
            data.x,
            data.y,
            bins=(n_bins, n_bins),
            cmap=plt.cm.jet,
        )
    else:
        for i, subj_id in enumerate(subj_ids):
            subj_data = data[data["id"] == subj_id]
            subj_data[["x", "y"]].plot.hist(
                bins=n_bins, alpha=0.5, ax=ax[0, i], title=subj_id
            )
            ax[1, i].hist2d(
                subj_data.x, subj_data.y, bins=(n_bins, n_bins), cmap=plt.cm.jet
            )

    plt.tight_layout()
    return fig


def create_subject_dataset(
    subject: dict,
    min_area: float = None,
    n_samples: int = 1,
    n_bins: int = 50,
):
    """
    Creates a dataset for a given subject dict

    :param subject: A dictionary containing the root path, subject ID, and their start and end times.
    :param optional min_area: The minimum area of the subject to be included in the dataset.
    :param optional n_samples: The number of samples to take from each bin.
    :param optional n_bins: The number of bins to use for sampling.
    :returns: A pandas data frame containing uniformly-sampled pos tracking and video data, and subject ID.
    """
    subj_data = get_raw_tracking_data(
        subject["root"],
        subject["id"],
        subject["start"],
        subject["end"],
    )
    if min_area:
        subj_data = subj_data[subj_data.area >= min_area] # when animals fuse
    subj_data = sample_n_from_bins(subj_data, n_samples=n_samples, n_bins=n_bins)
    return subj_data


def extract_ma_videos(session: dict, dest_root: str = "", subj_ids: list = None):
    """
    Extracts multi-animal videos for a given session.

    :param session: A session dictionary.
    :param dest_root: A string containing the destination root path. If empty, the current working directory is used.
    :param optional subj_ids: A list of subject ids to extract. If None, all subjects are extracted.
    :returns list: A list of extracted video paths.
    """

    paths = []
    if not subj_ids:
        subj_ids = session["subjects"].keys()
    for subj in subj_ids:
        if "multi_" in subj:
            subj_data = get_raw_tracking_data(
                session["root"],
                "multi_animal",
                session["subjects"][subj]["start"],
                session["subjects"][subj]["end"],
                source_frame=exp02.CameraTop.Video,
            )
            filepath = f'{dest_root}{Path(subj_data.iloc[0]["_path"]).stem.split("-00-00")[0]}-{session["subjects"][subj]["end"].hour}-00_pattern_tattoo.avi'
            paths.append(filepath)
            export(frames(subj_data), filepath , fps=50)
    return paths


def create_ma_dataset(session: dict, subj_ids: list = None):
    """
    Creates a multi-animal-frames-only dataset for a given session dict.

    :param session: A session dictionary.
    :param optional subj_ids: A list of subject ids to extract. If None, all subjects with prefix "multi_" are extracted.
    :returns list: A pandas data frame containing multi-animal pos tracking and video data, and subject ID.
    """
    all_subj_data = pd.DataFrame()
    if not subj_ids:
        subj_ids = session["subjects"].keys()
    for subj in subj_ids:
        if "multi_" in subj:
            subj_data = get_raw_tracking_data(
                session["root"],
                subj,
                session["subjects"][subj]["start"],
                session["subjects"][subj]["end"],
                source_frame=exp02.CameraTop.Video,
            )
            all_subj_data = pd.concat([all_subj_data, subj_data])
            
    return all_subj_data

In [2]:
# dictionaries for each session
aeon2 = {
    "root": "/ceph/aeon/aeon/data/raw/AEON2/experiment0.2/",
    "subjects": {
        "BAA-1101818": {
            "start": pd.Timestamp("2022-06-23 08:39:04.261089801"),
            "end": pd.Timestamp("2022-06-23 11:14:46.121759892"),
        },
        "BAA-1101819": {
            "start": pd.Timestamp("2022-06-21 13:28:10.593659878"),
            "end": pd.Timestamp("2022-06-21 16:34:29.241280079"),
        },
    },
    "session": "BAA-1101818_819",
}

aeon3 = {
    "root": "Z:/aeon/data/raw/AEON3/presocial0.1",
    "subjects": {
        "AEON3_NTP": {
            "start": pd.Timestamp("2023-03-03 16:40:00"),
            "end": pd.Timestamp("2023-03-03 16:55:00"),
        },
        "AEON3_TP": {
            "start": pd.Timestamp("2023-03-03 17:01:00"),
            "end": pd.Timestamp("2023-03-03 17:22:00"),
        },
        "multi_animal": {
            "start": pd.Timestamp("2023-03-03 17:23:00"),
            "end": pd.Timestamp("2023-03-03 17:43:00"),
        },
    },
    "session": "AEON3_NTP_TP_local",
}

aeon3b = {
    "root": "Z:/aeon/data/raw/AEON3/presocial0.1",
    "subjects": {
        "AEON3B_NTP": {
            "start": pd.Timestamp("2023-03-16 15:05:00"),
            "end": pd.Timestamp("2023-03-16 15:44:00"),
        },
        "AEON3B_TP": {
            "start": pd.Timestamp("2023-03-16 16:00:00"),
            "end": pd.Timestamp("2023-03-16 16:36:00"),
        },
        "multi_animal": {
            "start": pd.Timestamp("2023-03-16 16:37:00"),
            "end": pd.Timestamp("2023-03-16 17:19:00"),
        },
    },
    "session": "AEON3B_NTP_TP_local",
}

aeon3b_pattern = {
    "root": "Z:/aeon/data/raw/AEON3/presocial0.1/",
    "subjects": {
        "AEON3B_NTP": {
            "start": pd.Timestamp("2023-03-29 09:40:40"),
            "end": pd.Timestamp("2023-03-29 10:07:00"),
        },
        "AEON3B_TP1": {
            "start": pd.Timestamp("2023-03-29 10:10:44"),
            "end": pd.Timestamp("2023-03-29 10:46:30"),
        },
        "AEON3B_TP2": {
            "start": pd.Timestamp("2023-03-29 10:50:10"),
            "end": pd.Timestamp("2023-03-29 11:22:00"),
        },
        "multi_animal": {
            "start": pd.Timestamp("2023-03-29 11:23:30"),
            "end": pd.Timestamp("2023-03-29 12:31:00"),
        }
    },
    "session": "AEON3B_pattern_local",
}

aeon3b_pattern_nest = {
    "root": "Z:/aeon/data/raw/AEON3/presocial0.1/",
    "subjects": {
        "AEON3B_NTP": {
            "start": pd.Timestamp("2023-03-29 09:40:40"),
            "end": pd.Timestamp("2023-03-29 10:07:00"),
        },
        "AEON3B_TP1": {
            "start": pd.Timestamp("2023-03-29 10:10:44"),
            "end": pd.Timestamp("2023-03-29 10:46:30"),
        },
        "AEON3B_TP2": {
            "start": pd.Timestamp("2023-03-29 10:50:10"),
            "end": pd.Timestamp("2023-03-29 11:22:00"),
        },
        "multi_animal": {
            "start": pd.Timestamp("2023-03-29 11:23:30"),
            "end": pd.Timestamp("2023-03-29 12:31:00"),
        }
    },
    "nest_coords": [
        (1223,475),
        (1223,588),
        (1352,479),
        (1352,581),
    ],
    "session": "AEON3B_pattern_local_nest",
}

aeon3b_pattern_tattoo = {
    "root": "Z:/aeon/data/raw/AEON3/presocial0.1/",
    "subjects": {
        "AEON3B_NTP": { # BAA-1103352
            "start": pd.Timestamp("2023-06-16 16:13:30"),
            "end": pd.Timestamp("2023-06-16 17:00"),
        },
        "AEON3B_TP1": { # BAA-1103353
            "start": pd.Timestamp("2023-06-16 17:04"),
            "end": pd.Timestamp("2023-06-16 18:04"),
        },
        "AEON3B_TP2": { # BAA-1103351
            "start": pd.Timestamp("2023-06-16 18:08:19"),
            "end": pd.Timestamp("2023-06-16 18:38:52"),
        },
        "multi_animal": {
            "start": pd.Timestamp("2023-06-16 18:41:30"),
            "end": pd.Timestamp("2023-06-16 19:12:00"),
        }
    },
    "session": "AEON3B_pattern_tattoo",
}

aeon3b_pattern_tattoo2 = {
    "root": "Z:/aeon/data/raw/AEON3/multianimal-test/",
    "subjects": {
        "AEON3B_NTP": { # BAA-1103352
            "start": pd.Timestamp("2023-07-04 16:28:10"),
            "end": pd.Timestamp("2023-07-04 17:01:57"),
        },
        "AEON3B_TP1": { # CAA-1120139
            "start": pd.Timestamp("2023-07-04 17:06:22"),
            "end": pd.Timestamp("2023-07-04 17:34:52"),
        },
        "AEON3B_TP2": { # BAA-1103351
            "start": pd.Timestamp("2023-07-04 15:36:19"),
            "end": pd.Timestamp("2023-07-04 16:16:19"),
        },
        "multi_animal": {
            "start": pd.Timestamp("2023-07-04 17:37:08"),
            "end": pd.Timestamp("2023-07-04 18:10:13"),
        }
    },
    "session": "AEON3B_pattern_tattoo2",
}

aeon3b_pattern_tattoo3 = {
    "root": "Z:/aeon/data/raw/AEON3/multianimal-test/",
    "subjects": {
        "AEON3B_NTP": { # BAA-1103352
            "start": pd.Timestamp("2023-07-04 16:28:10"),
            "end": pd.Timestamp("2023-07-04 17:01:57"),
        },
        "AEON3B_TP1": { # BAA-1103369
            "start": pd.Timestamp("2023-07-28 14:24:00"),
            "end": pd.Timestamp("2023-07-28 15:15:00"),
        },
        "AEON3B_TP2": { # BAA-1103351
            "start": pd.Timestamp("2023-07-04 15:36:19"),
            "end": pd.Timestamp("2023-07-04 16:16:19"),
        },
        "multi_animal": {
            "start": pd.Timestamp("2023-07-28 15:21:00"),
            "end": pd.Timestamp("2023-07-28 16:18:00"),
        },
        "multi_animal2": {
            "start": pd.Timestamp("2023-08-01 10:19:00"),
            "end": pd.Timestamp("2023-08-01 11:07:00"),
        },
        "multi_animal3": {
            "start": pd.Timestamp("2023-08-03 10:40:00"),
            "end": pd.Timestamp("2023-08-03 11:28:00"),
        },
        "multi_animal4": {
            "start": pd.Timestamp("2023-08-08 14:14:10.119999886"),
            "end": pd.Timestamp("2023-08-08 15:01:32.288000107"),
        }, 
        "multi_animal5": {
            "start": pd.Timestamp("2023-08-11 11:54:01.340000153"),
            "end": pd.Timestamp("2023-08-11 13:20:36.303999901"),
        }, 
    },
    "session": "AEON3B_pattern_tattoo3",
}

In [None]:
# Extract start end times for MA session
# get subject enter, exit times
subject_visits = visits(aeon.load(aeon3b_pattern_tattoo3["root"], exp02.ExperimentalMetadata.SubjectState, start=pd.Timestamp("2023-08-11 00:00:00"), end=pd.Timestamp("2023-08-12 00:00:00")))
# get maintenance, experiment times
env_states = aeon.load(aeon3b_pattern_tattoo3["root"], exp02.ExperimentalMetadata.EnvironmentState, start=pd.Timestamp("2023-08-11 00:00:00"), end=pd.Timestamp("2023-08-12 00:00:00"))
env_states = env_states[~env_states.index.duplicated(keep="first")]
# get env_states timestamps after largest enter and before smallest exit
maintenance_states = env_states[env_states["state"] == "Maintenance"]
experiment_states = env_states[env_states["state"] == "Experiment"]
start_time = pd.Timestamp(experiment_states.index[experiment_states.index.get_indexer([subject_visits["enter"].max()], method='bfill')][0])
end_time = pd.Timestamp(maintenance_states.index[maintenance_states.index.get_indexer([subject_visits["exit"].min()], method='ffill')][0])
print(start_time, end_time)

In [None]:
# extract single + multi-animal frames as csv
all_subj_data = create_session_dataset(aeon3b_pattern_tattoo3)
all_subj_data.to_csv(f'{aeon3b_pattern_tattoo3["session"]}.csv')

In [None]:
# extract multi-animal (only) frames as csv
ma_subj_data = create_ma_dataset(aeon3b_pattern_tattoo3)
ma_subj_data.to_csv(f'{aeon3b_pattern_tattoo3["session"]}_ma.csv')

# SLEAP

In [1]:
import pandas as pd
import random
import sleap
import numpy as np

from pathlib import Path

from sleap.io.dataset import Labels
from sleap.io.video import Video
from sleap.io.pathutils import fix_path_separator
from sleap.gui.suggestions import VideoFrameSuggestions
from sleap.nn.config import *
from sleap.nn.inference import main as sleap_track
from sleap.nn.inference import TopDownMultiClassPredictor, Predictor, TopDownPredictor, Tracker

In [1]:
def generate_slp_dataset(session: dict, subj_data: "pd.DataFrame", skeleton: "sleap.Skeleton"):
    """
    Generates .slp dataset for a given session dict.

    :param session: A dictionary containing the root path, subject IDs, and their start and end times.
    :param subj_data: A pandas DataFrame containing the labeled data for a given session.
    :param skeleton: A sleap Skeleton object.
    :returns: A sleap Labels object containing labeled frames.
    """
    
    # create tracks dictionary from subj_ids that are not multi_animal
    tracks_dict = {
        subj: sleap.Track(spawned_on=0, name=subj)
        for subj in session["subjects"].keys()
        if "multi_" not in subj
    }

    lfs = []

    # create video dictionary from new labels
    videos_dict = {
        video: sleap.Video.from_filename(video)
        for video in subj_data._path.unique()
    }

    for _, row in subj_data.drop_duplicates(subset=["_path", "_frame"]).iterrows():
        instances = []
        if "multi_" in row.id:
            # duplicate instance for each track
            for track in tracks_dict.keys():
                instances.append(
                    sleap.Instance(
                        skeleton=skeleton,
                        track=tracks_dict[track],
                        points={"centroid": sleap.instance.Point(row.x, row.y)},
                    )
                )
        else:
            # create a new instance for each row
            instances.append(
                sleap.Instance(
                    skeleton=skeleton,
                    track=tracks_dict[row.id],
                    points={"centroid": sleap.instance.Point(row.x, row.y)},
                )
            )
        # create a new labeled frame
        lf = sleap.instance.LabeledFrame(
            video=videos_dict[row._path],
            frame_idx=row._frame,
            instances=instances,
        )
        lfs.append(lf)

    return sleap.Labels(labeled_frames=lfs)


def update_slp_video_paths(labels: "sleap.Labels", old_path: str, new_path: str):
    """
    Updates video paths in a SLEAP labels object (e.g., to move training from local to remote machine).
    
    :param labels: SLEAP Labels object.
    :param old_path: Old path to video files.
    :param new_path: New path to video files.
    :returns: SLEAP Labels object with updated video paths.
    """

    videos =  [sleap.Video.from_filename(fix_path_separator(vid.filename).replace(old_path, new_path)) for vid in labels.videos]

    lfs = []
    for lf in labels.labeled_frames:
        lf = sleap.instance.LabeledFrame(
                video=videos[labels.videos.index(lf.video)],
                frame_idx=lf.frame_idx,
                instances=lf.instances,
            )
        lfs.append(lf)
    
    return sleap.Labels(labeled_frames=lfs)

In [None]:
# create new skeleton
skeleton = sleap.Skeleton()
skeleton.add_node("centroid")

In [None]:
# generate slp training dataset for all subjects
subj_data = pd.read_csv(f'{aeon3b_pattern_tattoo3["session"]}.csv')
labels = generate_slp_dataset(aeon3b_pattern_tattoo3, subj_data, skeleton)
sleap.Labels.save_file(labels, f'{aeon3b_pattern_tattoo3["session"]}.slp')

In [None]:
# generate multi-animal "user-labeled" slp datasets for each session for prediction/evaluation
subj_data = pd.read_csv(f'{aeon3b_pattern_tattoo3["session"]}_ma.csv')
subj_ids = [subj for subj in aeon3b_pattern_tattoo3["subjects"].keys() if "multi_" in subj]
for subj_id in subj_ids:
    labels = generate_slp_dataset(aeon3b_pattern_tattoo3, subj_data[subj_data["id"]==subj_id], skeleton)
    sleap.Labels.save_file(labels, f'{aeon3b_pattern_tattoo3["session"]}_{subj_id}.slp')

#### train

In [2]:
# set initial parameters
subj_id = f'{aeon3b_pattern_tattoo3["session"]}'
run_name_centroid = f'{subj_id}_topdown_top.centroid'
run_name_centered_instance = f'{subj_id}_topdown_top.centered_instance_multiclass'
root = "/ceph/aeon/aeon/code/scratchpad/sleap/tail_pattern/"
runs_folder = root + "models/"
predictions_folder = root + "predictions/"
groundtruth_folder = root + "groundtruth/"

try:
    skeleton
except NameError:
    # create new skeleton
    skeleton = sleap.Skeleton()
    skeleton.add_node("centroid")

In [None]:
# update local video paths to ceph
new_labels = update_slp_video_paths(
    labels=sleap.load_file(f'{subj_id}.slp'), 
    old_path="Z:", 
    new_path="/ceph/aeon")
sleap.Labels.save_file(new_labels,  f'{root}{subj_id}.slp')

In [None]:
# split labels into train/val/test
labels = sleap.load_file(f'{root}{subj_id}.slp')

# generate a 0.8/0.1/0.1 train/val/test split
labels_train, labels_val_test = labels.split(n=0.8) 
labels_val, labels_test = labels_val_test.split(n=0.5)

# Save with images
labels_train.save(f'{root}{subj_id}.train.pkg.slp')#, with_images=True)
labels_val.save(f'{root}{subj_id}.val.pkg.slp')#, with_images=True)
labels_test.save(f'{root}{subj_id}.test.pkg.slp')#, with_images=True)

In [None]:
# centroid model
# initalise default training job config
cfg = TrainingJobConfig()
cfg.data.labels.training_labels = f'{root}{subj_id}.train.pkg.slp'
cfg.data.labels.validation_labels = f'{root}{subj_id}.val.pkg.slp'
cfg.data.labels.test_labels = f'{root}{subj_id}.test.pkg.slp'

# preprocessing and training params
cfg.data.preprocessing.input_scaling = 0.75 #0.5
cfg.data.instance_cropping.center_on_part = "centroid"
cfg.data.instance_cropping.crop_size = 128 # set crop size manually
cfg.optimization.augmentation_config.rotate = True
cfg.optimization.epochs = 600 #200
cfg.optimization.batch_size = 4

cfg.optimization.initial_learning_rate = 0.0001
cfg.optimization.learning_rate_schedule.reduce_on_plateau = True
cfg.optimization.learning_rate_schedule.reduction_factor = 0.5
cfg.optimization.learning_rate_schedule.plateau_min_delta = 1e-06 
cfg.optimization.learning_rate_schedule.plateau_patience = 20 #5
cfg.optimization.learning_rate_schedule.plateau_cooldown = 3
cfg.optimization.learning_rate_schedule.min_learning_rate = 1e-08

cfg.optimization.early_stopping.stop_training_on_plateau = True
cfg.optimization.early_stopping.plateau_min_delta = 1e-08
cfg.optimization.early_stopping.plateau_patience = 30 #20

# configure nn and model
cfg.model.backbone.unet = UNetConfig(
    max_stride=16,
    filters=16,
    filters_rate=2.00,
    output_stride=2,
    #up_interpolate=True, # save computations but may lower accuracy
)
cfg.model.heads.centroid = CentroidsHeadConfig(
    anchor_part="centroid",
    sigma=2.5,
    output_stride=2
)

# configure outputs
cfg.outputs.run_name = run_name_centroid
cfg.outputs.save_outputs = True
cfg.outputs.runs_folder = runs_folder
cfg.outputs.save_visualizations = True
cfg.outputs.checkpointing.initial_model = True
cfg.outputs.checkpointing.best_model = True

trainer = sleap.nn.training.Trainer.from_config(cfg)
trainer.setup()
trainer.train()

In [None]:
# part detection model: centered instance + multi-class
# initalise default training job config
cfg = TrainingJobConfig()

# update path to 0.8/0.1/0.1 train/val/test split
cfg.data.labels.training_labels = f'{root}{subj_id}.train.pkg.slp'
cfg.data.labels.validation_labels = f'{root}{subj_id}.val.pkg.slp'
cfg.data.labels.test_labels = f'{root}{subj_id}.test.pkg.slp'
cfg.data.labels.skeletons = [skeleton] # load skeleton

# preprocessing and training params
cfg.data.preprocessing.input_scaling = 1.0
cfg.data.instance_cropping.center_on_part = "centroid"
cfg.data.instance_cropping.crop_size = 128 # set crop size manually
cfg.optimization.augmentation_config.rotate = True
cfg.optimization.epochs = 600
cfg.optimization.batch_size = 8 # 4

cfg.optimization.initial_learning_rate = 0.0001
cfg.optimization.learning_rate_schedule.reduce_on_plateau = True
cfg.optimization.learning_rate_schedule.reduction_factor = 0.1 #0.5
cfg.optimization.learning_rate_schedule.plateau_min_delta = 1e-08 #1e-06 
cfg.optimization.learning_rate_schedule.plateau_patience = 20 #5
cfg.optimization.learning_rate_schedule.plateau_cooldown = 3
cfg.optimization.learning_rate_schedule.min_learning_rate = 1e-08

cfg.optimization.early_stopping.stop_training_on_plateau = True
cfg.optimization.early_stopping.plateau_min_delta = 1e-08
cfg.optimization.early_stopping.plateau_patience = 30 #20

# configure nn and model
cfg.model.backbone.unet = UNetConfig(
    max_stride=16, #32,
    output_stride=2, #4,
    filters=16, #24,
    filters_rate=1.5,
    #up_interpolate=True, # save computations but may lower accuracy
)
confmaps=CenteredInstanceConfmapsHeadConfig(
    anchor_part="centroid",
    sigma=1.5, #2.5, 
    output_stride=2, #4, 
    loss_weight=1.0, 
) 
# load labels.slp to get track names
labels = sleap.load_file(f'{root}{subj_id}.slp')
class_vectors=ClassVectorsHeadConfig(
    classes = [track.name for track in labels.tracks],
    output_stride=2, #16, #4,
    num_fc_layers=3,
    num_fc_units=256,
    global_pool=True,
    loss_weight=0.01 # TODO: try 1.0
)
cfg.model.heads.multi_class_topdown = MultiClassTopDownConfig(
    confmaps=confmaps,
    class_vectors=class_vectors
)

# configure outputs
cfg.outputs.run_name = run_name_centered_instance
cfg.outputs.save_outputs = True
cfg.outputs.runs_folder = runs_folder
cfg.outputs.save_visualizations = True
cfg.outputs.checkpointing.initial_model = True
cfg.outputs.checkpointing.best_model = True

trainer = sleap.nn.training.Trainer.from_config(cfg)

trainer.setup()
trainer.train()

In [None]:
# resume training
# Load config.
model_path = "models/AEON3_NTP_TP_local_topdown_top.centroid/"
cfg = sleap.load_config(model_path)

# Create and initialize the trainer.
trainer = sleap.nn.training.Trainer.from_config(cfg)
trainer.setup()

# Replace the randomly initialized weights with the saved weights.
trainer.keras_model.load_weights(f'{model_path}best_model.h5')

trainer.config.optimization.epochs = 200
trainer.train()

In [None]:
print(trainer.keras_model.outputs[0].shape) # confmaps  
print(trainer.keras_model.outputs[1].shape) # id part

#### predict

In [2]:
# set initial parameters
subj_id = f'{aeon3b_pattern_tattoo3["session"]}'
run_name_centroid = f'{subj_id}_topdown_top.centroid'
run_name_centered_instance = f'{subj_id}_topdown_top.centered_instance_multiclass'
root = "/ceph/aeon/aeon/code/scratchpad/sleap/tail_pattern/"
runs_folder = root + "models/"
predictions_folder = root + "predictions/"
groundtruth_folder = root + "groundtruth/"

print(run_name_centroid, run_name_centered_instance)

In [None]:
# predict on a single multi-animal video
subj_idx = 0
video_idx = 0

multi_subj_ids = [subj_id for subj_id in aeon3b_pattern_tattoo3["subjects"].keys() if "multi_" in subj_id]

# select one multi-animal session only
input_file = f'{root}{aeon3b_pattern_tattoo3["session"]}_{multi_subj_ids[subj_idx]}.slp'

# infer on user-labeled frames on the first video only
output_file_pr = f'{predictions_folder}{subj_id}_{multi_subj_ids[subj_idx]}_pr.slp'
sleap_track(
    [
        input_file,
        "--model",	f'{runs_folder}{run_name_centroid}',
        "--model",	f'{runs_folder}{run_name_centered_instance}',
        "--only-labeled-frames",
        "--video.index", str(video_idx),
        "--output", output_file_pr,
    ]
)

In [None]:
# extract 100 suggestions based on low-scoring frames to be proofread
labels_pr = sleap.load_file(output_file_pr)
output_file_low_gt = f'{groundtruth_folder}{subj_id}_1_{multi_subj_ids[subj_idx]}_low_gt.slp'
labels_missing = sleap.Labels([label for label in labels_pr.labels if label.n_predicted_instances < 3])
suggestions = VideoFrameSuggestions.suggest(
    labels=labels_missing,
    params=dict(
        videos=[labels_missing.videos[video_idx]],
        method="prediction_score",
        score_limit=0.5,
        instance_limit_lower=1,
        instance_limit_upper=3,
    ),
)

if len(suggestions) > 100:
    suggestions = random.sample(suggestions, 100)

lfs = []
for suggestion in suggestions:
    matching_frames = labels_missing.find(video=labels_missing.videos[video_idx], frame_idx=suggestion.frame_idx)
    if matching_frames:
        lf = matching_frames[0]
        instances = []
        for instance in lf.instances_to_show:
            instances.append(
                sleap.Instance(
                    skeleton=instance.skeleton,
                    track=instance.track,
                    points={"centroid": sleap.instance.Point(instance.points[0].x, instance.points[0].y)}

                )
            )
        lfs.append(sleap.instance.LabeledFrame(
            video=lf.video,
            frame_idx=lf.frame_idx,
            instances=instances,
        ))
sleap.Labels.save_file(sleap.Labels(labeled_frames=lfs), output_file_low_gt)

In [None]:
# extract 100 random consecutive frames as ground truth data for evaluation
output_file_rand_gt = f'{groundtruth_folder}{subj_id}_{multi_subj_ids[subj_idx]}_rand_gt.slp'
random_frame_from = True
consecutive = True
num_frames = 100 # 3000 = 1min

if consecutive:
        frame_from = random.randint(0, len(labels_pr.labeled_frames) - num_frames) if random_frame_from else 0
        selected_frame_idx = list(range(frame_from, frame_from + num_frames))
else: 
    selected_frame_idx = random.sample(range(len(labels_pr.labeled_frames)), num_frames)

sfs = []
for idx in selected_frame_idx:
    sf = labels_pr.labeled_frames[idx]
    instances = []
    for instance in sf.instances_to_show:
        instances.append(
            sleap.Instance(
                skeleton=instance.skeleton,
                track=instance.track,
                points={"centroid": sleap.instance.Point(instance.points[0].x, instance.points[0].y)}

            )
        )
    sfs.append(sleap.instance.LabeledFrame(
        video=sf.video,
        frame_idx=sf.frame_idx,
        instances=instances,
    ))

sleap.Labels.save_file(sleap.Labels(labeled_frames=sfs), output_file_rand_gt)

#### evaluate

In [None]:
# evaluate model on test/gt data
subj_idx = 0
multi_subj_ids = [subj_id for subj_id in aeon3b_pattern_tattoo3["subjects"].keys() if "multi_" in subj_id]

test_data = f'{root}{subj_id}.test.pkg.slp'
ground_truth_data = f'{groundtruth_folder}{subj_id}_{multi_subj_ids[subj_idx]}_rand_gt.slp'

# load model
predictor = TopDownMultiClassPredictor.from_trained_models(
    centroid_model_path=f'{runs_folder}{run_name_centroid}',
    confmap_model_path=f'{runs_folder}{run_name_centered_instance}',
)

# load ground truth data 
labels_gt = sleap.load_file(ground_truth_data)
labels_pr = predictor.predict(labels_gt)
metrics = sleap.nn.evals.evaluate(labels_gt, labels_pr, oks_scale=128)

framepairs = sleap.nn.evals.find_frame_pairs(labels_gt, labels_pr)
matches = sleap.nn.evals.match_frame_pairs(framepairs, scale=128)
positive_pairs = matches[0]
false_negatives = matches[1]
# for each labeled frame in labels_gt, get each instance's points
correct_id = []
for positive_pair in positive_pairs:
    correct_id.append(positive_pair[0].track.name == positive_pair[1].track.name)

# compute occupancy matrix
occupancy_matrix = np.zeros(
    (len(labels_gt.tracks), len(labels_gt.labeled_frames)), dtype=np.uint8
)

for i, lf in enumerate(labels_pr):
    frame_i = i
    for inst in lf.instances_to_show:
        # Assumes either all instances have tracks or no instances have tracks
        if inst.track is None:
            track_i = 0
        else:
            track_i = labels_pr.tracks.index(inst.track)

        occupancy_matrix[track_i, frame_i] = 1

n_frames_missing_tracks = len(
    np.where(np.sum(occupancy_matrix, axis=0) < len(labels_pr.tracks))[0]
)
print(f'{ground_truth_data} predicted with:')
print(f'{run_name_centroid}') 
print(f'{run_name_centered_instance}')
print(
    "Frames with <",
    len(labels_gt.tracks),
    "tracks:",
    n_frames_missing_tracks / len(labels_gt.labeled_frames),
    "(",
    n_frames_missing_tracks,
    "/",
    len(labels_gt.labeled_frames),
    ")",
)
print("Track occupancy:", np.mean(occupancy_matrix, axis=1))
print("Tracks identified:", len(correct_id))
print("Tracks correctly identified:", sum(correct_id))
print("Total tracks:", len(labels_gt.all_instances))
print("ID accuracy:", sum(correct_id) / len(correct_id))

#### export

In [None]:
# export model
predictor = TopDownMultiClassPredictor.from_trained_models(
    centroid_model_path=f'{runs_folder}{run_name_centroid}',
    confmap_model_path=f'{runs_folder}{run_name_centered_instance}',
)
predictor.export_model(aeon3b_pattern_tattoo3["session"] + "_topdown_multiclass")