In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import pydicom
from pydicom.pixel_data_handlers.util import apply_modality_lut, apply_voi_lut
from skimage.transform import resize

# ================================================================
# ZACH-ViT Preprocessing and ROI Extraction (TALOS Dataset)
# ================================================================
# This script extracts pleural-line ROIs from TALOS ultrasound
# DICOM recordings, applies 50% vertical compression, and saves
# standardized grayscale frames for VIS generation.
# ================================================================

# CONFIGURATION
TALOS_PATH = "../TALOS"
OUTPUT_DIR = "../Processed_ROI"
PATIENT_RANGE = [147, 199]  # Example patient IDs
NUM_POSITIONS = 4  # Probe positions per patient

# HELPER FUNCTIONS
def ensure_dir(path):
    os.makedirs(path, exist_ok=True)

def crop_roi(image, x_start=285, x_end=395, y_start=65, y_end=400):
    return image[y_start:y_end, x_start:x_end]

def compress_height(image, factor=0.5):
    new_height = int(image.shape[0] * factor)
    return resize(image, (new_height, image.shape[1]), anti_aliasing=True)

def save_cropped_image(path, image, base_name, frame_idx):
    ensure_dir(path)
    out_path = os.path.join(path, f"{base_name}_frame{frame_idx:03d}.png")
    plt.imsave(out_path, image, cmap="gray", vmin=0, vmax=1)

def compare_images(img0, img1):
    return np.mean(np.abs(img0 - img1)) > 0.02

def process_dicom(patient_id, probe_idx):
    fname = f"{TALOS_PATH}{patient_id}/TALOS{patient_id}_{probe_idx}.dcm"
    base_name = f"TALOS{patient_id}_{probe_idx}"
    output_path = os.path.join(OUTPUT_DIR, f"TALOS{patient_id}", f"pos_{probe_idx}")
    ensure_dir(output_path)

    ds = pydicom.dcmread(fname)
    pixel_array = ds.pixel_array.astype(np.float32)

    if "ModalityLUTSequence" in ds:
        pixel_array = apply_modality_lut(pixel_array, ds)
    if "VOILUTSequence" in ds:
        pixel_array = apply_voi_lut(pixel_array, ds)

    pixel_array /= np.max(pixel_array)

    frames = []
    for k in range(len(pixel_array)):
        frame = pixel_array[k] if pixel_array.ndim > 2 else pixel_array
        if frame.ndim == 3:
            frame = np.mean(frame, axis=-1)
        cropped = crop_roi(frame)
        compressed = compress_height(cropped, factor=0.5)
        frames.append(compressed)
        save_cropped_image(output_path, compressed, base_name, k)

    diffs = [compare_images(frames[i], frames[i+1]) for i in range(len(frames)-1)]
    diff_ratio = round(np.sum(diffs) / len(diffs) * 100, 2) if diffs else 0
    print(f"(Patient {patient_id}, Probe {probe_idx}) {len(frames)} frames | {diff_ratio}% differ across frames.")

# MAIN LOOP
for pid in PATIENT_RANGE:
    for pos in range(1, NUM_POSITIONS + 1):
        try:
            process_dicom(pid, pos)
        except Exception as e:
            print(f"⚠️ Skipping patient {pid}, position {pos}: {e}")

# MOCK VISUALIZATION
synthetic_frame = np.random.rand(480, 640) ** 2.5
synthetic_frame = np.clip(synthetic_frame * 255, 0, 255)
roi_example = crop_roi(synthetic_frame)
roi_compressed = compress_height(roi_example, factor=0.5)

plt.figure(figsize=(10, 4))
plt.subplot(1, 3, 1)
plt.imshow(synthetic_frame, cmap="gray")
plt.title("Original Ultrasound Frame")
plt.axis("off")

plt.subplot(1, 3, 2)
plt.imshow(roi_example, cmap="gray")
plt.title("Cropped ROI")
plt.axis("off")

plt.subplot(1, 3, 3)
plt.imshow(roi_compressed, cmap="gray")
plt.title("50% Height ROI (Final Frame)")
plt.axis("off")

plt.suptitle("Example ROI Extraction and Compression (TALOS Dataset)")
plt.tight_layout()
plt.show()
