In [None]:
!pip install SimpleITK pydicom TotalSegmentator scipy

In [1]:
# Cell 1: Setup
import os
import glob
import numpy as np
import pydicom
import datetime
from pydicom.uid import generate_uid, ExplicitVRLittleEndian
from pydicom.dataset import FileDataset, FileMetaDataset

#from google.colab import drive

from totalsegmentator.python_api import totalsegmentator
from scipy.ndimage import label as ndimage_label
import SimpleITK as sitk
import shutil

#drive.mount('/content/drive')

In [24]:
# Cell 1.1

root_dir = "/content/" # update as needed

shutil.unpack_archive("c0138-main.zip", root_dir)

root_dir = "/content/c0138-main/"

base_dicom_path = os.path.join(root_dir, "dataset/3D")
# User input:
scan_number = "44" # < Enter scan number (Provided: 44, 108, 171)
#

In [36]:
# Cell 1.2: path configuration
ct_dir = os.path.join(base_dicom_path, scan_number, "CT/")
pet_dir = os.path.join(base_dicom_path, scan_number, "PET/")

output_dir = os.path.join(base_dicom_path, scan_number, "PET_scrubbed/")

segm_dir = os.path.join(base_dicom_path, scan_number, "segm_dir/")
multilabel_liver_nii_filename = "liver_segments.nii"
multilabel_liver_nii_path = os.path.join(segm_dir, multilabel_liver_nii_filename)


os.makedirs(output_dir, exist_ok=True)
os.makedirs(segm_dir, exist_ok=True)

In [37]:
# Cell 2
def load_series(directory):
    """Return a list of (dataset, pixel_array) sorted by instance number."""
    paths = sorted(
        glob.glob(os.path.join(directory, "**", "*"), recursive=True),
        key=lambda p: int(pydicom.dcmread(p, stop_before_pixels=True).InstanceNumber)
    )
    series = []
    for p in paths:
        ds = pydicom.dcmread(p)
        arr = ds.pixel_array        # keep native dtype
        series.append((ds, arr))
    return series


