In [1]:
import torch
from torch import nn
import torchvision.models as models

backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)

feature_extractor = nn.Sequential(*list(backbone.children())[:-1])
feature_extractor = feature_extractor.to("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
from classes import VisionModule
from config import configuration

vision_model_obj = VisionModule( feature_extraction_model=feature_extractor, configuration=configuration )

In [3]:
from classes import TextModule

text_model_obj = TextModule()

In [10]:
import os
import cv2
import numpy as np
import joblib


def test_architecture_pipeline(img_path):

    print("\n==================== TEST PIPELINE ====================")

    # ---------------------------------------------------
    # 1. Preprocess → get crops folder
    # ---------------------------------------------------
    crops_path = vision_model_obj.run_image_preprocessing(img_path)

    if crops_path is None or not os.path.exists(crops_path):
        print("[ERROR] run_image_preprocessing() did not return a valid folder.")
        return None


    # ---------------------------------------------------
    # 2. Prepare save dirs (TEMP STORAGE for test)
    # ---------------------------------------------------
    base = os.path.dirname(os.path.dirname(img_path))

    deep_dir = os.path.join(base, "test_deep_features")
    hand_dir = os.path.join(base, "test_hand_features")
    trans_dir = os.path.join(base, "test_transformer_features")

    os.makedirs(deep_dir, exist_ok=True)
    os.makedirs(hand_dir, exist_ok=True)
    os.makedirs(trans_dir, exist_ok=True)

    img_id = os.path.splitext(os.path.basename(img_path))[0]

    # ---------------------------------------------------
    # 3. Extract features for each quadrant
    # ---------------------------------------------------
    quadrant_list = ["Q1", "Q2", "Q3", "Q4"]
    quadrant_vectors = []
    handcrafted_all_quadrants = {}

    for q in quadrant_list:

        crop_path = os.path.join(crops_path, f"{img_id}_{q}.png")

        if not os.path.exists(crop_path):
            print(f"[WARN] Missing crop: {crop_path}")
            continue


        # LOAD CROP
        crop = cv2.imread(crop_path)
        crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)

        # ---------------------------------------------------
        # DEEP FEATURES
        # ---------------------------------------------------
        cnn_ready = vision_model_obj.preprocess_for_cnn(crop)
        deep_vec = vision_model_obj.extract_deep_features(
            tensor=cnn_ready,
            save_dir=deep_dir,
            img_name=img_id,
            quadrant=q
        )

        # ---------------------------------------------------
        # HANDCRAFT FEATURES
        # ---------------------------------------------------
        hand_vec = vision_model_obj.handcrafted_features(
            cropped_img_pth=crop_path,
            save_dir=hand_dir,
            img_name=img_id,
            quadrant=q
        )
        handcrafted_all_quadrants[q] = hand_vec

        hand_vec = np.array(list(hand_vec.values()), dtype=np.float32)

        # ---------------------------------------------------
        # TRANSFORMER FEATURES
        # ---------------------------------------------------
        trans_vec = text_model_obj.extract_transformer_features(
            img_path=crop_path,
            save_dir=trans_dir,
            img_name=img_id,
            quadrant=q
        )

        # ---------------------------------------------------
        # CONCATENATE 3 MODALITIES FOR THIS QUADRANT
        # ---------------------------------------------------
        fused_vec = np.concatenate([deep_vec, hand_vec, trans_vec], axis=0)
        quadrant_vectors.append(fused_vec)

    if len(quadrant_vectors) == 0:
        print("[ERROR] No valid quadrant vectors extracted.")
        return None

    # ---------------------------------------------------
    # 4. FINAL IMAGE FEATURE = AVERAGE OF QUADRANTS
    # ---------------------------------------------------
    final_feature_vec = np.mean(quadrant_vectors, axis=0)


    # ---------------------------------------------------
    # 5. LOAD TRAINED XGBOOST MODEL
    # ---------------------------------------------------
    model_path = r"saved_models/kc_classifier.pkl"
    if not os.path.exists(model_path):
        print("[ERROR] XGBoost model file not found:", model_path)
        return None

    model = joblib.load(model_path)

    # ---------------------------------------------------
    # 6. PREDICTION
    # ---------------------------------------------------
    pred_prob = model.predict_proba(final_feature_vec.reshape(1, -1))[0]
    pred_class = model.predict(final_feature_vec.reshape(1, -1))[0]

    label_map = {0: "normal", 1: "Keratoconus"}
    pred_label = label_map[int(pred_class)]

    print("\n==================== RESULT ====================")
    print(f"Prediction: {pred_label}")
    print(f"Probability: normal={pred_prob[0]:.4f}, KC={pred_prob[1]:.4f}")
    print("=================================================\n")

    # ---------------------------------------------------
    # 7. OUTPUTS FOR LLM REPORT
    # ---------------------------------------------------
    final_prediction = {
        "image_id": img_id,
        "predicted_label": pred_label,
        "probabilities": {
            "normal": float(pred_prob[0]),
            "keratoconus": float(pred_prob[1])
        },
        "handcrafted_summary": handcrafted_all_quadrants
    }
    
    plain_prompt = f"""
    You are a clinical ophthalmology assistant specialized in interpreting corneal topography,
    keratoconus patterns, and handcrafted corneal image features.
    
    Your task is to provide a structured, clinically accurate summary.
    
    ==============================
    INPUT DATA
    ==============================
    Image ID: {img_id}
    
    Model Output:
    - Predicted Class: {pred_label}
    - Probability (Normal): {final_prediction['probabilities']['normal']:.4f}
    - Probability (Keratoconus): {final_prediction['probabilities']['keratoconus']:.4f}
    
    Handcrafted Feature Summary (All Quadrants):
    {handcrafted_all_quadrants}
    
    ==============================
    YOUR TASK
    ==============================
    
    Using the data above, produce a concise but clinically meaningful interpretation that includes:
    
    1. **Overall Diagnostic Impression**
       - What the predicted class suggests clinically.
       - Whether the probability supports a strong or weak suspicion.
    
    2. **Quadrant-by-Quadrant Interpretation**
       For each quadrant explicitly mentioned in the handcrafted summary:
       - Describe what the color intensity, texture, dominant color, or extracted features may indicate.
       - Whether the quadrant shows signs consistent with:
         * normal corneal structure
         * mild irregularity
         * early keratoconus changes
         * focal thinning or localized steepening
         * abnormal reflection patterns
       - Explain what these abnormalities typically mean clinically.
    
    3. **Handcrafted Feature Interpretation**
       - Explain what the extracted handcrafted cues (e.g., color dominance, pixel distribution,
         contrast, asymmetry across quadrants) could imply about corneal shape or regularity.
       - Highlight any quadrant asymmetry (superior vs inferior, temporal vs nasal).
    
    4. **Final Clinical Summary**
       - Summarize the findings in professional medical language.
       - Clearly state whether the results lean toward a normal cornea, subclinical keratoconus,
         or keratoconus.
       - If appropriate, suggest what further clinical tests would normally be recommended
         (e.g., tomography, pachymetry, topographic maps).
    
    ==============================
    RULES
    ==============================
    - Do NOT hallucinate missing measurements (e.g., K1, K2, pachymetry) if not provided.
    - Base your interpretation strictly on the handcrafted features and quadrant descriptors.
    - Be concise but clinically meaningful.
    - Do not output code; only the clinical interpretation.
    
    Now provide the clinical interpretation:
    """

    
    return final_prediction, plain_prompt
    

