In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys

sys.path.append("../")

import os
import matplotlib
import numpy as np
from typing import List
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

matplotlib.use("TkAgg")
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"

from src.utils.file import get_sample_paths_list
from src.datasets.sequence_dataset import SequenceDataset

SEQUENCE_LENGTH = 5

In [3]:
sample_paths_list = get_sample_paths_list()
dataset = SequenceDataset(
    sample_paths_list=sample_paths_list,
    sequence_length=SEQUENCE_LENGTH,
    with_transforms=True,
)
print(f"Dataset size: {len(dataset)}") 

Dataset size: 1785


In [4]:
def animate_sample(dataset, index):
    frames, ground_truths, global_ground_truth = dataset[index]

    # Average and unnormalize frames
    frames = (frames * 0.5 + 0.5) * 255
    frames = np.moveaxis(frames, 2, -1)
    frames = np.clip(frames, 0, 255).astype(np.uint8)
    frames = np.mean(frames, axis=1).astype(np.uint8)

    fig = plt.figure(figsize=(10, 5))
    img_plot = plt.imshow(frames[0][:, :, ::-1], cmap='gray') 
    global_heatmap_plot = plt.imshow(global_ground_truth, cmap='jet', alpha=0.5, vmin=0, vmax=1)
    dynamic_heatmap_plot = plt.imshow(ground_truths[0], cmap='hot', alpha=0.5, vmin=0, vmax=1)
    dynamic_heatmap_plot.set_visible(False)

    is_global_heatmap = True

    def update(frame):
        img_plot.set_array(frames[frame][:, :, ::-1])
        if is_global_heatmap:
            global_heatmap_plot.set_array(global_ground_truth)
            dynamic_heatmap_plot.set_visible(False)
            global_heatmap_plot.set_visible(True)
        else:
            dynamic_heatmap_plot.set_array(ground_truths[frame])
            dynamic_heatmap_plot.set_visible(True)
            global_heatmap_plot.set_visible(False)

        return img_plot, global_heatmap_plot, dynamic_heatmap_plot

    def on_key(event):
        nonlocal is_global_heatmap
        if event.key == ' ':
            plt.close(fig)
        elif event.key == 'e':
            is_global_heatmap = not is_global_heatmap

    fig.canvas.mpl_connect('key_press_event', on_key)

    num_frames = frames.shape[0]
    ani = FuncAnimation(fig, update, frames=num_frames, interval=500, repeat=True)

    plt.show()

animate_sample(dataset, 0)