This code aims to decode the XNAT Segmentation. The way XNAT store segmentations is to save it as a stacked version (all segmentations in all slices will be saved in one single dicom file, stacked on top of each othter). 

Step 1: Septate segmentation
Step 2: Original Image
Step 2: Map segment to the original image


## Setup

### Variables, libarires, and functions

In [None]:
#%pip install pydicom pandas matplotlib numpy
#%pip install GDCM pylibjpeg pylibjpeg-libjpeg pylibjpeg-openjpeg  # required packages for showing the dicom data using matplotlib
#%pip install ipywidgets # for scrollbar of multuple stacked slices

# variables
patient_directory=r"Sample_Data\CT\1000029"

#libraires
import os
import xml.etree.ElementTree as ET
from typing import Optional, Dict, Any
import logging
import pydicom
# Configure logging
logging.basicConfig(level=logging.INFO)
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgba
import pandas as pd 
import numpy as np
import ipywidgets as widgets
from ipywidgets import interact, IntSlider
# functions
def read_dicom(dicom_path):
    """
    Reads a DICOM file from the given path.

    Parameters:
    - selected_SEGdicom_fullpath: The full path to the DICOM file.

    Returns:
    - The DICOM dataset if read successfully, otherwise None.
    """
    try:
        dicom_data = pydicom.dcmread(dicom_path)
        logging.info(f"Dicom data loaded from: {dicom_path}.")
        return dicom_data
    except Exception as e:
        logging.error(f"Error in reading '{dicom_path}': {e}")
        return None

# functions to use in future
def get_nested_element(dataset, tags):
    """
    Navigate through the DICOM dataset using a list of tags and return the final element.

    Parameters:
        dataset (pydicom.dataset.Dataset): The DICOM dataset.
        tags (list): A list of tuples representing the tags to navigate through.

    Returns:
        The final element in the DICOM dataset specified by the tags.
    """
    current_element = dataset
    for tag in tags:
        tag = pydicom.tag.Tag(tag)
        if tag in current_element:
            current_element = current_element[tag]
        else:
            raise KeyError(f"Tag {tag} not found in the dataset.")
        
        # If the current element is a sequence, assume we want the first item
        if isinstance(current_element, pydicom.sequence.Sequence):
            if len(current_element) > 0:
                current_element = current_element[0]
            else:
                raise ValueError(f"Sequence at tag {tag} is empty.")
    
    return current_element

### FUNC: slice CT Viz 

In [None]:
def show_slice_image(slice_path):
    """
    Reads and displays a DICOM image from the given path.

    Parameters:
    - slice_path: The full path to the DICOM file.
    
    Returns:
    - The DICOM dataset if read successfully, otherwise None.
    """
    slice_path_normalized = os.path.normpath(slice_path)
    
    if os.path.exists(slice_path_normalized):
        dicom_data = read_dicom(slice_path_normalized)
        
        if dicom_data:
            # Display the DICOM image
            plt.imshow(dicom_data.pixel_array, cmap=plt.cm.gray)
            plt.title(f"Slice: {os.path.basename(slice_path_normalized)}")
            plt.show()
            
            return dicom_data
        else:
            logging.error("Failed to read DICOM data.")
            return None
    else:
        logging.error(f"File does not exist: {slice_path_normalized}")
        return None

# Example usage
slice_path=r"C:\Users\LEGION\Documents\GIT\XNAT-OHIF_read-overlay-visualize-segmentation-file\Sample_Data\CT\1000029\SCANS\5\DICOM\2.1334547622669903440304287630751540644425-5-20-1l4193e.dcm"
dicom_slice_data = show_slice_image(slice_path)

### FUNC: series CT Viz 

In [None]:
def read_dicom_pixel_data_with_sop_instance_uid(directory_path):
    """
    Reads pixel data and referenced SOP Instance UID from all DICOM files in the given directory.

    Parameters:
    - directory_path: Path to the directory containing DICOM files.

    Returns:
    - List of dictionaries with 'pixel_data', 'sop_instance_uid', and 'instance_number'.
    """
    dicom_files = [os.path.join(directory_path, f) for f in os.listdir(directory_path) if f.endswith('.dcm')]
    dicom_data_list = []
    
    for dicom_file in dicom_files:
        try:
            dicom_data = pydicom.dcmread(dicom_file)
            sop_instance_uid = dicom_data.SOPInstanceUID
            pixel_data = dicom_data.pixel_array
            instance_number = dicom_data.InstanceNumber if hasattr(dicom_data, 'InstanceNumber') else 0
            
            dicom_data_list.append({
                'pixel_data': pixel_data,
                'sop_instance_uid': sop_instance_uid,
                'instance_number': instance_number
            })
        except Exception as e:
            logging.error(f"Error reading DICOM file '{dicom_file}': {e}")
    
    # Sort by InstanceNumber
    dicom_data_list.sort(key=lambda x: x['instance_number'])
    
    return dicom_data_list

def show_dicom_stack(directory_path):
    """
    Displays a stack of DICOM images from the given directory with a scrollbar.

    Parameters:
    - directory_path: Path to the directory containing DICOM files.
    """
    dicom_data_list = read_dicom_pixel_data_with_sop_instance_uid(directory_path)
    if not dicom_data_list:
        logging.error("No DICOM files found or failed to read any DICOM files.")
        return
    
    def plot_slice(index):
        plt.imshow(dicom_data_list[index]['pixel_data'], cmap=plt.cm.gray)
        plt.title(f"Slice {index + 1}/{len(dicom_data_list)}\nSOP Instance UID: {dicom_data_list[index]['sop_instance_uid']}")
        plt.axis('off')
        plt.show()

    interact(plot_slice, index=IntSlider(min=0, max=len(dicom_data_list) - 1, step=1, value=0))

