<a href="https://colab.research.google.com/github/NarendraKumarMadireddy/Denosing-Dynamic-PET-Images-using-DAE/blob/main/projectPoison.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install nibabel pydicom matplotlib tensorflow scipy

Collecting nibabel
  Downloading nibabel-5.3.2-py3-none-any.whl.metadata (9.1 kB)
Collecting pydicom
  Downloading pydicom-3.0.1-py3-none-any.whl.metadata (9.4 kB)
Collecting tensorflow
  Downloading tensorflow-2.19.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.1 kB)
Collecting astunparse>=1.6.0 (from tensorflow)
  Downloading astunparse-1.6.3-py2.py3-none-any.whl.metadata (4.4 kB)
Collecting flatbuffers>=24.3.25 (from tensorflow)
  Downloading flatbuffers-25.2.10-py2.py3-none-any.whl.metadata (875 bytes)
Collecting google-pasta>=0.1.1 (from tensorflow)
  Downloading google_pasta-0.2.0-py3-none-any.whl.metadata (814 bytes)
Collecting libclang>=13.0.0 (from tensorflow)
  Downloading libclang-18.1.1-py2.py3-none-manylinux2010_x86_64.whl.metadata (5.2 kB)
Collecting tensorboard~=2.19.0 (from tensorflow)
  Downloading tensorboard-2.19.0-py3-none-any.whl.metadata (1.8 kB)
Collecting tensorflow-io-gcs-filesystem>=0.23.1 (from tensorflow)
  Downloading tensorflow_io

In [3]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Dropout, BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
import nibabel as nib
import os
import pydicom
from scipy.ndimage import gaussian_filter
import matplotlib.pyplot as plt
from google.colab import drive
import zipfile
import shutil
from tqdm import tqdm

In [4]:
DICOM_FOLDER_PATH = '/content/drive/MyDrive/PT_80p 150_30 OSEM'  # CHANGE THIS PATH

# Set where you want to save results
RESULTS_FOLDER = '/content/drive/MyDrive/results3'  # CHANGE THIS PATH
IMAGE_OUTPUT_FOLDER = os.path.join(RESULTS_FOLDER, 'images')  # Folder for image outputs

# Create results folders if they don't exist
os.makedirs(RESULTS_FOLDER, exist_ok=True)
os.makedirs(IMAGE_OUTPUT_FOLDER, exist_ok=True)

In [5]:
def load_dicom_files(directory):
    """Load all DICOM files from a directory structure."""
    dicom_files = []

    # Walk through directory structure
    for root, _, files in os.walk(directory):
        for file in files:
            if file.lower().endswith('.dcm'):
                dicom_files.append(os.path.join(root, file))

    print(f"Found {len(dicom_files)} DICOM files")
    return dicom_files

In [6]:
def organize_pet_series(dicom_files):
    """Organize DICOM files into different PET series."""
    series_dict = {}

    for file_path in tqdm(dicom_files, desc="Reading DICOM metadata"):
        try:
            ds = pydicom.dcmread(file_path)

            # Skip non-PET images if needed
            if hasattr(ds, 'Modality') and ds.Modality != 'PT':
                continue

            # Use SeriesInstanceUID as the key
            series_id = ds.SeriesInstanceUID

            if series_id not in series_dict:
                series_dict[series_id] = []

            series_dict[series_id].append((file_path, ds))
        except Exception as e:
            print(f"Error reading {file_path}: {e}")

    print(f"Found {len(series_dict)} unique PET series")
    return series_dict

