In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import zoom
import os
from pathlib import Path

# ====================== Inverse Transform Functions ======================
def unpad_volume(padded_volume, original_shape):
    current_shape = padded_volume.shape
    
    # Calculate how much to remove from each dimension
    remove_z = current_shape[0] - original_shape[0]
    remove_y = current_shape[1] - original_shape[1] 
    remove_x = current_shape[2] - original_shape[2]
    
    if remove_z < 0 or remove_y < 0 or remove_x < 0:
        print(f"‚ùå Cannot unpad: current shape {current_shape} is smaller than target {original_shape}")
        return padded_volume
    
    # Calculate start and end indices for cropping (centered)
    start_z = remove_z // 2
    end_z = start_z + original_shape[0]
    
    start_y = remove_y // 2
    end_y = start_y + original_shape[1]
    
    start_x = remove_x // 2
    end_x = start_x + original_shape[2]
    
    # Crop to original size
    unpadded = padded_volume[start_z:end_z, start_y:end_y, start_x:end_x]
    
    print(f"üìè Unpadding from {current_shape} to {original_shape}")
    
    return unpadded

def unresize_volume(resized_volume, original_shape, is_mask=False):
    current_shape = resized_volume.shape
    
    # Calculate scaling factors
    z_scale = original_shape[0] / current_shape[0]
    y_scale = original_shape[1] / current_shape[1]
    x_scale = original_shape[2] / current_shape[2]
    
    scale_factors = (z_scale, y_scale, x_scale)
    
    print(f"üîÑ Resizing back from {current_shape} to {original_shape}")
    
    # Use appropriate interpolation
    if is_mask:
        # Binary mask - use nearest neighbor
        result = zoom(resized_volume, scale_factors, order=0, mode='nearest')
    else:
        # Continuous values - use linear interpolation
        result = zoom(resized_volume, scale_factors, order=1, mode='nearest')
    
    return result

def denormalize_dose(normalized_dose, normalization_stats):
    if not normalization_stats or not normalization_stats.get('normalized', False):
        print("‚ö†Ô∏è No normalization stats provided, returning as-is")
        return normalized_dose
    
    method = normalization_stats.get('normalization_method', 'percentile')
    
    # Reverse visual enhancement if applied
    if normalization_stats.get('visual_enhancement', False):
        gamma = 0.7
        normalized_dose = np.power(normalized_dose, 1.0/gamma)
        print(f"üé® Reversed gamma correction (gamma={1.0/gamma:.2f})")
    
    # Reverse normalization based on method
    if method == 'percentile':
        factor = normalization_stats['normalization_factor']
        denormalized = normalized_dose * factor
        print(f"üíä Denormalized dose: [0,1] -> [0, {factor:.4f}] using percentile method")
        
    elif method == 'minmax':
        min_orig = normalization_stats['min_orig']
        max_orig = normalization_stats['max_orig']
        denormalized = normalized_dose * (max_orig - min_orig) + min_orig
        print(f"üíä Denormalized dose: [0,1] -> [{min_orig:.4f}, {max_orig:.4f}] using min-max")
        
    elif method == 'fixed':
        factor = normalization_stats['normalization_factor']
        denormalized = normalized_dose * factor
        print(f"üíä Denormalized dose: [0,1] -> [0, {factor:.4f}] using fixed method")
        
    else:
        print(f"‚ùå Unknown normalization method: {method}")
        denormalized = normalized_dose
    
    return denormalized

def unrescale_ct_from_range(rescaled_ct, min_hu=-1000, max_hu=1000, input_range=(0, 255)):
    min_in, max_in = input_range
    
    # Convert back to HU
    ct_hu = (rescaled_ct - min_in) / (max_in - min_in) * (max_hu - min_hu) + min_hu
    
    print(f"üîÑ Unrescaled CT: [{min_in}, {max_in}] -> [{min_hu}, {max_hu}] HU")
    
    return ct_hu

