In [None]:
import os
import numpy as np
import itertools
import matplotlib.pyplot as plt
from PIL import Image

# ================================================================
# ZACH-ViT - 03_0_SSDA_Generation.ipynb
# ================================================================
# ShuffleStrides Data Augmentation (0-SSDA)
# ---------------------------------------------------------------
# For each patient, this notebook generates all 4! = 24 permutations
# of the 4 positional strides, creating augmented VIS images.
# Each permutation preserves internal anatomical coherence within a stride
# while altering the probe-order relationships.
# ================================================================

# CONFIGURATION
INPUT_ROOT = "../Processed_ROI"
OUTPUT_DIR = "../SSDA_0"
NUM_POSITIONS = 4
os.makedirs(OUTPUT_DIR, exist_ok=True)

# ---------------------------------------------------------------
# Helper functions
# ---------------------------------------------------------------
def load_stride(patient_id, pos):
    """Load the stride image for a given patient and position."""
    stride_path = os.path.join(INPUT_ROOT, f"TALOS{patient_id}", f"pos_{pos}")
    frame_files = sorted([f for f in os.listdir(stride_path) if f.endswith(".png")])
    if not frame_files:
        return None
    # Load all frames and concatenate horizontally (same as VIS step)
    frames = [np.array(Image.open(os.path.join(stride_path, f)).convert("L")) / 255.0 for f in frame_files]
    stride = np.concatenate(frames, axis=1)
    return stride

def zero_pad_to_width(image, target_width):
    """Zero-pad image horizontally to match target width."""
    h, w = image.shape
    if w < target_width:
        pad_width = target_width - w
        return np.pad(image, ((0, 0), (0, pad_width)), mode='constant', constant_values=0)
    return image

def save_augmented_vis(patient_id, permutation_idx, vis_image):
    """Save an augmented VIS image corresponding to a stride permutation."""
    vis_uint8 = (vis_image * 255).astype(np.uint8)
    out_path = os.path.join(OUTPUT_DIR, f"TALOS{patient_id}_perm_{permutation_idx:02d}.png")
    Image.fromarray(vis_uint8).save(out_path)
    return out_path

def create_permuted_vis(patient_id):
    """Generate all 4! permutations of the stride order for one patient."""
    # Load all 4 position strides
    strides = []
    for pos in range(1, NUM_POSITIONS + 1):
        s = load_stride(patient_id, pos)
        if s is not None:
            strides.append((pos, s))
        else:
            print(f"⚠️ TALOS{patient_id}: missing stride for pos {pos}")
    
    if len(strides) < NUM_POSITIONS:
        print(f"⚠️ TALOS{patient_id}: insufficient valid strides ({len(strides)}/4), skipping.")
        return []
    
    # Pad all to same width
    max_width = max(s[1].shape[1] for s in strides)
    padded = [(p, zero_pad_to_width(s, max_width)) for p, s in strides]

    # Generate permutations of stride indices
    perms = list(itertools.permutations(range(NUM_POSITIONS)))
    output_paths = []

    for idx, perm in enumerate(perms, start=1):
        ordered_strides = [padded[i][1] for i in perm]
        vis = np.concatenate(ordered_strides, axis=0)
        vis /= np.max(vis)
        out_path = save_augmented_vis(patient_id, idx, vis)
        output_paths.append(out_path)

    print(f"✅ TALOS{patient_id}: generated {len(output_paths)} permuted VIS images.")
    return output_paths

# ---------------------------------------------------------------
# Example visualization
# ---------------------------------------------------------------
def visualize_example(patient_id):
    """Visualize example permutations for one patient."""
    create_permuted_vis(patient_id)
    fig, axes = plt.subplots(2, 3, figsize=(12, 6))
    for i, ax in enumerate(axes.flat[:6]):
        path = os.path.join(OUTPUT_DIR, f"TALOS{patient_id}_perm_{i+1:02d}.png")
        if os.path.exists(path):
            img = np.array(Image.open(path))
            ax.imshow(img, cmap="gray", aspect="auto")
            ax.set_title(f"Perm {i+1}")
            ax.axis("off")
    plt.suptitle(f"0-SSDA Examples for TALOS{patient_id}")
    plt.tight_layout()
    plt.show()

# ---------------------------------------------------------------
# RUN PIPELINE
# ---------------------------------------------------------------
PATIENT_RANGE = [147, 199]  # Example patients
for pid in PATIENT_RANGE:
    create_permuted_vis(pid)

# ---------------------------------------------------------------
# Visualization for one patient
# ---------------------------------------------------------------
visualize_example(147)
