In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch

from monai.transforms import (
    Compose, LoadImage, EnsureChannelFirst, Orientation, CropForeground,
    Resize, ScaleIntensityRangePercentiles, KeepLargestConnectedComponent
)

In [None]:
def build_transforms(use_lcc=False, target_shape=(128,128,128)):
    tfm = [
        LoadImage(image_only=True),
        EnsureChannelFirst(),
        Orientation("RAS"),
        CropForeground(),
        Resize(target_shape),
        ScaleIntensityRangePercentiles(lower=1, upper=99, b_min=0.0, b_max=1.0, clip=True),
    ]
    if use_lcc:
        tfm.append(KeepLargestConnectedComponent(applied_labels=[1]))
    return Compose(tfm)

tfm_no_lcc  = build_transforms(use_lcc=False)
tfm_with_lcc = build_transforms(use_lcc=True)

In [None]:
def plot_histograms(original, masked, out_path):
    orig = original.ravel()
    maskd = masked.ravel()

    plt.figure(figsize=(10,5))
    plt.hist(orig, bins=200, alpha=0.5, label="Original", color="blue")
    plt.hist(maskd, bins=200, alpha=0.5, label="After LCC", color="orange")
    plt.yscale("log")
    plt.xlabel("Intensity")
    plt.ylabel("Voxel count (log scale)")
    plt.title("Intensity Histogram: Before vs After LCC")
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_path, dpi=150)
    plt.close()


def show_slices(original, masked, out_path):
    o = original
    m = masked
    D, H, W = o.shape
    z, y, x = D//2, H//2, W//2

    fig, axes = plt.subplots(2, 3, figsize=(12, 6))

    slices = [
        (o[z,:,:], m[z,:,:], "Axial"),
        (o[:,y,:], m[:,y,:], "Coronal"),
        (o[:,:,x], m[:,:,x], "Sagittal"),
    ]

    for col, (s1, s2, name) in enumerate(slices):
        axes[0, col].imshow(s1, cmap="hot")
        axes[0, col].set_title(f"{name} – Original")
        axes[0, col].axis("off")

        axes[1, col].imshow(s2, cmap="hot")
        axes[1, col].set_title(f"{name} – After LCC")
        axes[1, col].axis("off")

    plt.tight_layout()
    plt.savefig(out_path, dpi=150)
    plt.close()

In [None]:
def compare_lcc(IMG_PATH, OUT_DIR):
    os.makedirs(OUT_DIR, exist_ok=True)
    
    img_no_lcc = tfm_no_lcc(IMG_PATH)
    img_with_lcc = tfm_with_lcc(IMG_PATH)

    # Convert MetaTensor → Tensor → numpy
    img_no_lcc_np  = img_no_lcc.squeeze().cpu().numpy()
    img_with_lcc_np = img_with_lcc.squeeze().cpu().numpy()

    hist_path = os.path.join(OUT_DIR, "histogram_before_after_LCC.png")
    slice_path = os.path.join(OUT_DIR, "slices_before_after_LCC.png")

    plot_histograms(img_no_lcc_np, img_with_lcc_np, hist_path)
    show_slices(img_no_lcc_np, img_with_lcc_np, slice_path)

    print("Done!")
    print(f"Histogram saved to: {hist_path}")
    print(f"Slice comparisons saved to: {slice_path}")

In [None]:
IMG_PATH = ""
OUT_DIR = ""

compare_lcc(IMG_PATH, OUT_DIR)