# ====================== Complete Postprocessing Pipeline ======================
def postprocess_prediction(predicted_volume, original_metadata, volume_type='dose', 
                         output_path=None, save_npy=True, show_visualization=True):
    
    print(f"üîÑ Starting postprocessing for {volume_type}...")
    print(f"üìä Input shape: {predicted_volume.shape}")
    
    # Step 1: Denormalize if needed
    processed_volume = predicted_volume.copy()
    
    if volume_type == 'dose' and 'normalization_stats' in original_metadata:
        processed_volume = denormalize_dose(processed_volume, original_metadata['normalization_stats'])
    elif volume_type == 'ct' and original_metadata.get('rescaled', False):
        min_hu = original_metadata.get('min_hu', -1000)
        max_hu = original_metadata.get('max_hu', 1000)
        input_range = original_metadata.get('rescale_range', (0, 255))
        processed_volume = unrescale_ct_from_range(processed_volume, min_hu, max_hu, input_range)
    
    # Step 2: Unpad if volume was padded
    if 'original_shape_before_padding' in original_metadata:
        target_shape = original_metadata['original_shape_before_padding']
        processed_volume = unpad_volume(processed_volume, target_shape)
    
    # Step 3: Unresize to original dimensions
    if 'original_shape' in original_metadata:
        original_shape = original_metadata['original_shape']
        is_mask = (volume_type == 'mask')
        processed_volume = unresize_volume(processed_volume, original_shape, is_mask)
    
    print(f"‚úÖ Final shape: {processed_volume.shape}")
    
    # Step 4: Save as .npy if requested
    if save_npy and output_path:
        np.save(output_path, processed_volume.astype(np.float32))
        print(f"üíæ Saved postprocessed volume to: {output_path}")
    
    # Step 5: Show visualization if requested
    if show_visualization:
        visualize_three_planes(processed_volume, volume_type, 
                             title=f"Postprocessed {volume_type.upper()}")
    
    return processed_volume

# ====================== Visualization Functions ======================
def visualize_three_planes(volume, volume_type='dose', slice_indices=None, title="Medical Volume"):
    
    if volume is None or len(volume.shape) != 3:
        print("‚ùå Invalid volume for visualization")
        return
    
    z_size, y_size, x_size = volume.shape
    
    # Default to middle slices if not specified
    if slice_indices is None:
        slice_indices = {
            'axial': z_size // 2,
            'coronal': y_size // 2, 
            'sagittal': x_size // 2
        }
    
    # Choose colormap based on volume type
    if volume_type == 'dose':
        cmap = 'hot'
        vmin, vmax = 0, np.percentile(volume, 95)  # Avoid outliers
    elif volume_type == 'ct':
        cmap = 'gray'
        vmin, vmax = np.percentile(volume, 1), np.percentile(volume, 99)
    elif volume_type == 'mask':
        cmap = 'gray'
        vmin, vmax = 0, 1
    else:
        cmap = 'viridis'
        vmin, vmax = np.min(volume), np.max(volume)
    
    # Create the plot
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    fig.suptitle(title, fontsize=16, fontweight='bold')
    
    # Axial view (xy plane)
    axial_slice = volume[slice_indices['axial'], :, :]
    im1 = axes[0].imshow(axial_slice, cmap=cmap, vmin=vmin, vmax=vmax, origin='lower')
    axes[0].set_title(f'Axial (z={slice_indices["axial"]})')
    axes[0].set_xlabel('X')
    axes[0].set_ylabel('Y')
    axes[0].axis('on')
    
    # Coronal view (xz plane) 
    coronal_slice = volume[:, slice_indices['coronal'], :].T  # Transpose for correct orientation
    im2 = axes[1].imshow(coronal_slice, cmap=cmap, vmin=vmin, vmax=vmax, origin='lower')
    axes[1].set_title(f'Coronal (y={slice_indices["coronal"]})')
    axes[1].set_xlabel('X')
    axes[1].set_ylabel('Z')
    axes[1].axis('on')
    
    # Sagittal view (yz plane)
    sagittal_slice = volume[:, :, slice_indices['sagittal']].T  # Transpose for correct orientation
    im3 = axes[2].imshow(sagittal_slice, cmap=cmap, vmin=vmin, vmax=vmax, origin='lower')
    axes[2].set_title(f'Sagittal (x={slice_indices["sagittal"]})')
    axes[2].set_xlabel('Y') 
    axes[2].set_ylabel('Z')
    axes[2].axis('on')
    
    # Add colorbars
    plt.colorbar(im1, ax=axes[0], shrink=0.8)
    plt.colorbar(im2, ax=axes[1], shrink=0.8)
    plt.colorbar(im3, ax=axes[2], shrink=0.8)
    
    plt.tight_layout()
    plt.show()
    
    # Print some statistics
    print(f"üìä {title} Statistics:")
    print(f"   Shape: {volume.shape}")
    print(f"   Range: [{np.min(volume):.4f}, {np.max(volume):.4f}]")
    print(f"   Mean: {np.mean(volume):.4f}")
    print(f"   Std: {np.std(volume):.4f}")

