Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/jabs/pose_estimation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from .pose_est_v4 import PoseEstimationV4
from .pose_est_v5 import PoseEstimationV5
from .pose_est_v6 import PoseEstimationV6
from .pose_est_v7 import PoseEstimationV7
from .pose_est_v8 import PoseEstimationV8


def open_pose_file(path: Path, cache_dir: Path | None = None):
Expand All @@ -25,6 +27,10 @@ def open_pose_file(path: Path, cache_dir: Path | None = None):
return PoseEstimationV5(path, cache_dir)
elif path.name.endswith("v6.h5"):
return PoseEstimationV6(path, cache_dir)
elif path.name.endswith("v7.h5"):
return PoseEstimationV7(path, cache_dir)
elif path.name.endswith("v8.h5"):
return PoseEstimationV8(path, cache_dir)
else:
raise ValueError("not a valid pose estimate filename")

Expand All @@ -44,7 +50,11 @@ def get_pose_path(video_path: Path):
file_base = video_path.with_suffix("")

# default to the highest version pose file for a video
if video_path.with_name(file_base.name + "_pose_est_v6.h5").exists():
if video_path.with_name(file_base.name + "_pose_est_v8.h5").exists():
return video_path.with_name(file_base.name + "_pose_est_v8.h5")
elif video_path.with_name(file_base.name + "_pose_est_v7.h5").exists():
return video_path.with_name(file_base.name + "_pose_est_v7.h5")
elif video_path.with_name(file_base.name + "_pose_est_v6.h5").exists():
return video_path.with_name(file_base.name + "_pose_est_v6.h5")
elif video_path.with_name(file_base.name + "_pose_est_v5.h5").exists():
return video_path.with_name(file_base.name + "_pose_est_v5.h5")
Expand Down
18 changes: 10 additions & 8 deletions src/jabs/pose_estimation/pose_est_v6.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,16 @@ def __init__(self, file_path: Path, cache_dir: Path | None = None, fps: int = 30
# transpose seg_data similar to the way the points are transposed.

# sort the segmentation data
self._segmentation_dict["seg_data"] = self._segmentation_sort(
self._segmentation_dict["seg_data"],
self._segmentation_dict["longterm_seg_id"],
)
self._segmentation_dict["seg_external_flag"] = self._segmentation_sort(
self._segmentation_dict["seg_external_flag"],
self._segmentation_dict["longterm_seg_id"],
)
if self._segmentation_dict["seg_data"] is not None:
self._segmentation_dict["seg_data"] = self._segmentation_sort(
self._segmentation_dict["seg_data"],
self._segmentation_dict["longterm_seg_id"],
)
if self._segmentation_dict["seg_external_flag"] is not None:
self._segmentation_dict["seg_external_flag"] = self._segmentation_sort(
self._segmentation_dict["seg_external_flag"],
self._segmentation_dict["longterm_seg_id"],
)

def get_seg_id(self, frame_index: int, identity: int) -> np.ndarray[Any, Any] | None:
"""get segmentation for a given frame and identity."""
Expand Down
13 changes: 13 additions & 0 deletions src/jabs/pose_estimation/pose_est_v7.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from .pose_est_v6 import PoseEstimationV6


class PoseEstimationV7(PoseEstimationV6):
"""Pose estimation version 7

Currently handled the same as v6 because we're not using the v7 dynamic_objects dataset yet.
"""

@property
def format_major_version(self) -> int:
"""Returns the major version of the pose file format."""
return 7
176 changes: 176 additions & 0 deletions src/jabs/pose_estimation/pose_est_v8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
from pathlib import Path

import h5py
import numpy as np

from jabs.constants import COMPRESSION, COMPRESSION_OPTS_DEFAULT

from .pose_est_v7 import PoseEstimationV7


class PoseEstimationV8(PoseEstimationV7):
"""Pose estimation version 8

Adds bounding box support.
"""

def __init__(self, file_path: Path, cache_dir: Path | None = None, fps: int = 30) -> None:
super().__init__(file_path, cache_dir, fps)
self._has_bounding_boxes = False
self._bboxes: np.ndarray | None = None

# Try to load reorganized bboxes (identity-first) from cache
# if not able to use cached bboxes, load from source
if not self._load_bboxes_from_cache(cache_dir):
self._load_from_h5(cache_dir)

@property
def format_major_version(self) -> int:
"""Returns the major version of the pose file format."""
return 8

def _load_from_h5(self, cache_dir: Path | None) -> None:
"""Load bounding boxes from source HDF5 file, reorganizing by identity.

Args:
cache_dir: directory to use for caching reorganized pose files, or None to disable caching.
"""
with h5py.File(self._path, "r") as pose_h5:
ds = pose_h5.get("poseest/bbox")
if ds is None or not ds.attrs.get("bboxes_generated", False):
# No bounding box data
# Update cache to reflect absence of bboxes
if cache_dir is not None:
try:
filename = self._path.name.replace(".h5", "_cache.h5")
cache_file_path = self._cache_dir / filename
with h5py.File(cache_file_path, "a") as cache_h5:
grp = cache_h5.require_group("poseest")
if "bboxes" in grp:
del grp["bboxes"]
empty_ds = grp.create_dataset("bboxes", shape=(0,), dtype=np.float32)
empty_ds.attrs["bboxes_generated"] = False
except OSError:
pass
return

bboxes = ds[:]

# Load identity mapping arrays, needed to reorganize bboxes by identity
instance_embed_id = pose_h5["poseest/instance_embed_id"][:]
id_mask = pose_h5["poseest/id_mask"][:]

# Determine number of identities the same way v4 does: mask out invalids and take max
if instance_embed_id.shape[1] > 0:
valid = id_mask == 0
if valid.any():
# instance_embed_id is 1-based; take max over valid entries
self._num_identities = int(instance_embed_id[valid].max())
else:
print(f"Warning: All identities masked in pose file: {self._path}")
self._num_identities = 0
else:
print(f"Warning: No identities found in pose file: {self._path}")
self._num_identities = 0

# Prepare an array grouped by identity, matching the v4 keypoint transform logic.
# Shapes:
# bboxes: [frame][ident_instance][2][2]
# id_mask: [frame][ident_instance] with 0 where valid, 1 where padded/missing
num_frames = bboxes.shape[0]
bboxes_tmp = np.full(
(num_frames, self._num_identities, 2, 2), np.nan, dtype=bboxes.dtype
)

# First use instance_embed_id to group bboxes by identity
# IMPORTANT: valid entries are where id_mask == 0 (not == 1).
valid = id_mask == 0
if valid.any() and self._num_identities > 0:
ids_flat = instance_embed_id[valid]
pos = ids_flat > 0
# Align rows and source slices with the filtered positives
rows = np.where(valid)[0][pos]
ids0 = ids_flat[pos] - 1
# Guard against any out-of-range IDs
in_range = ids0 < self._num_identities
if in_range.any():
rows = rows[in_range]
ids0 = ids0[in_range]
src = bboxes[valid, :, :][pos, :, :][in_range, :, :]
bboxes_tmp[rows, ids0, :, :] = src

# Transpose so that identity becomes the first index
# Before: [frame][ident][2][2]
# After: [ident][frame][2][2]
bboxes_by_ident = np.transpose(bboxes_tmp, (1, 0, 2, 3))
self._bboxes = bboxes_by_ident.astype(np.float32)

self._has_bounding_boxes = True

# Write reorganized bboxes to cache for faster future loads
if cache_dir is not None:
filename = self._path.name.replace(".h5", "_cache.h5")
cache_file_path = self._cache_dir / filename
with h5py.File(cache_file_path, "a") as cache_h5:
grp = cache_h5.require_group("poseest")
if "bboxes" in grp:
del grp["bboxes"]
ds_out = grp.create_dataset(
"bboxes",
data=self._bboxes,
compression=COMPRESSION,
compression_opts=COMPRESSION_OPTS_DEFAULT,
)
ds_out.attrs["bboxes_generated"] = True

def _load_bboxes_from_cache(self, cache_dir: Path | None) -> bool:
"""Attempt to load bounding boxes from cache.

Args:
cache_dir: directory to use for caching reorganized pose files, or None to disable caching.

Returns:
True if bounding boxes were successfully loaded from cache, False otherwise.
"""
use_cache = False
if cache_dir is not None:
try:
filename = self._path.name.replace(".h5", "_cache.h5")
cache_file_path = self._cache_dir / filename
with h5py.File(cache_file_path, "r") as cache_h5:
if "poseest/bboxes" in cache_h5:
ds_cache = cache_h5["poseest/bboxes"]
bgen_cache = ds_cache.attrs.get("bboxes_generated", False)
if bgen_cache and ds_cache.size > 0:
self._bboxes = ds_cache[:]
self._has_bounding_boxes = True
use_cache = True
else:
# Cached dataset exists but marked as not generated; treat as absent
# set use_cache to True to skip source loading
use_cache = True
except (OSError, KeyError):
# Cache missing or unreadable; fall back to source
pass
return use_cache

def get_bounding_boxes(self, identity: int) -> np.ndarray | None:
"""Get bounding box array for an identity index.

Args:
identity: identity index (0 to num_identities-1)

Returns:
bounding box array of shape [num_frames, 2, 2] or None if no bounding box data
is available. bounding box format is [upper_left_x, upper_left_y], [lower_right_x, lower_right_y].
"""
if self._bboxes is None:
return None
if identity < 0 or identity >= self._num_identities:
raise ValueError(f"Identity {identity} out of range (0 to {self._num_identities - 1})")
return self._bboxes[identity, :, :, :]

@property
def has_bounding_boxes(self) -> bool:
"""Returns True if bounding box data is available."""
return self._has_bounding_boxes
10 changes: 5 additions & 5 deletions src/jabs/project/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,13 +402,13 @@ def archive_behavior(self, behavior: str):
# archive labels and unfragmented_labels
archived_labels = {}
for video in self._video_manager.videos:
labels = self._video_manager.load_video_labels(video)
pose = self.load_pose_est(self._video_manager.video_path(video))
labels = self._video_manager.load_video_labels(video, pose)

# if no labels for video skip it
if labels is None:
continue

pose = self.load_pose_est(self._video_manager.video_path(video))
annotations = labels.as_dict(pose)

# ensure archive structure exists for this video:
Expand Down Expand Up @@ -536,7 +536,9 @@ def get_labeled_features(
if should_terminate_callable:
should_terminate_callable()

video_labels = self._video_manager.load_video_labels(video)
video_path = self._video_manager.video_path(video)
pose_est = self.load_pose_est(video_path)
video_labels = self._video_manager.load_video_labels(video, pose_est)

# if there are no labels for this video, skip it
if video_labels is None:
Expand All @@ -546,8 +548,6 @@ def get_labeled_features(
progress_callable()
continue

video_path = self._video_manager.video_path(video)
pose_est = self.load_pose_est(video_path)
# fps used to scale some features from per pixel time unit to
# per second
fps = get_fps(str(video_path))
Expand Down
19 changes: 14 additions & 5 deletions src/jabs/project/video_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
import sys
from pathlib import Path

from jabs.pose_estimation import get_frames_from_file, get_pose_path, open_pose_file
from jabs.pose_estimation import (
PoseEstimation,
get_frames_from_file,
get_pose_path,
open_pose_file,
)
from jabs.video_reader import VideoReader

from .project_paths import ProjectPaths
Expand Down Expand Up @@ -75,11 +80,14 @@ def total_project_identities(self) -> int:
"""Get the total number of identities across all videos."""
return self._total_project_identities

def load_video_labels(self, video_name) -> VideoLabels | None:
def load_video_labels(
self, video_name: Path | str, pose: PoseEstimation = None
) -> VideoLabels | None:
"""load labels for a video

Args:
video_name: filename of the video: string or pathlib.Path
pose: optional PoseEstimation object to use for identity mapping, if None we will open the pose file

Returns:
initialized VideoLabels object if annotations exist, otherwise None
Expand All @@ -92,9 +100,10 @@ def load_video_labels(self, video_name) -> VideoLabels | None:
# if annotations already exist for this video file in the project open them
if path.exists():
# VideoLabels.load can use pose to convert identity index to the display identity
pose = open_pose_file(
get_pose_path(self.video_path(video_filename)), self._paths.cache_dir
)
if pose is None:
pose = open_pose_file(
get_pose_path(self.video_path(video_filename)), self._paths.cache_dir
)
with path.open() as f:
return VideoLabels.load(json.load(f), pose)
else:
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading