In [None]:
# Single Cell for Non-Interactive Sequential Visualization (First 10 Steps)

import os
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import sys

# --- Matplotlib Magic Command (For static plots in output) ---
%matplotlib inline

# --- << SET THIS GLOBAL VARIABLE >> ---
# Change this path to the specific .pth file you want to visualize
EPISODE_FILE_PATH = "./dataset/<chess>_epi1.pth"
# Example: EPISODE_FILE_PATH = "./dataset/<chess>_pseudo_illegal_epi1.pth"

# --- Configuration (Should match dataset generation) ---
FRAME_SIZE = (512, 512, 3)  # (Height, Width, Channels)
TORCH_DATA_TYPE = torch.float32
GAME_NAME_TOKEN = "<chess>" # Or whatever token you used for the first action
FINAL_ACTION_TOKEN = "<exit>"
STEPS_TO_SHOW = 10 # Limit visualization to the first N steps

# --- Helper Function ---
def tensor_to_pil(tensor):
    """Converts a (H, W, C) tensor scaled [-1, 1] back to a PIL Image (RGB)."""
    if not isinstance(tensor, torch.Tensor):
        print(f"Warning: Expected a Tensor, got {type(tensor)}. Attempting conversion.")
        try:
            tensor = torch.tensor(tensor, dtype=TORCH_DATA_TYPE)
            if tensor.shape != FRAME_SIZE:
                 raise ValueError(f"Converted tensor shape {tensor.shape} doesn't match FRAME_SIZE {FRAME_SIZE}")
        except Exception as e:
            print(f"Error converting input to tensor: {e}")
            return Image.new('RGB', (FRAME_SIZE[1], FRAME_SIZE[0]), (255, 0, 0)) # Red error image

    if tensor.is_cuda:
        tensor = tensor.cpu()
    tensor = tensor.to(torch.float32)
    if tensor.shape != FRAME_SIZE:
         print(f"Error: Input tensor shape {tensor.shape} does not match expected FRAME_SIZE {FRAME_SIZE}")
         return Image.new('RGB', (FRAME_SIZE[1], FRAME_SIZE[0]), (255, 0, 255)) # Magenta error image

    tensor = tensor.permute(2, 0, 1) # HWC to CHW
    tensor = (tensor + 1.0) / 2.0    # [-1, 1] to [0, 1]
    tensor = torch.clamp(tensor, 0, 1)
    try:
        pil_image = transforms.ToPILImage()(tensor)
    except Exception as e:
        print(f"Error during transforms.ToPILImage(): {e}")
        return Image.new('RGB', (FRAME_SIZE[1], FRAME_SIZE[0]), (255, 255, 0)) # Yellow error image
    return pil_image

# --- Main Execution Logic ---

print("--- Starting Visualization ---")

# --- Basic File Check ---
valid_file = False
if not isinstance(EPISODE_FILE_PATH, str) or not EPISODE_FILE_PATH:
    print("Error: EPISODE_FILE_PATH is not set correctly.")
elif not os.path.isfile(EPISODE_FILE_PATH):
    print(f"Error: File not found at '{EPISODE_FILE_PATH}'")
elif not EPISODE_FILE_PATH.endswith(".pth"):
     print(f"Warning: File '{os.path.basename(EPISODE_FILE_PATH)}' does not end with .pth.")
     valid_file = True # Allow processing anyway
else:
    valid_file = True

if valid_file:
    # --- Load Data ---
    print(f"Loading episode data from: {EPISODE_FILE_PATH}")
    try:
        episode_data = torch.load(EPISODE_FILE_PATH, map_location=torch.device('cpu'))

        if not isinstance(episode_data, dict):
             print(f"Error: Loaded file is not a dictionary. Found type: {type(episode_data)}")
        else:
            # --- Data Validation ---
            required_keys = ['previous_frames', 'actions', 'target_frames']
            if not all(key in episode_data for key in required_keys):
                print("Error: Loaded dictionary is missing required keys.")
            else:
                prev_frames = episode_data['previous_frames']
                actions = episode_data['actions']
                target_frames = episode_data['target_frames']
                num_transitions = len(actions)
                print(f"  Total transitions in file: {num_transitions}")

                # Add shape/type validation if desired (copied from previous versions)
                if not isinstance(prev_frames, torch.Tensor) or not isinstance(target_frames, torch.Tensor) or \
                   prev_frames.shape[0] != num_transitions or target_frames.shape[0] != num_transitions or \
                   prev_frames.shape[1:] != FRAME_SIZE or target_frames.shape[1:] != FRAME_SIZE:
                    print("Error: Data validation failed (type, count, or shape mismatch).")
                else:
                    # --- Sequential Plotting Loop ---
                    num_to_show = min(STEPS_TO_SHOW, num_transitions)
                    print(f"Displaying the first {num_to_show} transitions sequentially:")

                    for i in range(num_to_show):
                        print(f"\n--- Transition {i+1}/{num_to_show} ---")
                        prev_frame_tensor = prev_frames[i]
                        action = actions[i]
                        target_frame_tensor = target_frames[i]

                        is_illegal_attempt = False
                        if action != GAME_NAME_TOKEN and action != FINAL_ACTION_TOKEN:
                             if torch.allclose(prev_frame_tensor, target_frame_tensor, atol=1e-6):
                                  is_illegal_attempt = True

                        prev_img = tensor_to_pil(prev_frame_tensor)
                        target_img = tensor_to_pil(target_frame_tensor)

                        # *** Create a NEW figure for each step ***
                        fig, ax = plt.subplots(1, 2, figsize=(9, 4.5)) # Smaller size for inline

                        ax[0].imshow(prev_img)
                        ax[0].set_title(f"Prev Frame ({i+1})")
                        ax[0].axis("off")

                        ax[1].imshow(target_img)
                        title_suffix = " (Illegal Attempt!)" if is_illegal_attempt else ""
                        ax[1].set_title(f"Action: '{action}'{title_suffix}")
                        ax[1].axis("off")

                        # fig.suptitle(f"Transition {i+1}/{num_to_show}", fontsize=10) # Optional overall title
                        fig.tight_layout()

                        # *** Show the current figure inline ***
                        plt.show(fig)

    except Exception as e:
        print(f"\nAn error occurred while loading or visualizing the file:")
        print(e)
        import traceback
        traceback.print_exc()

print("\n--- Visualization cell finished ---")