def compare_original_vs_processed(original_volume, processed_volume, volume_type='dose'):
    """
    Side-by-side comparison of original vs processed volume
    """
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    fig.suptitle(f'Original vs Processed {volume_type.upper()}', fontsize=16, fontweight='bold')
    
    # Choose middle slices
    z_mid = original_volume.shape[0] // 2
    y_mid = original_volume.shape[1] // 2 
    x_mid = original_volume.shape[2] // 2
    
    # Colormap settings
    if volume_type == 'dose':
        cmap = 'hot'
        vmin_orig = 0
        vmax_orig = np.percentile(original_volume, 95)
        vmin_proc = 0
        vmax_proc = np.percentile(processed_volume, 95)
    else:
        cmap = 'gray'
        vmin_orig = np.percentile(original_volume, 1)
        vmax_orig = np.percentile(original_volume, 99)
        vmin_proc = np.percentile(processed_volume, 1)
        vmax_proc = np.percentile(processed_volume, 99)
    
    views = ['Axial', 'Coronal', 'Sagittal']
    
    for i, view in enumerate(views):
        if view == 'Axial':
            orig_slice = original_volume[z_mid, :, :]
            proc_slice = processed_volume[z_mid, :, :]
        elif view == 'Coronal':
            orig_slice = original_volume[:, y_mid, :].T
            proc_slice = processed_volume[:, y_mid, :].T
        else:  # Sagittal
            orig_slice = original_volume[:, :, x_mid].T
            proc_slice = processed_volume[:, :, x_mid].T
        
        # Original
        axes[0, i].imshow(orig_slice, cmap=cmap, vmin=vmin_orig, vmax=vmax_orig, origin='lower')
        axes[0, i].set_title(f'Original {view}')
        axes[0, i].axis('on')
        
        # Processed
        axes[1, i].imshow(proc_slice, cmap=cmap, vmin=vmin_proc, vmax=vmax_proc, origin='lower')
        axes[1, i].set_title(f'Processed {view}')
        axes[1, i].axis('on')
    
    plt.tight_layout()
    plt.show()



# ====================== Batch Processing Function ======================
def postprocess_batch(predictions_dir, metadata_dir, output_dir, volume_type='dose'):
    os.makedirs(output_dir, exist_ok=True)
    
    prediction_files = sorted([f for f in os.listdir(predictions_dir) if f.endswith('.npy')])
    
    print(f"üîÑ Processing {len(prediction_files)} predictions...")
    
    for pred_file in prediction_files:
        print(f"\nüìÅ Processing: {pred_file}")
        
        # Load prediction
        pred_path = os.path.join(predictions_dir, pred_file)
        prediction = np.load(pred_path)
        
        # Load metadata (assuming same name but .json extension)
        meta_file = pred_file.replace('.npy', '_metadata.json')
        meta_path = os.path.join(metadata_dir, meta_file)
        
        try:
            import json
            with open(meta_path, 'r') as f:
                metadata = json.load(f)
        except:
            print(f"‚ö†Ô∏è Could not load metadata for {pred_file}, using default")
            metadata = create_sample_metadata()
        
        # Postprocess
        output_path = os.path.join(output_dir, f"postprocessed_{pred_file}")
        
        postprocessed = postprocess_prediction(
            predicted_volume=prediction,
            original_metadata=metadata,
            volume_type=volume_type,
            output_path=output_path,
            save_npy=True,
            show_visualization=False  # Don't show vis for batch processing
        )
        
        print(f"‚úÖ Saved: {output_path}")
    
    print(f"\nüéâ Batch postprocessing complete! Results saved to: {output_dir}")

# ====================== DICOM Reference Functions ======================