In [38]:
# Cell 3 ── liver segmentation on CT series ─────────
def segment_liver_ct(ct_dir_local, pet_dir_local, multilabel_liver_nii_path_local):
    """
    Generates liver and fat masks in PET coordinate space.
    Handles multi-label liver segmentation output from TotalSegmentator.

    Args:
        ct_dir_local (str): Path to the directory containing CT DICOM series.
        pet_dir_local (str): Path to the directory containing PET DICOM series.
        multilabel_liver_nii_path_local (str): Full path for TotalSegmentator multi-label output NIfTI file.

    Returns:
        tuple[np.ndarray, np.ndarray]: liver_mask_pet, fat_mask_pet (boolean arrays)
    """

    # 1) run TotalSegmentator (if needed)

    if not os.path.exists(multilabel_liver_nii_path_local):
        print(f"Running TotalSegmentator for liver segments...")

        os.makedirs(os.path.dirname(multilabel_liver_nii_path_local), exist_ok=True)
        try:
            totalsegmentator(
                input=ct_dir_local,
                output=multilabel_liver_nii_path_local,
                task="liver_segments",
                ml=True,
                device="gpu"
            )
            print(f"TotalSegmentator finished. Expecting multi-label file: {multilabel_liver_nii_path_local}")
            if not os.path.exists(multilabel_liver_nii_path_local):
                 raise FileNotFoundError(f"TotalSegmentator did not produce expected file: {multilabel_liver_nii_path_local}")
        except Exception as e:
            print(f"Error running TotalSegmentator: {e}")
            try:
                reader_pet_check = sitk.ImageSeriesReader()
                uid_pet_check = reader_pet_check.GetGDCMSeriesIDs(pet_dir_local)[0]
                reader_pet_check.SetFileNames(reader_pet_check.GetGDCMSeriesFileNames(pet_dir_local, uid_pet_check))
                pet_ref_check = reader_pet_check.Execute()
                pet_shape = sitk.GetArrayFromImage(pet_ref_check).shape
                return np.zeros(pet_shape, dtype=bool), np.zeros(pet_shape, dtype=bool)
            except Exception as pet_e:
                 print(f"Could not read PET dimensions to create empty masks: {pet_e}")
                 return np.zeros((1,1,1), dtype=bool), np.zeros((1,1,1), dtype=bool)

    print(f"Loading multi-label liver segmentation from: {multilabel_liver_nii_path_local}")
    liver_multilabel_nii = sitk.ReadImage(multilabel_liver_nii_path_local)

    liver_binary_mask_nii = (liver_multilabel_nii >= 1) & (liver_multilabel_nii <= 8)
    liver_binary_mask_nii = sitk.Cast(liver_binary_mask_nii, sitk.sitkUInt8)
    print("Combined liver segments 1-8 into a single binary mask.")

    print(f"Loading CT series from: {ct_dir_local}")
    reader_ct = sitk.ImageSeriesReader()
    uid_ct = reader_ct.GetGDCMSeriesIDs(ct_dir_local)[0]
    reader_ct.SetFileNames(reader_ct.GetGDCMSeriesFileNames(ct_dir_local, uid_ct))
    ct_img = reader_ct.Execute()
    slope = float(ct_img.GetMetaData('0028|1053') if ct_img.HasMetaDataKey('0028|1053') else 1.0)
    intercept = float(ct_img.GetMetaData('0028|1052') if ct_img.HasMetaDataKey('0028|1052') else 0.0)
    ct_arr = sitk.GetArrayFromImage(ct_img)
    hu_arr = ct_arr * slope + intercept
    fat_ct_arr = ((hu_arr >= -190) & (hu_arr <= -20)).astype(np.uint8)
    fat_img = sitk.GetImageFromArray(fat_ct_arr)
    fat_img.CopyInformation(ct_img)
    print(f"Generated CT fat mask (shape: {fat_ct_arr.shape})")

    print(f"Loading PET series reference geometry from: {pet_dir_local}")
    reader_pet = sitk.ImageSeriesReader()
    uid_pet = reader_pet.GetGDCMSeriesIDs(pet_dir_local)[0]
    reader_pet.SetFileNames(reader_pet.GetGDCMSeriesFileNames(pet_dir_local, uid_pet))
    pet_ref = reader_pet.Execute()
    print(f"PET reference image loaded (shape: {sitk.GetArrayFromImage(pet_ref).shape})")

    print("Resampling masks to PET grid...")
    def resample_to_pet(img_to_resample, interpolation=sitk.sitkNearestNeighbor):
        ref_size = pet_ref.GetSize()
        ref_spacing = pet_ref.GetSpacing()
        ref_origin = pet_ref.GetOrigin()
        ref_direction = pet_ref.GetDirection()
        transform = sitk.Transform()
        resampled_img = sitk.Resample(
            img_to_resample, ref_size, transform, interpolation, ref_origin,
            ref_spacing, ref_direction, 0, img_to_resample.GetPixelID()
        )
        return resampled_img

    liver_rs = resample_to_pet(liver_binary_mask_nii, sitk.sitkNearestNeighbor)
    fat_rs = resample_to_pet(fat_img, sitk.sitkNearestNeighbor)
    print("Resampling complete.")

    liver_mask_pet = sitk.GetArrayFromImage(liver_rs).astype(bool)
    fat_mask_pet   = sitk.GetArrayFromImage(fat_rs).astype(bool)
    print(f"Final PET-space liver mask shape: {liver_mask_pet.shape}, "
          f"Fat mask shape: {fat_mask_pet.shape}")

    return liver_mask_pet, fat_mask_pet

In [39]:
# Cell 3.1 - Helper functions