# Example usage
directory_path = r"C:\Users\LEGION\Documents\GIT\XNAT-OHIF_read-overlay-visualize-segmentation-file\Sample_Data\CT\1000029\SCANS\5\DICOM"
show_dicom_stack(directory_path)

### FUNC: on-permise  CT and one segmentaiton Viz
This will merge and show the image on scroll bar change (on-permise). If you have good processor this would be better, as you won't have delay when you change the scrollbar. Otherwise, use the next function, which will create the images before visualization (longer time for vizulation since it will generate images before-perims).

In [None]:
def merge_dictionaries(original_dic, segment_dic):
    """
    Merges two dictionaries based on the SOP Instance UID.

    Parameters:
    - original_dic: List of dictionaries with 'pixel_data', 'sop_instance_uid', and 'instance_number'.
    - segment_dic: List of dictionaries with 'pixel_data', 'sop_instance_uid', and 'instance_number'.

    Returns:
    - Merged list of dictionaries.
    """
    merged_data = []
    segment_lookup = {item['sop_instance_uid']: item['pixel_data'] for item in segment_dic}
    
    for original_item in original_dic:
        sop_instance_uid = original_item['sop_instance_uid']
        segmentation_pixel = segment_lookup.get(sop_instance_uid, None)
        
        merged_data.append({
            'sop_instance_uid': sop_instance_uid,
            'original_pixel': original_item['pixel_data'],
            'segmentation_pixel': segmentation_pixel
        })
    
    return merged_data

def prepare_images(merged_data, segment_overlay_transparency, segment_overlay_color):
    """
    Prepares images with overlays from merged data.

    Parameters:
    - merged_data: List of dictionaries with 'original_pixel', 'segmentation_pixel', and 'sop_instance_uid'.
    - segment_overlay_transparency: Transparency level for the segmentation overlay.
    - segment_overlay_color: Hex color code for the segmentation overlay.

    Returns:
    - List of prepared images.
    """
    prepared_images = []
    overlay_rgba = to_rgba(segment_overlay_color, alpha=segment_overlay_transparency)
    
    for data in merged_data:
        original_pixel = data['original_pixel']
        segmentation_pixel = data['segmentation_pixel']
        
        fig, ax = plt.subplots()
        ax.imshow(original_pixel, cmap=plt.cm.gray)
        
        if segmentation_pixel is not None:
            # Create an overlay with the specified color and transparency
            overlay = np.zeros((*segmentation_pixel.shape, 4))
            overlay[..., :3] = overlay_rgba[:3]
            overlay[..., 3] = (segmentation_pixel > 0) * segment_overlay_transparency
            
            ax.imshow(overlay, cmap=None, alpha=segment_overlay_transparency)
        
        ax.set_title(f"SOP Instance UID: {data['sop_instance_uid']}")
        ax.axis('off')
        fig.canvas.draw()
        
        # Convert the canvas to an image
        image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
        image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        prepared_images.append(image)
        
        plt.close(fig)
    
    return prepared_images

def show_merged_dicom_stack(original_dic, segment_dic, segment_overlay_transparency=0.4, segment_overlay_color='#FFA500',label_for_overlay_image=""):
    """
    Displays a stack of merged DICOM images with overlay in an interactive scrollable format.

    Parameters:
    - original_dic: List of dictionaries with 'pixel_data', 'sop_instance_uid', and 'instance_number'.
    - segment_dic: List of dictionaries with 'pixel_data', 'sop_instance_uid', and 'instance_number'.
    - segment_overlay_transparency: Transparency level for the segmentation overlay.
    - segment_overlay_color: Hex color code for the segmentation overlay.
    """
    merged_data = merge_dictionaries(original_dic, segment_dic)
    overlay_rgba = to_rgba(segment_overlay_color, alpha=segment_overlay_transparency)
    
    def plot_slice(index):
        original_pixel = merged_data[index]['original_pixel']
        segmentation_pixel = merged_data[index]['segmentation_pixel']
        
        plt.imshow(original_pixel, cmap=plt.cm.gray)
        
        if segmentation_pixel is not None:
            # Create an overlay with the specified color and transparency
            overlay = np.zeros((*segmentation_pixel.shape, 4))
            overlay[..., :3] = overlay_rgba[:3]
            overlay[..., 3] = (segmentation_pixel > 0) * segment_overlay_transparency
            
            plt.imshow(overlay, cmap=None, alpha=segment_overlay_transparency)
        
        plt.title(f"Slice {index + 1}/{len(merged_data)}\nSOP Instance UID: {merged_data[index]['sop_instance_uid']}\n{label_for_overlay_image}")
        plt.axis('off')
        plt.show()

    interact(plot_slice, index=IntSlider(min=0, max=len(merged_data) - 1, step=1, value=0))


original_directory_path = r"C:\Users\LEGION\Documents\GIT\XNAT-OHIF_read-overlay-visualize-segmentation-file\Sample_Data\CT\1000029\SCANS\5\DICOM"
segment_directory_path = r"C:\Users\LEGION\Documents\GIT\XNAT-OHIF_read-overlay-visualize-segmentation-file\Sample_Data\CT\1000029\SCANS\5\DICOM"

original_dic = read_dicom_pixel_data_with_sop_instance_uid(original_directory_path)
segment_dic = read_dicom_pixel_data_with_sop_instance_uid(segment_directory_path)

show_merged_dicom_stack(original_dic, segment_dic, segment_overlay_transparency=0.4, segment_overlay_color='#FFA500', label_for_overlay_image="desired label for seg")


