In [69]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [70]:
import sys

sys.path.append("../")

import os
import matplotlib
import numpy as np
from typing import List
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
matplotlib.use("TkAgg")

os.environ["KMP_DUPLICATE_LIB_OK"] = "True"

from src.datasets.sequence_dataset import SequenceDataset
from src.config import SAMPLES_PATH
from src.utils.file import get_files_recursive

SEQUENCE_LENGTH = 5

In [71]:
def get_sample_path_files() -> List[List[str]]:
    """
    Get sample files and group by folder
    
    Returns:
    List[List[str]]: List of list of sample files
    """
    sample_paths_list = get_files_recursive(SAMPLES_PATH, "*.pkl")
    sample_paths_dict = {}
    for path in sample_paths_list:
        folder_path = "/".join(path.split("/")[:-1])
        if folder_path not in sample_paths_dict:
            sample_paths_dict[folder_path] = []
        sample_paths_dict[folder_path].append(path)
    sample_paths_list = list(sample_paths_dict.values())

    return sample_paths_list
sample_paths_list = get_sample_path_files()

In [100]:
dataset = SequenceDataset(
    sample_paths_list=sample_paths_list,
    sequence_length=SEQUENCE_LENGTH,
    with_transforms=True,
)
print(f"Dataset size: {len(dataset)}")

Dataset size: 161


In [101]:
def animate_sample(dataset, index):
    """Animate the image series and ground truth saliency heatmap for a given index."""
    # Access the dataset instance
    sequence = dataset[index]
    image_series_list = [np.mean(sample.image_series, axis=0).astype(np.uint8) for sample in sequence]
    ground_truths = [sample.ground_truth for sample in sequence]

    # Check that the lengths match
    assert len(image_series_list) == len(ground_truths), "❌ Mismatch in series length and ground truth length"
    num_frames = len(image_series_list)

    # Set up the figure and axes
    fig = plt.figure(figsize=(10, 5))

    # Initialize the plots with the first frame
    img_plot = plt.imshow(image_series_list[0][:, :, ::-1], cmap='gray')
    heatmap_plot = plt.imshow(image_series_list[0], cmap='hot', alpha=0.6, vmin=0, vmax=1)

    def update(frame):
        """Update function for animation."""
        # Update the images for the current frame
        img_plot.set_array(image_series_list[frame][:, :, ::-1])
        heatmap_plot.set_array(ground_truths[frame])

        return img_plot, heatmap_plot

    # Function to close the figure when space bar is pressed
    def on_key(event):
        if event.key == ' ':  # Space bar pressed
            plt.close(fig)

    # Connect the key press event to the on_key function
    fig.canvas.mpl_connect('key_press_event', on_key)

    # Create the animation
    ani = FuncAnimation(fig, update, frames=num_frames, interval=500, repeat=True)

    plt.show()

# Example usage
# This will animate the first instance in the dataset
animate_sample(dataset, 0)