def remove_small_3d_clusters(mask_3d, min_voxels_3d):
    if not mask_3d.any() or min_voxels_3d <= 0:
        return mask_3d

    labeled_mask, num_labels = ndimage_label(mask_3d)

    if num_labels == 0:
         return mask_3d

    label_sizes = np.bincount(labeled_mask.ravel())


    too_small_labels = np.where((label_sizes[1:] < min_voxels_3d))[0] + 1

    small_cluster_mask = np.isin(labeled_mask, too_small_labels)

    output_mask = mask_3d.copy()
    output_mask[small_cluster_mask] = False

    print(f"  Removed {len(too_small_labels)} 3D clusters smaller than {min_voxels_3d} voxels.")
    return output_mask


In [60]:
# Cell 4  ── BAT detection & suppression  ────────────────────────────

def suppress_bat(pet_series,
                                          liver_mask_pet, fat_mask_pet,
                                          liver_mult=1.2,
                                          min_3d_cluster_voxels=50):

    print(f"\n--- Starting suppress_bat ---")
    print(f"Total slices: {len(pet_series)}")
    print(f"Input liver mask shape: {liver_mask_pet.shape}, Any True: {liver_mask_pet.any()}")
    print(f"Input fat mask shape: {fat_mask_pet.shape}, Any True: {fat_mask_pet.any()}")
    print(f"Liver multiplier: {liver_mult}, Min 3D Cluster Voxels: {min_3d_cluster_voxels}")

    if not liver_mask_pet.any() or not fat_mask_pet.any():
        print("ERROR: Input liver or fat mask is empty!")
        return [arr for ds, arr in pet_series]

    # --- Step 1: Stack PET volume ---
    try:
        pet_volume_f32 = np.stack([arr.astype(np.float32, copy=False) for ds, arr in pet_series], axis=0)
        if pet_volume_f32.shape != liver_mask_pet.shape:
             raise ValueError(f"Shape mismatch: PET Volume {pet_volume_f32.shape} vs Masks {liver_mask_pet.shape}")
    except Exception as e:
        print(f"ERROR: Could not stack PET slices or shape mismatch: {e}")
        return [arr for ds, arr in pet_series]

    # --- Step 2: Calculate global medians ---
    pet_liver_values_3d = pet_volume_f32[liver_mask_pet]
    pet_fat_values_3d = pet_volume_f32[fat_mask_pet]
    if pet_liver_values_3d.size == 0 or pet_fat_values_3d.size == 0:
        print("ERROR: Masks cover zero voxels in PET volume!")
        return [arr for ds, arr in pet_series]

    global_liver_median = np.median(pet_liver_values_3d)
    global_adipose_median = np.median(pet_fat_values_3d)
    global_bat_threshold = liver_mult * global_liver_median

    print(f"Global Liver Median: {global_liver_median:.2f}")
    print(f"Global Adipose Median: {global_adipose_median:.2f}")
    print(f"Global BAT Threshold (>= {liver_mult} * liver_med): {global_bat_threshold:.2f}")
    target_patch_value = np.round(global_adipose_median).astype(pet_series[0][1].dtype)
    print(f"Target Patch Value: {target_patch_value}") # Verify patch value

    # --- Step 3: Determine Z-range based on liver mask ---
    liver_z_indices = np.where(np.any(liver_mask_pet, axis=(1, 2)))[0]
    if liver_z_indices.size == 0:
         print("WARNING: No liver mask found. Processing all slices.")
         min_process_z = 0
         max_process_z_ignored = pet_volume_f32.shape[0] -1
    else:
         # process from top slice down to the minimum index containing liver
         min_process_z = np.min(liver_z_indices)
         max_process_z_ignored = pet_volume_f32.shape[0] - 1
         print(f"Processing slices from index {max_process_z_ignored} down to {min_process_z} (inclusive, based on liver mask extent).")

    # --- Step 4: Build initial mask ---
    initial_bat_mask_3d = np.zeros_like(liver_mask_pet, dtype=bool)
    slices_considered_for_mask = 0
    # iterate through all slices, but only apply threshold if within range
    for idx in range(pet_volume_f32.shape[0]):
        if idx < min_process_z: # skip slices below the liver's minimum extent
             continue

        pet_f32_slice = pet_volume_f32[idx]
        fat_mask_slice = fat_mask_pet[idx]

        if not fat_mask_slice.any(): continue

        bat_slice = fat_mask_slice & (pet_f32_slice >= global_bat_threshold)
        initial_bat_mask_3d[idx] = bat_slice
        if bat_slice.any():
            slices_considered_for_mask += 1

    total_initial_candidates = np.sum(initial_bat_mask_3d)
    print(f"Found {total_initial_candidates} initial BAT candidate voxels across relevant slices (index >= {min_process_z}).")

    # --- Step 5: 3D cluster filtering ---
    final_bat_mask_3d = remove_small_3d_clusters(initial_bat_mask_3d, min_3d_cluster_voxels)
    total_final_bat_voxels = np.sum(final_bat_mask_3d)
    print(f"Total BAT voxels remaining after 3D cluster filter: {total_final_bat_voxels}")

    # --- Step 6: Patch slices ---
    new_pet_arrays = []
    slices_modified = 0
    for idx, (ds, pet_arr_orig) in enumerate(pet_series):

         bat_mask_slice = final_bat_mask_3d[idx]

         if bat_mask_slice.any():
             patched = pet_arr_orig.copy()
             patched[bat_mask_slice] = target_patch_value
             new_pet_arrays.append(patched)
             slices_modified += 1
         else:
              # No BAT on this slice after filtering, or slice outside Z range
              new_pet_arrays.append(pet_arr_orig)

    print(f"\n--- suppress_bat Finished ---")
    print(f"Slices actually modified (within index range >= {min_process_z}): {slices_modified}")
    print(f"Total BAT voxels patched (3D count): {total_final_bat_voxels}")
    print(f"-----------------------------\n")
    if slices_modified == 0:
        print("WARNING: No slices were modified.")
    if total_final_bat_voxels == 0 and total_initial_candidates > 0:
         print("WARNING: Initial BAT candidates found, but zero voxels remained after 3D cluster filtering.")

    return new_pet_arrays, final_bat_mask


