In [None]:
import sys
import os
import numpy as np
import torch
from PIL import features  
import pydicom
import cv2
import matplotlib.pyplot as plt
import logging
from scipy.ndimage import zoom
import traceback
import csv

# Fix features issue in Pillow
if 'pydicom.pixels.decoders.pillow' in sys.modules:
    sys.modules['pydicom.pixels.decoders.pillow'].features = features

# Set up logging
logging.basicConfig(level=logging.INFO, 
                   format='%(asctime)s - %(levelname)s - %(message)s',
                   handlers=[logging.StreamHandler()])
logger = logging.getLogger(__name__)

# Set up GPU device if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device} ‚ú®")

# ====================== DICOM Data Loading Functions ======================

def load_dicom_series(folder_path, recursive=True):
    """Load DICOM series CT images"""
    ct_slices = []
    slice_positions = []
    
    def scan_for_ct_files(path):
        nonlocal ct_slices, slice_positions
        try:
            for item in os.listdir(path):
                item_path = os.path.join(path, item)
                if os.path.isdir(item_path) and recursive:
                    scan_for_ct_files(item_path)
                elif item.endswith('.dcm'):
                    try:
                        dcm = pydicom.dcmread(item_path)
                        if hasattr(dcm, 'Modality') and dcm.Modality == 'CT':
                            ct_slices.append(dcm)
                            slice_positions.append(float(dcm.ImagePositionPatient[2]))
                    except Exception as e:
                        logger.warning(f"Error reading file {item_path}: {str(e)}")
        except Exception as e:
            logger.error(f"Error scanning directory {path}: {str(e)}")
    
    logger.info(f"üîç Scanning for CT files in: {folder_path}")
    scan_for_ct_files(folder_path)
    
    if not ct_slices:
        logger.error(f"‚ùå No CT images found in {folder_path} or its subdirectories")
        return []
    
    logger.info(f"‚úÖ Found {len(ct_slices)} CT slices")
    # Sort by slice position
    sorted_pairs = sorted(zip(slice_positions, ct_slices))
    sorted_slices = [pair[1] for pair in sorted_pairs]
    return sorted_slices

def find_structure_file(folder_path, recursive=True):
    """Find RTSTRUCT file"""
    found_files = []
    
    def scan_for_rtstruct(path):
        try:
            for item in os.listdir(path):
                item_path = os.path.join(path, item)
                if os.path.isdir(item_path) and recursive:
                    scan_for_rtstruct(item_path)
                elif item.endswith('.dcm'):
                    try:
                        dcm = pydicom.dcmread(item_path)
                        if hasattr(dcm, 'Modality') and dcm.Modality == 'RTSTRUCT':
                            found_files.append(dcm)
                            logger.info(f"‚úÖ Found RTSTRUCT file: {item_path}")
                    except Exception as e:
                        pass
        except Exception as e:
            logger.error(f"Error scanning directory {path}: {str(e)}")
    
    logger.info(f"üîç Scanning for RTSTRUCT files in: {folder_path}")
    scan_for_rtstruct(folder_path)
    
    if not found_files:
        logger.warning(f"‚ö†Ô∏è No RTSTRUCT files found in {folder_path}")
        return None
    
    return found_files[0]

def find_dose_file(folder_path, recursive=True):
    """Find RTDOSE file"""
    found_files = []
    
    def scan_for_rtdose(path):
        try:
            for item in os.listdir(path):
                item_path = os.path.join(path, item)
                if os.path.isdir(item_path) and recursive:
                    scan_for_rtdose(item_path)
                elif item.endswith('.dcm'):
                    try:
                        dcm = pydicom.dcmread(item_path)
                        if hasattr(dcm, 'Modality') and dcm.Modality == 'RTDOSE':
                            found_files.append(dcm)
                            logger.info(f"‚úÖ Found RTDOSE file: {item_path}")
                    except Exception as e:
                        pass
        except Exception as e:
            logger.error(f"Error scanning directory {path}: {str(e)}")
    
    logger.info(f"üîç Scanning for RTDOSE files in: {folder_path}")
    scan_for_rtdose(folder_path)
    
    if not found_files:
        logger.warning(f"‚ö†Ô∏è No RTDOSE files found in {folder_path}")
        return None
    
    return found_files[0]

# ====================== Mask Creation Functions ======================

def get_roi_contours(rs_file, roi_name):
    """Extract contour data from RTSTRUCT file"""
    # Convert to lowercase for case-insensitive matching
    roi_name_lower = roi_name.lower().strip()
    
    for roi in rs_file.StructureSetROISequence:
        if roi.ROIName.lower().strip() == roi_name_lower:
            roi_number = roi.ROINumber
            logger.info(f"‚úÖ Found ROI: {roi.ROIName} (Number: {roi_number})")
            break
    else:
        logger.warning(f"‚ö†Ô∏è ROI {roi_name} not found")
        return []
    
    for roi_contour in rs_file.ROIContourSequence:
        if roi_contour.ReferencedROINumber == roi_number:
            if hasattr(roi_contour, 'ContourSequence'):
                logger.info(f"‚úÖ Found contour data for ROI: {roi_name}")
                return roi_contour.ContourSequence
            else:
                logger.warning(f"‚ö†Ô∏è No ContourSequence found for ROI: {roi_name}")
                return []
    
    logger.warning(f"‚ö†Ô∏è No matching contour data found for ROI: {roi_name}")
    return []