In [None]:
def merge_dictionaries(original_dic, segment_dic):
    """
    Merges two dictionaries based on the SOP Instance UID.

    Parameters:
    - original_dic: List of dictionaries with 'pixel_data', 'sop_instance_uid', and 'instance_number'.
    - segment_dic: List of dictionaries with 'pixel_data', 'sop_instance_uid', and 'instance_number'.

    Returns:
    - Merged list of dictionaries.
    """
    merged_data = []
    segment_lookup = {item['sop_instance_uid']: item['pixel_data'] for item in segment_dic}
    
    for original_item in original_dic:
        sop_instance_uid = original_item['sop_instance_uid']
        segmentation_pixel = segment_lookup.get(sop_instance_uid, None)
        
        merged_data.append({
            'sop_instance_uid': sop_instance_uid,
            'original_pixel': original_item['pixel_data'],
            'segmentation_pixel': segmentation_pixel
        })
    
    return merged_data

def prepare_images(merged_data, segment_overlay_transparency, segment_overlay_color, figsize=(10, 10), label_for_overlay_image=""):
    """
    Prepares images with overlays from merged data.

    Parameters:
    - merged_data: List of dictionaries with 'original_pixel', 'segmentation_pixel', and 'sop_instance_uid'.
    - segment_overlay_transparency: Transparency level for the segmentation overlay.
    - segment_overlay_color: Hex color code for the segmentation overlay.
    - figsize: Tuple representing the figure size (width, height).

    Returns:
    - List of prepared images.
    """
    prepared_images = []
    overlay_rgba = to_rgba(segment_overlay_color, alpha=segment_overlay_transparency)
    
    for data in merged_data:
        original_pixel = data['original_pixel']
        segmentation_pixel = data['segmentation_pixel']
        
        fig, ax = plt.subplots(figsize=figsize)
        ax.imshow(original_pixel, cmap=plt.cm.gray)
        
        if segmentation_pixel is not None:
            # Create an overlay with the specified color and transparency
            overlay = np.zeros((*segmentation_pixel.shape, 4))
            overlay[..., :3] = overlay_rgba[:3]
            overlay[..., 3] = (segmentation_pixel > 0) * segment_overlay_transparency
            
            ax.imshow(overlay, cmap=None, alpha=segment_overlay_transparency)
        
        ax.set_title(f"SOP Instance UID: {data['sop_instance_uid']}\n{label_for_overlay_image}")
        ax.axis('off')
        fig.canvas.draw()
        
        # Convert the canvas to an image
        image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
        image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        prepared_images.append(image)
        
        plt.close(fig)
    
    return prepared_images

def show_merged_dicom_stack_fast(original_dic, segment_dic, segment_overlay_transparency=0.4, segment_overlay_color='#FFA500', figsize=(10, 10),label_for_overlay_image=""):
    """
    Displays a stack of merged DICOM images with overlay in an interactive scrollable format.

    Parameters:
    - original_dic: List of dictionaries with 'pixel_data', 'sop_instance_uid', and 'instance_number'.
    - segment_dic: List of dictionaries with 'pixel_data', 'sop_instance_uid', and 'instance_number'.
    - segment_overlay_transparency: Transparency level for the segmentation overlay.
    - segment_overlay_color: Hex color code for the segmentation overlay.
    """
    merged_data = merge_dictionaries(original_dic, segment_dic)
    prepared_images = prepare_images(merged_data, segment_overlay_transparency, segment_overlay_color, figsize=figsize)
    
    def plot_slice(index):
        plt.title(f"Slice {index + 1}/{len(prepared_images)}\n{label_for_overlay_image}")
        plt.imshow(prepared_images[index])
        
        plt.axis('off')
        plt.show()

    interact(plot_slice, index=IntSlider(min=0, max=len(prepared_images) - 1, step=1, value=0))


original_directory_path = r"C:\Users\LEGION\Documents\GIT\XNAT-OHIF_read-overlay-visualize-segmentation-file\Sample_Data\CT\1000029\SCANS\5\DICOM"
segment_directory_path = r"C:\Users\LEGION\Documents\GIT\XNAT-OHIF_read-overlay-visualize-segmentation-file\Sample_Data\CT\1000029\SCANS\5\DICOM"

original_dic = read_dicom_pixel_data_with_sop_instance_uid(original_directory_path)
segment_dic = read_dicom_pixel_data_with_sop_instance_uid(segment_directory_path)

show_merged_dicom_stack_fast(original_dic, segment_dic, segment_overlay_transparency=0.4, segment_overlay_color='#FFA500', figsize=(20,20), label_for_overlay_image="desired label for seg")


### FUNC: Multiple Segmentations 

In [None]:
def MULTI_merge_dictionaries(original_dic, segments_list):
    """
    Merges original dictionary with multiple segmentation dictionaries based on the SOP Instance UID.

    Parameters:
    - original_dic: List of dictionaries with 'pixel_data', 'sop_instance_uid', and 'instance_number'.
    - segments_list: List of tuples where each tuple contains a segmentation dictionary and its metadata
      (segment_dic, segment_overlay_color, label_for_overlay_image).

    Returns:
    - Merged list of dictionaries with multiple segmentation overlays.
    """
    merged_data = []
    segments_lookup = [
        {
            'segment_lookup': {item['sop_instance_uid']: item['pixel_data'] for item in segment_dic},
            'segment_overlay_color': segment_overlay_color,
            'label_for_overlay_image': label_for_overlay_image
        }
        for segment_dic, segment_overlay_color, label_for_overlay_image in segments_list
    ]
    
    for original_item in original_dic:
        sop_instance_uid = original_item['sop_instance_uid']
        segmentations = []
        
        for segment in segments_lookup:
            segmentation_pixel = segment['segment_lookup'].get(sop_instance_uid, None)
            segmentations.append({
                'segmentation_pixel': segmentation_pixel,
                'segment_overlay_color': segment['segment_overlay_color'],
                'label_for_overlay_image': segment['label_for_overlay_image']
            })
        
        merged_data.append({
            'sop_instance_uid': sop_instance_uid,
            'original_pixel': original_item['pixel_data'],
            'segmentations': segmentations
        })
    
    return merged_data