In [11]:
from IPython.display import Markdown, display
import ipywidgets as widgets
from classes import GEMINIAGENT

def run_with_chat(img_path):
    # 1. Run your vision + XGBoost pipeline
    final_prediction, prompt_text = test_architecture_pipeline(img_path)

    # 2. Create Gemini agent
    agent = GEMINIAGENT(
        api_key="API KEYYYY HEEEREE !!!!",
        system_message="""
        You are a medical assistant and OCULUS - PENTACAM 4 Maps analysis.
        """,
        model="gemini-2.5-flash"
    )

    # 3. Auto clinical report
    report = agent.ask(prompt_text)
    display(Markdown("## Clinical Report\n\n" + report))

    # -------- CHAT MODE --------
    print("\n===== CHAT MODE =====")
    
    # Create output area for chat history
    chat_output = widgets.Output()
    
    input_box = widgets.Text(
        placeholder="Type your question...",
        description="You:",
        layout=widgets.Layout(width="500px")
    )

    send_button = widgets.Button(
        description="Send",
        button_style="success"
    )
    
    # Container for input controls
    input_container = widgets.HBox([input_box, send_button])

    def send_message(_):
        user_input = input_box.value.strip()
        if not user_input:
            return

        input_box.value = ""  # Clear box

        if user_input.lower() in ["exit", "quit"]:
            with chat_output:
                print("Chat ended.")
            return

        # Display user message
        with chat_output:
            display(Markdown(f"**You:** {user_input}"))
        
        # Get and display assistant reply
        try:
            reply = agent.ask(user_input)
            with chat_output:
                display(Markdown(f"**Assistant:** {reply}\n\n---\n"))
        except Exception as e:
            with chat_output:
                display(Markdown(f"**Error:** {str(e)}\n\n---\n"))

    # Allow Enter key to send message
    def handle_submit(sender):
        send_message(None)
    
    input_box.on_submit(handle_submit)
    send_button.on_click(send_message)
    
    # Display the interface
    display(input_container)
    display(chat_output)

In [12]:
run_with_chat("dataset/test/images/4.jpg")



Prediction: Keratoconus
Probability: normal=0.4631, KC=0.5369

Initializing GEMINIAGENT chat for model gemini-2.5-flash.
GEMINIAGENT initialized successfully for model gemini-2.5-flash.
GEMINIAGENT.ask() completed successfully with reply length 5393 chars.


## Clinical Report