In [61]:
# Cell 5

def save_modified_pet(pet_series, new_arrays, out_dir):
    """
    pet_series: list of (ds, _) tuples from load_series(pet_dir)
    new_arrays:  list of numpy arrays, same length and shape as pet_series
    """
    os.makedirs(out_dir, exist_ok=True)

    new_series_uid = generate_uid()

    for idx, ((ds, _), arr) in enumerate(zip(pet_series, new_arrays), start=1):

        rows, cols = int(ds.Rows), int(ds.Columns)
        if arr.shape != (rows, cols):
            raise ValueError(
                f"Slice {idx:03d} shape mismatch: arr is {arr.shape}, "
                f"but DICOM says {(rows, cols)}"
            )

        ds_mod = ds.copy()
        ds_mod.SeriesInstanceUID = new_series_uid
        ds_mod.SOPInstanceUID    = generate_uid()

        # enforce correct dimensions
        ds_mod.Rows    = rows
        ds_mod.Columns = cols


        ds_mod.PixelData = arr.tobytes()

        # 0001.dcm, 0002.dcm, …
        out_path = os.path.join(out_dir, f"{idx:04d}.dcm")
        ds_mod.save_as(out_path)

    print(f"Wrote {len(new_arrays)} full-resolution PET slices to {out_dir}")


In [62]:
# Utility: sort a (dataset, array) series by absolute z-position
def sort_series_by_z(series):
    return sorted(series,
                  key=lambda pair: float(pair[0].ImagePositionPatient[2]))

In [63]:
# Cell 6: New MIP