def MULTI_prepare_images(merged_data, segment_overlay_transparency, figsize=(10, 10)):
    """
    Prepares images with overlays from merged data.

    Parameters:
    - merged_data: List of dictionaries with 'original_pixel', 'segmentations', and 'sop_instance_uid'.
    - segment_overlay_transparency: Transparency level for the segmentation overlay.
    - figsize: Tuple representing the figure size (width, height).

    Returns:
    - List of prepared images.
    """
    prepared_images = []

    for data in merged_data:
        original_pixel = data['original_pixel']
        segmentations = data['segmentations']
        
        fig, ax = plt.subplots(figsize=figsize)
        ax.imshow(original_pixel, cmap=plt.cm.gray)
        
        for segmentation in segmentations:
            segmentation_pixel = segmentation['segmentation_pixel']
            segment_overlay_color = segmentation['segment_overlay_color']
            label_for_overlay_image = segmentation['label_for_overlay_image']
            
            if segmentation_pixel is not None:
                overlay_rgba = to_rgba(segment_overlay_color, alpha=segment_overlay_transparency)
                overlay = np.zeros((*segmentation_pixel.shape, 4))
                overlay[..., :3] = overlay_rgba[:3]
                overlay[..., 3] = (segmentation_pixel > 0) * segment_overlay_transparency
                
                ax.imshow(overlay, cmap=None, alpha=segment_overlay_transparency)
        
        ax.set_title(f"SOP Instance UID: {data['sop_instance_uid']}\n{', '.join([seg['label_for_overlay_image'] for seg in segmentations])}")
        ax.axis('off')
        fig.canvas.draw()
        
        # Convert the canvas to an image
        image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
        image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        prepared_images.append(image)
        
        plt.close(fig)
    
    return prepared_images

def MULTI_show_merged_dicom_stack_fast(prepared_images):
    """
    Displays a stack of merged DICOM images with overlay in an interactive scrollable format.

    Parameters:
    - original_dic: List of dictionaries with 'pixel_data', 'sop_instance_uid', and 'instance_number'.
    - segments_list: List of tuples where each tuple contains a segmentation dictionary and its metadata
      (segment_dic, segment_overlay_color, label_for_overlay_image).
    - segment_overlay_transparency: Transparency level for the segmentation overlay.
    """
    
    def plot_slice(index):
        plt.title(f"Slice {index + 1}/{len(prepared_images)}\n{'\n'.join([seg['label_for_overlay_image'] for seg in merged_data[index]['segmentations']])}")
        plt.imshow(prepared_images[index])
        
        plt.axis('off')
        plt.show()

    interact(plot_slice, index=IntSlider(min=0, max=len(prepared_images) - 1, step=1, value=0))

#assuming we have segments_list and segment_overlay_transparency and figsize
merged_data = MULTI_merge_dictionaries(original_dic, segments_list)
prepared_images = MULTI_prepare_images(merged_data, segment_overlay_transparency, figsize=figsize)
MULTI_show_merged_dicom_stack_fast(prepared_images)


# Step 0:   Select Segmentation

## Find all segmentation in folder 

In [None]:
def assessors_path(patient_directory: str) -> Optional[str]:
    """
    This function takes the path to a patient directory and returns the path of the ASSESSORS directory if it exists.
    
    Parameters:
    - patient_directory (str): The directory path of the patient.
    
    Returns:
    - Optional[str]: The path to the ASSESSORS directory if it exists, otherwise None.
    """
    
    assessors_dir = os.path.join(patient_directory, 'ASSESSORS')
    if not os.path.exists(assessors_dir):
        logging.info("No segmentation exists")
        return None
    logging.info(f"'ASSESSORS' folder exists: {assessors_dir}")
    return assessors_dir


def get_segmentations_from_assessors_path(assessors_path: str) -> Dict[str, Dict[str, Any]]:
    """
    This function searches the ASSESSORS folder and creates a dictionary of all segmentations.
    
    Parameters:
    - assessors_path (str): The path to the ASSESSORS directory.
    
    Returns:
    - Dict[str, Dict[str, Any]]: A dictionary with the name of the segmentor and datetime of segmentation as keys. 
      Each key maps to another dictionary containing 'created_by', 'created_time', 'dicom_name', and 'dicom_fullpath'.
    """
    segmentation_paths = [d for d in os.listdir(assessors_path) if os.path.isdir(os.path.join(assessors_path, d))]
    segmentations = {}

    for seg in segmentation_paths:
        seg_dir = os.path.join(assessors_path, seg, 'SEG')
        if os.path.exists(seg_dir):
            files = os.listdir(seg_dir)
            
            # Find XML files
            xml_files = [f for f in files if f.endswith('.xml')]
            
            for xml_file in xml_files:
                xml_path = os.path.join(seg_dir, xml_file)
                try:
                    # Parse the XML file
                    tree = ET.parse(xml_path)
                    root = tree.getroot()
                    
                    # Extract createdBy and createdTime
                    entry = root.find('.//cat:entry', namespaces={'cat': 'http://nrg.wustl.edu/catalog'})
                    if entry is not None:
                        created_by = entry.get('createdBy')
                        created_time = entry.get('createdTime')
                        dicom_name = entry.get('name')
                        dicom_fullpath = os.path.join(seg_dir, dicom_name)
                        
                        # Save to dictionary
                        key = f"{created_by}>>{created_time}"
                        segmentations[key] = {
                            'created_by': created_by,
                            'created_time': created_time,
                            'dicom_name': dicom_name,
                            'dicom_fullpath': dicom_fullpath
                        }
                        logging.info(segmentations[key])
                except ET.ParseError as e:
                    logging.error(f"Error parsing XML file {xml_file}: {e}")
                except Exception as e:
                    logging.error(f"Unexpected error: {e}")

    return segmentations