def convert_dicom_to_pixel_coordinates(contour_data, ct_slice):
    """Convert DICOM coordinates to pixel coordinates"""
    try:
        # Extract necessary DICOM tags
        pixel_spacing = ct_slice.PixelSpacing
        image_position = ct_slice.ImagePositionPatient
        
        # Calculate physical to pixel transformation
        x_pixel_size, y_pixel_size = pixel_spacing
        x_origin, y_origin, _ = image_position
        
        if len(contour_data) >= 3:  # Check if we have at least one point
            # Convert to array
            contour_array = np.array(contour_data).reshape(-1, 3)
            
            # Split coordinates
            x_world = contour_array[:, 0]
            y_world = contour_array[:, 1]
            
            # Transform world to pixel coordinates
            x_pixel = np.round((x_world - x_origin) / x_pixel_size).astype(np.int32)
            y_pixel = np.round((y_world - y_origin) / y_pixel_size).astype(np.int32)
            
            # Combine x and y coordinates
            pixel_coords = np.stack([x_pixel, y_pixel], axis=1)
            
            return pixel_coords
        else:
            return np.array([])
    except Exception as e:
        logger.error(f"Error converting coordinates: {str(e)}")
        return np.array([])

def create_masks(contour_data_list, ct_slices):
    """Create masks from contour data"""
    masks = []
    
    for ct_slice in ct_slices:
        mask = np.zeros((ct_slice.Rows, ct_slice.Columns), dtype=np.uint8)
        slice_z = ct_slice.ImagePositionPatient[2]
        
        # Find contours for this slice
        for contour in contour_data_list:
            if hasattr(contour, 'ContourData') and len(contour.ContourData) > 6:  # Ensure at least 3 points
                # Extract contour points
                contour_points = np.array(contour.ContourData).reshape(-1, 3)
                contour_z = contour_points[0, 2]
                
                # Check if contour is close to the current slice
                if abs(contour_z - slice_z) < 0.5:  # Typical slice thickness
                    # Convert contour points to pixel coordinates
                    pixel_coords = convert_dicom_to_pixel_coordinates(contour.ContourData, ct_slice)
                    
                    if len(pixel_coords) > 2:  # Need at least 3 points to form a polygon
                        # Fill polygon using OpenCV
                        cv2.fillPoly(mask, [pixel_coords], 1)
        
        masks.append(mask)
    
    return np.array(masks)

# ====================== Dose Processing Functions ======================

def get_dose_coordinates(dose_dcm):
    """Extract the physical coordinates of the dose grid."""
    # Get the origin of the dose grid
    origin = np.array(dose_dcm.ImagePositionPatient, dtype=np.float32)
    
    # Get the pixel spacing in the x-y plane
    pixel_spacing = np.array(dose_dcm.PixelSpacing, dtype=np.float32)
    
    # Get the z-spacing from the grid frame offset vector
    z_spacing = np.diff(dose_dcm.GridFrameOffsetVector)[0] if len(dose_dcm.GridFrameOffsetVector) > 1 else dose_dcm.GridFrameOffsetVector[0]
    
    # Get orientation (direction cosines) - typically identity for dose grids
    orientation = np.array(dose_dcm.ImageOrientationPatient, dtype=np.float32).reshape(2, 3)
    row_direction = orientation[0]
    col_direction = orientation[1]
    z_direction = np.cross(row_direction, col_direction)
    
    return {
        'origin': origin,
        'pixel_spacing': pixel_spacing,
        'z_spacing': z_spacing,
        'orientation': np.vstack((row_direction, col_direction, z_direction))
    }

def get_ct_coordinates(ct_slices):
    """Extract the physical coordinates of the CT volume."""
    # Use the first slice for x-y information
    first_slice = ct_slices[0]
    
    # Get the origin of the CT volume
    origin = np.array(first_slice.ImagePositionPatient, dtype=np.float32)
    
    # Get the pixel spacing in the x-y plane
    pixel_spacing = np.array(first_slice.PixelSpacing, dtype=np.float32)
    
    # Calculate z-spacing from consecutive slices
    if len(ct_slices) > 1:
        z_spacing = abs(float(ct_slices[1].ImagePositionPatient[2]) - float(ct_slices[0].ImagePositionPatient[2]))
    else:
        z_spacing = 1.0  # Default if only one slice
    
    # Get orientation (direction cosines)
    orientation = np.array(first_slice.ImageOrientationPatient, dtype=np.float32).reshape(2, 3)
    row_direction = orientation[0]
    col_direction = orientation[1]
    z_direction = np.cross(row_direction, col_direction)
    
    return {
        'origin': origin,
        'pixel_spacing': pixel_spacing,
        'z_spacing': z_spacing,
        'orientation': np.vstack((row_direction, col_direction, z_direction)),
        'shape': (len(ct_slices), first_slice.Rows, first_slice.Columns)
    }

