In [None]:
# Install required packages
!pip install torch torchvision torchaudio
!pip install gradio
!pip install 'git+https://github.com/facebookresearch/detectron2.git'

In [None]:
import cv2
import numpy as np
import torch
import os
import json
import pandas as pd
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb
import gradio as gr

from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2 import model_zoo
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog

# Dictionary of color blindness simulation matrices
COLOR_BLINDNESS_MATRICES = {
    "protanopia": np.array([[0.567, 0.433, 0.0], [0.558, 0.442, 0.0], [0.0, 0.242, 0.758]]),
    "deuteranopia": np.array([[0.625, 0.375, 0.0], [0.7, 0.3, 0.0], [0.0, 0.3, 0.7]]),
    "tritanopia": np.array([[0.95, 0.05, 0.0], [0.0, 0.433, 0.567], [0.0, 0.475, 0.525]]),
    "protanomaly": np.array([[0.817, 0.183, 0.0], [0.333, 0.667, 0.0], [0.0, 0.125, 0.875]]),
    "deuteranomaly": np.array([[0.8, 0.2, 0.0], [0.258, 0.742, 0.0], [0.0, 0.142, 0.858]]),
    "tritanomaly": np.array([[0.967, 0.033, 0.0], [0.0, 0.733, 0.267], [0.0, 0.183, 0.817]]),
    "achromatopsia": np.array([[0.299, 0.587, 0.114], [0.299, 0.587, 0.114], [0.299, 0.587, 0.114]]),
    "achromatomaly": np.array([[0.618, 0.320, 0.062], [0.163, 0.775, 0.062], [0.163, 0.320, 0.516]])
}

# Create temporary directory for outputs
os.makedirs("/tmp/cvd_outputs", exist_ok=True)

def simulate_color_blindness(image, deficiency_type):
    """Simulate color blindness by applying a transformation matrix."""
    if deficiency_type not in COLOR_BLINDNESS_MATRICES:
        raise ValueError(f"Invalid deficiency type. Choose from {list(COLOR_BLINDNESS_MATRICES.keys())}.")

    # If image is a file path, read it
    if isinstance(image, str):
        img = cv2.imread(image)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    else:
        img = image.copy()
        if len(img.shape) == 2:  # Grayscale
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
        elif img.shape[2] == 4:  # RGBA
            img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)

    normalized_image = img.astype(float) / 255.0
    transform_matrix = COLOR_BLINDNESS_MATRICES[deficiency_type]
    transformed_image = np.dot(normalized_image.reshape(-1, 3), transform_matrix.T)
    transformed_image = np.clip(transformed_image, 0, 1).reshape(img.shape)
    transformed_image = (transformed_image * 255).astype(np.uint8)

    return transformed_image

def setup_detectron2_model():
    """Set up the Detectron2 configuration and predictor."""
    cfg = get_cfg()
    cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
    cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
    return DefaultPredictor(cfg), cfg

