In [26]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [27]:
import sys

sys.path.append("../")

import torch
import matplotlib
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt

matplotlib.use("TkAgg")

from src.models.livesal import LiveSAL
from src.config import MODELS_PATH
from src.utils.file import get_paths_recursive
from src.datasets.salicon_dataset import get_dataloaders as get_salicon_dataloaders
from src.datasets.dhf1k_dataset import get_dataloaders as get_dhf1k_dataloaders
from src.config import (
    SEED,
    DEVICE,
    SPLITS,
    N_WORKERS,
    PROCESSED_SALICON_PATH,
    PROCESSED_DHF1K_PATH,
)

USE_LATEST_MODEL = False
STATE_FOLDER_PATH = f"{MODELS_PATH}/livesal/"
STATE_FILE_PATH = f"{STATE_FOLDER_PATH}/20241121-234622_livesal.pth"

In [31]:
sample_folder_paths = get_paths_recursive(folder_path=PROCESSED_SALICON_PATH, match_pattern="*", file_type="d")
train_loader, val_loader, test_loader = get_salicon_dataloaders(
    sample_folder_paths=sample_folder_paths,
    with_transforms=True,
    batch_size=1,
    train_split=SPLITS[0],
    val_split=SPLITS[1],
    test_split=SPLITS[2],
    train_shuffle=True,
    n_workers=N_WORKERS,
    seed=SEED,
)

🌱 Setting the seed to 0 for generating dataloaders.


In [17]:
sample_folder_paths = get_paths_recursive(folder_path=PROCESSED_DHF1K_PATH, match_pattern="*", file_type="d")
train_loader, val_loader, test_loader = get_dhf1k_dataloaders(
    sample_folder_paths=sample_folder_paths,
    sequence_length=5, # TODO: not with salicon
    with_transforms=True,
    batch_size=1,
    train_split=SPLITS[0],
    val_split=SPLITS[1],
    test_split=SPLITS[2],
    train_shuffle=True,
    n_workers=N_WORKERS,
    seed=SEED,
)

🌱 Setting the seed to 0 for generating dataloaders.


In [29]:
model = LiveSAL(
    hidden_channels=64, 
    output_channels=5,
    with_absolute_positional_embeddings=True,
    with_relative_positional_embeddings=True,
    n_heads=1,
    neighbor_radius=1,
    n_iterations=5,
    with_graph_processing=True,
    freeze_encoder=True,
    with_depth_information=True,
    fusion_level=4,
)
if USE_LATEST_MODEL:
    state_file_paths = get_paths_recursive(folder_path=STATE_FOLDER_PATH, match_pattern="*.pth", file_type="f")
    state_file_path = sorted(state_file_paths, key=lambda x: x.split("/")[-1].split(".")[0])[-1]
else:
    state_file_path = STATE_FILE_PATH
model.load_state_dict(torch.load(state_file_path))
print(f"✅ Loaded state from {Path(state_file_path).resolve()}")

✅ Loaded state from C:\Users\arnau\cours\master_project\git-estimation\data\models\livesal\20241121-234622_livesal.pth


In [33]:
from matplotlib.widgets import Button
from matplotlib.animation import FuncAnimation

def visualize_batches(model, dataloader, interval=500):
    model.eval()
    model = model.to(DEVICE)
    
    # Create iterator
    dataloader_iter = iter(dataloader)
    
    # Create figure and axes
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 10))
    plt.subplots_adjust(bottom=0.2)  # Make room for button
    
    # Global variables to store current batch data
    current_batch = None
    anim = None
    
    def process_batch():
        frame, ground_truths, global_ground_truth = next(dataloader_iter)
        input = frame.to(DEVICE)
        with torch.no_grad():
            output = model(input)

        if output.dim() == 3:
            output = output.unsqueeze(1)

        if model.output_channels == 1:
            ground_truth = global_ground_truth.unsqueeze(1).to(DEVICE)
        else:
            ground_truth = ground_truths.to(DEVICE)
        
        return input, output, ground_truth
    
    def init_animation():
        ax1.clear()
        ax2.clear()
        ax3.clear()
        return ax1, ax2, ax3

    def animate(frame_idx):
        ax1.clear()
        ax2.clear()
        ax3.clear()
        
        input, output, ground_truths = current_batch
        
        # Get the first sample of the sequence for the current frame
        if input.dim() == 4:
            input_frame = input[0]
        elif input.dim() == 5:
            input_frame = input[0, frame_idx]
        output_frame = output[0, frame_idx]
        gt_frame = ground_truths[0, frame_idx]
        
        # Process input frame for visualization
        input_np = input_frame.permute(1, 2, 0).cpu().detach().numpy()
        
        # Process output and ground truth
        output_np = output_frame.cpu().detach().numpy()
        gt_np = gt_frame.cpu().detach().numpy()
        
        # Display frames
        ax1.imshow(input_np)
        ax2.imshow(output_np, cmap="jet", vmin=0, vmax=1)
        ax3.imshow(gt_np, cmap="jet", vmin=0, vmax=1)
        
        ax1.set_title(f'Input (Sequence {frame_idx + 1}/3)')
        ax2.set_title(f'Output (Sequence {frame_idx + 1}/3)')
        ax3.set_title(f'Ground Truth (Sequence {frame_idx + 1}/3)')
        ax1.axis('off')
        ax2.axis('off')
        ax3.axis('off')
        
        return ax1, ax2, ax3
    
    def update_plot(event=None):
        nonlocal current_batch, anim
        
        try:
            # Get new batch
            current_batch = process_batch()
            
            # If there's an existing animation, stop it
            if anim is not None:
                anim.event_source.stop()
            
            # Create new animation
            anim = FuncAnimation(
                fig, 
                animate, 
                init_func=init_animation,
                frames=model.output_channels,
                interval=interval,
                repeat=True,
                blit=True
            )
            
        except StopIteration:
            print("Reached the end of the dataset. Restarting...")
            nonlocal dataloader_iter
            dataloader_iter = iter(dataloader)
            update_plot()
        
        fig.canvas.draw_idle()
    
    # Add button
    ax_button = plt.axes([0.4, 0.05, 0.2, 0.075])
    button = Button(ax_button, 'Next Batch')
    button.on_clicked(update_plot)
    
    # Show initial batch
    update_plot()
    
    plt.show(block=True)

visualize_batches(model, train_loader, interval=500)

>>> torch.Size([1, 3, 331, 331])


: 