In [1]:
import json
from collections import defaultdict
from pathlib import Path
from typing import Optional, Sequence

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from IPython.display import HTML

pd.set_option("display.max_colwidth", None)
import math
import shutil

from tqdm import tqdm

from common.constants import *

In [None]:
src_annot_dir = Path("original_data") / "train" / "train_annotation"
dst_annot_dir = Path("data") / "annotation"

if dst_annot_dir.exists():
    shutil.rmtree(dst_annot_dir)
shutil.copytree(src_annot_dir, dst_annot_dir)

### Create test.csv with all labs that appear in test

In [None]:
from helpers import has_annotation_df


meta = pd.read_csv("original_data/train/train.csv")

mask_keep = meta["lab_id"].isin(LAB_NAMES_IN_TEST)
mask_keep &= meta.apply(lambda row: has_annotation_df(row["lab_id"], row["video_id"]), axis=1)
meta[mask_keep].to_csv("data/test.csv")

### Create train.csv with "has_annotation", "tracked_[keypoint]" and cnt_frames columns

In [None]:
train_meta = pd.read_csv("original_data/train/train.csv")

all_keypoints = set()
for s in train_meta["body_parts_tracked"]:
    s = json.loads(s)
    s = set(s)
    all_keypoints |= s

# print(f"Keypoints: {list(sorted(all_keypoints))}")
# categorical = defaultdict(set)
# for name in ["strain", "color", "sex", "age", "condition"]:
#     for mouse in range(1, 5):
#         categorical[name] |= set(train_meta[f"mouse{mouse}_{name}"].dropna().unique())

# for k, v in categorical.items():
#     if k == "condition":
#         v = [x for x in v if "lights on" not in x and "lights off" not in x]
#         v.append("lights on")
#         v.append("lights off")
#     v = list(sorted(v))
#     print(f"ALL_{k}s = {v}")

N = len(train_meta)

keypoint_columns = {}
for k in all_keypoints:
    keypoint_columns[f"tracked_{k}"] = np.zeros((N,), dtype=bool)

has_annotation = np.zeros((N,), dtype=bool)
cnt_frames = np.zeros((N, ), dtype=np.int32)
cnt_missing_annotation_file = 0
for i, video_meta in enumerate(train_meta.to_dict(orient="records")):
    lab = video_meta["lab_id"]
    video = video_meta["video_id"]
    tracking_path = Path("data") / "tracking" / lab / f"{video}.parquet"
    assert tracking_path.exists(), tracking_path
    track = pd.read_parquet(tracking_path)
    cnt_frames[i] = track.video_frame.max() + 1
    if isinstance(video_meta["behaviors_labeled"], str):
        has_annotation[i] = True
    annotation_path = Path("data") / "annotation" / lab / f"{video}.parquet"
    if has_annotation[i] and not annotation_path.exists():
        cnt_missing_annotation_file += 1
        empty_df = pd.DataFrame(
            columns=["agent_id", "target_id", "action", "start_frame", "stop_frame"]
        )
        empty_df.to_parquet(annotation_path)
        print(f"Write empty df to: {annotation_path}")
    keypoints = set(json.loads(video_meta["body_parts_tracked"]))
    for k in keypoints:
        keypoint_columns[f"tracked_{k}"][i] = True

for k, col in keypoint_columns.items():
    train_meta[k] = col
train_meta["has_annotation"] = has_annotation
train_meta["cnt_frames"] = cnt_frames

print(
    f"Videos with annotation: {np.sum(has_annotation)}, {np.sum(has_annotation) / N * 100:.2f}%"
)
print(
    f"Cnt missing annotation files (nothing happened in the video): {cnt_missing_annotation_file}"
)

train_meta.to_csv("data/train.csv")
print(train_meta.columns)
train_meta.head()

### Remove banned and rare actions from annotation files

In [None]:
for annot_path in dst_annot_dir.rglob("*.parquet"):
    annot = pd.read_parquet(annot_path)
    old_len = len(annot)
    annot = annot[~annot.action.isin(ACTIONS_TO_REMOVE)]
    new_len = len(annot)
    removed = old_len - new_len
    if removed == 0:
        continue
    print(f"path = {str(annot_path)}, removed segments = {removed}")
    annot.to_parquet(annot_path)

In [None]:
from helpers import get_annotation_by_video_meta, get_train_meta
from parse_utils import behaviors_labeled_to_str, parse_behaviors_labeled


def remove_banned_actions_from_meta(meta: pd.DataFrame) -> pd.DataFrame:
    def remove_actions_from_beh(beh: str | float):
        if pd.isna(beh):
            return beh
        lst = parse_behaviors_labeled(beh)
        lst_new = [b for b in lst if b.action not in ACTIONS_TO_REMOVE]
        return behaviors_labeled_to_str(lst_new)

    meta["behaviors_labeled"] = meta["behaviors_labeled"].apply(remove_actions_from_beh)

    # sanity check
    # 1) no banned actions left
    # 2) all actual annotations appear in behaviors labeled
    for video_meta in tqdm(
        train_meta[train_meta.has_annotation].to_dict(orient="records")
    ):
        behaviors_labeled = parse_behaviors_labeled(video_meta["behaviors_labeled"])
        assert not any(
            beh for beh in behaviors_labeled if beh.action in ACTIONS_TO_REMOVE
        )

        labeled = set((beh.agent, beh.target, beh.action) for beh in behaviors_labeled)
        annot = get_annotation_by_video_meta(video_meta)
        for annot_row in annot.to_dict(orient="records"):
            triple = (
                int(annot_row["agent_id"]),
                int(annot_row["target_id"]),
                annot_row["action"],
            )
            assert (
                triple in labeled
            ), f"triple = {triple}, labeled = {labeled}, video_id = {video_meta["video_id"]}"

    return train_meta


train_meta = get_train_meta()
train_meta = remove_banned_actions_from_meta(train_meta)

### Remove actions which are annotated but did not occurr in a lab

In [None]:
occurred_actions_by_lab = defaultdict(set)

for video_meta in train_meta[train_meta.has_annotation].to_dict(orient="records"):
    lab_id = video_meta["lab_id"]
    annot = get_annotation_by_video_meta(video_meta)
    actions = set(annot.action.unique())
    occurred_actions_by_lab[lab_id] |= actions

def get_new_behaviors_labeled(video_meta) -> str:
    lab_id = video_meta["lab_id"]
    beh = parse_behaviors_labeled(video_meta["behaviors_labeled"])
    new_beh = [b for b in beh if b.action in occurred_actions_by_lab[lab_id]]
    if len(new_beh) == len(beh):
        return video_meta["behaviors_labeled"]
    return behaviors_labeled_to_str(new_beh)

train_meta["behaviors_labeled"] = train_meta.apply(get_new_behaviors_labeled, axis=1)

train_meta.to_csv("data/train.csv")