In [None]:
from PIL import Image, ImageEnhance
import numpy as np
import os
import matplotlib.pyplot as plt

"""
üß† **Merge MRI Brain Image with Colored Signal Maps**
- Overlays **processed signal maps** on top of anatomical MRI scans.
- Uses **adjustable transparency** for better visualization.
- Saves **merged images** individually & generates a **combined figure**.
"""

def merge_mri_with_signal(mri_image_path, signal_images_folder, output_folder, probe, alpha=0.6, brightness_factor=1.5, signal_boost=1.2):
    """
    üìå **Merges anatomical MRI images with processed colored SNR signal images.**
    
    ‚úÖ **Input Parameters**:
    - `mri_image_path`: Path to the grayscale anatomical MRI image (RARE scan).
    - `signal_images_folder`: Directory containing **processed colored SNR** images.
    - `output_folder`: Directory where **merged images** will be saved.
    - `probe`: **"p1"** for 3FDG or **"p2"** for 3FDGal.
    - `alpha`: Transparency level for overlaying the signal map.
    - `brightness_factor`: Brightness enhancement for the MRI image.
    - `signal_boost`: Intensity boost for colored signal images.

    ‚úÖ **Output**:
    - Saves **individual** merged images in the `merged_signal_brain/` directory.
    - Saves a **combined figure** containing up to **9 images**.
    """

    os.makedirs(output_folder, exist_ok=True)  # Ensure output directory exists

    # üñº **Load MRI grayscale image**
    mri_image = Image.open(mri_image_path)
    mri_array = np.array(mri_image, dtype=np.float32)
    
    # Normalize & Enhance MRI brightness
    mri_array = (mri_array - np.min(mri_array)) / (np.max(mri_array) - np.min(mri_array))
    mri_array = np.clip(mri_array * brightness_factor * 255, 0, 255)

    merged_images = []
    titles = []

    # üéØ **Metabolite & Time Mapping**
    metabolite_map = {
        "p1": {"m1": "3FDGlucose", "m3": "3FDGluconic Acid", "m4": "3FDSorbitol"},
        "p2": {"m1": "3FDGalactose", "m3": "3FDGalactonic Acid", "m4": "3FDGalactitol"}
    }
    time_map = {"t1": "1hr", "t2": "2hr", "t3": "3hr"}

    # üîç **Filter signal files based on selected probe**
    signal_files = [f for f in sorted(os.listdir(signal_images_folder)) if f"Colored_71{'2' if probe == 'p1' else '1'}" in f and f.endswith(".tif")]

    if not signal_files:
        print("‚ö† No valid signal images found for merging. Please check filenames & directory.")
        return

    for file_name in signal_files:
        signal_image_path = os.path.join(signal_images_folder, file_name)
        signal_img = Image.open(signal_image_path).convert("RGB")  # Ensure signal image is in RGB mode

        # Extract Metabolite & Time Information
        metabolite_key = next((key for key in metabolite_map[probe] if key in file_name), "Unknown")
        time_key = next((key for key in time_map if key in file_name), "Unknown")

        title = f"{metabolite_map[probe].get(metabolite_key, 'Unknown')} - {time_map.get(time_key, 'Unknown')}"
        titles.append(title)

        # üñº **Resize signal image to match MRI image**
        signal_img_resized = signal_img.resize(mri_image.size, Image.BICUBIC)
        signal_array = np.array(signal_img_resized, dtype=np.float32)

        # üé® **Boost signal intensity**
        signal_array = np.clip(signal_array * signal_boost, 0, 255)

        # üñå **Overlay signal on MRI using transparency (alpha)**
        merged_img = ((1 - alpha) * mri_array[..., None] + alpha * (signal_array / 255.0) * 255)
        merged_img = np.clip(merged_img, 0, 255).astype(np.uint8)

        # üìÇ **Save merged image**
        save_path = os.path.join(output_folder, f"merged_{file_name}")
        Image.fromarray(merged_img).save(save_path)
        print(f"‚úÖ Saved: {save_path}")

        merged_images.append(merged_img)

    # üìå **Display up to 9 merged images**
    num_images = min(len(merged_images), 9)
    if num_images == 0:
        print("‚ö† No images to display.")
        return

    cols = 3
    rows = (num_images + cols - 1) // cols

    fig, axes = plt.subplots(rows, cols, figsize=(15, 5 * rows))
    axes = axes.flatten()

    for ax, img, title in zip(axes[:num_images], merged_images[:num_images], titles[:num_images]):
        ax.imshow(img)
        ax.set_title(title, fontsize=12)
        ax.axis("off")

    # Hide any unused subplots
    for ax in axes[num_images:]:
        fig.delaxes(ax)

    plt.tight_layout()

    # üìÇ **Save combined figure**
    combined_save_path = os.path.join(output_folder, f"Merged_All_Images_{probe}.png")
    plt.savefig(combined_save_path, dpi=300)
    print(f"‚úÖ Combined merged image saved: {combined_save_path}")

    plt.show()

# üîπ **User Input: Define Data Paths**
repo_directory = r"YOUR_LOCAL_PATH_HERE/FMRI-Metabolism-Quantification-main"

# üéØ **Choose probe ('p1' for 3FDG or 'p2' for 3FDGal)**
probe = "p1"  # Change to "p2" for 3FDGal

# üìÇ **Set directories**
mri_image_path = os.path.join(repo_directory, f"brain_image_3FDG.tif" if probe == "p1" else "brain_image_3FDGal.tif")
signal_images_folder = os.path.join(repo_directory, "colored_snr_results")  # Directory with processed colored SNR images
output_directory = os.path.join(repo_directory, "merged_signal_brain")  # New folder for merged images

# üîç **Merge MRI with processed signal images**
merge_mri_with_signal(mri_image_path, signal_images_folder, output_directory, probe)