ass_path = assessors_path(patient_directory)                    
all_segmentation_dic = get_segmentations_from_assessors_path(ass_path)


## Select the desired segmentation

In [None]:
def select_segmentation_from_valid(all_segmentation_dic: Dict[str, Dict[str, Any]], default_selected_segmentation: str = None) -> Optional[str]:
    """
    This function allows the user to select a segmentation from the dictionary of segmentations.
    
    Parameters:
    - all_segmentation_dic (Dict[str, Dict[str, Any]]): Dictionary containing segmentations.
    - default_selected_segmentation (str, optional): Default selected segmentation. Defaults to None.
    
    Returns:
    - Optional[str]: The full path to the selected DICOM file, or None if not found.
    """
    valid_options_to_select = list(all_segmentation_dic.keys())
    len_options_to_select = len(valid_options_to_select)
    
    if len_options_to_select == 0:
        logging.warning("No valid segmentations available.")
        return None
    
    if not default_selected_segmentation:
        show_options_to_select = "\n".join([f"{i}: {valid_options_to_select[i]}" for i in range(len_options_to_select)])
        print(f"Available segmentations:\n{show_options_to_select}")
        selected_segmentation = input("Enter the index or name of your desired segmentation to visualize: ")
    else:
        selected_segmentation = default_selected_segmentation
    
    try:
        selected_segmentation_index = int(selected_segmentation)
        if selected_segmentation_index < 0 or selected_segmentation_index >= len_options_to_select:
            raise IndexError
        selected_segmentation_name = valid_options_to_select[selected_segmentation_index]
    except (ValueError, IndexError):
        selected_segmentation_name = selected_segmentation
        
    segmentation_info = all_segmentation_dic.get(selected_segmentation_name)
    if segmentation_info:
        logging.info(f"Selected SEG: {selected_segmentation_name}")
        logging.info(f"Path to selected SEG: {segmentation_info.get('dicom_fullpath')}")
        return segmentation_info.get('dicom_fullpath')
    else:
        logging.error(f"Selected segmentation '{selected_segmentation_name}' not found in the dictionary.")
        return None

# Example usage
selected_SEGdicom_fullpath = select_segmentation_from_valid(all_segmentation_dic,default_selected_segmentation=12)
selected_SEGdicom_data = read_dicom(selected_SEGdicom_fullpath)

In [None]:
def read_selected_segmentation_for_widgets(selected_segmentation_name, all_segmentation_dic):
    segmentation_info = all_segmentation_dic.get(selected_segmentation_name)
    if segmentation_info:
        logging.info(f"Selected SEG: {selected_segmentation_name}")
        logging.info(f"Path to selected SEG: {segmentation_info.get('dicom_fullpath')}")
        return segmentation_info.get('dicom_fullpath')
    else:
        logging.error(f"Selected segmentation '{selected_segmentation_name}' not found in the dictionary.")
        return None

# Step 1:   Segmentation file

## Extract data of segmentation from dicom

#### Reference Series UID

In [None]:
def get_referenced_series_UID(segmentation_dicom) -> str:
    """
    Extract a list of Referenced SOP Instance UIDs from a DICOM segmentation dataset.

    Parameters:
        segmentation_dicom (pydicom.dataset.Dataset): The DICOM dataset containing segmentation data.

    Returns:
        list: A list of Referenced SOP Instance UIDs.
    """
    try:
        referenced_series_sequence = get_nested_element(segmentation_dicom, [(0x0008, 0x1115)])
        for series_instance in referenced_series_sequence:
            if 'SeriesInstanceUID' in series_instance:
                Ref_series_UID = series_instance.SeriesInstanceUID
                logging.info(f"Referenced Series Instance UID: {Ref_series_UID}")
                return Ref_series_UID
    except (KeyError, ValueError) as e:
        logging.error(f"Error extracting Referenced Series Instance UID: {e}")

Ref_series_UID = get_referenced_series_UID(selected_SEGdicom_data)

#### Segmentation number-name dict map

In [None]:

def create_segment_number_to_label_map(segmentation_dicom) -> dict:
    """
    Create a dictionary mapping Segment Number to Segment Label from a DICOM segmentation object.

    Parameters:
        segmentation_dicom (pydicom.dataset.Dataset): The DICOM dataset containing segmentation data.

    Returns:
        dict: A dictionary mapping Segment Number (int) to Segment Label (str).
    """
    segment_map = {}
    try:
        segment_sequence = get_nested_element(segmentation_dicom, [(0x0062, 0x0002)])
        for item in segment_sequence:
            if (0x0062, 0x0004) in item and (0x0062, 0x0005) in item:
                segment_number = item[(0x0062, 0x0004)].value
                segment_label = item[(0x0062, 0x0005)].value
                segment_map[segment_number] = segment_label
        logging.info(f"Segment map: {segment_map}")
    except (KeyError, ValueError) as e:
        logging.error(f"Error creating segment map: {e}")
    return segment_map

segment_map = create_segment_number_to_label_map(selected_SEGdicom_data)

#### slice data: RefSOPUID + Segment number (and label) + pixel data

