In [1]:
import torch
from torch import nn
import torchvision.models as models

backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)

feature_extractor = nn.Sequential(*list(backbone.children())[:-1])
feature_extractor = feature_extractor.to("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
from classes import VisionModule
from config import configuration

vision_model_obj = VisionModule( feature_extraction_model=feature_extractor, configuration=configuration )

In [3]:
from classes import TextModule

text_model_obj = TextModule()

In [9]:
import os
import cv2
import numpy as np
import joblib


def test_architecture_pipeline(img_path):

    print("\n==================== TEST PIPELINE ====================")

    # ---------------------------------------------------
    # 1. Preprocess → get crops folder
    # ---------------------------------------------------
    crops_path = vision_model_obj.run_image_preprocessing(img_path)

    if crops_path is None or not os.path.exists(crops_path):
        print("[ERROR] run_image_preprocessing() did not return a valid folder.")
        return None


    # ---------------------------------------------------
    # 2. Prepare save dirs (TEMP STORAGE for test)
    # ---------------------------------------------------
    base = os.path.dirname(os.path.dirname(img_path))

    deep_dir = os.path.join(base, "test_deep_features")
    hand_dir = os.path.join(base, "test_hand_features")
    trans_dir = os.path.join(base, "test_transformer_features")

    os.makedirs(deep_dir, exist_ok=True)
    os.makedirs(hand_dir, exist_ok=True)
    os.makedirs(trans_dir, exist_ok=True)

    img_id = os.path.splitext(os.path.basename(img_path))[0]

    # ---------------------------------------------------
    # 3. Extract features for each quadrant
    # ---------------------------------------------------
    quadrant_list = ["Q1", "Q2", "Q3", "Q4"]
    quadrant_vectors = []
    handcrafted_all_quadrants = {}

    for q in quadrant_list:

        crop_path = os.path.join(crops_path, f"{img_id}_{q}.png")

        if not os.path.exists(crop_path):
            print(f"[WARN] Missing crop: {crop_path}")
            continue


        # LOAD CROP
        crop = cv2.imread(crop_path)
        crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)

        # ---------------------------------------------------
        # DEEP FEATURES
        # ---------------------------------------------------
        cnn_ready = vision_model_obj.preprocess_for_cnn(crop)
        deep_vec = vision_model_obj.extract_deep_features(
            tensor=cnn_ready,
            save_dir=deep_dir,
            img_name=img_id,
            quadrant=q
        )

        # ---------------------------------------------------
        # HANDCRAFT FEATURES
        # ---------------------------------------------------
        hand_vec = vision_model_obj.handcrafted_features(
            cropped_img_pth=crop_path,
            save_dir=hand_dir,
            img_name=img_id,
            quadrant=q
        )
        handcrafted_all_quadrants[q] = hand_vec

        hand_vec = np.array(list(hand_vec.values()), dtype=np.float32)

        # ---------------------------------------------------
        # TRANSFORMER FEATURES
        # ---------------------------------------------------
        trans_vec = text_model_obj.extract_transformer_features(
            img_path=crop_path,
            save_dir=trans_dir,
            img_name=img_id,
            quadrant=q
        )

        # ---------------------------------------------------
        # CONCATENATE 3 MODALITIES FOR THIS QUADRANT
        # ---------------------------------------------------
        fused_vec = np.concatenate([deep_vec, hand_vec, trans_vec], axis=0)
        quadrant_vectors.append(fused_vec)

    if len(quadrant_vectors) == 0:
        print("[ERROR] No valid quadrant vectors extracted.")
        return None

    # ---------------------------------------------------
    # 4. FINAL IMAGE FEATURE = AVERAGE OF QUADRANTS
    # ---------------------------------------------------
    final_feature_vec = np.mean(quadrant_vectors, axis=0)


    # ---------------------------------------------------
    # 5. LOAD TRAINED XGBOOST MODEL
    # ---------------------------------------------------
    model_path = r"saved_models/kc_classifier.pkl"
    if not os.path.exists(model_path):
        print("[ERROR] XGBoost model file not found:", model_path)
        return None

    model = joblib.load(model_path)

    # ---------------------------------------------------
    # 6. PREDICTION
    # ---------------------------------------------------
    pred_prob = model.predict_proba(final_feature_vec.reshape(1, -1))[0]
    pred_class = model.predict(final_feature_vec.reshape(1, -1))[0]

    label_map = {0: "normal", 1: "Keratoconus"}
    pred_label = label_map[int(pred_class)]

    print("\n==================== RESULT ====================")
    print(f"Prediction: {pred_label}")
    print(f"Probability: normal={pred_prob[0]:.4f}, KC={pred_prob[1]:.4f}")
    print("=================================================\n")

    # ---------------------------------------------------
    # 7. OUTPUTS FOR LLM REPORT
    # ---------------------------------------------------
    final_prediction = {
        "image_id": img_id,
        "predicted_label": pred_label,
        "probabilities": {
            "normal": float(pred_prob[0]),
            "keratoconus": float(pred_prob[1])
        },
        "handcrafted_summary": handcrafted_all_quadrants
    }
    
    plain_prompt = f"""
    You are a clinical ophthalmology assistant specialized in interpreting corneal topography,
    keratoconus patterns, and handcrafted corneal image features.
    
    Your task is to provide a structured, clinically accurate summary.
    
    ==============================
    INPUT DATA
    ==============================
    Image ID: {img_id}
    
    Model Output:
    - Predicted Class: {pred_label}
    - Probability (Normal): {final_prediction['probabilities']['normal']:.4f}
    - Probability (Keratoconus): {final_prediction['probabilities']['keratoconus']:.4f}
    
    Handcrafted Feature Summary (All Quadrants):
    {handcrafted_all_quadrants}
    
    ==============================
    YOUR TASK
    ==============================
    
    Using the data above, produce a concise but clinically meaningful interpretation that includes:
    
    1. **Overall Diagnostic Impression**
       - What the predicted class suggests clinically.
       - Whether the probability supports a strong or weak suspicion.
    
    2. **Quadrant-by-Quadrant Interpretation**
       For each quadrant explicitly mentioned in the handcrafted summary:
       - Describe what the color intensity, texture, dominant color, or extracted features may indicate.
       - Whether the quadrant shows signs consistent with:
         * normal corneal structure
         * mild irregularity
         * early keratoconus changes
         * focal thinning or localized steepening
         * abnormal reflection patterns
       - Explain what these abnormalities typically mean clinically.
    
    3. **Handcrafted Feature Interpretation**
       - Explain what the extracted handcrafted cues (e.g., color dominance, pixel distribution,
         contrast, asymmetry across quadrants) could imply about corneal shape or regularity.
       - Highlight any quadrant asymmetry (superior vs inferior, temporal vs nasal).
    
    4. **Final Clinical Summary**
       - Summarize the findings in professional medical language.
       - Clearly state whether the results lean toward a normal cornea, subclinical keratoconus,
         or keratoconus.
       - If appropriate, suggest what further clinical tests would normally be recommended
         (e.g., tomography, pachymetry, topographic maps).
    
    ==============================
    RULES
    ==============================
    - Do NOT hallucinate missing measurements (e.g., K1, K2, pachymetry) if not provided.
    - Base your interpretation strictly on the handcrafted features and quadrant descriptors.
    - Be concise but clinically meaningful.
    - Do not output code; only the clinical interpretation.
    
    Now provide the clinical interpretation:
    """

    
    return final_prediction, plain_prompt
    

