In [2]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import SimpleITK as sitk
from ipywidgets import interact, IntSlider, Dropdown, Checkbox, VBox, Output, Layout
from glob import glob

# --- Helper Functions ---
def load_mhd_image(mhd_path):
    image = sitk.ReadImage(mhd_path)
    array = sitk.GetArrayFromImage(image)  # shape: [z, y, x]
    spacing = np.array(image.GetSpacing())[::-1]  # [z, y, x]
    origin = np.array(image.GetOrigin())[::-1]
    return array, spacing, origin

def world_to_voxel(world, origin, spacing):
    stretched = np.abs(np.array(world) - np.array(origin))
    voxel = stretched / spacing
    return voxel.astype(int)

# --- Viewer Function ---
def build_scroll_mhd_with_dropdown(annotations_df, image_dir, label_dir):
    # Find available UIDs based on existing YOLO label PNGs
    positive_files = glob(os.path.join(image_dir, 'positives', '*.png'))
    uids_present = set(os.path.basename(f).split('_')[0] for f in positive_files)

    # Only keep UIDs that are in annotations
    all_uids = sorted(annotations_df['seriesuid'].unique())
    valid_uids = [uid for uid in all_uids if uid in uids_present]

    # Setup dropdown
    dropdown = Dropdown(
    options=valid_uids,
    description='Series UID:',
    value=valid_uids[0],
    layout=Layout(width='600px')
)
    out = Output()

    def update_viewer(change):
        with out:
            out.clear_output()
            uid = dropdown.value
            mhd_path = glob(f"input_data/*/{uid}.mhd")[0]
            scroll_mhd_with_overlays(mhd_path, annotations_df, label_dir)

    dropdown.observe(update_viewer, names='value')

    # Initial viewer
    update_viewer(None)

    return VBox([dropdown, out])

# --- Main Viewer with Overlays ---
def scroll_mhd_with_overlays(
    mhd_path,
    annotations_df,
    yolo_label_dir,
):
    volume, spacing, origin = load_mhd_image(mhd_path)
    uid = os.path.splitext(os.path.basename(mhd_path))[0]
    height, width = volume.shape[1:]

    # Load annotation-based nodule centers
    scan_df = annotations_df[annotations_df["seriesuid"] == uid]
    nodules = []
    nodule_info = []
    first_nodule_z = 0  # fallback

    for i, row in scan_df.iterrows():
        world = [row["coordZ"], row["coordY"], row["coordX"]]
        diameter = row["diameter_mm"]
        voxel = world_to_voxel(world, origin, spacing)
        vz, vy, vx = voxel
        nodules.append(((vz, vy, vx), diameter))
        nodule_info.append((i, int(round(vz)), round(vy), round(vx), round(diameter, 1)))
        if i == 0:
            first_nodule_z = int(round(vz))
            first_nodule_z = np.clip(first_nodule_z, 0, volume.shape[0] - 1)

    # Display list of nodules
    print(f"\nFound {len(nodules)} nodule(s) in scan: {uid}")
    print("Index | Slice (z) | y | x | Diameter (mm)")
    print("-" * 40)
    for idx, z, y, x, d in nodule_info:
        print(f"{idx:^5} | {z:^9} | {y:^3} | {x:^3} | {d:^12}")
    print(f"-----------------------------------------------------\n")

    def view(z, show_nodules, show_yolo_boxes):
        img = volume[z]
        img = np.clip(img, -1000, 400)
        img = ((img + 1000) / 1400 * 255).astype(np.uint8)

        fig, ax = plt.subplots()
        ax.imshow(img, cmap='gray')
        ax.set_title(f"Slice {z}")
        ax.axis('off')

        if show_nodules:
            for (vz, vy, vx), diameter in nodules:
                if int(round(vz)) == z:
                    radius_px = (diameter / spacing[1]) / 2
                    ax.add_patch(plt.Circle((vx, vy), radius_px, color='red', fill=False, lw=1, label='Annotations ROI'))

        if show_yolo_boxes:
            yolo_file = os.path.join(yolo_label_dir, f"{uid}_{z}.txt")
            if os.path.exists(yolo_file):
                with open(yolo_file, 'r') as f:
                    for line in f:
                        cls, x_center, y_center, bw, bh = map(float, line.strip().split())
                        x1 = (x_center - bw / 2) * width
                        y1 = (y_center - bh / 2) * height
                        ax.add_patch(plt.Rectangle(
                            (x1, y1), bw * width, bh * height,
                            edgecolor='blue', facecolor='none', lw=1, label='Preprocessed box'
                        ))

        handles, labels = ax.get_legend_handles_labels()
        by_label = dict(zip(labels, handles))
        if by_label:
            ax.legend(by_label.values(), by_label.keys())

        plt.show()

    interact(
        view,
        z=IntSlider(min=0, max=volume.shape[0] - 1, step=1, value=first_nodule_z),
        show_nodules=Checkbox(value=True, description='Show Annotations ROIs'),
        show_yolo_boxes=Checkbox(value=True, description='Show preprocessed box')
    )

df = pd.read_csv("annotations.csv")
viewer_ui = build_scroll_mhd_with_dropdown(df, "yolov8_dataset/images", "yolov8_dataset/labels/positives")
display(viewer_ui)


VBox(children=(Dropdown(description='Series UID:', layout=Layout(width='600px'), options=('1.3.6.1.4.1.14519.5…