In [None]:
def get_segmentation_data_including_RefSOPUID_refSegNum_pixelData(segmentation_dicom, segment_map):
    """
    Reads a segmentation DICOM file and creates a dictionary for each slice.

    Parameters:
    - segmentation_dicom: The DICOM dataset for the segmentation.
    - segment_map: Dictionary to map segment numbers to segment labels.

    Returns:
    - Dictionary with segment labels as keys and lists of dictionaries as values. Each dictionary contains 'pixel_data' and 'sop_instance_uid'.
    """
    segmentation_data = {}

    try:
        # Get the number of frames
        num_frames = segmentation_dicom.NumberOfFrames
        logging.info(f"Number of frames in the DICOM: {num_frames}")

        # Get pixel data
        pixel_data = segmentation_dicom.pixel_array

        # Get referenced instance UID and segmentation number for each frame
        for frame_index in range(num_frames):
            try:
                frame = segmentation_dicom.PerFrameFunctionalGroupsSequence[frame_index]
                referenced_sop_instance_uid = frame.DerivationImageSequence[0].SourceImageSequence[0].ReferencedSOPInstanceUID
                segment_number = frame.SegmentIdentificationSequence[0].ReferencedSegmentNumber
                
                if segment_number in segment_map:
                    segment_label = segment_map[segment_number]
                    slice_info = {
                        'pixel_data': pixel_data[frame_index],
                        'sop_instance_uid': referenced_sop_instance_uid
                    }
                    if segment_label not in segmentation_data:
                        segmentation_data[segment_label] = []
                    segmentation_data[segment_label].append(slice_info)
                else:
                    logging.warning(f"Segment number {segment_number} not found in segment_map.")
                    
            except Exception as e:
                logging.error(f"Error processing frame {frame_index}: {e}")

        # Log the count of slices for each segment label
        for segment_label, slices in segmentation_data.items():
            logging.info(f"Segment '{segment_label}': {len(slices)} slices")

    except AttributeError as e:
        logging.error(f"Error reading DICOM attributes: {e}")
    except Exception as e:
        logging.error(f"Unexpected error: {e}")

    return segmentation_data

slice_info_list = get_segmentation_data_including_RefSOPUID_refSegNum_pixelData(selected_SEGdicom_data,segment_map)
slice_info_list

## Summary: All Data 

In [None]:
Ref_series_UID = get_referenced_series_UID(selected_SEGdicom_data)
segment_map = create_segment_number_to_label_map(selected_SEGdicom_data)
slice_info_list = get_segmentation_data_including_RefSOPUID_refSegNum_pixelData(selected_SEGdicom_data,segment_map)

----

# Step 2:   Original CT (via XML created by XNAT)

### Load XML for folders series

In [None]:
import os
import xml.etree.ElementTree as ET
import logging

def extract_series_data(patient_directory):
    """
    Extracts series data from XML files within the SCANS directory of the given patient directory.

    Parameters:
    - patient_directory: Path to the patient's directory containing the SCANS folder.

    Returns:
    - Dictionary containing series data.
    """
    scans_dir = os.path.join(patient_directory, 'SCANS')
    series_folders = [d for d in os.listdir(scans_dir) if os.path.isdir(os.path.join(scans_dir, d))]
    
    series_data = {}

    for series in series_folders:
        series_path = os.path.join(scans_dir, series, 'DICOM')
        xml_files = [f for f in os.listdir(series_path) if f.endswith('.xml')]

        if not xml_files:
            logging.warning(f"No XML files found in series directory: {series_path}")
            continue

        # Assuming there's only one XML file per series directory
        xml_file = xml_files[0]
        xml_path = os.path.join(series_path, xml_file)

        try:
            tree = ET.parse(xml_path)
            root = tree.getroot()

            # Extract original_RefSeriesUID
            ref_series_uid = root.attrib.get('UID', None)
            if not ref_series_uid:
                logging.warning(f"No UID found in XML root for series {series}")
                continue

            # Extract original_RefInstanceUID_dict
            ref_instance_uid_dict = {}
            for entry in root.findall('.//cat:entry', namespaces={'cat': 'http://nrg.wustl.edu/catalog'}):
                uid = entry.attrib.get('UID')
                uri = entry.attrib.get('URI')
                if uid and uri:
                    ref_instance_uid_dict[uid] = os.path.join(series_path, uri)

            series_data[series] = {
                'original_RefSeriesUID': ref_series_uid,
                'original_RefInstanceUID_dict': ref_instance_uid_dict
            }
        except ET.ParseError as e:
            logging.error(f"Error parsing XML file {xml_file}: {e}")
        except Exception as e:
            logging.error(f"Unexpected error processing series {series}: {e}")
            
    logging.info(f"Successfully read the XML file for {patient_directory}. Series and slice counts:")
    for key, value in series_data.items():
        original_RefInstanceUID_dict = value.get('original_RefInstanceUID_dict',{})
        logging.info(f"    {key}: {len(original_RefInstanceUID_dict)}")    
        
    return series_data

# Example usage
original_series_data = extract_series_data(patient_directory)

original_series_data

# Step 2: Original CT using native dicom

In [None]:
#complete in future
# loop through the series folder and get the series UID ('original_RefSeriesUID'). Add each slice file in 'original_RefInstanceUID_dict' as dictionary with instance UID as key and file path as value. 
# use them

# Step 3: Merge Segmentation and original and  

In [None]:
def get_path_to_matched_original_series(original_series_data,Ref_series_UID):
    for original_folder, original_info_dic in original_series_data.items():
        if original_info_dic['original_RefSeriesUID']==Ref_series_UID:
            logging.info(f"Successfully find a match for series UIDs in folder: '{original_folder}'")
            matched_original_series_data = read_dicom_pixel_data_with_sop_instance_uid(os.path.join(patient_directory,"SCANS", original_folder,"DICOM"))
    if matched_original_series_data:
        logging.info(" The matched series of original image created successfully.")
        return matched_original_series_data
    else:
        logging.warning("No matched series UID was ")