In [7]:
from IPython.display import Markdown, display
import ipywidgets as widgets
from classes import GEMINIAGENT

def run_with_chat(img_path):
    # 1. Run your vision + XGBoost pipeline
    final_prediction, prompt_text = test_architecture_pipeline(img_path)

    # 2. Create Gemini agent
    agent = GEMINIAGENT(
        api_key="HHHEEERRREEE TTTHHHEEE KKKEEEYYY",
        system_message="""
        You are a medical assistant and OCULUS - PENTACAM 4 Maps analysis.
        """,
        model="gemini-2.5-flash"
    )

    # 3. Auto clinical report
    report = agent.ask(prompt_text)
    display(Markdown("## Clinical Report\n\n" + report))

    # -------- CHAT MODE --------
    print("\n===== CHAT MODE =====")
    
    # Create output area for chat history
    chat_output = widgets.Output()
    
    input_box = widgets.Text(
        placeholder="Type your question...",
        description="You:",
        layout=widgets.Layout(width="500px")
    )

    send_button = widgets.Button(
        description="Send",
        button_style="success"
    )
    
    # Container for input controls
    input_container = widgets.HBox([input_box, send_button])

    def send_message(_):
        user_input = input_box.value.strip()
        if not user_input:
            return

        input_box.value = ""  # Clear box

        if user_input.lower() in ["exit", "quit"]:
            with chat_output:
                print("Chat ended.")
            return

        # Display user message
        with chat_output:
            display(Markdown(f"**You:** {user_input}"))
        
        # Get and display assistant reply
        try:
            reply = agent.ask(user_input)
            with chat_output:
                display(Markdown(f"**Assistant:** {reply}\n\n---\n"))
        except Exception as e:
            with chat_output:
                display(Markdown(f"**Error:** {str(e)}\n\n---\n"))

    # Allow Enter key to send message
    def handle_submit(sender):
        send_message(None)
    
    input_box.on_submit(handle_submit)
    send_button.on_click(send_message)
    
    # Display the interface
    display(input_container)
    display(chat_output)

