In [None]:
import os
import numpy as np
import itertools
import random
import matplotlib.pyplot as plt
from PIL import Image

# ================================================================
# ZACH-ViT - 04_0_2_SSDA_Generation.ipynb
# ================================================================
# Semi-supervised ShuffleStrides Data Augmentation (0₂-SSDA)
# ---------------------------------------------------------------
# This notebook extends 0-SSDA by also shuffling the temporal order
# of frames within each positional stride using prime-number–seeded
# randomization while preserving intra-stride coherence.
# The approach introduces controlled variability that mimics
# real-world probe motion without breaking anatomical logic.
# ================================================================

# CONFIGURATION
INPUT_ROOT = "../Processed_ROI"
OUTPUT_DIR = "../SSDA_0_2"
NUM_POSITIONS = 4
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Prime numbers used as reproducible seeds for intra-stride shuffling
PRIME_SEEDS = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]

# ---------------------------------------------------------------
# Helper functions
# ---------------------------------------------------------------
def load_frames_from_position(path):
    """Load all .png frames from a given position folder and return them sorted."""
    frames = []
    if not os.path.exists(path):
        return None
    files = sorted([f for f in os.listdir(path) if f.endswith(".png")])
    for fname in files:
        img = np.array(Image.open(os.path.join(path, fname)).convert("L"), dtype=np.float32) / 255.0
        frames.append(img)
    return frames if len(frames) > 0 else None

def shuffle_frames(frames, seed):
    """Shuffle frame order within a stride using a reproducible prime seed."""
    random.seed(seed)
    indices = list(range(len(frames)))
    random.shuffle(indices)
    return [frames[i] for i in indices]

def concatenate_frames_horizontally(frames):
    """Concatenate a list of frames horizontally."""
    if frames is None or len(frames) == 0:
        return None
    return np.concatenate(frames, axis=1)

def zero_pad_to_width(image, target_width):
    """Zero-pad image horizontally to match target width."""
    h, w = image.shape
    if w < target_width:
        pad_width = target_width - w
        return np.pad(image, ((0, 0), (0, pad_width)), mode='constant', constant_values=0)
    return image

def save_augmented_vis(patient_id, perm_idx, vis_image):
    """Save augmented VIS as PNG."""
    vis_uint8 = (vis_image * 255).astype(np.uint8)
    out_path = os.path.join(OUTPUT_DIR, f"TALOS{patient_id}_perm_{perm_idx:02d}.png")
    Image.fromarray(vis_uint8).save(out_path)
    return out_path

# ---------------------------------------------------------------
# Main 0₂-SSDA generation
# ---------------------------------------------------------------
def create_permuted_vis_with_frame_shuffling(patient_id):
    """Generate all 4! permutations and apply within-stride frame shuffling."""
    patient_dir = os.path.join(INPUT_ROOT, f"TALOS{patient_id}")
    all_strides = []

    # Load and shuffle each stride internally using prime-number seeds
    for pos in range(1, NUM_POSITIONS + 1):
        pos_path = os.path.join(patient_dir, f"pos_{pos}")
        frames = load_frames_from_position(pos_path)
        if frames is None:
            print(f"⚠️ TALOS{patient_id}: missing frames for pos {pos}")
            continue

        seed = PRIME_SEEDS[(patient_id + pos) % len(PRIME_SEEDS)]
        shuffled_frames = shuffle_frames(frames, seed)
        stride = concatenate_frames_horizontally(shuffled_frames)
        all_strides.append((pos, stride))

    if len(all_strides) < NUM_POSITIONS:
        print(f"⚠️ TALOS{patient_id}: insufficient valid strides ({len(all_strides)}/4). Skipping.")
        return []

    # Equalize width via zero padding
    max_width = max(s[1].shape[1] for s in all_strides)
    padded = [(p, zero_pad_to_width(s, max_width)) for p, s in all_strides]

    # Generate all 24 stride permutations
    perms = list(itertools.permutations(range(NUM_POSITIONS)))
    out_paths = []

    for idx, perm in enumerate(perms, start=1):
        ordered_strides = [padded[i][1] for i in perm]
        vis = np.concatenate(ordered_strides, axis=0)
        vis /= np.max(vis)
        path = save_augmented_vis(patient_id, idx, vis)
        out_paths.append(path)

    print(f"✅ TALOS{patient_id}: generated {len(out_paths)} permuted & shuffled VIS images.")
    return out_paths

# ---------------------------------------------------------------
# Visualization
# ---------------------------------------------------------------
def visualize_example(patient_id):
    """Visualize a few shuffled-permutation examples."""
    create_permuted_vis_with_frame_shuffling(patient_id)
    fig, axes = plt.subplots(2, 3, figsize=(12, 6))
    for i, ax in enumerate(axes.flat[:6]):
        path = os.path.join(OUTPUT_DIR, f"TALOS{patient_id}_perm_{i+1:02d}.png")
        if os.path.exists(path):
            img = np.array(Image.open(path))
            ax.imshow(img, cmap="gray", aspect="auto")
            ax.set_title(f"perm {i+1}")
            ax.axis("off")
    plt.suptitle(f"0₂-SSDA Examples for TALOS{patient_id}")
    plt.tight_layout()
    plt.show()

# ---------------------------------------------------------------
# RUN PIPELINE
# ---------------------------------------------------------------
PATIENT_RANGE = [147, 199]  # Example range
for pid in PATIENT_RANGE:
    create_permuted_vis_with_frame_shuffling(pid)

# ---------------------------------------------------------------
# Example visualization
# ---------------------------------------------------------------
visualize_example(147)
