In [3]:
# ==============================
# ZACH-ViT: Zero-Token Vision Transformer for Lung Ultrasound
# VIS, 0-SSDA and SSDA_p
# Paper: https://arxiv.org/abs/XXXX.XXXXX
# Code: https://github.com/Bluesman79/ZACH-ViT
# Licensed under Apache License 2.0
# Author: Athanasios Angelakis
# ==============================

import os
import numpy as np
import matplotlib.pyplot as plt
import pydicom
from pydicom.pixel_data_handlers.util import apply_modality_lut, apply_voi_lut
from skimage.transform import resize
from PIL import Image
import time
import itertools
import random

# 1st Module
# ---------------------------------------------------------------
# Helper functions
# ---------------------------------------------------------------
def ensure_dir(path):
    os.makedirs(path, exist_ok=True)

def crop_roi(image, x_start=285, x_end=395, y_start=65, y_end=400): 
    '''
    ROI coordinates are defined after thorough image analysis; per vendor, software version, etc, these could be different.
    Users  provide their own ROI coordinates.
    '''
    return image[y_start:y_end, x_start:x_end]

def compress_height(image, factor=0.5):
    new_height = int(image.shape[0] * factor)
    return resize(image, (new_height, image.shape[1]), anti_aliasing=True)

def save_cropped_image(path, image, base_name, frame_idx):
    ensure_dir(path)
    out_path = os.path.join(path, f"{base_name}_frame{frame_idx:03d}.png")
    plt.imsave(out_path, image, cmap="gray", vmin=0, vmax=1)

def process_dicom(patient_id, probe_idx):
    fname = f"{TALOS_PATH}{patient_id}/TALOS{patient_id}_{probe_idx}.dcm"
    base_name = f"TALOS{patient_id}_{probe_idx}"
    output_path = os.path.join(OUTPUT_DIR_ROI, f"TALOS{patient_id}", f"pos_{probe_idx}")
    ensure_dir(output_path)

    ds = pydicom.dcmread(fname)
    pixel_array = ds.pixel_array.astype(np.float32)

    if "ModalityLUTSequence" in ds:
        pixel_array = apply_modality_lut(pixel_array, ds)
    if "VOILUTSequence" in ds:
        pixel_array = apply_voi_lut(pixel_array, ds)

    pixel_array /= np.max(pixel_array)

    frames = []
    for k in range(len(pixel_array)):
        frame = pixel_array[k] if pixel_array.ndim > 2 else pixel_array
        if frame.ndim == 3:
            frame = np.mean(frame, axis=-1)
        cropped = crop_roi(frame)
        compressed = compress_height(cropped, factor=0.5)
        frames.append(compressed)
        save_cropped_image(output_path, compressed, base_name, k)


# 2nd Module
# ---------------------------------------------------------------
# Helper functions
# ---------------------------------------------------------------
def load_frames_from_position(path):
    """Load all .png frames from a given position folder and return them sorted."""
    frames = []
    if not os.path.exists(path):
        return None
    files = sorted([f for f in os.listdir(path) if f.endswith(".png")])
    for fname in files:
        img = np.array(Image.open(os.path.join(path, fname)).convert("L"), dtype=np.float32) / 255.0
        frames.append(img)
    return frames if len(frames) > 0 else None

def concatenate_frames_horizontally(frames):
    """Concatenate a list of frames horizontally into one long stride."""
    if frames is None or len(frames) == 0:
        return None
    return np.concatenate(frames, axis=1)

def zero_pad_to_width(image, target_width):
    """Zero-pad image horizontally to match target width."""
    h, w = image.shape
    if w < target_width:
        pad_width = target_width - w
        padded = np.pad(image, ((0, 0), (0, pad_width)), mode='constant', constant_values=0)
        return padded
    return image

def save_vis_image(patient_id, vis_image):
    """Save the VIS image as grayscale PNG."""
    vis_path = os.path.join(OUTPUT_DIR_VIS, f"TALOS{patient_id}_VIS.png")
    vis_uint8 = (vis_image * 255).astype(np.uint8)
    Image.fromarray(vis_uint8).save(vis_path)
    print(f"Saved VIS for TALOS{patient_id}: shape={vis_image.shape}, file={vis_path}")