matched_series_original_data_dic = get_path_to_matched_original_series(original_series_data,Ref_series_UID)

# SHOW: one segmentation merged file

## Select segmentation to show

In [None]:

segmentations_to_chose = " ".join([f"<><><>Segment '{segment_label}': {len(slices)} slices" for segment_label, slices in slice_info_list.items()])
selected_segmentation_label= input(f"Select one of the segmentation labels to merge with original image: {segmentations_to_chose}")
selected_segmentation_data_dic = slice_info_list[selected_segmentation_label]
selected_segmentation_data_dic

In [None]:

show_merged_dicom_stack_fast(matched_series_original_data_dic, selected_segmentation_data_dic, segment_overlay_transparency=0.4, segment_overlay_color='#FFA500', figsize=(20,20), label_for_overlay_image=f"Segment Label:{str(selected_segmentation_label)}")


# SHOW: Multiple

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgba
import ipywidgets as widgets
from IPython.display import display
from ipywidgets import interact, IntSlider
import matplotlib.pyplot as plt

def MULTI_merge_dictionaries(original_dic, segments_list):
    """
    Merges original dictionary with multiple segmentation dictionaries based on the SOP Instance UID.

    Parameters:
    - original_dic: List of dictionaries with 'pixel_data', 'sop_instance_uid', and 'instance_number'.
    - segments_list: List of tuples where each tuple contains a segmentation dictionary and its metadata
      (segment_dic, segment_overlay_color, label_for_overlay_image).

    Returns:
    - Merged list of dictionaries with multiple segmentation overlays.
    """
    merged_data = []
    segments_lookup = [
        {
            'segment_lookup': {item['sop_instance_uid']: item['pixel_data'] for item in segment_dic},
            'segment_overlay_color': segment_overlay_color,
            'label_for_overlay_image': label_for_overlay_image
        }
        for segment_dic, segment_overlay_color, label_for_overlay_image in segments_list
    ]
    
    for original_item in original_dic:
        sop_instance_uid = original_item['sop_instance_uid']
        segmentations = []
        
        for segment in segments_lookup:
            segmentation_pixel = segment['segment_lookup'].get(sop_instance_uid, None)
            segmentations.append({
                'segmentation_pixel': segmentation_pixel,
                'segment_overlay_color': segment['segment_overlay_color'],
                'label_for_overlay_image': segment['label_for_overlay_image']
            })
        
        merged_data.append({
            'sop_instance_uid': sop_instance_uid,
            'original_pixel': original_item['pixel_data'],
            'segmentations': segmentations
        })
    
    return merged_data

def MULTI_prepare_images(merged_data, segment_overlay_transparency, figsize=(10, 10)):
    """
    Prepares images with overlays from merged data.

    Parameters:
    - merged_data: List of dictionaries with 'original_pixel', 'segmentations', and 'sop_instance_uid'.
    - segment_overlay_transparency: Transparency level for the segmentation overlay.
    - figsize: Tuple representing the figure size (width, height).

    Returns:
    - List of prepared images.
    """
    prepared_images = []

    for data in merged_data:
        original_pixel = data['original_pixel']
        segmentations = data['segmentations']
        
        fig, ax = plt.subplots(figsize=figsize)
        ax.imshow(original_pixel, cmap=plt.cm.gray)
        
        for segmentation in segmentations:
            segmentation_pixel = segmentation['segmentation_pixel']
            segment_overlay_color = segmentation['segment_overlay_color']
            label_for_overlay_image = segmentation['label_for_overlay_image']
            
            if segmentation_pixel is not None:
                overlay_rgba = to_rgba(segment_overlay_color, alpha=segment_overlay_transparency)
                overlay = np.zeros((*segmentation_pixel.shape, 4))
                overlay[..., :3] = overlay_rgba[:3]
                overlay[..., 3] = (segmentation_pixel > 0) * segment_overlay_transparency
                
                ax.imshow(overlay, cmap=None, alpha=segment_overlay_transparency)
        
        ax.set_title(f"SOP Instance UID: {data['sop_instance_uid']}\n{', '.join([seg['label_for_overlay_image'] for seg in segmentations])}")
        ax.axis('off')
        fig.canvas.draw()
        
        # Convert the canvas to an image
        image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
        image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        prepared_images.append(image)
        
        plt.close(fig)
    
    return prepared_images

def MULTI_show_merged_dicom_stack_fast(prepared_images, color_map, figsize=(15, 15)):
    """
    Displays a stack of merged DICOM images with overlay in an interactive scrollable format.

    Parameters:
    - prepared_images: List of prepared images with overlays.
    - color_map: Dictionary mapping segment labels to their corresponding colors.
    - figsize: Tuple representing the figure size (width, height) for display.
    """
    
    def plot_slice(index):
        fig, ax = plt.subplots(figsize=figsize)
        ax.imshow(prepared_images[index])
        ax.axis('off')
        
        title = f"Slice {index + 1}/{len(prepared_images)}\n" + ", ".join([f"{key}: {value}" for key, value in color_map.items()])
        fig.text(0.5, 0.01, title, ha='center', fontsize=12)
        plt.tight_layout(pad=0)
        plt.show()

    interact(plot_slice, index=IntSlider(min=0, max=len(prepared_images) - 1, step=1, value=0, orientation='horizontal'))


colors_list = [
    '#FF0000',  # Red
    '#00FF00',  # Green
    '#0000FF',  # Blue
    '#FFFF00',  # Yellow
    '#FF00FF',  # Magenta
    '#00FFFF',  # Cyan
    '#800000',  # Maroon
    '#808000',  # Olive
    '#008080',  # Teal
    '#800080',  # Purple
    '#FFA500',  # Orange
    '#A52A2A',  # Brown
]