def read_dicom_metadata(dicom_folder_path):
    try:
        import pydicom
        import glob
    except ImportError:
        print("‚ùå pydicom not installed. Install with: pip install pydicom")
        return None
    
    # Find all DICOM files
    dcm_files = glob.glob(os.path.join(dicom_folder_path, "*.dcm"))
    if not dcm_files:
        print(f"‚ùå No DICOM files found in {dicom_folder_path}")
        return None
    
    dcm_files.sort()  # Sort by filename
    print(f"üìÅ Found {len(dcm_files)} DICOM files")
    
    # Read first file to get metadata
    first_dcm = pydicom.dcmread(dcm_files[0])
    
    # Get image dimensions
    rows = int(first_dcm.Rows)
    columns = int(first_dcm.Columns)
    num_slices = len(dcm_files)
    
    # Get pixel spacing
    try:
        pixel_spacing = first_dcm.PixelSpacing  # [row_spacing, col_spacing]
        slice_thickness = float(first_dcm.SliceThickness)
    except:
        print("‚ö†Ô∏è Could not read spacing information, using defaults")
        pixel_spacing = [1.0, 1.0]
        slice_thickness = 1.0
    
    # Create metadata dictionary
    metadata = {
        'original_shape': (num_slices, rows, columns),  # (z, y, x)
        'pixel_spacing': [float(pixel_spacing[0]), float(pixel_spacing[1])],  # [y, x]
        'slice_thickness': slice_thickness,  # z
        'num_slices': num_slices,
        'matrix_size': (rows, columns),
        'dicom_folder': dicom_folder_path,
        'dicom_files': dcm_files
    }
    
    print(f"üìä DICOM Metadata:")
    print(f"   Original shape: {metadata['original_shape']}")
    print(f"   Pixel spacing: {metadata['pixel_spacing']} mm")
    print(f"   Slice thickness: {metadata['slice_thickness']} mm")
    
    return metadata

def create_metadata_from_dicom(dicom_folder_path, preprocessing_params=None):
    # Read DICOM metadata
    dicom_meta = read_dicom_metadata(dicom_folder_path)
    if dicom_meta is None:
        return None
    
    # Default preprocessing parameters if not provided
    if preprocessing_params is None:
        preprocessing_params = {
            'target_size': (256, 256),  # Resize target
            'target_slices': None,      # No slice padding/cropping
            'normalization_method': 'percentile',
            'percentile': 95,
            'rescale_ct': True,
            'min_hu': -1000,
            'max_hu': 1000,
            'rescale_range': (0, 255)
        }
    
    # Create complete metadata
    complete_metadata = {
        # Original DICOM info
        'original_shape': dicom_meta['original_shape'],
        'pixel_spacing': dicom_meta['pixel_spacing'],
        'slice_thickness': dicom_meta['slice_thickness'],
        'dicom_folder': dicom_meta['dicom_folder'],
        
        # Preprocessing parameters
        'target_size': preprocessing_params['target_size'],
        'target_slices': preprocessing_params.get('target_slices'),
        
        # For CT rescaling
        'rescaled': preprocessing_params.get('rescale_ct', False),
        'min_hu': preprocessing_params.get('min_hu', -1000),
        'max_hu': preprocessing_params.get('max_hu', 1000),
        'rescale_range': preprocessing_params.get('rescale_range', (0, 255)),
        
        # For dose normalization (will be filled during postprocessing)
        'normalization_stats': None
    }
    
    # Calculate intermediate shapes for proper unpadding/unresizing
    original_shape = dicom_meta['original_shape']
    target_size = preprocessing_params['target_size']
    target_slices = preprocessing_params.get('target_slices')
    
    # Shape after resizing XY but before Z padding
    resized_shape = (original_shape[0], target_size[0], target_size[1])
    complete_metadata['original_shape_before_padding'] = resized_shape
    
    # Final preprocessed shape
    if target_slices:
        final_shape = (target_slices, target_size[0], target_size[1])
    else:
        final_shape = resized_shape
    complete_metadata['preprocessed_shape'] = final_shape
    
    print(f"‚úÖ Created complete metadata:")
    print(f"   Original: {original_shape}")
    print(f"   After resize: {resized_shape}")
    print(f"   Final preprocessed: {final_shape}")
    
    return complete_metadata