In [7]:
def build_pet_volume(series_files):
    """Convert a DICOM series into a 3D numpy array."""
    # Sort files by slice location or instance number
    try:
        sorted_files = sorted(series_files, key=lambda x: float(x[1].SliceLocation) if hasattr(x[1], 'SliceLocation') else x[1].InstanceNumber)
    except:
        # Fallback sorting by filename if SliceLocation not available
        sorted_files = sorted(series_files, key=lambda x: x[0])

    # Get dimensions from first image
    ds = sorted_files[0][1]
    rows = ds.Rows
    cols = ds.Columns

    # Initialize volume array
    volume = np.zeros((len(sorted_files), rows, cols))

    # Fill the volume with pixel data
    for i, (_, ds) in enumerate(tqdm(sorted_files, desc="Building volume")):
        # Rescale the pixel values if needed
        rescale_slope = 1.0
        rescale_intercept = 0.0

        if hasattr(ds, 'RescaleSlope'):
            rescale_slope = float(ds.RescaleSlope)
        if hasattr(ds, 'RescaleIntercept'):
            rescale_intercept = float(ds.RescaleIntercept)

        # Extract pixel array and apply rescaling
        pixel_array = ds.pixel_array
        volume[i, :, :] = pixel_array * rescale_slope + rescale_intercept

    return volume

def normalize_volume(volume):
    """Normalize volume to 0-1 range."""
    min_val = np.min(volume)
    max_val = np.max(volume)

    if max_val > min_val:
        return (volume - min_val) / (max_val - min_val)
    else:
        return volume

In [8]:
def add_poisson_noise(clean_data, snr=10):
    """Add Poisson noise to simulate low-dose PET acquisition."""
    # Scale the data
    scaled_data = clean_data * snr

    # Add Poisson noise
    noisy_data = np.random.poisson(scaled_data)

    # Scale back to original range
    noisy_data = noisy_data / snr

    # Clip to [0, 1] for normalized data
    noisy_data = np.clip(noisy_data, 0, 1)

    return noisy_data

In [9]:
def extract_patches(volume, patch_size=(5, 5, 5), stride=2):
    """Extract 3D patches from a volume."""
    # Volume shape: [Z, X, Y]
    Z, X, Y = volume.shape

    patches = []
    positions = []

    # Extract patches using sliding window
    for z in range(0, Z - patch_size[0] + 1, stride):
        for x in range(0, X - patch_size[1] + 1, stride):
            for y in range(0, Y - patch_size[2] + 1, stride):
                # Extract patch
                patch = volume[z:z+patch_size[0],
                               x:x+patch_size[1],
                               y:y+patch_size[2]]

                # Reshape to vector
                patch_vector = patch.flatten()

                patches.append(patch_vector)
                positions.append((z, x, y))

    return np.array(patches), positions

In [10]:
def prepare_training_data(clean_volumes, noisy_volumes, patch_size=(5, 5, 5), stride=2):
    """Prepare training data by extracting patches from clean and noisy volumes."""
    all_clean_patches = []
    all_noisy_patches = []

    for clean_vol, noisy_vol in zip(clean_volumes, noisy_volumes):
        clean_patches, _ = extract_patches(clean_vol, patch_size, stride)
        noisy_patches, _ = extract_patches(noisy_vol, patch_size, stride)

        all_clean_patches.append(clean_patches)
        all_noisy_patches.append(noisy_patches)

    # Concatenate patches from all volumes
    X_train = np.vstack(all_noisy_patches)
    y_train = np.vstack(all_clean_patches)

    return X_train, y_train

In [11]:
from tensorflow.keras.layers import Input, Dense, BatchNormalization, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
import tensorflow as tf

# Custom thresholded accuracy metric
def thresholded_accuracy(y_true, y_pred):
    return tf.reduce_mean(tf.cast(tf.abs(y_true - y_pred) <= 0.1, tf.float32))