# ---------------------------------------------------------------
# Main VIS Construction
# ---------------------------------------------------------------
def construct_vis_for_patient(patient_id):
    """Construct VIS image for one patient across 4 probe positions."""
    patient_dir = os.path.join(INPUT_ROOT, f"TALOS{patient_id}")
    all_strides = []

    for pos in range(1, NUM_POSITIONS + 1):
        pos_path = os.path.join(patient_dir, f"pos_{pos}")
        frames = load_frames_from_position(pos_path)
        if frames is None:
            print(f"No frames found for TALOS{patient_id}, position {pos}")
            continue
        stride = concatenate_frames_horizontally(frames)
        all_strides.append(stride)

    if len(all_strides) == 0:
        print(f"Skipping TALOS{patient_id}: no valid positions found.")
        return None

    # Determine max width for zero-padding
    max_width = max(stride.shape[1] for stride in all_strides)
    padded_strides = [zero_pad_to_width(s, max_width) for s in all_strides]

    # Stack vertically
    vis_image = np.concatenate(padded_strides, axis=0)

    # Normalize VIS to [0,1]
    vis_image = vis_image / np.max(vis_image)

    # Save output
    save_vis_image(patient_id, vis_image)
    return vis_image


# 3rd Module
# ---------------------------------------------------------------
# Helper functions
# ---------------------------------------------------------------
def load_stride(patient_id, pos):
    """Load the stride image for a given patient and position."""
    stride_path = os.path.join(INPUT_ROOT, f"TALOS{patient_id}", f"pos_{pos}")
    frame_files = sorted([f for f in os.listdir(stride_path) if f.endswith(".png")])
    if not frame_files:
        return None
    # Load all frames and concatenate horizontally (same as VIS step)
    frames = [np.array(Image.open(os.path.join(stride_path, f)).convert("L")) / 255.0 for f in frame_files]
    stride = np.concatenate(frames, axis=1)
    return stride

def save_augmented_vis(patient_id, permutation_idx, vis_image):
    """Save an augmented VIS image corresponding to a stride permutation."""
    vis_uint8 = (vis_image * 255).astype(np.uint8)
    out_path = os.path.join(OUTPUT_DIR_0_SSDA, f"TALOS{patient_id}_perm_{permutation_idx:02d}.png")
    Image.fromarray(vis_uint8).save(out_path)
    return out_path

def create_permuted_vis(patient_id):
    """Generate all 4! permutations of the stride order for one patient."""
    # Load all 4 position strides
    strides = []
    for pos in range(1, NUM_POSITIONS + 1):
        s = load_stride(patient_id, pos)
        if s is not None:
            strides.append((pos, s))
        else:
            print(f"TALOS{patient_id}: missing stride for pos {pos}")
    
    if len(strides) < NUM_POSITIONS:
        print(f"TALOS{patient_id}: insufficient valid strides ({len(strides)}/4), skipping.")
        return []
    
    # Pad all to same width
    max_width = max(s[1].shape[1] for s in strides)
    padded = [(p, zero_pad_to_width(s, max_width)) for p, s in strides]

    # Generate permutations of stride indices
    perms = list(itertools.permutations(range(NUM_POSITIONS)))
    output_paths = []

    for idx, perm in enumerate(perms, start=1):
        ordered_strides = [padded[i][1] for i in perm]
        vis = np.concatenate(ordered_strides, axis=0)
        vis /= np.max(vis)
        out_path = save_augmented_vis(patient_id, idx, vis)
        output_paths.append(out_path)

    print(f"TALOS{patient_id}: generated {len(output_paths)} permuted VIS images.")
    return output_paths


# 4th Module
# ---------------------------------------------------------------
# Helper functions
# ---------------------------------------------------------------
def shuffle_frames(frames, seed):
    """Shuffle frame order within a stride using a reproducible prime seed."""
    random.seed(seed)
    indices = list(range(len(frames)))
    random.shuffle(indices)
    return [frames[i] for i in indices]

def save_augmented_vis_prime(patient_id, perm_idx, prime, vis_image):
    """Save augmented VIS as PNG with prime-number suffix inside its prime folder."""
    vis_uint8 = (vis_image * 255).astype(np.uint8)
    prime_dir = os.path.join(OUTPUT_DIR_MAIN, f"p{prime}")
    out_path = os.path.join(prime_dir, f"TALOS{patient_id}_perm_{perm_idx:02d}_p{prime}.png")
    Image.fromarray(vis_uint8).save(out_path)
    return out_path