def postprocess_with_dicom_reference(predicted_volume, dicom_folder_path, 
                                   volume_type='dose', preprocessing_params=None,
                                   output_path=None, save_npy=True, show_visualization=True):
    
    print(f"üîÑ Postprocessing with DICOM reference: {dicom_folder_path}")
    
    # Create metadata from DICOM
    metadata = create_metadata_from_dicom(dicom_folder_path, preprocessing_params)
    if metadata is None:
        print("‚ùå Failed to create metadata from DICOM")
        return None
    
    # Add normalization stats if this is a dose volume
    if volume_type == 'dose':
        # You might want to provide actual normalization stats here
        # For now, using sample stats
        metadata['normalization_stats'] = {
            'normalized': True,
            'normalization_method': 'percentile',
            'normalization_factor': 7.5,  # This should come from your preprocessing
            'percentile_used': 95,
            'visual_enhancement': True,
            'min_orig': 0.0,
            'max_orig': 8.2,
            'mean_orig': 1.5
        }
    
    # Use the standard postprocessing pipeline
    result = postprocess_prediction(
        predicted_volume=predicted_volume,
        original_metadata=metadata,
        volume_type=volume_type,
        output_path=output_path,
        save_npy=save_npy,
        show_visualization=show_visualization
    )
    
    return result

def load_and_compare_with_dicom(predicted_volume, dicom_folder_path, 
                               volume_type='dose', preprocessing_params=None):
    try:
        import pydicom
        import glob
    except ImportError:
        print("‚ùå pydicom not installed. Install with: pip install pydicom")
        return None, None
    
    print(f"üìä Loading and comparing with DICOM reference...")
    
    # Postprocess prediction
    postprocessed = postprocess_with_dicom_reference(
        predicted_volume=predicted_volume,
        dicom_folder_path=dicom_folder_path,
        volume_type=volume_type,
        preprocessing_params=preprocessing_params,
        show_visualization=False  # We'll do custom comparison
    )
    
    if postprocessed is None:
        return None, None
    
    # Load original DICOM for comparison (if it's CT)
    if volume_type == 'ct':
        dcm_files = sorted(glob.glob(os.path.join(dicom_folder_path, "*.dcm")))
        original_volume = []
        
        for dcm_file in dcm_files:
            dcm = pydicom.dcmread(dcm_file)
            original_volume.append(dcm.pixel_array)
        
        original_volume = np.array(original_volume)
        
        # Convert to HU if needed
        if hasattr(pydicom.dcmread(dcm_files[0]), 'RescaleSlope'):
            slope = float(pydicom.dcmread(dcm_files[0]).RescaleSlope)
            intercept = float(pydicom.dcmread(dcm_files[0]).RescaleIntercept)
            original_volume = original_volume * slope + intercept
        
        print(f"üìä Loaded original DICOM: {original_volume.shape}")
        
        # Compare
        compare_original_vs_processed(original_volume, postprocessed, volume_type)
        
        return original_volume, postprocessed
    
    else:
        # For dose/mask, just show the postprocessed result
        visualize_three_planes(postprocessed, volume_type, 
                             title=f"Postprocessed {volume_type.upper()}")
        return None, postprocessed

# ====================== Jupyter Notebook Usage Examples ======================
def jupyter_example_with_dicom_reference():
    """
    Example using DICOM reference - most accurate method
    """
    print("üè• DICOM Reference Example")
    print("="*40)
    
    # Example usage
    dicom_folder = "/path/to/original/dicom/folder"
    prediction_file = "/path/to/model/prediction.npy"
    
    # Define preprocessing parameters that were used
    preprocessing_params = {
        'target_size': (256, 256),
        'target_slices': 64,  # or None if no Z padding was used
        'rescale_ct': True,
        'min_hu': -1000,
        'max_hu': 1000,
        'rescale_range': (0, 255),
        'normalization_method': 'percentile',
        'percentile': 95
    }
    
    # Load prediction
    prediction = np.load(prediction_file)
    
    # Postprocess using DICOM reference
    result = postprocess_with_dicom_reference(
        predicted_volume=prediction,
        dicom_folder_path=dicom_folder,
        volume_type='dose',  # or 'ct', 'mask'
        preprocessing_params=preprocessing_params,
        output_path="dicom_referenced_result.npy",
        save_npy=True,
        show_visualization=True
    )
    
    return result