def extract_dose_data(dose_dcm):
    """Extract and properly scale dose data from RTDOSE DICOM."""
    try:
        # Extract pixel array
        dose_array = dose_dcm.pixel_array.astype(np.float32)
        
        # Apply dose grid scaling
        if hasattr(dose_dcm, 'DoseGridScaling'):
            dose_array = dose_array * dose_dcm.DoseGridScaling
        else:
            logger.warning("‚ö†Ô∏è DoseGridScaling not found in RTDOSE file!")
        
        # Extract dose-related metadata
        dose_meta = {
            'shape': dose_array.shape,
            'min_dose': float(dose_array.min()),
            'max_dose': float(dose_array.max()),
            'mean_dose': float(dose_array.mean()),
            'median_dose': float(np.median(dose_array))
        }
        
        # Extract grid information
        if hasattr(dose_dcm, 'ImagePositionPatient'):
            dose_meta['origin'] = dose_dcm.ImagePositionPatient
        
        if hasattr(dose_dcm, 'PixelSpacing'):
            dose_meta['pixel_spacing'] = dose_dcm.PixelSpacing
        
        if hasattr(dose_dcm, 'GridFrameOffsetVector'):
            dose_meta['z_spacing'] = [float(offset) for offset in dose_dcm.GridFrameOffsetVector]
        
        logger.info(f"üíä Dose range: {dose_meta['min_dose']:.4f} to {dose_meta['max_dose']:.4f} Gy, median: {dose_meta['median_dose']:.4f} Gy")
        logger.info(f"üìä Dose shape: {dose_meta['shape']}")
        
        return dose_array, dose_meta
    
    except Exception as e:
        logger.error(f"‚ùå Error extracting dose data: {str(e)}")
        traceback.print_exc()
        return None, None

def align_dose_to_ct(dose_array, dose_coords, ct_coords):
    """
    Improved method to align dose grid to CT grid with better spatial accuracy
    """
    logger.info("üîÑ Aligning dose grid to CT grid with improved method...")
    
    # Extract coordinate information
    dose_origin = dose_coords['origin']
    dose_spacing = np.array([
        dose_coords['z_spacing'], 
        dose_coords['pixel_spacing'][1], 
        dose_coords['pixel_spacing'][0]
    ])
    
    ct_origin = ct_coords['origin']
    ct_spacing = np.array([
        ct_coords['z_spacing'], 
        ct_coords['pixel_spacing'][1], 
        ct_coords['pixel_spacing'][0]
    ])
    
    ct_shape = ct_coords['shape']
    dose_shape = dose_array.shape
    
    logger.info(f"üìè Dose origin: {dose_origin}, CT origin: {ct_origin}")
    logger.info(f"üìè Dose spacing: {dose_spacing}, CT spacing: {ct_spacing}")
    logger.info(f"üìè Dose shape: {dose_shape}, CT shape: {ct_shape}")
    
    try:
        # 1. Create empty aligned dose array in CT grid
        aligned_dose = np.zeros(ct_shape, dtype=np.float32)
        
        # 2. For each voxel in the CT grid, find the corresponding point in the dose grid
        # Create coordinate meshgrid for CT
        z_coords, y_coords, x_coords = np.meshgrid(
            np.arange(ct_shape[0]),
            np.arange(ct_shape[1]),
            np.arange(ct_shape[2]),
            indexing='ij'
        )
        
        # 3. Convert CT grid indices to physical coordinates
        physical_z = ct_origin[2] + z_coords * ct_spacing[0]
        physical_y = ct_origin[1] + y_coords * ct_spacing[1]
        physical_x = ct_origin[0] + x_coords * ct_spacing[2]
        
        # 4. Convert physical coordinates to dose grid indices
        dose_z = (physical_z - dose_origin[2]) / dose_spacing[0]
        dose_y = (physical_y - dose_origin[1]) / dose_spacing[1]
        dose_x = (physical_x - dose_origin[0]) / dose_spacing[2]
        
        # 5. Perform trilinear interpolation for points inside the dose grid
        # Find valid indices (inside dose grid bounds)
        valid = (
            (dose_z >= 0) & (dose_z < dose_shape[0] - 1) &
            (dose_y >= 0) & (dose_y < dose_shape[1] - 1) &
            (dose_x >= 0) & (dose_x < dose_shape[2] - 1)
        )
        
        # Floor indices for interpolation
        dose_z0 = np.floor(dose_z[valid]).astype(int)
        dose_y0 = np.floor(dose_y[valid]).astype(int)
        dose_x0 = np.floor(dose_x[valid]).astype(int)
        
        # Ceiling indices for interpolation
        dose_z1 = np.minimum(dose_z0 + 1, dose_shape[0] - 1)
        dose_y1 = np.minimum(dose_y0 + 1, dose_shape[1] - 1)
        dose_x1 = np.minimum(dose_x0 + 1, dose_shape[2] - 1)
        
        # Interpolation weights
        dz = dose_z[valid] - dose_z0
        dy = dose_y[valid] - dose_y0
        dx = dose_x[valid] - dose_x0
        
        # Get corner values
        c000 = dose_array[dose_z0, dose_y0, dose_x0]
        c001 = dose_array[dose_z0, dose_y0, dose_x1]
        c010 = dose_array[dose_z0, dose_y1, dose_x0]
        c011 = dose_array[dose_z0, dose_y1, dose_x1]
        c100 = dose_array[dose_z1, dose_y0, dose_x0]
        c101 = dose_array[dose_z1, dose_y0, dose_x1]
        c110 = dose_array[dose_z1, dose_y1, dose_x0]
        c111 = dose_array[dose_z1, dose_y1, dose_x1]
        
        # Trilinear interpolation formula
        aligned_dose[z_coords[valid], y_coords[valid], x_coords[valid]] = (
            c000 * (1 - dz) * (1 - dy) * (1 - dx) +
            c001 * (1 - dz) * (1 - dy) * dx +
            c010 * (1 - dz) * dy * (1 - dx) +
            c011 * (1 - dz) * dy * dx +
            c100 * dz * (1 - dy) * (1 - dx) +
            c101 * dz * (1 - dy) * dx +
            c110 * dz * dy * (1 - dx) +
            c111 * dz * dy * dx
        )
        
        logger.info(f"‚úÖ Dose aligned successfully: {np.sum(valid)} CT voxels mapped from dose grid")
        
        return aligned_dose
        
    except Exception as e:
        logger.error(f"‚ùå Error during dose alignment: {str(e)}")
        traceback.print_exc()
        
        # Return zeros in case of failure
        return np.zeros(ct_shape, dtype=np.float32)