def create_permuted_vis_with_frame_shuffling(patient_id):
    """Generate all 4! permutations for each prime seed and save outputs."""
    patient_dir = os.path.join(INPUT_ROOT, f"TALOS{patient_id}")
    all_paths = []

    for prime in PRIME_SEEDS:
        all_strides = []

        # Load and shuffle each stride internally using the current prime seed
        for pos in range(1, NUM_POSITIONS + 1):
            pos_path = os.path.join(patient_dir, f"pos_{pos}")
            frames = load_frames_from_position(pos_path)
            if frames is None:
                print(f"TALOS{patient_id}: missing frames for pos {pos} (prime={prime})")
                continue

            shuffled_frames = shuffle_frames(frames, prime)
            stride = concatenate_frames_horizontally(shuffled_frames)
            all_strides.append((pos, stride))

        if len(all_strides) < NUM_POSITIONS:
            print(f"TALOS{patient_id}: insufficient strides ({len(all_strides)}/4) for prime={prime}. Skipping.")
            continue

        # Equalize widths
        max_width = max(s[1].shape[1] for s in all_strides)
        padded = [(p, zero_pad_to_width(s, max_width)) for p, s in all_strides]

        # Generate all 4! stride permutations
        perms = list(itertools.permutations(range(NUM_POSITIONS)))

        for idx, perm in enumerate(perms, start=1):
            ordered_strides = [padded[i][1] for i in perm]
            vis = np.concatenate(ordered_strides, axis=0)
            vis /= np.max(vis)
            out_path = save_augmented_vis_prime(patient_id, idx, prime, vis)
            all_paths.append(out_path)

        print(f"TALOS{patient_id}: generated {len(perms)} VIS (prime={prime})")

    print(f"TALOS{patient_id}: total {len(all_paths)} VIS across {len(PRIME_SEEDS)} prime seeds.")
    return all_paths

In [4]:
# 1st Module
# ================================================================
# ZACH-ViT: Preprocessing and ROI Extraction
# ================================================================
# This script extracts pleural-line ROIs from TALOS ultrasound
# DICOM recordings, applies 50% vertical reduction, and saves
# standardized grayscale frames for VIS generation.
# ================================================================
# CONFIGURATION
TALOS_PATH = "../Data/TALOS" #INPUT should exist; it's the dir where all patients' initial DICOM videos are; a patient's dir is named as "/TALOSi", where i is a positive integer
OUTPUT_DIR_ROI = "../Data/Processed_ROI" #INPUT; it is created; for each patient it has a subdir named as the patients id, with four subdirs (each one contains the frames as images per transducer position: {"/pos_1", "/pos_2", "/pos_3", "/pos_4"})
PATIENT_RANGE = [100,122]  # Example patient ID numbers; INPUT
NUM_POSITIONS = 4  # Probe positions per patient; STATIC
print(f"Preprocessing is starting for {len(PATIENT_RANGE)} patients.\n")
# ----------------------------------------------------------------
# MAIN
print(f"1st Module is starting.")
ts1 = time.time()
for pid in PATIENT_RANGE:
    for pos in range(1, NUM_POSITIONS + 1):
        try:
            process_dicom(pid, pos)
        except Exception as e:
            print(f"⚠️ Skipping patient {pid}, position {pos}: {e}")
te1 = time.time()
print(f"Time for the 1st Module 'ROI Extraction': {round(te1-ts1,2)} secs.\n")
# ----------------------------------------------------------------            
# delay added
time.sleep(5)  # Delay for 5 seconds; ----------------------------

# 2nd Module
# ================================================================
# ZACH-ViT: VIS Construction
# ================================================================
# This notebook constructs Video Image Sequence (VIS) representations
# for each patient by concatenating position-wise frame sequences.
# Each probe position produces a horizontal stride (frames concatenated horizontally),
# and the 4 strides are stacked vertically to form the VIS image.
# ================================================================
# CONFIGURATION
#INPUT_ROOT = "../Data/Processed_ROI" #INPUT; it is created from the 1st module
INPUT_ROOT = OUTPUT_DIR_ROI
OUTPUT_DIR_VIS = "../Data/VIS" #INPUT; it is created; it contains the VIS images for each patient;
os.makedirs(OUTPUT_DIR_VIS, exist_ok=True)
# MAIN
print(f"2nd Module is starting.")
ts2 = time.time()
PATIENT_RANGE = PATIENT_RANGE #same input as in the 1st module
for pid in PATIENT_RANGE:
    construct_vis_for_patient(pid)