def build_dae_model(input_dim, hidden_dim=128, n_layers=5):
    """Build a deep autoencoder model for denoising."""

    # Input layer
    inputs = Input(shape=(input_dim,))

    # Encoder
    x = Dense(hidden_dim * 2, activation='elu')(inputs)
    x = BatchNormalization()(x)
    x = Dropout(0.1)(x)

    # Hidden layers
    x = Dense(hidden_dim, activation='elu')(x)
    x = BatchNormalization()(x)

    for i in range(n_layers - 3):
        x = Dense(hidden_dim // 2, activation='elu')(x)
        x = BatchNormalization()(x)

    # Bottleneck layer
    x = Dense(hidden_dim // 4, activation='relu')(x)

    # Decoder
    for i in range(n_layers - 3):
        x = Dense(hidden_dim // 2, activation='elu')(x)
        x = BatchNormalization()(x)

    x = Dense(hidden_dim, activation='elu')(x)
    x = BatchNormalization()(x)

    x = Dense(hidden_dim * 2, activation='elu')(x)
    x = BatchNormalization()(x)

    # Output layer
    outputs = Dense(input_dim, activation='sigmoid')(x)

    # Create model
    model = Model(inputs=inputs, outputs=outputs)

    # Compile model with custom accuracy
    model.compile(
        optimizer=Adam(learning_rate=0.001),
        loss='mean_squared_error',
        metrics=[thresholded_accuracy]
    )

    return model


In [12]:
def train_dae(X_train, y_train, model, epochs=100, batch_size=64, validation_split=0.2):
    """Train the DAE model."""
    # Define callbacks
    early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
    checkpoint = ModelCheckpoint(os.path.join(RESULTS_FOLDER, 'dae_model.h5'),
                                monitor='val_loss', save_best_only=True)

    # Train model
    history = model.fit(
        X_train, y_train,
        epochs=epochs,
        batch_size=batch_size,
        validation_split=validation_split,
        callbacks=[early_stopping, checkpoint],
        verbose=1
    )

    return model, history


In [13]:
def denoise_volume(model, noisy_volume, patch_size=(5, 5, 5), stride=2):
    """Apply DAE to denoise a full PET volume."""
    # Get volume shape
    Z, X, Y = noisy_volume.shape

    # Initialize output volume and count matrix (for averaging overlapping patches)
    denoised_volume = np.zeros_like(noisy_volume)
    count_matrix = np.zeros_like(noisy_volume)

    # Extract patches and their positions
    patches, positions = extract_patches(noisy_volume, patch_size, stride)

    # Denoise patches in batches to avoid memory issues
    batch_size = 1000
    n_batches = (patches.shape[0] + batch_size - 1) // batch_size

    denoised_patches = []
    for i in range(n_batches):
        start_idx = i * batch_size
        end_idx = min((i + 1) * batch_size, patches.shape[0])
        batch_denoised = model.predict(patches[start_idx:end_idx], verbose=0)
        denoised_patches.append(batch_denoised)

    denoised_patches = np.vstack(denoised_patches)

    # Put denoised patches back into the volume
    for i, (z, x, y) in enumerate(positions):
        # Reshape denoised patch back to original shape
        denoised_patch = denoised_patches[i].reshape(patch_size)

        # Add denoised patch to the output volume
        denoised_volume[z:z+patch_size[0], x:x+patch_size[1], y:y+patch_size[2]] += denoised_patch

        # Update count matrix
        count_matrix[z:z+patch_size[0], x:x+patch_size[1], y:y+patch_size[2]] += 1

    # Average overlapping patches
    denoised_volume = np.divide(denoised_volume, count_matrix, out=np.zeros_like(denoised_volume), where=count_matrix!=0)

    return denoised_volume

In [14]:
def save_slice_images(original_volume, noisy_volume, denoised_volume, series_idx, metrics=None):
    """Save individual slice images for comparison."""
    # Create folder for this series
    series_folder = os.path.join(IMAGE_OUTPUT_FOLDER, f'series_{series_idx}')
    os.makedirs(series_folder, exist_ok=True)

    # Get dimensions
    Z, X, Y = original_volume.shape

    # Calculate slice positions for visualization (start, middle, end)
    slice_positions = [
        Z // 4,           # First quarter
        Z // 2,           # Middle
        3 * Z // 4        # Third quarter
    ]

    # Save selected slices
    for pos_idx, slice_pos in enumerate(slice_positions):
        if slice_pos >= Z:
            continue

        # Save original slice
        plt.figure(figsize=(6, 6))
        plt.imshow(original_volume[slice_pos], cmap='hot')
        plt.colorbar()
        plt.title(f'Original Image - Slice {slice_pos}')
        plt.savefig(os.path.join(series_folder, f'original_slice_{pos_idx}.png'))
        plt.close()

        # Save noisy slice
        plt.figure(figsize=(6, 6))
        plt.imshow(noisy_volume[slice_pos], cmap='hot')
        plt.colorbar()
        psnr_text = f' - PSNR: {metrics["psnr_noisy"]:.2f} dB' if metrics else ''
        plt.title(f'Noisy Image - Slice {slice_pos}{psnr_text}')
        plt.savefig(os.path.join(series_folder, f'noisy_slice_{pos_idx}.png'))
        plt.close()

        # Save denoised slice
        plt.figure(figsize=(6, 6))
        plt.imshow(denoised_volume[slice_pos], cmap='hot')
        plt.colorbar()
        psnr_text = f' - PSNR: {metrics["psnr_denoised"]:.2f} dB' if metrics else ''
        plt.title(f'Denoised Image - Slice {slice_pos}{psnr_text}')
        plt.savefig(os.path.join(series_folder, f'denoised_slice_{pos_idx}.png'))
        plt.close()

        # Save side-by-side comparison
        plt.figure(figsize=(18, 6))

        plt.subplot(1, 3, 1)
        plt.imshow(original_volume[slice_pos], cmap='hot')
        plt.colorbar()
        plt.title(f'Original - Slice {slice_pos}')

        plt.subplot(1, 3, 2)
        plt.imshow(noisy_volume[slice_pos], cmap='hot')
        plt.colorbar()
        psnr_text = f' - PSNR: {metrics["psnr_noisy"]:.2f} dB' if metrics else ''
        plt.title(f'Noisy{psnr_text}')

        plt.subplot(1, 3, 3)
        plt.imshow(denoised_volume[slice_pos], cmap='hot')
        plt.colorbar()
        psnr_text = f' - PSNR: {metrics["psnr_denoised"]:.2f} dB' if metrics else ''
        plt.title(f'Denoised{psnr_text}')

        plt.tight_layout()
        plt.savefig(os.path.join(series_folder, f'comparison_slice_{pos_idx}.png'))
        plt.close()

    # Save additional views (axial, coronal, sagittal)
    save_multiview_comparison(original_volume, noisy_volume, denoised_volume, series_folder, metrics)

In [15]:
def save_multiview_comparison(original_volume, noisy_volume, denoised_volume, output_folder, metrics=None):
    """Save comparison of axial, coronal, and sagittal views."""
    Z, X, Y = original_volume.shape

    # Define positions for each view
    axial_pos = Z // 2
    coronal_pos = X // 2
    sagittal_pos = Y // 2

    # --------------- Axial View (z-plane) ---------------
    plt.figure(figsize=(18, 6))

    plt.subplot(1, 3, 1)
    plt.imshow(original_volume[axial_pos], cmap='hot')
    plt.colorbar()
    plt.title('Original - Axial View')

    plt.subplot(1, 3, 2)
    plt.imshow(noisy_volume[axial_pos], cmap='hot')
    plt.colorbar()
    psnr_text = f' - PSNR: {metrics["psnr_noisy"]:.2f} dB' if metrics else ''
    plt.title(f'Noisy{psnr_text}')

    plt.subplot(1, 3, 3)
    plt.imshow(denoised_volume[axial_pos], cmap='hot')
    plt.colorbar()
    psnr_text = f' - PSNR: {metrics["psnr_denoised"]:.2f} dB' if metrics else ''
    plt.title(f'Denoised{psnr_text}')

    plt.tight_layout()
    plt.savefig(os.path.join(output_folder, 'axial_view_comparison.png'))
    plt.close()

    # --------------- Coronal View (x-plane) ---------------
    plt.figure(figsize=(18, 6))

    plt.subplot(1, 3, 1)
    plt.imshow(original_volume[:, coronal_pos, :], cmap='hot')
    plt.colorbar()
    plt.title('Original - Coronal View')

    plt.subplot(1, 3, 2)
    plt.imshow(noisy_volume[:, coronal_pos, :], cmap='hot')
    plt.colorbar()
    plt.title(f'Noisy{psnr_text}')

    plt.subplot(1, 3, 3)
    plt.imshow(denoised_volume[:, coronal_pos, :], cmap='hot')
    plt.colorbar()
    plt.title(f'Denoised{psnr_text}')

    plt.tight_layout()
    plt.savefig(os.path.join(output_folder, 'coronal_view_comparison.png'))
    plt.close()

    plt.figure(figsize=(18, 6))

    plt.subplot(1, 3, 1)
    plt.imshow(original_volume[:, :, sagittal_pos], cmap='hot')
    plt.colorbar()
    plt.title('Original - Sagittal View')

    plt.subplot(1, 3, 2)
    plt.imshow(noisy_volume[:, :, sagittal_pos], cmap='hot')
    plt.colorbar()
    plt.title(f'Noisy{psnr_text}')

    plt.subplot(1, 3, 3)
    plt.imshow(denoised_volume[:, :, sagittal_pos], cmap='hot')
    plt.colorbar()
    plt.title(f'Denoised{psnr_text}')

    plt.tight_layout()
    plt.savefig(os.path.join(output_folder, 'sagittal_view_comparison.png'))
    plt.close()

In [16]:
def save_mip_comparison(original_volume, noisy_volume, denoised_volume, output_folder, metrics=None):
    """Save maximum intensity projection (MIP) comparisons."""
    # Create MIPs
    original_mip = np.max(original_volume, axis=0)
    noisy_mip = np.max(noisy_volume, axis=0)
    denoised_mip = np.max(denoised_volume, axis=0)

    # Save comparison
    plt.figure(figsize=(18, 6))

    plt.subplot(1, 3, 1)
    plt.imshow(original_mip, cmap='hot')
    plt.colorbar()
    plt.title('Original - MIP')

    plt.subplot(1, 3, 2)
    plt.imshow(noisy_mip, cmap='hot')
    plt.colorbar()
    psnr_text = f' - PSNR: {metrics["psnr_noisy"]:.2f} dB' if metrics else ''
    plt.title(f'Noisy{psnr_text}')

    plt.subplot(1, 3, 3)
    plt.imshow(denoised_mip, cmap='hot')
    plt.colorbar()
    psnr_text = f' - PSNR: {metrics["psnr_denoised"]:.2f} dB' if metrics else ''
    plt.title(f'Denoised{psnr_text}')

    plt.tight_layout()
    plt.savefig(os.path.join(output_folder, 'mip_comparison.png'))
    plt.close()

    # Save individual MIPs at higher resolution
    plt.figure(figsize=(10, 10))
    plt.imshow(original_mip, cmap='hot')
    plt.colorbar()
    plt.title('Original - Maximum Intensity Projection')
    plt.savefig(os.path.join(output_folder, 'original_mip.png'))
    plt.close()

    plt.figure(figsize=(10, 10))
    plt.imshow(noisy_mip, cmap='hot')
    plt.colorbar()
    plt.title(f'Noisy - Maximum Intensity Projection{psnr_text}')
    plt.savefig(os.path.join(output_folder, 'noisy_mip.png'))
    plt.close()

    plt.figure(figsize=(10, 10))
    plt.imshow(denoised_mip, cmap='hot')
    plt.colorbar()
    plt.title(f'Denoised - Maximum Intensity Projection{psnr_text}')
    plt.savefig(os.path.join(output_folder, 'denoised_mip.png'))
    plt.close()

In [17]:
def save_difference_images(original_volume, noisy_volume, denoised_volume, output_folder):
    """Save difference images to show noise patterns removed."""
    # Calculate difference volumes
    noisy_diff = np.abs(noisy_volume - original_volume)
    denoised_diff = np.abs(denoised_volume - original_volume)
    improvement = noisy_diff - denoised_diff  # Positive values show improvement

    # Middle slice
    middle_slice = original_volume.shape[0] // 2

    # Save difference images
    plt.figure(figsize=(18, 6))

    plt.subplot(1, 3, 1)
    plt.imshow(noisy_diff[middle_slice], cmap='viridis')
    plt.colorbar()
    plt.title('Original vs Noisy Difference')

    plt.subplot(1, 3, 2)
    plt.imshow(denoised_diff[middle_slice], cmap='viridis')
    plt.colorbar()
    plt.title('Original vs Denoised Difference')

    plt.subplot(1, 3, 3)
    plt.imshow(improvement[middle_slice], cmap='RdBu_r', vmin=-0.2, vmax=0.2)
    plt.colorbar()
    plt.title('Improvement (Red = Better)')

    plt.tight_layout()
    plt.savefig(os.path.join(output_folder, 'difference_analysis.png'))
    plt.close()

    # Save MIP of difference volumes
    noisy_diff_mip = np.max(noisy_diff, axis=0)
    denoised_diff_mip = np.max(denoised_diff, axis=0)
    improvement_mip = np.max(improvement, axis=0)

    plt.figure(figsize=(18, 6))

    plt.subplot(1, 3, 1)
    plt.imshow(noisy_diff_mip, cmap='viridis')
    plt.colorbar()
    plt.title('Original vs Noisy Difference - MIP')

    plt.subplot(1, 3, 2)
    plt.imshow(denoised_diff_mip, cmap='viridis')
    plt.colorbar()
    plt.title('Original vs Denoised Difference - MIP')

    plt.subplot(1, 3, 3)
    plt.imshow(improvement_mip, cmap='RdBu_r', vmin=-0.2, vmax=0.2)
    plt.colorbar()
    plt.title('Improvement - MIP (Red = Better)')

    plt.tight_layout()
    plt.savefig(os.path.join(output_folder, 'difference_analysis_mip.png'))
    plt.close()

In [18]:
print("Loading DICOM files...")
dicom_files = load_dicom_files(DICOM_FOLDER_PATH)

if len(dicom_files) == 0:
    print(f"No DICOM files found in {DICOM_FOLDER_PATH}")
else:
    # Organize into series
    series_dict = organize_pet_series(dicom_files)

    if len(series_dict) == 0:
        print("No valid PET series found")
    else:
        # Build volumes from the first few series for training
        volumes = []
        series_ids = list(series_dict.keys())

        # Limit to maximum 5 series for training to avoid memory issues
        max_training_series = min(5, len(series_ids))

        print(f"Building volumes from {max_training_series} series for training...")
        for i in range(max_training_series):
            series_id = series_ids[i]
            print(f"Processing series {i+1}/{max_training_series}: {series_id[:8]}...")
            volume = build_pet_volume(series_dict[series_id])
            volumes.append(normalize_volume(volume))

        # Create noisy versions for training
        noisy_volumes = []
        clean_volumes = []

        print("Generating noisy training data...")
        for vol in volumes:
            # Use the original as clean
            clean_volumes.append(vol)

            # Add noise for training
            noisy_vol = add_poisson_noise(vol, snr=5)
            noisy_volumes.append(noisy_vol)

        # Prepare training data
        print("Preparing training patches...")
        patch_size = (5, 5, 5)  # 3D patch size
        stride = 2              # Stride for patch extraction

        X_train, y_train = prepare_training_data(clean_volumes, noisy_volumes, patch_size, stride)
        print(f"Training data shape: {X_train.shape}, {y_train.shape}")

        # Build and train model
        print("Building DAE model...")
        input_dim = X_train.shape[1]
        model = build_dae_model(input_dim, hidden_dim=128, n_layers=5)
        model.summary()

        print("Training DAE model...")
        model, history = train_dae(X_train, y_train, model, epochs=3, batch_size=128)

        plt.figure(figsize=(12, 5))

        plt.subplot(1, 2, 1)
        plt.plot(history.history['loss'], label='Training Loss')
        plt.plot(history.history['val_loss'], label='Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Training vs Validation Loss')
        plt.legend()

# Plot training and validation accuracy
        plt.subplot(1, 2, 2)
        plt.plot(history.history['thresholded_accuracy'], label='Training Accuracy')
        plt.plot(history.history['val_thresholded_accuracy'], label='Validation Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.title('Training vs Validation Accuracy')
        plt.legend()

        plt.tight_layout()
        plt.savefig(os.path.join(RESULTS_FOLDER, 'training_history_with_accuracy.png'))
        plt.close()


        # Process all series for denoising
        print("\nDenoising all PET series...")
        for i, series_id in enumerate(series_ids):
            print(f"Denoising series {i+1}/{len(series_ids)}: {series_id[:8]}...")

            # Build original volume
            original_volume = build_pet_volume(series_dict[series_id])
            original_volume_norm = normalize_volume(original_volume)

            # Add noise for testing (simulate low-dose)
            noisy_volume = add_poisson_noise(original_volume_norm, snr=5)

            # Denoise volume
            denoised_volume = denoise_volume(model, noisy_volume, patch_size, stride)

            # Calculate metrics
            mse_noisy = np.mean((noisy_volume - original_volume_norm)**2)
            mse_denoised = np.mean((denoised_volume - original_volume_norm)**2)

            psnr_noisy = 10 * np.log10(1.0 / mse_noisy)  # Normalized to 0-1
            psnr_denoised = 10 * np.log10(1.0 / mse_denoised)

            metrics = {
                'mse_noisy': mse_noisy,
                'mse_denoised': mse_denoised,
                'psnr_noisy': psnr_noisy,
                'psnr_denoised': psnr_denoised
            }

            print(f"MSE - Noisy: {mse_noisy:.4f}, Denoised: {mse_denoised:.4f}")
            print(f"PSNR - Noisy: {psnr_noisy:.4f} dB, Denoised: {psnr_denoised:.4f} dB")

            # Save detailed visualizations
            save_slice_images(original_volume_norm, noisy_volume, denoised_volume, i, metrics)

            # Save MIP comparisons
            series_folder = os.path.join(IMAGE_OUTPUT_FOLDER, f'series_{i}')
            save_mip_comparison(original_volume_norm, noisy_volume, denoised_volume, series_folder, metrics)

            # Save difference images
            save_difference_images(original_volume_norm, noisy_volume, denoised_volume, series_folder)

            np.save(os.path.join(RESULTS_FOLDER, f'original_volume_series_{i}.npy'), original_volume_norm)
            np.save(os.path.join(RESULTS_FOLDER, f'noisy_volume_series_{i}.npy'), noisy_volume)
            np.save(os.path.join(RESULTS_FOLDER, f'denoised_volume_series_{i}.npy'), denoised_volume)

Loading DICOM files...
Found 1706 DICOM files


Reading DICOM metadata: 100%|██████████| 1706/1706 [00:37<00:00, 45.99it/s] 


Found 1 unique PET series
Building volumes from 1 series for training...
Processing series 1/1: 1.2.840....


Building volume: 100%|██████████| 1706/1706 [00:01<00:00, 1509.52it/s]


Generating noisy training data...
Preparing training patches...
Training data shape: (7519436, 125), (7519436, 125)
Building DAE model...


Training DAE model...
Epoch 1/3
[1m46993/46997[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 8ms/step - loss: 0.0026 - thresholded_accuracy: 0.9821



[1m46997/46997[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m408s[0m 9ms/step - loss: 0.0026 - thresholded_accuracy: 0.9821 - val_loss: 1.3934e-05 - val_thresholded_accuracy: 0.9999
Epoch 2/3
[1m46997/46997[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m408s[0m 9ms/step - loss: 3.0865e-05 - thresholded_accuracy: 0.9995 - val_loss: 1.9187e-05 - val_thresholded_accuracy: 0.9999
Epoch 3/3
[1m46991/46997[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 8ms/step - loss: 2.5633e-05 - thresholded_accuracy: 0.9996



[1m46997/46997[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m414s[0m 9ms/step - loss: 2.5633e-05 - thresholded_accuracy: 0.9996 - val_loss: 7.5792e-06 - val_thresholded_accuracy: 0.9999

Denoising all PET series...
Denoising series 1/1: 1.2.840....


Building volume: 100%|██████████| 1706/1706 [00:00<00:00, 5150.39it/s]


MSE - Noisy: 0.0007, Denoised: 0.0000
PSNR - Noisy: 31.7896 dB, Denoised: 48.4536 dB