def align_dose_to_ct_sitk(dose_array, dose_metadata, ct_slices):
    try:
        # Check if SimpleITK is installed
        import SimpleITK as sitk
    except ImportError:
        logger.error("‚ùå SimpleITK not installed. Please install with: pip install SimpleITK")
        logger.info("‚ú® Using the default alignment method instead.")
        return None
        
    logger.info("üîÑ Aligning dose to CT using SimpleITK (high accuracy)...")
    
    try:
        # 1. Create SimpleITK image from dose array
        dose_origin = np.array(dose_metadata['origin'], dtype=np.float64)
        dose_spacing = np.array([
            dose_metadata['z_spacing'],
            dose_metadata['pixel_spacing'][1],
            dose_metadata['pixel_spacing'][0]
        ], dtype=np.float64)
        
        # Convert to SimpleITK image
        dose_sitk = sitk.GetImageFromArray(dose_array.astype(np.float32))
        dose_sitk.SetOrigin(dose_origin)
        dose_sitk.SetSpacing(dose_spacing)
        
        # 2. Create a reference CT image
        # Extract CT metadata
        ref_origin = np.array(ct_slices[0].ImagePositionPatient, dtype=np.float64)
        pixel_spacing = np.array(ct_slices[0].PixelSpacing, dtype=np.float64)
        
        # Calculate z spacing from consecutive slices
        if len(ct_slices) > 1:
            z_spacing = abs(float(ct_slices[1].ImagePositionPatient[2]) - 
                         float(ct_slices[0].ImagePositionPatient[2]))
        else:
            z_spacing = 1.0  # Default if only one slice
            
        ref_spacing = np.array([z_spacing, pixel_spacing[1], pixel_spacing[0]], dtype=np.float64)
        
        # Get orientation (direction cosines)
        orientation = np.array(ct_slices[0].ImageOrientationPatient, dtype=np.float64).reshape(2, 3)
        row_direction = orientation[0]
        col_direction = orientation[1]
        slice_direction = np.cross(row_direction, col_direction)
        direction_matrix = np.vstack((row_direction, col_direction, slice_direction)).flatten()
        
        # Create a dummy CT array of the right size
        ct_shape = (len(ct_slices), ct_slices[0].Rows, ct_slices[0].Columns)
        ref_image = sitk.Image(ct_shape[2], ct_shape[1], ct_shape[0], sitk.sitkFloat32)
        ref_image.SetOrigin(ref_origin)
        ref_image.SetSpacing(ref_spacing)
        ref_image.SetDirection(direction_matrix)
        
        # 3. Resample dose to match CT grid
        # Define interpolator
        interpolator = sitk.sitkLinear  # Linear interpolation
        
        # Create resampler
        resampler = sitk.ResampleImageFilter()
        resampler.SetReferenceImage(ref_image)  # Use CT grid as reference
        resampler.SetInterpolator(interpolator)
        resampler.SetDefaultPixelValue(0.0)  # Areas outside the dose grid will be 0
        
        # Execute resampling
        aligned_dose_sitk = resampler.Execute(dose_sitk)
        
        # Convert back to numpy array
        aligned_dose = sitk.GetArrayFromImage(aligned_dose_sitk)
        
        logger.info(f"‚úÖ Dose aligned successfully with SimpleITK: Original shape: {dose_array.shape}, Aligned shape: {aligned_dose.shape}")
        
        return aligned_dose
        
    except Exception as e:
        logger.error(f"‚ùå Error aligning dose with SimpleITK: {e}")
        logger.info("‚ú® Fallback to default alignment method.")
        return None