def generate_frontal_mip(dicom_directory, output_dicom_path):
    """
    Generates a frontal (coronal) MIP from correctly scaled data, displays it,
    and saves it as a standard-compliant DICOM file.
    """
    print(f"Generating MIP for '{dicom_directory}'...")
    series = sort_series_by_z(load_series(dicom_directory))

    if not series:
        print(f"No DICOM files found in '{dicom_directory}'.")
        return

    # create a scaled volume by applying each slice's specific slope/intercept
    scaled_slices = []
    for ds, arr in series:
        slope = float(ds.get('RescaleSlope', 1.0))
        intercept = float(ds.get('RescaleIntercept', 0.0))
        scaled_slices.append(arr.astype(np.float32) * slope + intercept)

    volume_scaled = np.stack(scaled_slices, axis=0)
    mip = np.max(volume_scaled, axis=1)

    template_ds = series[0][0]
    position_template_ds = series[-1][0]
    mip_rows, mip_cols = mip.shape

    mip_for_dicom = np.flipud(mip)


    mip_for_dicom[mip_for_dicom < 0] = 0

    vmax = np.max(mip_for_dicom)

    if vmax > 0:
        # Scale to fit in uint16 if values are large, otherwise cast directly
        if vmax > 65535:
             mip_for_dicom = (mip_for_dicom / vmax) * 65535

    mip_for_storage = mip_for_dicom.astype(np.uint16)

    wc = None
    ww = None
    if mip.max() > 0:

        lower_bound = np.percentile(mip[mip > 0], 33.8)
        upper_bound = np.percentile(mip[mip > 0], 97.4)

        ww = upper_bound - lower_bound
        wc = lower_bound + (ww / 2)

    file_meta = FileMetaDataset()
    file_meta.MediaStorageSOPClassUID = pydicom.uid.SecondaryCaptureImageStorage
    file_meta.MediaStorageSOPInstanceUID = generate_uid()
    file_meta.TransferSyntaxUID = ExplicitVRLittleEndian
    file_meta.ImplementationClassUID = pydicom.uid.PYDICOM_IMPLEMENTATION_UID

    ds = FileDataset(output_dicom_path, {}, file_meta=file_meta)

    ds.PatientName = template_ds.get('PatientName', 'UNKNOWN')
    ds.PatientID = template_ds.get('PatientID', 'UNKNOWN')
    ds.StudyInstanceUID = template_ds.StudyInstanceUID
    ds.SeriesInstanceUID = generate_uid()
    ds.SOPClassUID = file_meta.MediaStorageSOPClassUID
    ds.SOPInstanceUID = file_meta.MediaStorageSOPInstanceUID
    ds.StudyID = template_ds.get('StudyID', '1')
    ds.SeriesNumber = 999
    ds.Modality = "PT"
    dt = datetime.datetime.now()
    ds.ContentDate = dt.strftime('%Y%m%d')
    ds.ContentTime = dt.strftime('%H%M%S.%f')

    if wc is not None and ww is not None:
        ds.WindowCenter = str(round(wc, 4))
        ds.WindowWidth = str(round(ww, 4))

    ds.RescaleSlope = "1"
    ds.RescaleIntercept = "0"

    ds.Rows = mip_rows
    ds.Columns = mip_cols
    z_spacing = abs(series[1][0].ImagePositionPatient[2] - series[0][0].ImagePositionPatient[2])
    ds.PixelSpacing = [str(z_spacing), str(template_ds.PixelSpacing[1])]
    ds.ImageOrientationPatient = [1, 0, 0, 0, 0, -1]
    ds.ImagePositionPatient = position_template_ds.ImagePositionPatient
    ds.PhotometricInterpretation = "MONOCHROME2"
    ds.SamplesPerPixel = 1
    ds.PixelRepresentation = 0 # Unsigned integer
    ds.BitsAllocated = 16
    ds.BitsStored = 16
    ds.HighBit = 15

    ds.PixelData = mip_for_storage.tobytes()

    ds.save_as(output_dicom_path, enforce_file_format=True)
    print(f"MIP saved as a DICOM file to: {output_dicom_path}")