def jupyter_example_quick_dicom_check(dicom_folder):
    """
    Quick check of DICOM metadata
    """
    print("üîç Quick DICOM Check")
    print("="*30)
    
    metadata = read_dicom_metadata(dicom_folder)
    if metadata:
        print("‚úÖ DICOM metadata loaded successfully!")
        return metadata
    else:
        print("‚ùå Failed to load DICOM metadata")
        return None

def jupyter_example_with_real_data(prediction_file, metadata_dict=None):
    print(f"üìÅ Loading prediction from: {prediction_file}")
    
    # Load your prediction
    prediction = np.load(prediction_file)
    print(f"üìä Loaded prediction shape: {prediction.shape}")
    
    # Use provided metadata or create sample
    if metadata_dict is None:
        print("‚ö†Ô∏è No metadata provided, using sample metadata")
        metadata_dict = create_sample_metadata()
    
    # Postprocess
    result = postprocess_prediction(
        predicted_volume=prediction,
        original_metadata=metadata_dict,
        volume_type='dose',  # Change to 'ct' or 'mask' as needed
        output_path="postprocessed_result.npy",
        save_npy=True,
        show_visualization=True
    )
    
    return result

def jupyter_quick_visualize(volume_path, volume_type='dose'):
    print(f"üëÅÔ∏è Quick visualization of: {volume_path}")
    
    volume = np.load(volume_path)
    visualize_three_planes(volume, volume_type, title=f"{volume_type.upper()} Volume")
    
    return volume

def check_padding_status(original_shape, prediction_shape):
    orig_z, orig_y, orig_x = original_shape
    pred_z, pred_y, pred_x = prediction_shape
    
    print(f"Original shape: {original_shape}")
    print(f"Prediction shape: {prediction_shape}")
    
    # ‡∏ï‡∏£‡∏ß‡∏à‡∏™‡∏≠‡∏ö XY dimensions (‡∏°‡∏±‡∏Å‡∏à‡∏∞ resize)
    if orig_y != pred_y or orig_x != pred_x:
        print(f"‚úÖ XY was resized: ({orig_y}, {orig_x}) -> ({pred_y}, {pred_x})")
    
    # ‡∏ï‡∏£‡∏ß‡∏à‡∏™‡∏≠‡∏ö Z dimension
    if orig_z == pred_z:
        print("‚úÖ Z dimension unchanged - no padding")
        return None  # No Z padding
    elif orig_z < pred_z:
        print(f"‚úÖ Z was padded: {orig_z} -> {pred_z} (added {pred_z - orig_z} slices)")
        return pred_z  # Return target slices
    else:
        print(f"‚úÖ Z was cropped: {orig_z} -> {pred_z} (removed {orig_z - pred_z} slices)")
        return pred_z  # Return target slices

def auto_detect_preprocessing_params(original_shape, prediction_shape, 
                                   min_hu=-1000, max_hu=1000):
    orig_z, orig_y, orig_x = original_shape
    pred_z, pred_y, pred_x = prediction_shape
    
    params = {
        'target_size': (pred_y, pred_x),  # ‡πÉ‡∏ä‡πâ‡∏Ç‡∏ô‡∏≤‡∏î‡∏à‡∏≤‡∏Å prediction
        'target_slices': pred_z if pred_z != orig_z else None,
        'rescale_ct': True,  # ‡∏™‡∏°‡∏°‡∏ï‡∏¥‡∏ß‡πà‡∏≤ rescale
        'min_hu': min_hu,
        'max_hu': max_hu,
        'rescale_range': (0, 255),
        'normalization_method': 'percentile',
        'percentile': 95
    }
    
    print("üîç Auto-detected preprocessing parameters:")
    for key, value in params.items():
        print(f"   {key}: {value}")
    
    return params