# ================ Basic Visualization Functions ================
def visualize_ct_slice(ct_array, slice_idx=None, save_path=None):
    """Visualize a CT slice"""
    if slice_idx is None:
        slice_idx = ct_array.shape[0] // 2
    
    plt.figure(figsize=(6, 6))
    plt.imshow(ct_array[slice_idx], cmap='gray')
    plt.title(f'CT Slice {slice_idx} üîç')
    plt.colorbar(label='HU')
    plt.axis('off')
    
    if save_path:
        plt.savefig(save_path, dpi=150)
        plt.close()
    else:
        plt.show()

def visualize_mask_slice(ct_array, mask_array, slice_idx=None, roi_name="", save_path=None):
    """Visualize a mask overlay on CT slice"""
    if slice_idx is None:
        # Find slice with most mask content
        mask_sums = np.sum(mask_array, axis=(1, 2))
        if np.max(mask_sums) > 0:
            slice_idx = np.argmax(mask_sums)
        else:
            slice_idx = mask_array.shape[0] // 2
    
    plt.figure(figsize=(8, 8))
    plt.imshow(ct_array[slice_idx], cmap='gray')
    plt.imshow(mask_array[slice_idx], cmap='hot', alpha=0.5)
    plt.title(f'Mask {roi_name} (Slice {slice_idx}) üé≠')
    plt.axis('off')
    
    if save_path:
        plt.savefig(save_path, dpi=150)
        plt.close()
    else:
        plt.show()

def visualize_dose_slice(ct_array, dose_array, slice_idx=None, save_path=None):
    """Visualize a dose overlay on CT slice"""
    if slice_idx is None:
        # Find slice with most dose content
        dose_sums = np.sum(dose_array, axis=(1, 2))
        if np.max(dose_sums) > 0:
            slice_idx = np.argmax(dose_sums)
        else:
            slice_idx = dose_array.shape[0] // 2
    
    plt.figure(figsize=(10, 5))
    
    # Show CT
    plt.subplot(1, 2, 1)
    plt.imshow(ct_array[slice_idx], cmap='gray')
    plt.title(f'CT (Slice {slice_idx}) üñºÔ∏è')
    plt.axis('off')
    
    # Show Dose overlay on CT
    plt.subplot(1, 2, 2)
    plt.imshow(ct_array[slice_idx], cmap='gray')
    plt.imshow(dose_array[slice_idx], cmap='jet', alpha=0.7)
    plt.title(f'Dose Overlay (Slice {slice_idx}) üíä')
    plt.colorbar(label='Dose (Gy)')
    plt.axis('off')
    
    if save_path:
        plt.savefig(save_path, dpi=150)
        plt.close()
    else:
        plt.show()

# ====================== Main Processing Functions ======================

def process_dicom_ct(ct_slices):
    """Process CT images and create volume"""
    rows, cols = ct_slices[0].Rows, ct_slices[0].Columns
    ct_volume = np.zeros((len(ct_slices), rows, cols), dtype=np.float32)
    
    for idx, ct_slice in enumerate(ct_slices):
        # Rescale pixel values to Hounsfield units if needed
        if hasattr(ct_slice, 'RescaleSlope') and hasattr(ct_slice, 'RescaleIntercept'):
            pixel_array = ct_slice.pixel_array.astype(np.float32)
            rescaled = pixel_array * ct_slice.RescaleSlope + ct_slice.RescaleIntercept
            ct_volume[idx] = rescaled
        else:
            ct_volume[idx] = ct_slice.pixel_array.astype(np.float32)
    
    return ct_volume