In [73]:
# Cell 7: Main execution

if not os.path.isdir(ct_dir):
    print(f"ERROR: CT directory not found: {ct_dir}")
elif not os.path.isdir(pet_dir):
    print(f"ERROR: PET directory not found: {pet_dir}")
else:
    print("Input directories confirmed.")

    # load DICOM PET series (sorted)
    pet_series = sort_series_by_z(load_series(pet_dir))

    # get liver & fat masks
    liver_mask_pet, fat_mask_pet = segment_liver_ct(ct_dir, pet_dir, multilabel_liver_nii_path)

    # scrub
    scrubbed, _ = suppress_bat(
        pet_series,
        liver_mask_pet,
        fat_mask_pet,
        liver_mult=1.2,
        min_3d_cluster_voxels=10 # hyper parameters, found these a good compromise
    )

    save_modified_pet(pet_series, scrubbed, output_dir)


Input directories confirmed.
Loading multi-label liver segmentation from: /content/c0138-main/dataset/3D/44/segm_dir/liver_segments.nii
Combined liver segments 1-8 into a single binary mask.
Loading CT series from: /content/c0138-main/dataset/3D/44/CT/
Generated CT fat mask (shape: (301, 512, 512))
Loading PET series reference geometry from: /content/c0138-main/dataset/3D/44/PET/
PET reference image loaded (shape: (301, 200, 200))
Resampling masks to PET grid...
Resampling complete.
Final PET-space liver mask shape: (301, 200, 200), Fat mask shape: (301, 200, 200)

