In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

"""
🖌 **Apply External Colormaps to MRI SNR Images**
- Computes SNR and removes background noise.
- Enhances visualization using external colormap images.
- Saves processed images in a structured output directory.
"""

def compute_snr(image):
    """
    🧮 Computes Signal-to-Noise Ratio (SNR) for the given image.
    """
    signal_region = image[7:24, 10:20]  
    noise_region = image[2:31, 2:4]  

    noise = np.std(noise_region)
    max_signal = np.max(signal_region)
    snr = max_signal / noise if noise > 0 else 0

    return snr, max_signal, noise, noise_region

def apply_external_colormap(image, colormap_image, noise_threshold):
    """
    🎨 Applies external colormap image to grayscale MRI images.
    """
    image = np.array(image, dtype=np.float32)
    image[image < noise_threshold] = 0  

    norm_img = (image - noise_threshold) / (np.max(image) - noise_threshold + 1e-6)  
    norm_img[norm_img < 0] = 0  

    colormap_array = np.array(colormap_image.resize((image.shape[1], image.shape[0])))
    colored_img = (norm_img[..., None] * colormap_array).astype(np.uint8)
    return colored_img

def format_title(file_name, probe):
    """
    🏷 Formats filenames into readable labels (Metabolite + Timepoint).
    """
    time_map = {"t1": "1hr", "t2": "2hr", "t3": "3hr"}
    metabolite_map = {
        "p1": {"m1": "3FDGlucose", "m3": "3FDGluconic Acid", "m4": "3FDSorbitol"},
        "p2": {"m1": "3FDGalactose", "m3": "3FDGalactonic Acid", "m4": "3FDGalactitol"}
    }

    parts = file_name.replace(".tif", "").split("_")
    if len(parts) < 4:
        return file_name  

    time = time_map.get(parts[2], parts[2])
    metabolite = metabolite_map[probe].get(parts[3], parts[3])

    return f"{metabolite} - {time}"

def display_and_save_colored_images(images, colormap_files, probe, output_dir):
    """
    📸 **Displays and saves MRI images with colormap enhancement.**
    """
    os.makedirs(output_dir, exist_ok=True)  # Ensure output directory exists
    
    colormap_mapping = {
        "m1": colormap_files.get("m1"),  
        "m3": colormap_files.get("m3"),  
        "m4": colormap_files.get("m4")  
    }
    
    num_images = len(images)
    cols = min(3, num_images)  
    rows = (num_images // cols) + (num_images % cols > 0)

    fig, axes = plt.subplots(rows, cols, figsize=(cols * 5, rows * 5))
    axes = np.array(axes).reshape(-1)  

    all_colored_images = []

    for ax, (file_name, img) in zip(axes, images.items()):
        snr, max_signal, noise, noise_region = compute_snr(np.array(img))
        
        noise_threshold = np.mean(noise_region) + 3 * np.std(noise_region)  

        file_name = file_name.lower()
        colormap_image = None
        for key in colormap_mapping.keys():
            if f"_{key}.tif" in file_name:
                colormap_image = colormap_mapping[key]
                break

        if colormap_image:
            colored_img = apply_external_colormap(np.array(img), colormap_image, noise_threshold)
        else:
            colored_img = np.array(img)

        ax.imshow(colored_img)
        formatted_title = format_title(file_name, probe)
        ax.set_title(f"{formatted_title}\nSNR: {snr:.2f}", fontsize=14, fontweight="bold")
        ax.axis("off")

        # 📂 Save each processed image
        save_path = os.path.join(output_dir, f"Colored_{file_name}")
        Image.fromarray(colored_img).save(save_path)
        print(f"✅ Saved: {save_path}")

        all_colored_images.append(colored_img)

    for ax in axes[len(images):]:
        ax.axis("off")

    plt.tight_layout()

    # 📂 Save combined figure
    combined_save_path = os.path.join(output_dir, f"All_Colored_SNR_Images_{probe}.png")
    plt.savefig(combined_save_path, dpi=300)
    print(f"✅ Combined image saved: {combined_save_path}")

    plt.show()

# 🔹 **User Input: Define Data Paths**
repo_directory = r"YOUR_LOCAL_PATH_HERE/FMRI-Metabolism-Quantification-main"

# 🎯 **Choose probe ('p1' for 3FDG or 'p2' for 3FDGal)**
probe = "p1"

# 📂 **Set directories for images and colormap files**
example_images_dir = os.path.join(repo_directory, f"example_images_3FDG" if probe == "p1" else "example_images_3FDGal")
colormap_images_dir = os.path.join(repo_directory, f"colormaps_3FDG" if probe == "p1" else "colormaps_3FDGal")

# 📂 **Define output directory for colored images**
output_directory = os.path.join(repo_directory, "colored_snr_results")

# 🔍 **Load colormap images**
colormap_files = {
    "m1": Image.open(os.path.join(colormap_images_dir, "green_cropped_32.tif")) if probe == "p1" else Image.open(os.path.join(colormap_images_dir, "cyan_cropped_32.tif")),
    "m3": Image.open(os.path.join(colormap_images_dir, "red_cropped_32.tif")) if probe == "p1" else Image.open(os.path.join(colormap_images_dir, "magenta_cropped_32.tif")),
    "m4": Image.open(os.path.join(colormap_images_dir, "yellow_cropped_32.tif")) if probe == "p1" else Image.open(os.path.join(colormap_images_dir, "gold_cropped_32.tif"))
}

# 🔍 **Load and process images**
display_and_save_colored_images(images, colormap_files, probe, output_directory)
