In [6]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
import sys

sys.path.append("../")

import torch
import matplotlib
import numpy as np
import matplotlib.pyplot as plt

matplotlib.use("TkAgg")

from src.models.unet import UNet
from src.config import MODELS_PATH
from src.utils.file import get_sample_paths_list
from src.datasets.sequence_dataset import get_dataloaders
from src.config import (
    SPLITS,
    N_WORKERS,
    SEQUENCE_LENGTH,
    WITH_TRANSFORMS,
)

STATE_PATH = f"{MODELS_PATH}/unet/20241111-160602_unet.pth"

In [8]:
sample_paths_list = get_sample_paths_list()
train_loader, val_loader, test_loader = get_dataloaders(
    sample_paths_list=sample_paths_list,
    sequence_length=SEQUENCE_LENGTH,
    with_transforms=WITH_TRANSFORMS,
    batch_size=1,
    train_split=SPLITS[0],
    val_split=SPLITS[1],
    test_split=SPLITS[2],
    train_shuffle=True,
    n_workers=N_WORKERS,
)

🌱 Setting the seed to 0 for generating dataloaders
✅ Created dataloaders with 1375 training, 385 validation, and 165 test batches.


In [9]:
model = UNet(freeze_encoder=False)
model.load_state_dict(torch.load(STATE_PATH))

<All keys matched successfully>

In [11]:
from matplotlib.widgets import Button
from matplotlib.animation import FuncAnimation

def visualize_batches(model, dataloader, device='cuda', interval=500):  # interval in ms
    model.eval()
    model = model.to(device)
    
    # Create iterator
    dataloader_iter = iter(dataloader)
    
    # Create figure and axes
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 10))
    plt.subplots_adjust(bottom=0.2)  # Make room for button
    
    # Global variables to store current batch data
    current_batch = None
    anim = None
    
    def process_batch():
        frames, ground_truths, global_ground_truth = next(dataloader_iter)
        input = frames.to(device)
        with torch.no_grad():
            output = model(input)
        
        return input, output, ground_truths
    
    def init_animation():
        ax1.clear()
        ax2.clear()
        ax3.clear()
        return ax1, ax2, ax3

    def animate(frame_idx):
        ax1.clear()
        ax2.clear()
        ax3.clear()
        
        input, output, ground_truths = current_batch
        
        # Get the first sample of the sequence for the current frame
        input_frame = input[0, frame_idx, 0]  # [batch, sequence, sample, channel, height, width]
        output_frame = output[0, frame_idx]    # [batch, sequence, height, width]
        gt_frame = ground_truths[0, frame_idx] # [batch, sequence, height, width]
        
        # Process input frame for visualization
        input_np = input_frame.permute(1, 2, 0).cpu().detach().numpy()[:, :, ::-1]
        input_np = (input_np * 0.5 + 0.5) * 255
        input_np = input_np.astype(np.uint8)
        
        # Process output and ground truth
        output_np = output_frame.cpu().detach().numpy()
        gt_np = gt_frame.cpu().detach().numpy()
        
        # Display frames
        ax1.imshow(input_np)
        ax2.imshow(output_np, cmap="jet", vmin=0, vmax=1)
        ax3.imshow(gt_np, cmap="jet", vmin=0, vmax=1)
        
        ax1.set_title(f'Input (Sequence {frame_idx + 1}/3)')
        ax2.set_title(f'Output (Sequence {frame_idx + 1}/3)')
        ax3.set_title(f'Ground Truth (Sequence {frame_idx + 1}/3)')
        ax1.axis('off')
        ax2.axis('off')
        ax3.axis('off')
        
        return ax1, ax2, ax3
    
    def update_plot(event=None):
        nonlocal current_batch, anim
        
        try:
            # Get new batch
            current_batch = process_batch()
            
            # If there's an existing animation, stop it
            if anim is not None:
                anim.event_source.stop()
            
            # Create new animation
            anim = FuncAnimation(
                fig, 
                animate, 
                init_func=init_animation,
                frames=3,  # number of sequences
                interval=interval,  # ms between frames
                repeat=True,
                blit=True
            )
            
        except StopIteration:
            print("Reached the end of the dataset. Restarting...")
            nonlocal dataloader_iter
            dataloader_iter = iter(dataloader)
            update_plot()
        
        fig.canvas.draw_idle()
    
    # Add button
    ax_button = plt.axes([0.4, 0.05, 0.2, 0.075])
    button = Button(ax_button, 'Next Batch')
    button.on_clicked(update_plot)
    
    # Show initial batch
    update_plot()
    
    plt.show(block=True)

visualize_batches(model, test_loader, interval=500)