def detect_objects(image, predictor, cfg):
    """Detect objects using Detectron2 and return details."""
    # Convert to BGR for detectron2
    if len(image.shape) == 3 and image.shape[2] == 3:
        image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    else:
        image_bgr = image

    outputs = predictor(image_bgr)
    instances = outputs["instances"].to("cpu")
    boxes = instances.pred_boxes.tensor.numpy()
    scores = instances.scores.numpy()
    class_ids = instances.pred_classes.numpy()

    # Get class names from metadata
    metadata = MetadataCatalog.get(cfg.DATASETS.TRAIN[0])
    class_names = metadata.thing_classes

    # Visualize the detection output
    v = Visualizer(image_bgr, metadata=metadata, scale=1.2)
    vis_output = v.draw_instance_predictions(instances)
    vis_image = vis_output.get_image()
    vis_image = cv2.cvtColor(vis_image, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB for display

    # Format output
    detected_objects = []
    masks = instances.pred_masks.numpy() if instances.has("pred_masks") else None

    for i in range(len(boxes)):
        class_name = class_names[class_ids[i]] if class_ids[i] < len(class_names) else "Unknown"
        bbox_data = {
            "class": class_name,
            "confidence": float(scores[i]),
            "bounding_box": {
                "x1": float(boxes[i][0]),
                "y1": float(boxes[i][1]),
                "x2": float(boxes[i][2]),
                "y2": float(boxes[i][3]),
            }
        }

        if masks is not None:
            mask = (masks[i] * 255).astype(np.uint8)
            contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            bbox_data["edges"] = [contour.reshape(-1, 2).tolist() for contour in contours]

        detected_objects.append(bbox_data)

    return detected_objects, vis_image

def save_detected_objects_with_edges(detected_objects, output_path):
    """Save detected objects with edges to a JSON file."""
    edge_data = [obj for obj in detected_objects if "edges" in obj]
    with open(output_path, "w") as f:
        json.dump(edge_data, f, indent=2)
    return output_path

def analyze_iou_and_lab_similarity(image_rgb, json_path, color_threshold=20.0, iou_threshold=0):
    """Analyze detected objects based on IoU and LAB color similarity."""
    height, width = image_rgb.shape[:2]

    with open(json_path, "r") as f:
        objects = json.load(f)

    # Convert contour edges to mask
    def contour_to_mask(edges):
        mask = np.zeros((height, width), dtype=np.uint8)
        for edge in edges:
            contour = np.array(edge, dtype=np.int32).reshape((-1, 1, 2))
            cv2.drawContours(mask, [contour], -1, 255, thickness=cv2.FILLED)
        return mask

    # Get average LAB color for the mask
    def get_avg_lab(mask, image_rgb):
        masked_pixels = image_rgb[mask == 255]
        if masked_pixels.size == 0:
            return np.zeros(3)
        lab = rgb2lab(masked_pixels.reshape(-1, 1, 3) / 255.0).reshape(-1, 3)
        return np.mean(lab, axis=0)

    # Compute Intersection over Union (IoU)
    def compute_iou(mask1, mask2):
        inter = np.logical_and(mask1 == 255, mask2 == 255).sum()
        union = np.logical_or(mask1 == 255, mask2 == 255).sum()
        return inter / union if union > 0 else 0

    # Prepare masks and LAB values for each object
    masks = [contour_to_mask(obj["edges"]) for obj in objects]
    lab_means = [get_avg_lab(mask, image_rgb) for mask in masks]

    # Compare all pairs based on IoU and LAB color similarity
    similar_pairs = []
    for i in range(len(objects)):
        for j in range(i + 1, len(objects)):
            iou = compute_iou(masks[i], masks[j])
            if iou >= iou_threshold:
                deltaE = np.linalg.norm(lab_means[i] - lab_means[j])
                if deltaE < color_threshold:
                    similar_pairs.append({
                        "obj1_index": i,
                        "obj2_index": j,
                        "label1": objects[i]["class"],
                        "label2": objects[j]["class"],
                        "iou": float(iou),
                        "deltaE": float(deltaE)
                    })

    # Highlight edges of similar objects
    highlighted = image_rgb.copy()
    for pair in similar_pairs:
        for obj_id in [pair["obj1_index"], pair["obj2_index"]]:
            contours = [np.array(edge, np.int32).reshape((-1, 1, 2)) for edge in objects[obj_id]["edges"]]
            cv2.drawContours(highlighted, contours, -1, (0, 255, 0), thickness=3)

    return similar_pairs, highlighted, masks

def adjust_similar_object_colors(image_rgb, masks, similar_pairs, color_shift=60, shift_mode='auto'):
    """Shifts the color of one object in each similar-colored object pair."""
    # Convert to LAB for better color manipulation
    lab_image = rgb2lab(image_rgb.astype(np.float32) / 255.0)
    adjusted_image = lab_image.copy()
    adjusted_ids = set()

    for pair in similar_pairs:
        target_id = pair["obj2_index"]
        if target_id in adjusted_ids:
            continue  # Skip if already adjusted

        mask = masks[target_id]
        object_pixels = adjusted_image[mask == 255]

        if object_pixels.size == 0:
            continue  # Skip empty masks

        # Compute mean LAB color
        mean_color = np.mean(object_pixels, axis=0)

        # Determine shift direction
        if shift_mode == 'a':
            shift = np.array([0, color_shift, 0])
        elif shift_mode == 'b':
            shift = np.array([0, 0, color_shift])
        else:  # auto
            shift = np.array([0, 0, 0])
            if abs(mean_color[1]) < abs(mean_color[2]):
                shift[1] = color_shift  # a channel
            else:
                shift[2] = color_shift  # b channel

        # Apply shift
        for i in range(3):
            channel = adjusted_image[:, :, i]
            channel[mask == 255] = np.clip(channel[mask == 255] + shift[i], -128 if i > 0 else 0, 127 if i > 0 else 100)
            adjusted_image[:, :, i] = channel

        adjusted_ids.add(target_id)

    # Convert back to RGB
    adjusted_rgb = lab2rgb(adjusted_image)
    adjusted_rgb = (adjusted_rgb * 255).astype(np.uint8)

    return adjusted_rgb

def simulate_color_blindness_on_region(rgb_image, deficiency_type, region_mask):
    """Applies color blindness simulation only on a masked region of the image."""
    if deficiency_type not in COLOR_BLINDNESS_MATRICES:
        raise ValueError(f"Invalid deficiency type. Choose from {list(COLOR_BLINDNESS_MATRICES.keys())}.")

    transformed_image = rgb_image.copy().astype(np.float32) / 255.0
    mask_indices = np.where(region_mask == 255)

    # Extract only the masked pixels
    pixels = transformed_image[mask_indices]
    transformed_pixels = np.dot(pixels, COLOR_BLINDNESS_MATRICES[deficiency_type].T)
    transformed_pixels = np.clip(transformed_pixels, 0, 1)

    # Update only masked region in the image
    transformed_image[mask_indices] = transformed_pixels
    transformed_image = (transformed_image * 255).astype(np.uint8)

    return transformed_image

def get_default_shift_mode(deficiency_type):
    """Determine the optimal shift mode based on CVD type"""
    if deficiency_type in ['protanopia', 'deuteranopia', 'protanomaly', 'deuteranomaly']:
        # Red-green color blindness - shift on blue-yellow (b) axis
        return 'b'
    elif deficiency_type in ['tritanopia', 'tritanomaly']:
        # Blue-yellow color blindness - shift on red-green (a) axis
        return 'a'
    else:
        # For achromatopsia or achromatomaly, try both
        return 'auto'

def run_cvd_enhancement(image, deficiency_type, color_threshold, iou_threshold, color_shift, shift_mode, max_iterations):
    """Full pipeline for CVD enhancement"""
    # Initialize outputs
    results = {}

    # Convert uploaded image to RGB
    if image is None:
        return None, None, None, None, None, "Error: No image uploaded"

    if isinstance(image, str):
        image = cv2.imread(image)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    elif len(image.shape) == 2:  # Grayscale
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
    elif image.shape[2] == 4:  # RGBA
        image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)

    # Save original image
    original_path = "/tmp/cvd_outputs/original.jpg"
    cv2.imwrite(original_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
    results["original"] = image

    # Step 1: Simulate color blindness
    transformed_image = simulate_color_blindness(image, deficiency_type)
    results["cvd_simulated"] = transformed_image

    # Step 2: Setup model and detect objects
    predictor, cfg = setup_detectron2_model()
    detected_objects, vis_image = detect_objects(image, predictor, cfg)
    results["detection"] = vis_image

    # Step 3: Save detected objects
    json_path = "/tmp/cvd_outputs/detected_objects_with_edges.json"
    save_detected_objects_with_edges(detected_objects, json_path)

    # Step 4: Analyze similarity
    similar_pairs, highlighted, masks = analyze_iou_and_lab_similarity(
        transformed_image,
        json_path,
        color_threshold=float(color_threshold),
        iou_threshold=float(iou_threshold)
    )
    results["similar_objects"] = highlighted

    if not similar_pairs:
        return (results.get("original"),
                results.get("cvd_simulated"),
                results.get("detection"),
                results.get("similar_objects"),
                results.get("enhanced_image", transformed_image),
                f"No similar object pairs found. No enhancement needed.")

    # Step 5: Initialize for iterative enhancement
    iteration = 0
    adjusted_image = transformed_image.copy()

    log_messages = [f"Found {len(similar_pairs)} similar object pairs. Starting enhancement..."]

    # Step 6: Iterative enhancement
    while iteration < int(max_iterations) and similar_pairs:
        # Enhance colors
        adjusted_rgb = adjust_similar_object_colors(
            adjusted_image,
            masks,
            similar_pairs,
            color_shift=float(color_shift),
            shift_mode=shift_mode
        )

        # Create combined mask for all adjusted objects
        adjusted_mask = np.zeros(image.shape[:2], dtype=np.uint8)
        for pair in similar_pairs:
            adjusted_mask = cv2.bitwise_or(adjusted_mask, masks[pair["obj2_index"]])

        # Apply color blindness simulation to the adjusted regions
        adjusted_image = simulate_color_blindness_on_region(adjusted_rgb, deficiency_type, adjusted_mask)

        # Re-analyze similarity after adjustment
        similar_pairs, highlighted, masks = analyze_iou_and_lab_similarity(
            adjusted_image,
            json_path,
            color_threshold=float(color_threshold),
            iou_threshold=float(iou_threshold)
        )

        log_messages.append(f"Iteration {iteration + 1}: {len(similar_pairs)} similar pairs remaining")

        if len(similar_pairs) == 0:
            log_messages.append(f"All similar objects now distinguishable. Enhancement complete!")
            break

        iteration += 1

    results["enhanced_image"] = adjusted_image

    # Return results
    return (results.get("original"),
            results.get("cvd_simulated"),
            results.get("detection"),
            results.get("similar_objects"),
            results.get("enhanced_image"),
            "\n".join(log_messages))

# Create the Gradio interface
def create_interface():
    # Custom CSS for better styling
    custom_css = """
    .app-header {
        text-align: center;
        margin: 0 auto;
        max-width: 1200px;
        padding: 1rem 0;
    }

    .app-title {
        font-size: 2.5rem !important;
        font-weight: bold;
        margin-bottom: 0.5rem;
        color: #1f77b4;
    }

    .app-description {
        font-size: 1.2rem !important;
        margin-bottom: 2rem;
        max-width: 900px;
        margin-left: auto;
        margin-right: auto;
    }

    .how-it-works {
        background-color: #222222;
        color: #ffffff;
        border-radius: 8px;
        padding: 1.5rem;
        margin-bottom: 2rem;
        font-size: 1.2rem !important;
        max-width: 1000px;
        margin-left: auto;
        margin-right: auto;
        box-shadow: 0 2px 10px rgba(0,0,0,0.2);
    }

    .how-it-works h3 {
        text-align: center;
        font-size: 1.5rem !important;
        margin-bottom: 1rem;
        color: #ffffff;
    }

    .how-it-works ol {
        padding-left: 2rem;
    }

    .how-it-works li {
        margin-bottom: 0.5rem;
    }

    .how-it-works p {
        margin-top: 1rem;
    }

    /* Fix image scaling */
    .upload-image-container .wrap {
        display: flex !important;
        justify-content: center !important;
        align-items: center !important;
        height: auto !important;
        min-height: 450px !important;
    }

    .upload-image-container img {
        max-width: 100% !important;
        max-height: 100% !important;
        width: auto !important;
        height: auto !important;
        object-fit: contain !important;
    }

    .image-results-container {
        flex-wrap: wrap !important;
        justify-content: center !important;
    }

    /* Result images */
    .result-image-container {
        height: auto !important;
        min-height: 450px !important;
        display: flex !important;
        align-items: center !important;
        justify-content: center !important;
    }

    .result-image-container img {
        max-width: 100% !important;
        max-height: 100% !important;
        width: auto !important;
        height: auto !important;
        object-fit: contain !important;
        margin: 0 auto !important;
    }

    .results-header {
        text-align: center;
        font-size: 2rem !important;
        margin: 1.5rem 0;
    }

    .result-column h3 {
        text-align: center;
        font-size: 1.5rem !important;
        margin-bottom: 0.5rem;
    }

    .footer-button {
        display: block;
        margin: 2rem auto;
        max-width: 300px;
    }

    /* Tab styling */
    .detail-tabs .tab-nav {
        background-color: #f5f5f5;
        border-radius: 6px;
        padding: 5px;
    }

    .detail-tabs .tabitem {
        padding: 0.5rem !important;
    }
    """

    with gr.Blocks(title="Color Vision Deficiency Enhancement Tool", css=custom_css) as app:
        # Title and description section (always visible)
        with gr.Row(elem_classes=["app-header"]):
            with gr.Column():
                gr.Markdown("# Color Vision Deficiency Enhancement Tool", elem_classes=["app-title"])
                gr.Markdown(
                    "Transform images for better visibility by people with color vision deficiency. "
                    "This intelligent system detects objects in images, identifies those that might appear similar to people with color blindness, "
                    "and enhances color differences to make them more distinguishable.",
                    elem_classes=["app-description"]
                )

        # How It Works section - centralized above the input section
        with gr.Row(visible=True, elem_classes=["how-it-works-container"]) as how_it_works_section:
            with gr.Column():
                gr.Markdown("""
                <div class="how-it-works">
                <h3>How It Works</h3>

                <ol>
                <li><strong>Upload</strong> your image that needs enhancement</li>
                <li><strong>Select</strong> the type of color vision deficiency to accommodate</li>
                <li>The system <strong>detects objects</strong> in your image using AI</li>
                <li>It <strong>finds objects</strong> that might look similar to someone with the selected CVD</li>
                <li>Colors are <strong>enhanced</strong> to make objects more distinguishable</li>
                <li>View side-by-side <strong>comparisons</strong> of original, CVD simulation, and enhanced versions</li>
                </ol>

                <p>Fine-tune the enhancement with advanced settings for optimal results.</p>
                </div>
                """)

        # Input section
        with gr.Row(visible=True) as input_section:
            # Left panel - Image upload
            with gr.Column(scale=3):
                input_image = gr.Image(
                    label="Upload Image",
                    type="numpy",
                    elem_classes=["upload-image-container"],
                    height=450,
                    image_mode="RGB",
                    sources=["upload", "clipboard"]
                )

            # Right panel - Settings
            with gr.Column(scale=2):
                deficiency_type = gr.Dropdown(
                    choices=list(COLOR_BLINDNESS_MATRICES.keys()),
                    value="protanopia",
                    label="Color Vision Deficiency Type"
                )

                with gr.Accordion("Advanced Settings", open=False):
                    color_threshold = gr.Slider(minimum=5, maximum=50, value=20, step=1,
                                              label="Color Similarity Threshold (ΔE)")
                    iou_threshold = gr.Slider(minimum=0, maximum=0.5, value=0, step=0.01,
                                            label="Object Overlap Threshold (IoU)")
                    color_shift = gr.Slider(minimum=20, maximum=100, value=50, step=5,
                                          label="Color Shift Amount")
                    shift_mode = gr.Radio(["auto", "a", "b"], value="b",
                                        label="Color Shift Mode")
                    max_iterations = gr.Slider(minimum=1, maximum=10, value=3, step=1,
                                             label="Maximum Enhancement Iterations")

                # Update shift mode based on deficiency type
                deficiency_type.change(
                    fn=lambda x: gr.update(value=get_default_shift_mode(x)),
                    inputs=[deficiency_type],
                    outputs=[shift_mode]
                )

                run_button = gr.Button("Run Enhancement", variant="primary", size="lg")

        # Processing log (temporarily visible during processing)
        with gr.Row(visible=False) as processing_section:
            output_log = gr.Textbox(label="Processing Status", lines=3)

        # Results header
        with gr.Row(visible=False, elem_classes=["results-header"]) as results_header:
            gr.Markdown("## Results")

        # Main results row - showing 3 main images
        with gr.Row(visible=False, equal_height=True) as main_results:
            with gr.Column(elem_classes=["result-column"]):
                gr.Markdown("### Original")
                original_output = gr.Image(elem_classes=["result-image-container"])

            with gr.Column(elem_classes=["result-column"]):
                gr.Markdown("### CVD Simulation")
                cvd_output = gr.Image(elem_classes=["result-image-container"])

            with gr.Column(elem_classes=["result-column"]):
                gr.Markdown("### Enhanced for CVD")
                enhanced_output = gr.Image(elem_classes=["result-image-container"])

        # Optional detailed view with tabs (same scale as main results)
        with gr.Row(visible=False) as detailed_results:
            # Use a Column to make it span the full width
            with gr.Column():
                with gr.Tabs(elem_classes=["detail-tabs"]) as tabs:
                    # FIX: Remove the 'selected=True' parameter from TabItem
                    with gr.TabItem("Object Detection", id="tab-detection"):
                        detection_output = gr.Image(
                            elem_classes=["result-image-container"],
                            height=350
                        )

                    with gr.TabItem("Similar Objects", id="tab-similar"):
                        similar_output = gr.Image(
                            elem_classes=["result-image-container"],
                            height=350
                        )

        # Rerun button (visible after processing)
        with gr.Row(visible=False) as rerun_section:
            rerun_button = gr.Button("Process Another Image", variant="secondary", size="lg", elem_classes=["footer-button"])

        # Define workflow
        def show_processing():
            return {
                input_section: gr.update(visible=False),
                how_it_works_section: gr.update(visible=False),
                processing_section: gr.update(visible=True),
                results_header: gr.update(visible=False),
                main_results: gr.update(visible=False),
                detailed_results: gr.update(visible=False),
                rerun_section: gr.update(visible=False)
            }

        def show_results():
            return {
                input_section: gr.update(visible=False),
                how_it_works_section: gr.update(visible=False),
                processing_section: gr.update(visible=False),
                results_header: gr.update(visible=True),
                main_results: gr.update(visible=True),
                detailed_results: gr.update(visible=True),
                rerun_section: gr.update(visible=True)
            }

        def reset_interface():
            return {
                input_section: gr.update(visible=True),
                how_it_works_section: gr.update(visible=True),
                processing_section: gr.update(visible=False),
                results_header: gr.update(visible=False),
                main_results: gr.update(visible=False),
                detailed_results: gr.update(visible=False),
                rerun_section: gr.update(visible=False)
            }

        # Connect event handlers
        run_button.click(
            fn=show_processing,
            outputs=[input_section, how_it_works_section, processing_section, results_header, main_results, detailed_results, rerun_section]
        ).then(
            fn=run_cvd_enhancement,
            inputs=[input_image, deficiency_type, color_threshold, iou_threshold, color_shift, shift_mode, max_iterations],
            outputs=[original_output, cvd_output, detection_output, similar_output, enhanced_output, output_log]
        ).then(
            fn=show_results,
            outputs=[input_section, how_it_works_section, processing_section, results_header, main_results, detailed_results, rerun_section]
        )

        rerun_button.click(
            fn=reset_interface,
            outputs=[input_section, how_it_works_section, processing_section, results_header, main_results, detailed_results, rerun_section]
        )

    return app

# Load and run the interface
demo = create_interface()
demo.launch(debug=True, share=True)