--- Starting suppress_bat ---
Total slices: 301
Input liver mask shape: (301, 200, 200), Any True: True
Input fat mask shape: (301, 200, 200), Any True: True
Liver multiplier: 1.2, Min 3D Cluster Voxels: 10
Global Liver Median: 28290.00
Global Adipose Median: 4588.00
Global BAT Threshold (>= 1.2 * liver_med): 33948.00
Target Patch Value: 4588
Processing slices from index 300 down to 158 (inclusive, based on liver mask extent

In [65]:
# Cell 8: Generate, display, and save the MIP from the scrubbed PET series
mip_dcm_path = os.path.join(base_dicom_path, scan_number, f"MIP_new.dcm")
generate_frontal_mip(output_dir, mip_dcm_path)

Generating MIP for '/content/c0138-main/dataset/3D/44/PET_scrubbed/'...
MIP saved as a DICOM file to: /content/c0138-main/dataset/3D/44/MIP_new.dcm


In [71]:
# Cell 9: Unit and Component Tests

print("\n--- Running Unit & Component Tests for Phase 2 ---")

# Test 1: remove_small_3d_clusters function
try:
    # Create a 3D mask with 3 clusters of various size
    test_mask = np.zeros((20, 20, 20), dtype=bool)
    test_mask[1:2, 1:2, 1:6] = True   # cluster 1: 5 voxels
    test_mask[5:6, 5:8, 5:10] = True  # Cluster 2: 15 voxels
    test_mask[10:15, 10:15, 10:11] = True # cluster 3: 25 voxels

    assert np.sum(test_mask) == 45, "Test 1 PRE-CHECK FAILED: Initial mask voxel count is incorrect."

    min_voxels = 10
    filtered_mask = remove_small_3d_clusters(test_mask, min_voxels)

    # Expected result: Clusters 2 and 3 remain (15 + 25 = 40 voxels)
    expected_voxels = 40
    actual_voxels = np.sum(filtered_mask)

    assert actual_voxels == expected_voxels, f"Test 1 FAILED: Expected {expected_voxels} voxels, but {actual_voxels} remained."
    print("Test 1 PASSED: remove_small_3d_clusters correctly filters small clusters.")

except Exception as e:
    print(f"Test 1 FAILED: {e}")


# Test 2: Heuristic Test for suppress_bat function
try:
    # create a synthetic PET/CT volume
    shape = (50, 50, 50) # Z-index 0 = feet, 49 = head
    pet_vol = np.zeros(shape, dtype=np.uint16)
    fat_mask = np.zeros(shape, dtype=bool)
    liver_mask = np.zeros(shape, dtype=bool)

    # define regions
    liver_region = (slice(10, 15), slice(10, 15), slice(10, 15))
    lesion_region = (slice(30, 35), slice(30, 35), slice(30, 35))
    bat_region = (slice(40, 45), slice(40, 45), slice(40, 45))

    pet_vol[:] = 100
    liver_mask[liver_region] = True
    pet_vol[liver_region] = 1000

    # define a general fat area
    general_fat_region = (slice(5, 30), slice(5, 30), slice(5, 30))
    fat_mask[general_fat_region] = True
    pet_vol[general_fat_region] = 500

    # add BAT
    fat_mask[bat_region] = True
    pet_vol[bat_region] = 2000

    fat_mask[lesion_region] = False
    pet_vol[lesion_region] = 2000

    dummy_ds = pydicom.dataset.Dataset()
    dummy_ds.Rows, dummy_ds.Columns = shape[1], shape[2]
    original_series_heuristic = [(dummy_ds, pet_vol[i]) for i in range(shape[0])]

    suppressed_arrays, final_bat_mask = suppress_bat(
        original_series_heuristic,
        liver_mask_pet=liver_mask,
        fat_mask_pet=fat_mask,
        liver_mult=1.5,
        min_3d_cluster_voxels=5
    )
    suppressed_vol = np.stack(suppressed_arrays)

    # --- Assertions ---
    suppressed_bat_val = np.mean(suppressed_vol[bat_region])
    assert 499 < suppressed_bat_val < 501, f"Test 2a FAILED: BAT not suppressed correctly. Mean value is {suppressed_bat_val}, expected ~500."

    original_lesion_val = np.mean(pet_vol[lesion_region])
    suppressed_lesion_val = np.mean(suppressed_vol[lesion_region])
    assert original_lesion_val == suppressed_lesion_val, "Test 2b FAILED: Lesion uptake was altered."

    background_region = (slice(0, 5), slice(0, 5), slice(0, 5))
    original_bg_val = np.mean(pet_vol[background_region])
    suppressed_bg_val = np.mean(suppressed_vol[background_region])
    assert original_bg_val == suppressed_bg_val, "Test 2c FAILED: Background tissue uptake was altered."

    print("Test 2 PASSED: suppress_bat function correctly and selectively suppresses synthetic BAT.")

except Exception as e:
    print(f"Test 2 FAILED: {e}")

print("-------------------------------------------------")


--- Running Unit & Component Tests for Phase 2 ---
  Removed 1 3D clusters smaller than 10 voxels.
Test 1 PASSED: remove_small_3d_clusters correctly filters small clusters.

--- Starting suppress_bat ---
Total slices: 50
Input liver mask shape: (50, 50, 50), Any True: True
Input fat mask shape: (50, 50, 50), Any True: True
Liver multiplier: 1.5, Min 3D Cluster Voxels: 5
Global Liver Median: 500.00
Global Adipose Median: 500.00
Global BAT Threshold (>= 1.5 * liver_med): 750.00
Target Patch Value: 500
Processing slices from index 49 down to 10 (inclusive, based on liver mask extent).
Found 125 initial BAT candidate voxels across relevant slices (index >= 10).
  Removed 0 3D clusters smaller than 5 voxels.
Total BAT voxels remaining after 3D cluster filter: 125

--- suppress_bat Finished ---
Slices actually modified (within index range >= 10): 5
Total BAT voxels patched (3D count): 125
-----------------------------

Test 2 PASSED: suppress_bat function correctly and selectively suppresse