In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch

from monai.transforms import (
    Compose, LoadImage, EnsureChannelFirst, Orientation, CropForeground,
    Resize, ScaleIntensityRangePercentiles, KeepLargestConnectedComponent, Lambda
)

In [None]:
import torch
import scipy.ndimage as ndi

def brain_outer_mask(x):
    """
    x: torch tensor [1, D, H, W]
    Returns mask of same shape: 1 inside brain, 0 outside.
    """
    vol = x.squeeze().cpu().numpy()

    # Simple foreground threshold (very safe)
    thr = vol.mean() * 0.2
    init_mask = vol > thr

    # Largest connected component (outer boundary)
    lbl, n = ndi.label(init_mask) # label connected components, lbl=labels, n=number of components, by default 4-connectivity for 2D, 6-connectivity for 3D
    if n < 1: return torch.ones_like(x)  # fallback: no masking

    largest = (lbl == np.argmax(np.bincount(lbl.flat)[1:]) + 1) # boolean array [D, H, W], largest connected component

    # Fill interior holes (this is the fix!)
    filled = ndi.binary_fill_holes(largest) # fill holes in binary object, 6-connectivity for 3D

    # erode the mask → removes skull!
    # eroded = ndi.binary_erosion(filled, iterations=2)  # 1–2 voxels is ideal, shrinks the forefround region (1s) by one layer of voxels for each iteration

    # Step 4: return as tensor mask
    mask = torch.from_numpy(filled).float().unsqueeze(0)
    return mask.to(x.device)


In [None]:
def build_transforms(step=False, target_shape=(128,128,128)):
    tfm = [
        LoadImage(image_only=True),
        EnsureChannelFirst(),
        Orientation("RAS"),
        CropForeground(),
    ]
    if step:
        tfm.append(Resize(target_shape))
        tfm.append(ScaleIntensityRangePercentiles(lower=1, upper=99, b_min=0, b_max=1.0, clip=True))
        if step=='remove_bg':
            tfm.append(Lambda(lambda x: x * brain_outer_mask(x)))
            #tfm.append(KeepLargestConnectedComponent(applied_labels=[1]))
    return Compose(tfm)

tfm_orig = build_transforms(step=False)
tfm_no_lcc  = build_transforms(step='norm')
tfm_with_lcc = build_transforms(step='remove_bg')

In [None]:
def plot_histograms(original, masked, out_path):
    orig = original.ravel()
    maskd = masked.ravel()

    plt.figure(figsize=(10,5))
    plt.hist(orig, bins=200, label="Original", color="blue")
    plt.hist(maskd, bins=200, alpha=0.5, label="After LCC", color="orange")
    plt.yscale("log")
    plt.xlabel("Intensity")
    plt.ylabel("Voxel count (log scale)")
    plt.title("Intensity Histogram: Before vs After LCC")
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_path, dpi=150)
    plt.close()


def show_slices(original, normed, masked_image, mask, out_path):
    o = original
    n = normed
    mi = masked_image
    m = mask
    D, H, W = o.shape
    z, y, x = D//2, H//2, W//2

    fig, axes = plt.subplots(4, 3, figsize=(12, 12))

    slices = [
        (o[z,:,:], n[z,:,:], mi[z,:,:], m[z,:,:], "Axial"),
        (o[:,y,:], n[:,y,:], mi[:,y,:], m[:,y,:], "Coronal"),
        (o[:,:,x], n[:,:,x], mi[:,:,x], m[:,:,x],"Sagittal"),
    ]

    for col, (s1, s2, s3, s4, name) in enumerate(slices):
        axes[0, col].imshow(s1, cmap="hot")
        axes[0, col].set_title(f"{name} – Original")
        axes[0, col].axis("off")

        axes[1, col].imshow(s2, cmap="hot")
        axes[1, col].set_title(f"{name} – Normalized")
        axes[1, col].axis("off")

        axes[2, col].imshow(s3, cmap="hot")
        axes[2, col].set_title(f"{name} – Removed Background")
        axes[2, col].axis("off")

        axes[3, col].imshow(s4, cmap="hot")
        axes[3, col].set_title(f"{name} – Mask")
        axes[3, col].axis("off")

    plt.tight_layout()
    plt.savefig(out_path, dpi=150)
    plt.close()

In [None]:
def compare_lcc(IMG_PATH, OUT_DIR):
    os.makedirs(OUT_DIR, exist_ok=True)
    image_name = os.path.basename(IMG_PATH)

    img_orig = tfm_orig(IMG_PATH)
    img_no_lcc = tfm_no_lcc(IMG_PATH)
    img_with_lcc = tfm_with_lcc(IMG_PATH)
    img_mask = brain_outer_mask(img_no_lcc)

    # Convert MetaTensor → Tensor → numpy
    img_orig_np = img_orig.squeeze().cpu().numpy()
    img_no_lcc_np  = img_no_lcc.squeeze().cpu().numpy()
    img_with_lcc_np = img_with_lcc.squeeze().cpu().numpy()
    mask = img_mask.squeeze().cpu().numpy()

    hist_path = os.path.join(OUT_DIR, f"histogram_before_after_LCC_{image_name}.png")
    slice_path = os.path.join(OUT_DIR, f"slices_before_after_LCC_{image_name}.png")

    plot_histograms(img_no_lcc_np, img_with_lcc_np, hist_path)
    show_slices(img_orig_np, img_no_lcc_np, img_with_lcc_np, mask, slice_path)

    print("Done!")
    print(f"Histogram saved to: {hist_path}")
    print(f"Slice comparisons saved to: {slice_path}")

In [None]:
OUT_DIR = "/Users/yu7637xi/Desktop/Projects/models/AI-PET/results/test_brainmask/"
image_path_list = ["/PET/IDEAS/56228/PET_Florbetaben_Conv,_Non-Rigid_Reg_to_Std_Img_Vox_Size,_Uniform_Res,_Inten_Norm/1940-11-27_14_08_58.0/I10426629/IDEAS_56228_20240410113725466.nii",
                   "/PET/IDEAS/57735/PET_Florbetaben_Conv,_Non-Rigid_Reg_to_Std_Img_Vox_Size,_Uniform_Res,_Inten_Norm/1940-08-25_13_41_31.0/I10427747/IDEAS_57735_20240410140950159.nii",
                   "/PET/IDEAS/62368/PET_Florbetaben_Conv,_Non-Rigid_Reg_to_Std_Img_Vox_Size,_Uniform_Res,_Inten_Norm/1939-09-26_14_32_05.0/I10431031/IDEAS_62368_20240411091930342.nii"
]

for IMG_PATH in image_path_list:
    print(f"Processing image: {IMG_PATH}")
    compare_lcc(IMG_PATH, OUT_DIR)