def process_dicom_and_masks(patient_folder, target_rois, output_dir):
    """Process DICOM files, CT images and create masks"""
    try:
        os.makedirs(output_dir, exist_ok=True)
        
        # Load CT slices - search recursively
        logger.info(f"üîç Searching for CT slices in {patient_folder}...")
        ct_slices = load_dicom_series(patient_folder, recursive=True)
        
        if not ct_slices:
            logger.error(f"‚ùå No CT slices found in {patient_folder}, skipping...")
            return
        
        # Find RTSTRUCT file - search recursively
        logger.info(f"üîç Searching for RTSTRUCT files in {patient_folder}...")
        rs_file = find_structure_file(patient_folder, recursive=True)
        
        if rs_file is None:
            logger.warning(f"‚ö†Ô∏è No RTSTRUCT file found in {patient_folder}, skipping...")
            return
        
        # List all available ROIs for debugging
        logger.info("üìã Available ROIs in the RTSTRUCT file:")
        roi_names = []
        for roi in rs_file.StructureSetROISequence:
            roi_names.append(roi.ROIName)
            logger.info(f"- {roi.ROIName} (Number: {roi.ROINumber})")
        
        patient_name = os.path.basename(patient_folder)
        patient_output_path = os.path.join(output_dir, patient_name)
        os.makedirs(patient_output_path, exist_ok=True)
        
        # Save available ROIs list
        with open(os.path.join(patient_output_path, "available_rois.txt"), "w") as f:
            for roi_name in roi_names:
                f.write(f"{roi_name}\n")
        
        # Process CT slices to create volume
        ct_volume = process_dicom_ct(ct_slices)
        
        # Save CT volume (raw data)
        np.save(os.path.join(patient_output_path, "CT_raw.npy"), ct_volume)
        logger.info(f"üíæ Saved CT volume for {patient_name}, shape: {ct_volume.shape}")
        
        # Visualize a slice from CT
        os.makedirs(os.path.join(patient_output_path, "visualizations"), exist_ok=True)
        vis_path = os.path.join(patient_output_path, "visualizations", "ct_slice.png")
        visualize_ct_slice(ct_volume, save_path=vis_path)
        
        # Process each target ROI
        for roi_name in target_rois:
            try:
                logger.info(f"‚≠ê Processing ROI: {roi_name}")
                
                # Get contours for the ROI
                contour_data = get_roi_contours(rs_file, roi_name)
                
                if not contour_data:
                    logger.warning(f"‚ö†Ô∏è No contour data found for ROI: {roi_name}, skipping...")
                    continue
                
                # Create masks
                masks = create_masks(contour_data, ct_slices)
                
                # Save mask
                np.save(os.path.join(patient_output_path, f"Mask_{roi_name}.npy"), masks)
                logger.info(f"üíæ Saved mask for ROI '{roi_name}' in {patient_name}, shape: {masks.shape}")
                
                # Visualize a slice with mask
                vis_path = os.path.join(patient_output_path, "visualizations", f"mask_{roi_name}.png")
                visualize_mask_slice(ct_volume, masks, roi_name=roi_name, save_path=vis_path)
            
            except Exception as e:
                logger.error(f"‚ùå Error processing ROI '{roi_name}' in {patient_folder}: {str(e)}")
                traceback.print_exc()
    
    except Exception as e:
        logger.error(f"‚ùå Error processing folder {patient_folder}: {str(e)}")
        traceback.print_exc()

