In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# ================================================================
# ZACH-ViT VIS Construction Notebook (TALOS Dataset)
# ================================================================
# This notebook constructs Video Image Sequence (VIS) representations
# for each patient by concatenating position-wise frame sequences.
# Each probe position produces a horizontal stride (frames concatenated horizontally),
# and the 4 strides are stacked vertically to form the VIS image.
# ================================================================

# CONFIGURATION
INPUT_ROOT = "../Processed_ROI"
OUTPUT_DIR = "../VIS"
NUM_POSITIONS = 4  # 4 probe positions per patient
os.makedirs(OUTPUT_DIR, exist_ok=True)

# ---------------------------------------------------------------
# Helper functions
# ---------------------------------------------------------------
def load_frames_from_position(path):
    """Load all .png frames from a given position folder and return them sorted."""
    frames = []
    if not os.path.exists(path):
        return None
    files = sorted([f for f in os.listdir(path) if f.endswith(".png")])
    for fname in files:
        img = np.array(Image.open(os.path.join(path, fname)).convert("L"), dtype=np.float32) / 255.0
        frames.append(img)
    return frames if len(frames) > 0 else None

def concatenate_frames_horizontally(frames):
    """Concatenate a list of frames horizontally into one long stride."""
    if frames is None or len(frames) == 0:
        return None
    return np.concatenate(frames, axis=1)

def zero_pad_to_width(image, target_width):
    """Zero-pad image horizontally to match target width."""
    h, w = image.shape
    if w < target_width:
        pad_width = target_width - w
        padded = np.pad(image, ((0, 0), (0, pad_width)), mode='constant', constant_values=0)
        return padded
    return image

def save_vis_image(patient_id, vis_image):
    """Save the VIS image as grayscale PNG."""
    vis_path = os.path.join(OUTPUT_DIR, f"TALOS{patient_id}_VIS.png")
    vis_uint8 = (vis_image * 255).astype(np.uint8)
    Image.fromarray(vis_uint8).save(vis_path)
    print(f"✅ Saved VIS for TALOS{patient_id}: shape={vis_image.shape}, file={vis_path}")

# ---------------------------------------------------------------
# Main VIS Construction
# ---------------------------------------------------------------
def construct_vis_for_patient(patient_id):
    """Construct VIS image for one patient across 4 probe positions."""
    patient_dir = os.path.join(INPUT_ROOT, f"TALOS{patient_id}")
    all_strides = []

    for pos in range(1, NUM_POSITIONS + 1):
        pos_path = os.path.join(patient_dir, f"pos_{pos}")
        frames = load_frames_from_position(pos_path)
        if frames is None:
            print(f"⚠️ No frames found for TALOS{patient_id}, position {pos}")
            continue
        stride = concatenate_frames_horizontally(frames)
        all_strides.append(stride)

    if len(all_strides) == 0:
        print(f"⚠️ Skipping TALOS{patient_id}: no valid positions found.")
        return None

    # Determine max width for zero-padding
    max_width = max(stride.shape[1] for stride in all_strides)
    padded_strides = [zero_pad_to_width(s, max_width) for s in all_strides]

    # Stack vertically
    vis_image = np.concatenate(padded_strides, axis=0)

    # Normalize VIS to [0,1]
    vis_image = vis_image / np.max(vis_image)

    # Save output
    save_vis_image(patient_id, vis_image)
    return vis_image

# ---------------------------------------------------------------
# Example visualization
# ---------------------------------------------------------------
def visualize_vis_example(patient_id):
    vis_image = construct_vis_for_patient(patient_id)
    if vis_image is not None:
        plt.figure(figsize=(12, 6))
        plt.imshow(vis_image, cmap="gray", aspect="auto")
        plt.title(f"Video Image Sequence (VIS) for TALOS{patient_id}")
        plt.axis("off")
        plt.show()

# ---------------------------------------------------------------
# RUN PIPELINE
# ---------------------------------------------------------------
PATIENT_RANGE = [147, 199]  # Example
for pid in PATIENT_RANGE:
    construct_vis_for_patient(pid)

# ---------------------------------------------------------------
# Example visualization for one patient
# ---------------------------------------------------------------
visualize_vis_example(147)