color_names = [
    'Red',
    'Green',
    'Blue',
    'Yellow',
    'Magenta',
    'Cyan',
    'Maroon',
    'Olive',
    'Teal',
    'Purple',
    'Orange',
    'Brown'
]

multiple_segment_list = []
color_map = {}
i = 0
for key, item in slice_info_list.items():
    color = colors_list[i % len(colors_list)]  # Ensure the index stays within the length of colors_list
    multiple_segment_list.append((item, color, f"Label: {str(key)}"))
    color_name = color_names[i % len(color_names)] 
    color_map[key] = color_name
    i += 1

merged_data = MULTI_merge_dictionaries(original_dic, multiple_segment_list)
prepared_images = MULTI_prepare_images(merged_data, segment_overlay_transparency=0.4, figsize=(20, 20))

MULTI_show_merged_dicom_stack_fast(prepared_images, color_map, figsize=(8, 8))


In [None]:
merged_data

# Jupyter Widget base app

In [None]:
patient_directory=r"Sample_Data\CT\1000029"

In [None]:
import ipywidgets as widgets
from IPython.display import display

# Create a text box for directory input
directory_textbox = widgets.Text(
    value='C:/XNAT_BACK/archive/PanCanAID/arc001',  # Default value
    description='Directory:',
    placeholder='Enter directory path'
)

# Create a dropdown for folder selection
folder_dropdown = widgets.Dropdown(
    description='Folders:',
    options=[],
    disabled=True
)

# Create a dropdown for segmentation selection
segmentation_dropdown = widgets.Dropdown(
    description='Segmentations:',
    options=[],
    disabled=True
)

# Output widget to display selected segmentation path
output = widgets.Output()
def update_folder_dropdown(*args):
    folder_dropdown.options = []
    directory = directory_textbox.value
    if os.path.isdir(directory):
        folder_dropdown.options = [
            f for f in os.listdir(directory) if os.path.isdir(os.path.join(directory, f))
        ]
        folder_dropdown.disabled = False
    else:
        folder_dropdown.disabled = True
        segmentation_dropdown.disabled = True

def update_segmentation_dropdown(*args):
    folder = folder_dropdown.value
    directory = directory_textbox.value
    if folder:
        selected_patient_directory = os.path.join(directory, folder)
        assessors_directory = assessors_path(selected_patient_directory)
        if os.path.isdir(assessors_directory):
            segmentations = get_segmentations_from_assessors_path(assessors_directory)
            segmentation_dropdown.options = list(segmentations.keys())
            segmentation_dropdown.disabled = False
        else:
            segmentation_dropdown.disabled = True
    else:
        segmentation_dropdown.disabled = True

def display_selected_segmentation(*args):
    with output:
        output.clear_output()
        selected_segmentation_name = segmentation_dropdown.value
        directory = directory_textbox.value
        folder = folder_dropdown.value
        selected_patient_directory = os.path.join(directory, folder)
        assessors_directory = assessors_path(selected_patient_directory)
        segmentations = get_segmentations_from_assessors_path(assessors_directory)
        dicom_fullpath = read_selected_segmentation_for_widgets(selected_segmentation_name, segmentations)
        if dicom_fullpath:
            selected_SEGdicom_data = read_dicom(dicom_fullpath)
            Ref_series_UID = get_referenced_series_UID(selected_SEGdicom_data)
            segment_map = create_segment_number_to_label_map(selected_SEGdicom_data)
            slice_info_list = get_segmentation_data_including_RefSOPUID_refSegNum_pixelData(selected_SEGdicom_data, segment_map)

            original_series_data = extract_series_data(selected_patient_directory)
            matched_series_original_data_dic = get_path_to_matched_original_series(original_series_data, Ref_series_UID)

            colors_list = [
                '#FF0000',  # Red
                '#00FF00',  # Green
                '#0000FF',  # Blue
                '#FFFF00',  # Yellow
                '#FF00FF',  # Magenta
                '#00FFFF',  # Cyan
                '#800000',  # Maroon
                '#808000',  # Olive
                '#008080',  # Teal
                '#800080',  # Purple
                '#FFA500',  # Orange
                '#A52A2A',  # Brown
            ]

            color_names = [
                'Red',
                'Green',
                'Blue',
                'Yellow',
                'Magenta',
                'Cyan',
                'Maroon',
                'Olive',
                'Teal',
                'Purple',
                'Orange',
                'Brown'
            ]

            multiple_segment_list = []
            color_map = {}
            i = 0
            for key, item in slice_info_list.items():
                color = colors_list[i % len(colors_list)]  # Ensure the index stays within the length of colors_list
                multiple_segment_list.append((item, color, f"Label: {str(key)}"))
                color_name = color_names[i % len(color_names)] 
                color_map[key] = color_name
                i += 1

            merged_data = MULTI_merge_dictionaries(matched_series_original_data_dic, multiple_segment_list)
            prepared_images = MULTI_prepare_images(merged_data, segment_overlay_transparency=0.4, figsize=(20, 20))

            MULTI_show_merged_dicom_stack_fast(prepared_images, color_map, figsize=(8, 8))


# Add observers to update dropdowns
directory_textbox.observe(update_folder_dropdown, 'value')
folder_dropdown.observe(update_segmentation_dropdown, 'value')
segmentation_dropdown.observe(display_selected_segmentation, 'value')

# Initial display of the widgets
display(directory_textbox, folder_dropdown, segmentation_dropdown, output)

# [ ] revise the code to show the data
# [ ] revise the code to save the merged dictionary in a pickle file