def process_rt_dose(patient_folder, output_dir):
    """Process RT Dose file with proper alignment to the CT dataset."""
    patient_name = os.path.basename(patient_folder)
    patient_output_path = os.path.join(output_dir, patient_name)
    os.makedirs(patient_output_path, exist_ok=True)
    
    # Find RT Dose file
    logger.info(f"üîç Searching for RT Dose file in {patient_folder}...")
    dose_dcm = find_dose_file(patient_folder)
    
    if dose_dcm is None:
        logger.warning(f"‚ö†Ô∏è No RT Dose file found in {patient_folder}")
        return
    
    # Load CT metadata for alignment
    ct_slices = load_dicom_series(patient_folder)
    if not ct_slices:
        logger.error(f"‚ùå Cannot process dose without CT data for alignment")
        return
    
    # Check if CT raw data already exists
    ct_file = os.path.join(patient_output_path, "CT_raw.npy")
    if os.path.exists(ct_file):
        ct_volume = np.load(ct_file)
    else:
        ct_volume = process_dicom_ct(ct_slices)
    
    logger.info(f"‚≠ê Processing RT Dose file for {patient_name}")
    
    # Extract dose data
    dose_array, dose_meta = extract_dose_data(dose_dcm)
    if dose_array is None:
        logger.error("‚ùå Failed to extract dose data")
        return
    
    # Get coordinate information
    try:
        dose_coords = get_dose_coordinates(dose_dcm)
        ct_coords = get_ct_coordinates(ct_slices)
        
        # Try to use SimpleITK for better alignment if available
        sitk_aligned_dose = align_dose_to_ct_sitk(dose_array, dose_meta, ct_slices)
        
        if sitk_aligned_dose is not None:
            # SimpleITK alignment successful
            aligned_dose = sitk_aligned_dose
            alignment_method = "SimpleITK"
            logger.info("‚ú® Used SimpleITK for high-precision dose alignment")
        else:
            # Fall back to custom alignment method
            aligned_dose = align_dose_to_ct(dose_array, dose_coords, ct_coords)
            alignment_method = "Custom"
            logger.info("‚ú® Used custom method for dose alignment")
        
        # Save dose stats to CSV
        csv_path = os.path.join(output_dir, "dose_metadata.csv")
        
        # Check if file exists and create with header if not
        if not os.path.exists(csv_path):
            with open(csv_path, "w", newline="") as csvfile:
                writer = csv.writer(csvfile)
                writer.writerow([
                    "PatientID", "OriginalShape", "AlignedShape", 
                    "MinDose", "MaxDose", "MeanDose", "MedianDose", "AlignmentMethod"
                ])
        
        # Append metadata
        with open(csv_path, "a", newline="") as csvfile:
            writer = csv.writer(csvfile)
            writer.writerow([
                patient_name,
                str(dose_meta['shape']),
                str(aligned_dose.shape),
                f"{dose_meta['min_dose']:.6f}",
                f"{dose_meta['max_dose']:.6f}",
                f"{dose_meta['mean_dose']:.6f}",
                f"{dose_meta['median_dose']:.6f}",
                alignment_method
            ])
        
        # Save the aligned dose data
        np.save(os.path.join(patient_output_path, "Dose_raw.npy"), aligned_dose)
        
        # Visualize dose
        os.makedirs(os.path.join(patient_output_path, "visualizations"), exist_ok=True)
        vis_path = os.path.join(patient_output_path, "visualizations", "dose_overlay.png")
        visualize_dose_slice(ct_volume, aligned_dose, save_path=vis_path)
        
        # Create a comparison visualization if using SimpleITK (to verify improvement)
        if alignment_method == "SimpleITK" and dose_coords is not None and ct_coords is not None:
            # Also create the old alignment for comparison
            old_aligned_dose = None
            try:
                # Just for visualization - import zoom directly here to avoid dependency issues
                from scipy.ndimage import zoom
                
                # Use simplified old alignment method for comparison
                scaling_factors = ct_coords['spacing'] / dose_coords['spacing']
                old_aligned_dose = zoom(dose_array, scaling_factors, order=1, mode='nearest')
                
                # Crop/pad to match CT dimensions if needed
                if old_aligned_dose.shape != ct_coords['shape']:
                    temp_dose = np.zeros(ct_coords['shape'], dtype=np.float32)
                    min_z = min(old_aligned_dose.shape[0], ct_coords['shape'][0])
                    min_y = min(old_aligned_dose.shape[1], ct_coords['shape'][1])
                    min_x = min(old_aligned_dose.shape[2], ct_coords['shape'][2])
                    temp_dose[:min_z, :min_y, :min_x] = old_aligned_dose[:min_z, :min_y, :min_x]
                    old_aligned_dose = temp_dose
                
                # Create comparison visualization
                comp_vis_path = os.path.join(patient_output_path, "visualizations", "dose_alignment_comparison.png")
                visualize_alignment_comparison(ct_volume, old_aligned_dose, aligned_dose, comp_vis_path)
            except Exception:
                # Skip comparison visualization if it fails
                pass
        
        logger.info(f"üíæ Saved aligned dose for {patient_name}")
        return True
        
    except Exception as e:
        logger.error(f"‚ùå Error processing dose file: {str(e)}")
        traceback.print_exc()
        return False

def visualize_alignment_comparison(ct_volume, old_aligned_dose, new_aligned_dose, save_path):
    # Find a good slice for visualization
    if old_aligned_dose is not None and new_aligned_dose is not None:
        # Find slice with maximum dose in new alignment
        dose_sums = np.sum(new_aligned_dose, axis=(1, 2))
        if np.max(dose_sums) > 0:
            slice_idx = np.argmax(dose_sums)
        else:
            slice_idx = new_aligned_dose.shape[0] // 2
        
        plt.figure(figsize=(15, 5))
        
        # CT only
        plt.subplot(1, 3, 1)
        plt.imshow(ct_volume[slice_idx], cmap='gray')
        plt.title(f'CT (Slice {slice_idx}) üñºÔ∏è')
        plt.axis('off')
        
        # Old alignment
        plt.subplot(1, 3, 2)
        plt.imshow(ct_volume[slice_idx], cmap='gray')
        plt.imshow(old_aligned_dose[slice_idx], cmap='jet', alpha=0.7)
        plt.title('Old Alignment Method üîç')
        plt.axis('off')
        
        # New alignment
        plt.subplot(1, 3, 3)
        plt.imshow(ct_volume[slice_idx], cmap='gray')
        plt.imshow(new_aligned_dose[slice_idx], cmap='jet', alpha=0.7)
        plt.title('New Alignment Method üéØ')
        plt.axis('off')
        
        plt.tight_layout()
        plt.savefig(save_path, dpi=150)
        plt.close()
        
        logger.info(f"üíæ Saved alignment comparison to {save_path}")
    else:
        logger.warning("‚ö†Ô∏è Cannot create alignment comparison - missing data")