te2 = time.time()
print(f"Time for the 2nd Module 'VIS': {round(te2-ts2,2)} secs.\n")
# ----------------------------------------------------------------
# delay added
time.sleep(5)  # Delay for 5 seconds; ----------------------------

# 3rd Module
# ================================================================
# ZACH-ViT: ShuffleStrides Data Augmentation (0-SSDA)
# ================================================================
# For each patient, this notebook generates all 4! = 24 permutations
# of the 4 positional strides, creating augmented VIS images.
# Each permutation preserves internal anatomical coherence within a stride
# while altering the probe-order relationships.
# ================================================================
# CONFIGURATION
INPUT_ROOT = OUTPUT_DIR_ROI
OUTPUT_DIR_0_SSDA = "../Data/0_SSDA" #INPUT; it is created; it contains the 24 0_SSDA images for each patient
os.makedirs(OUTPUT_DIR_0_SSDA, exist_ok=True)
# MAIN
print(f"3rd Module is starting.")
ts3 = time.time()
for pid in PATIENT_RANGE:
    create_permuted_vis(pid)
te3 = time.time()
print(f"Time for the 3rd Module '0-SSDA': {round(te3-ts3,2)} secs.\n")
# ----------------------------------------------------------------
# delay added
time.sleep(5)  # Delay for 5 seconds; ----------------------------

# 4th Module
# ================================================================
# ZACH-ViT: Semi-supervised ShuffleStrides Data Augmentation (SSDA_p)
# ================================================================
# For each patient and each prime number in PRIME_SEEDS:
#  - Shuffle the temporal order of frames within each positional stride
#    using the prime as a random seed.
#  - Generate all 4! stride permutations (rows)
#  - Save each augmented VIS image inside a subfolder named after the prime
# ================================================================
# CONFIGURATION
INPUT_ROOT = OUTPUT_DIR_ROI
# Prime numbers used as reproducible seeds for intra-stride shuffling
PRIME_SEEDS = [2, 3] #INPUT; list of different prime numbers of random_seeds
# Create main output directory (e.g., "../Data/2_3_SSDA/")
OUTPUT_DIR_MAIN = os.path.join("../Data", "_".join(str(p) for p in PRIME_SEEDS) + "_SSDA")
os.makedirs(OUTPUT_DIR_MAIN, exist_ok=True)
# MAIN
print(f"4th Module is starting.")
ts4 = time.time()
# Create one subdir per prime number
for prime in PRIME_SEEDS:
    os.makedirs(os.path.join(OUTPUT_DIR_MAIN, f"p{prime}"), exist_ok=True)
for pid in PATIENT_RANGE:
    create_permuted_vis_with_frame_shuffling(pid)
te4 = time.time()
print(f"Time for the 4th Module 'SSDA_p': {round(te4-ts4,2)} secs.\n")
print(f"Total time for the preprocessing: {round(te4-ts1,2)} secs.\n")
# ----------------------------------------------------------------

Preprocessing is starting for 2 patients.

1st Module is starting.
Time for the 1st Module 'ROI Extraction': 14.84 secs.

2nd Module is starting.
Saved VIS for TALOS100: shape=(668, 13200), file=../Data/VIS/TALOS100_VIS.png
Saved VIS for TALOS122: shape=(668, 8690), file=../Data/VIS/TALOS122_VIS.png
Time for the 2nd Module 'VIS': 1.59 secs.

3rd Module is starting.
TALOS100: generated 24 permuted VIS images.
TALOS122: generated 24 permuted VIS images.
Time for the 3rd Module '0-SSDA': 29.09 secs.

4th Module is starting.
TALOS100: generated 24 VIS (prime=2)
TALOS100: generated 24 VIS (prime=3)
TALOS100: total 48 VIS across 2 prime seeds.
TALOS122: generated 24 VIS (prime=2)
TALOS122: generated 24 VIS (prime=3)
TALOS122: total 48 VIS across 2 prime seeds.
Time for the 4th Module 'SSDA_p': 56.05 secs.

Total time for the preprocessing: 116.59 secs.