def create_manual_metadata(original_shape, prediction_shape, volume_type='dose'):
    orig_z, orig_y, orig_x = original_shape
    pred_z, pred_y, pred_x = prediction_shape
    
    metadata = {
        'original_shape': original_shape,
        'preprocessed_shape': prediction_shape,
        'target_size': (pred_y, pred_x),
        'target_slices': pred_z if pred_z != orig_z else None,
        'rescaled': True,
        'min_hu': -1000,
        'max_hu': 1000,
        'rescale_range': (0, 255)
    }
    
    # ‡∏™‡∏≥‡∏´‡∏£‡∏±‡∏ö dose
    if volume_type == 'dose':
        metadata['normalization_stats'] = {
            'normalized': True,
            'normalization_method': 'percentile',
            'normalization_factor': 7.5,  # ‡∏õ‡∏£‡∏±‡∏ö‡∏ï‡∏≤‡∏°‡∏à‡∏£‡∏¥‡∏á
            'percentile_used': 95,
            'visual_enhancement': True,
            'min_orig': 0.0,
            'max_orig': 8.2,
            'mean_orig': 1.5
        }
    
    # ‡∏Ñ‡∏≥‡∏ô‡∏ß‡∏ì shape ‡∏Å‡∏•‡∏≤‡∏á‡∏ó‡∏≤‡∏á
    if metadata['target_slices'] is not None:
        # ‡∏°‡∏µ‡∏Å‡∏≤‡∏£ pad Z
        metadata['original_shape_before_padding'] = (orig_z, pred_y, pred_x)
    else:
        # ‡πÑ‡∏°‡πà‡∏°‡∏µ‡∏Å‡∏≤‡∏£ pad Z
        metadata['original_shape_before_padding'] = None
    
    return metadata

def quick_postprocess(prediction_file, dicom_folder, volume_type='dose'):
    # ‡πÇ‡∏´‡∏•‡∏î‡∏Ç‡πâ‡∏≠‡∏°‡∏π‡∏•
    prediction = np.load(prediction_file)
    metadata = read_dicom_metadata(dicom_folder)
    
    # Auto-detect parameters
    params = auto_detect_preprocessing_params(
        metadata['original_shape'], 
        prediction.shape
    )
    
    # Postprocess
    result = postprocess_with_dicom_reference(
        predicted_volume=prediction,
        dicom_folder_path=dicom_folder,
        volume_type=volume_type,
        preprocessing_params=params,
        output_path=f"quick_result_{volume_type}.npy",
        save_npy=True,
        show_visualization=True
    )
    
    return result

In [5]:
# ====================== Example Usage Functions ======================
def create_sample_metadata():
    """
    Create sample metadata that would typically be saved during preprocessing
    """
    return {
        'original_shape': (64, 512, 512),  # Original CT/dose shape
        'original_shape_before_padding': (60, 480, 480),  # Shape before padding but after resizing
        'preprocessed_shape': (64, 256, 256),  # Final preprocessed shape
        'target_size': (256, 256),  # XY resize target
        'target_slices': 64,  # Z padding/cropping target
        'rescaled': True,  # Whether CT was rescaled
        'min_hu': -1000,
        'max_hu': 1000,
        'rescale_range': (0, 255),
        'normalization_stats': {  # For dose normalization
            'normalized': True,
            'normalization_method': 'percentile',
            'normalization_factor': 7.0,
            'percentile_used': 95,
            'visual_enhancement': True,
            'min_orig': 0.0,
            'max_orig': 8.2,
            'mean_orig': 1.5
        }
    }

def example_usage():
    """
    Example of how to use the postprocessing functions
    """
    print("üöÄ Example Usage of Postprocessing Pipeline")
    print("="*50)
    
    # Simulate a predicted dose volume (normalized, resized, padded)
    predicted_dose = np.random.random((64, 256, 256)) * 0.8  # Simulated prediction [0,1]
    
    # Create sample metadata (this would come from your preprocessing)
    metadata = create_sample_metadata()
    
    # Postprocess the prediction
    output_path = "postprocessed_dose.npy"
    
    postprocessed_dose = postprocess_prediction(
        predicted_volume=predicted_dose,
        original_metadata=metadata,
        volume_type='dose',
        output_path=output_path,
        save_npy=True,
        show_visualization=True
    )
    
    print(f"\n‚úÖ Postprocessing complete!")
    print(f"üìä Original prediction shape: {predicted_dose.shape}")
    print(f"üìä Final postprocessed shape: {postprocessed_dose.shape}")
    
    return postprocessed_dose

In [None]:
# 1. ‡∏ï‡∏£‡∏ß‡∏à‡∏™‡∏≠‡∏ö DICOM metadata ‡∏Å‡πà‡∏≠‡∏ô
metadata = read_dicom_metadata("/path/to/dicom/folder")
print(f"Original DICOM shape: {metadata['original_shape']}")

