In [None]:
from PIL import Image
import numpy as np
import os
import matplotlib.pyplot as plt

def merge_mri_with_signal(mri_image_path, signal_images_folder, output_folder, alpha=0.6, brightness_factor=1.5, signal_boost=1.2):
    """
    Merges MRI anatomical image with processed (colored) signal images, overlaying the signal map on the anatomical scan.

    Parameters:
    - mri_image_path: Path to the MRI grayscale image.
    - signal_images_folder: Path to the folder containing processed colored signal images.
    - output_folder: Directory to save merged images.
    - alpha: Transparency level for the signal overlay (0 = fully transparent, 1 = fully opaque).
    - brightness_factor: Factor to increase MRI image brightness.
    - signal_boost: Factor to increase signal intensity.

    Output:
    - Saves merged images where the colored signal is blended over the MRI scan.
    - Displays the merged images in a well-formatted plot with proper labels.
    """
    os.makedirs(output_folder, exist_ok=True)  # Ensure output directory exists

    # Load MRI image as grayscale and normalize
    mri_image = Image.open(mri_image_path)
    mri_array = np.array(mri_image, dtype=np.float32)
    mri_array = (mri_array - np.min(mri_array)) / (np.max(mri_array) - np.min(mri_array))
    mri_array = np.clip(mri_array * brightness_factor * 255, 0, 255)

    merged_images = []
    titles = []

    # Mapping for metabolite and time
    metabolite_map = {"m1": "3FDGlucose", "m3": "3FDGluconic Acid", "m4": "3FDSorbitol"}
    time_map = {"t1": "1h", "t2": "2h", "t3": "3h"}

    # 🔍 Find signal images dynamically
    signal_files = [f for f in sorted(os.listdir(signal_images_folder)) if f.startswith("Colored_71") and f.endswith(".tif")]

    if not signal_files:
        print("⚠ No valid images found for merging. Please check the filenames and directory.")
        return

    for file_name in signal_files:
        signal_image_path = os.path.join(signal_images_folder, file_name)
        signal_img = Image.open(signal_image_path).convert("RGB")  # Ensure signal image is in RGB

        # Extract metabolite and time information
        metabolite_key = next((key for key in metabolite_map if key in file_name), "Unknown")
        time_key = next((key for key in time_map if key in file_name), "Unknown")

        title = f"{metabolite_map.get(metabolite_key, 'Unknown')} - {time_map.get(time_key, 'Unknown')}"
        titles.append(title)

        # Resize signal image to match MRI
        signal_img_resized = signal_img.resize(mri_image.size, Image.BICUBIC)
        signal_array = np.array(signal_img_resized, dtype=np.float32)

        # Boost signal intensity
        signal_array = np.clip(signal_array * signal_boost, 0, 255)

        # Merge images with transparency
        merged_img = ((1 - alpha) * mri_array[..., None] + alpha * (signal_array / 255.0) * 255)
        merged_img = np.clip(merged_img, 0, 255).astype(np.uint8)

        # Save merged image
        save_path = os.path.join(output_folder, f"merged_{file_name}")
        Image.fromarray(merged_img).save(save_path)
        print(f"✅ Saved merged image: {save_path}")

        merged_images.append(merged_img)

    # Display merged images (Ensure only 9 images are shown)
    num_images = min(len(merged_images), 9)
    if num_images == 0:
        print("⚠ No images to display.")
        return

    cols = 3
    rows = (num_images + cols - 1) // cols

    fig, axes = plt.subplots(rows, cols, figsize=(15, 5 * rows))
    axes = axes.flatten()

    for ax, img, title in zip(axes[:num_images], merged_images[:num_images], titles[:num_images]):
        ax.imshow(img)
        ax.set_title(title, fontsize=12)
        ax.axis("off")

    # Hide any unused subplots
    for ax in axes[num_images:]:
        fig.delaxes(ax)

    plt.tight_layout()

    # Save combined figure
    combined_save_path = os.path.join(output_folder, "Merged_All_Images.png")
    plt.savefig(combined_save_path, dpi=300)
    print(f"✅ Combined merged image saved at: {combined_save_path}")

    plt.show()


# 🔹 **Define Paths for MRI and Signal Images**
repo_directory = r"YOUR_LOCAL_PATH_HERE/FMRI-Metabolism-Quantification-main" # ← UPDATE THIS PATH!

# 🎯 **Choose probe ('p1' for 3FDG or 'p2' for 3FDGal)**
probe = "p1"  # Change to "p2" for 3FDGal

# 📂 **Set directories for MRI images**
mri_images_folder = os.path.join(repo_directory, f"example_images_{'3FDG' if probe == 'p1' else '3FDGal'}")

# 🔍 Set the exact MRI image filename based on the probe
mri_image_filename = "brain_image_3FDG.tif" if probe == "p1" else "brain_image_3FDGal.tif"
mri_image_path = os.path.join(mri_images_folder, mri_image_filename)

# ❌ If the file doesn't exist, raise an error
if not os.path.exists(mri_image_path):
    raise FileNotFoundError(f"❌ ERROR: Expected MRI image '{mri_image_filename}' not found in {mri_images_folder}")

print(f"✅ Using MRI image: {mri_image_path}")


# 📂 **Path to the directory containing colored signal images**
signal_images_folder = os.path.join(repo_directory, "output_images")  # Directory with processed colored SNR images
output_directory = os.path.join(repo_directory, "merged_signal_brain")  # Folder for merged images

# 🔍 **Run the merge function**
merge_mri_with_signal(mri_image_path, signal_images_folder, output_directory)