def combine_masks(output_path, target_rois):
    """Combine all masks"""
    for patient_id in os.listdir(output_path):
        patient_path = os.path.join(output_path, patient_id)
        if not os.path.isdir(patient_path):
            continue

        print(f"üîÑ Combining masks for Patient: {patient_id}")

        combined_mask = None
        
        # Collect all masks
        for roi_name in target_rois:
            mask_path = os.path.join(patient_path, f"Mask_{roi_name}.npy")
            if not os.path.exists(mask_path):
                continue
                
            mask_array = np.load(mask_path)
            
            if combined_mask is None:
                combined_mask = mask_array.copy()
            else:
                combined_mask += mask_array

        if combined_mask is not None:
            # Ensure values are 0 or 1
            combined_mask = np.clip(combined_mask, 0, 1)
            # Save combined mask
            combined_mask_path = os.path.join(patient_path, "Combined_Mask.npy")
            np.save(combined_mask_path, combined_mask)
            
            # Visualize combined mask
            ct_path = os.path.join(patient_path, "CT_raw.npy")
            if os.path.exists(ct_path):
                ct_array = np.load(ct_path)
                vis_path = os.path.join(patient_path, "visualizations", "combined_mask.png")
                visualize_mask_slice(ct_array, combined_mask, roi_name="Combined", save_path=vis_path)
                
            print(f"‚úÖ Combined mask saved for Patient: {patient_id}")
        else:
            print(f"‚ö†Ô∏è No masks to combine for Patient: {patient_id}")

def process_all_data(dicom_directory, output_directory, target_rois):
    """Process all data (CT, Mask, and Dose) to create basic raw data"""
    os.makedirs(output_directory, exist_ok=True)
    
    # Check if single patient or multiple patients folder
    has_patient_folders = any(os.path.isdir(os.path.join(dicom_directory, item)) 
                             for item in os.listdir(dicom_directory))
    
    if not has_patient_folders:
        # Single patient folder
        patient_name = os.path.basename(dicom_directory)
        logger.info(f"üë§ Processing single patient folder: {patient_name}")
        
        # Process CT and create Masks
        process_dicom_and_masks(dicom_directory, target_rois, output_directory)
        
        # Process Dose
        process_rt_dose(dicom_directory, output_directory)
        
        # Combine Masks
        combine_masks(output_directory, target_rois)
        
    else:
        # Multiple patients folder
        for item in os.listdir(dicom_directory):
            patient_path = os.path.join(dicom_directory, item)
            if not os.path.isdir(patient_path):
                continue
                
            logger.info(f"üë§ Processing patient folder: {item}")
            
            # Process CT and create Masks
            process_dicom_and_masks(patient_path, target_rois, output_directory)
            
            # Process Dose
            process_rt_dose(patient_path, output_directory)
        
        # Combine Masks after processing all patients
        combine_masks(output_directory, target_rois)
    
    logger.info("üéâ All data processing completed successfully!")
    
    # Show summary
    num_patients = len([d for d in os.listdir(output_directory) 
                     if os.path.isdir(os.path.join(output_directory, d))])
    
    # Count created files
    total_files = 0
    ct_files = 0
    mask_files = 0
    dose_files = 0
    combined_masks = 0
    
    for root, dirs, files in os.walk(output_directory):
        for file in files:
            if file.endswith('.npy'):
                total_files += 1
                if file == "CT_raw.npy":
                    ct_files += 1
                elif file.startswith("Mask_"):
                    mask_files += 1
                elif file == "Dose_raw.npy":
                    dose_files += 1
                elif file == "Combined_Mask.npy":
                    combined_masks += 1
    
    print("\nüìä Processing Summary:")
    print(f"üë• Total patients processed: {num_patients}")
    print(f"üìÅ Total files created: {total_files}")
    print(f"üñºÔ∏è CT files: {ct_files}")
    print(f"üé≠ Mask files (all ROIs): {mask_files}")
    print(f"üíä Dose files: {dose_files}")
    print(f"üîÑ Combined Masks: {combined_masks}")
    
    print("\n‚ú® Data is ready for preprocessing step! ‚ú®")

dicom_directory = r"d:\ThesisMayn\Predict"  
output_directory = r"d:\ThesisMayn\Data_numpy\v2\Predict"  
    
target_rois = ['sigmoid_ct','bladder_ct', 'rectum_ct','bowel_ct','hrctv_ct','Applicator1','applicator1',
                   'Applicator2','applicator2','Applicator3','applicator3','Applicator4','applicator4',
                   'Applicator5','applicator5','Applicator6','applicator6','Applicator7','applicator7',
                   'Applicator8','applicator8','Applicator9','applicator9','Applicator10','applicator10',
                   'needdle1','needdle2','needdle3','needdle4','needdle5','needdle6']  
    
# Process all data
process_all_data(dicom_directory, output_directory, target_rois)