# ‡πÉ‡∏ä‡πâ function ‡∏ï‡∏£‡∏ß‡∏à‡∏™‡∏≠‡∏ö
target_slices = check_padding_status(metadata['original_shape'], prediction.shape)
# ‡∏≠‡∏±‡∏õ‡πÄ‡∏î‡∏ï preprocessing_params
preprocessing_params['target_slices'] = target_slices

# ‡πÉ‡∏ä‡πâ auto-detection
auto_params = auto_detect_preprocessing_params(
    metadata['original_shape'], 
    prediction.shape
)

# 2. ‡πÇ‡∏´‡∏•‡∏î prediction
prediction = np.load("your_model_output.npy")
print(f"Prediction shape: {prediction.shape}")

# 3. ‡∏Å‡∏≥‡∏´‡∏ô‡∏î‡∏û‡∏≤‡∏£‡∏≤‡∏°‡∏¥‡πÄ‡∏ï‡∏≠‡∏£‡πå preprocessing ‡πÇ‡∏î‡∏¢‡πÑ‡∏°‡πà‡∏£‡∏∞‡∏ö‡∏∏ target_slices ‡∏´‡∏£‡∏∑‡∏≠‡πÉ‡∏™‡πà None
preprocessing_params = {
    'target_size': (256, 256),      # ‡∏Ç‡∏ô‡∏≤‡∏î‡∏ó‡∏µ‡πà resize ‡πÄ‡∏õ‡πá‡∏ô
    'target_slices': None,          # ‡πÑ‡∏°‡πà‡πÑ‡∏î‡πâ pad Z dimension ‡∏´‡∏£‡∏∑‡∏≠‡πÑ‡∏°‡πà‡πÅ‡∏ô‡πà‡πÉ‡∏à
    'rescale_ct': True,             # ‡πÉ‡∏ä‡πà‡πÑ‡∏´‡∏°‡∏ó‡∏µ‡πà rescale CT
    'min_hu': -1000,                # ‡∏ä‡πà‡∏ß‡∏á HU ‡∏ó‡∏µ‡πà‡πÉ‡∏ä‡πâ
    'max_hu': 1000,
    'normalization_method': 'percentile',
    'percentile': 95
}

# 4. Postprocess ‡πÇ‡∏î‡∏¢‡∏≠‡πâ‡∏≤‡∏á‡∏≠‡∏¥‡∏á‡∏à‡∏≤‡∏Å DICOM ‡πÄ‡∏î‡∏¥‡∏°
result = postprocess_with_dicom_reference(
    predicted_volume=prediction,
    dicom_folder_path="/path/to/original/dicom/folder",
    volume_type='dose',
    preprocessing_params=auto_params,
    output_path="result_no_padding.npy",
    save_npy=True,
    show_visualization=True
)

In [None]:
# ============ ‡∏Å‡∏≤‡∏£‡πÉ‡∏ä‡πâ‡∏á‡∏≤‡∏ô‡πÅ‡∏ö‡∏ö Manual Metadata (‡∏ñ‡πâ‡∏≤‡πÑ‡∏°‡πà‡∏°‡∏µ DICOM) ============
# ‡πÉ‡∏ä‡πâ‡∏á‡∏≤‡∏ô manual metadata
manual_meta = create_manual_metadata(
    original_shape=(64, 512, 512),  # ‡∏£‡∏∞‡∏ö‡∏∏ shape ‡πÄ‡∏î‡∏¥‡∏°
    prediction_shape=prediction.shape,
    volume_type='dose'
)

result_manual = postprocess_prediction(
    predicted_volume=prediction,
    original_metadata=manual_meta,
    volume_type='dose',
    output_path="result_manual.npy",
    save_npy=True,
    show_visualization=True
)

# ============ ‡∏Å‡∏≤‡∏£‡πÄ‡∏õ‡∏£‡∏µ‡∏¢‡∏ö‡πÄ‡∏ó‡∏µ‡∏¢‡∏ö‡∏ú‡∏•‡∏•‡∏±‡∏û‡∏ò‡πå ============
if volume_type == 'ct':
    original, processed = load_and_compare_with_dicom(
        predicted_volume=prediction,
        dicom_folder_path="/path/to/original/dicom/folder",
        volume_type='ct',
        preprocessing_params=auto_params
    )