In [8]:
run_with_chat("dataset/test/images/4.jpg")



Prediction: Keratoconus
Probability: normal=0.4631, KC=0.5369

Initializing GEMINIAGENT chat for model gemini-2.5-flash.
GEMINIAGENT initialized successfully for model gemini-2.5-flash.
GEMINIAGENT.ask() completed successfully with reply length 3207 chars.


## Clinical Report

**Clinical Summary of Oculus Pentacam 4 Maps Analysis (Image ID: 4)**

The AI model predicts **Keratoconus** with a probability of 53.69%, indicating a slight lean towards ectatic corneal disease, though the probability is close to the normal threshold (46.31%). This suggests a borderline, early, or atypical presentation of Keratoconus.

**Key features supporting the Keratoconus prediction include:**

1.  **Localized Steepening/Thinning (Center-Periphery Ratio):** All four corneal quadrants exhibit a `center_periphery_ratio` greater than 1, implying that the central area of each quadrant is relatively steeper or thinner compared to its periphery. This finding is particularly pronounced in the **Superior Temporal (Q1)** quadrant (ratio of 1.176), suggesting a more significant localized conical protrusion or thinning in this region.
2.  **Corneal Asymmetry (Radial Symmetry):** The **Superior Temporal (Q1)** quadrant shows the highest `radial_symmetry` value (0.086), indicating a higher degree of radial asymmetry in this area, which is a common characteristic of irregular corneal surfaces found in ectatic diseases.
3.  **Dominant Color Features (Hue, Saturation, Value):** The **Superior Nasal (Q2)** quadrant displays remarkably high dominant saturation (0.86) and value (0.87). On standard Pentacam maps, such values often correspond to "hotter" colors (e.g., red, orange) which typically signify areas of significant steepening or elevation, suggesting a focal change in this quadrant.
4.  **Localized Inferior Steepening (Inferior-Superior Ratio):** While not a global index, the `inferior_superior_ratio` within the **Superior Temporal (Q1)** (1.055) and **Inferior Nasal (Q4)** (1.044) quadrants indicates that the inferior portions of these respective quadrants are slightly steeper or more elevated than their superior counterparts. This finding, especially in Q4, can be an indicator of early or localized inferior asymmetry, a known sign of keratoconus.

**Overall Impression:**
The analysis reveals several features indicative of corneal irregularity, localized steepening or thinning, and asymmetry, which are consistent with early or forme fruste Keratoconus. The most prominent findings are noted in the superior temporal (Q1) and superior nasal (Q2) quadrants, suggesting a possible atypical or superior cone presentation rather than the classic inferior-temporal location.

**Recommendation:**
Given the borderline probability and the presence of features suggestive of ectasia, clinical correlation is strongly recommended. This should include:
*   Detailed patient history, including family history of keratoconus and visual symptoms.
*   Best-corrected visual acuity and refraction, checking for irregular astigmatism.
*   Slit-lamp examination for subtle signs of keratoconus (e.g., Vogt's striae, Fleischer's ring, prominent corneal nerves).
*   Further review of the complete Pentacam 4 Maps (including Axial/Tangential Curvature, Elevation Maps, and Pachymetry Map) to assess the specific pattern and magnitude of corneal steepening, posterior elevation, and corneal thinning.
*   Consider serial examinations to monitor for progression, especially in younger patients.


===== CHAT MODE =====


  input_box.on_submit(handle_submit)


HBox(children=(Text(value='', description='You:', layout=Layout(width='500px'), placeholder='Type your questio…

Output()