**1. Overall Diagnostic Impression**

The model predicts a class of Keratoconus with a probability of 0.5369 (53.69%), while Normal is predicted at 0.4631 (46.31%). This indicates a slight lean towards keratoconus but with a relatively low confidence level, suggesting a weak suspicion rather than a strong diagnostic certainty based solely on the model's output.

**2. Quadrant-by-Quadrant Interpretation**

*   **Q1 (Superior-Temporal Quadrant):**
    *   This quadrant shows a relatively high average intensity (0.93) and a notable center-periphery ratio (1.176), indicating that the central portion of this quadrant is significantly brighter or steeper than its periphery. The radial symmetry is moderate (0.086).
    *   *Interpretation:* These features suggest localized irregularity or mild steepening within the superior-temporal region. It is not entirely consistent with a normal corneal structure due to the pronounced central-peripheral difference in reflectivity/steepness.

*   **Q2 (Superior-Nasal Quadrant):**
    *   Characterized by a high dominant saturation (0.865) and value (0.874), which might indicate a distinct reflection pattern. The average intensity is high (0.855) with a slightly elevated standard deviation (0.162) compared to Q1 and Q3. The center-periphery ratio is mildly elevated (1.038).
    *   *Interpretation:* While showing some intensity variation, this quadrant appears relatively more uniform than Q1 and Q4, with less pronounced central steepening. It may indicate mild irregularities but is closer to a normal corneal structure compared to other quadrants.

*   **Q3 (Inferior-Nasal Quadrant):**
    *   This quadrant exhibits high average intensity (0.905) but the lowest standard deviation (0.142) and the lowest radial symmetry value (0.028) among all quadrants, indicating high uniformity. The center-periphery ratio is mild (1.033).
    *   *Interpretation:* The features in this quadrant are highly consistent with a normal corneal structure, showing good uniformity and minimal signs of irregularity or steepening.

*   **Q4 (Inferior-Temporal Quadrant):**
    *   This quadrant displays a high average intensity (0.854) but, critically, the highest standard deviation of intensity (0.192) among all quadrants. This high standard deviation signifies significant variability in pixel intensity or reflection patterns. There's also a slight inferior-superior asymmetry (ratio 1.043) and moderate radial asymmetry (0.059).
    *   *Interpretation:* The pronounced texture irregularity and variability in intensity are highly suggestive of early keratoconus changes, focal thinning, or localized steepening in the inferior-temporal cornea, which is a classic location for ectasia. This indicates an abnormal reflection pattern.

**3. Handcrafted Feature Interpretation**

The handcrafted features highlight significant asymmetries and irregularities:

*   **Intensity Variability:** The most striking finding is the significantly elevated standard deviation of intensity in Q4 (inferior-temporal quadrant), indicating considerable local irregularity and variations in corneal surface texture or steepness. This contrasts with the higher uniformity in Q3.
*   **Central Steepening:** All quadrants show a center-periphery ratio greater than 1, implying a general trend of central steepening or brighter reflections compared to the periphery. This effect is most pronounced in Q1 (superior-temporal).
*   **Inferior-Superior Asymmetry:** Q1 (1.055) and Q4 (1.043) show a subtle inferior-superior asymmetry in intensity, suggesting that the inferior aspects of these quadrants are slightly brighter/steeper. This pattern, particularly in the inferior-temporal region (Q4), is a common early indicator of keratoconus.
*   **Overall Asymmetry:** The differing radial symmetry values across quadrants (lowest in Q3, highest in Q1) further confirm the non-uniform nature of the corneal surface. The overall distribution of irregularities, with the inferior-temporal quadrant being the most irregular, is highly suspicious.

**4. Final Clinical Summary**

The combined findings, particularly the model's suspicion of keratoconus and the detailed analysis of handcrafted features, point towards a cornea with significant irregularities. The most concerning features are the marked intensity variability and subtle inferior-superior steepening in the inferior-temporal quadrant (Q4), alongside localized central steepening in the superior-temporal quadrant (Q1). While Q3 appears normal, the overall picture suggests a lack of corneal regularity and symmetry.

Based on these results, the findings lean towards **subclinical keratoconus or early keratoconus**.

**Recommended Further Clinical Tests:**
To confirm this suspicion and provide a definitive diagnosis, the following tests would be highly recommended:
*   **Corneal Tomography:** Essential for assessing posterior corneal elevation, pachymetry maps (to identify localized thinning), and comprehensive topographic patterns.
*   **Corneal Pachymetry:** To measure corneal thickness and identify areas of localized thinning, especially in the inferior-temporal region.
*   **Corneal Biomechanics:** To evaluate the cornea's resistance to deformation, which can detect early ectatic changes.
*   **Serial Topographic Maps:** To monitor for progression over time if initial findings are borderline.


===== CHAT MODE =====


  input_box.on_submit(handle_submit)


HBox(children=(Text(value='', description='You:', layout=Layout(width='500px'), placeholder='Type your questio…

Output()