In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

sys.path.append("../")

import os
import torch
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
from src.utils.parser import get_config
from src.datasets.salicon_dataset import SaliconDataModule
from src.lightning_models.lightning_model import LightningModel

matplotlib.use("TkAgg")
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"

from src.models.tempsal import TempSAL
from src.utils.file import get_paths_recursive
from src.config import (
    SEED,
    DEVICE,
    N_WORKERS,
    CONFIG_PATH,
    MODELS_PATH,
    SEQUENCE_LENGTH,
    PROCESSED_SALICON_PATH,
)

STATE_FILE_PATH = f"{MODELS_PATH}/tempsal/20241130-144934_tempsal/epoch=9-val_loss=0.67-v1.ckpt"
CONFIG_FILE_PATH = f"{CONFIG_PATH}/tempsal/default.yml"

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
config = get_config(CONFIG_FILE_PATH)
n_epochs = int(config["n_epochs"])
learning_rate = float(config["learning_rate"])
weight_decay = float(config["weight_decay"])
evaluation_steps = int(config["evaluation_steps"])
splits = tuple(map(float, config["splits"]))
save_model = bool(config["save_model"])
with_transforms = bool(config["with_transforms"])
freeze_encoder = bool(config["freeze_encoder"])
hidden_channels_list = list(map(int, config["hidden_channels_list"]))

In [4]:
sample_folder_paths = get_paths_recursive(
    folder_path=PROCESSED_SALICON_PATH, match_pattern="*", path_type="d"
)
data_module = SaliconDataModule(
    sample_folder_paths=sample_folder_paths,
    batch_size=1,
    train_split=splits[0],
    val_split=splits[1],
    test_split=splits[2],
    with_transforms=with_transforms,
    n_workers=N_WORKERS,
    seed=SEED,
)
data_module.setup("test")
test_loader = data_module.test_dataloader()

Seed set to 0


🌱 Setting the seed to 0 for generating dataloaders.


In [5]:
model = TempSAL(
    freeze_encoder=freeze_encoder,
    hidden_channels_list=hidden_channels_list,
)
lightning_model = LightningModel.load_from_checkpoint(
    checkpoint_path=STATE_FILE_PATH,
    model=model,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    name="tempsal",
    dataset="salicon",
)

In [11]:
from matplotlib.widgets import Button
from matplotlib.animation import FuncAnimation

def visualize_batches(lightning_model, dataloader, interval=500):
    lightning_model.eval()
    lightning_model = lightning_model.to(DEVICE)
    
    # Create iterator
    dataloader_iter = iter(dataloader)
    
    # Create figure and axes
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 10))
    plt.subplots_adjust(bottom=0.2)  # Make room for button
    
    # Global variables to store current batch data
    current_batch = None
    anim = None
    
    def process_batch():
        frame, ground_truths, global_ground_truth = next(dataloader_iter)
        input = frame.to(DEVICE)
        with torch.no_grad():
            _, output = lightning_model.model(input)

        if output.dim() == 3:
            output = output.unsqueeze(1).repeat(1, SEQUENCE_LENGTH, 1, 1)
            ground_truth = global_ground_truth.unsqueeze(1).repeat(1, SEQUENCE_LENGTH, 1, 1).to(DEVICE)
        else:
            ground_truth = ground_truths.to(DEVICE)
        
        return input, output, ground_truth
    
    def init_animation():
        ax1.clear()
        ax2.clear()
        ax3.clear()
        return ax1, ax2, ax3

    def animate(frame_idx):
        ax1.clear()
        ax2.clear()
        ax3.clear()
        
        input, output, ground_truths = current_batch
        
        # Get the first sample of the sequence for the current frame
        input_frame = input[0]
        output_frame = output[0, frame_idx]
        gt_frame = ground_truths[0, frame_idx]
        
        # Process input frame for visualization
        input_np = input_frame.permute(1, 2, 0).cpu().detach().numpy()
        input_np = input_np * 255
        input_np = input_np.astype(np.uint8)
        
        # Process output and ground truth
        output_np = output_frame.cpu().detach().numpy()
        gt_np = gt_frame.cpu().detach().numpy()
        
        # Display frames
        ax1.imshow(input_np)
        ax2.imshow(output_np, cmap="jet", vmin=0, vmax=1)
        ax3.imshow(gt_np, cmap="jet", vmin=0, vmax=1)
        
        ax1.set_title(f'Input (Sequence {frame_idx + 1}/3)')
        ax2.set_title(f'Output (Sequence {frame_idx + 1}/3)')
        ax3.set_title(f'Ground Truth (Sequence {frame_idx + 1}/3)')
        ax1.axis('off')
        ax2.axis('off')
        ax3.axis('off')

        return ax1, ax2, ax3
    
    def update_plot(event=None):
        nonlocal current_batch, anim
        
        try:
            # Get new batch
            current_batch = process_batch()
            
            # If there's an existing animation, stop it
            if anim is not None:
                anim.event_source.stop()
            
            # Create new animation
            anim = FuncAnimation(
                fig, 
                animate, 
                init_func=init_animation,
                frames=SEQUENCE_LENGTH,
                interval=interval,
                repeat=True,
            )
            
        except StopIteration:
            print("Reached the end of the dataset. Restarting...")
            nonlocal dataloader_iter
            dataloader_iter = iter(dataloader)
            update_plot()
        
        fig.canvas.draw_idle()
    
    # Add button
    ax_button = plt.axes([0.4, 0.05, 0.2, 0.075])
    button = Button(ax_button, 'Next Batch')
    button.on_clicked(update_plot)
    
    # Show initial batch
    update_plot()
    
    plt.show(block=True)

visualize_batches(lightning_model, test_loader, interval=500)