In [1]:
import matplotlib.pyplot as plt
import numpy as np
import os
from src.datasets.video_dataset import make_videodataset

In [2]:
# Parameters for testing
data_paths = ["../data_path_sample.npy"]  # Path to a CSV file containing test data paths and labels
batch_size = 4
frames_per_clip = 8

In [3]:
# Create the dataset and dataloader
from torchvision import transforms

resize_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

dataset, data_loader, _ = make_videodataset(
    data_paths=data_paths,
    batch_size=batch_size,
    frames_per_clip=frames_per_clip,
    frame_step=2,
    num_clips=1,
    random_clip_sampling=True,
    allow_clip_overlap=False,
    filter_short_videos=False,
    filter_long_videos=int(10**9),
    transform=resize_transform,
    shared_transform=None,
    rank=0,
    world_size=1,
    datasets_weights=None,
    collator=None,
    drop_last=True,
    num_workers=0,  # Set to 0 for testing on local machines
    pin_mem=False,
)


In [None]:
# Fetch one batch of data
for batch in data_loader:
    videos, labels, clip_indices = batch

    print(f"Videos shape: {videos[0].shape}, Labels: {labels}, Clip indices: {clip_indices}")

    # Visualize one video from the batch
    video = videos[0]  # Select the first video in the batch

    for i, frame in enumerate(video):
        frame_np = frame.permute(1, 2, 0).numpy()  # Convert (C, H, W) -> (H, W, C)
        plt.imshow(frame_np)
        plt.title(f"Frame {i + 1}")
        plt.axis("off")
        plt.show()

    break  # Only process the first batch for visualization