In [67]:
import torch
from torch import nn
import torchvision.models as models

backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)

feature_extractor = nn.Sequential(*list(backbone.children())[:-1])
feature_extractor = feature_extractor.to("cuda" if torch.cuda.is_available() else "cpu")

In [68]:
from classes import VisionModule
from config import configuration

vision_model_obj = VisionModule( feature_extraction_model=feature_extractor, configuration=configuration )

In [104]:
from classes import TextModule

text_model_obj = TextModule()

In [143]:
import os
import cv2
import numpy as np
import joblib


def test_architecture_pipeline(img_path):

    print("\n==================== TEST PIPELINE ====================")

    # ---------------------------------------------------
    # 1. Preprocess → get crops folder
    # ---------------------------------------------------
    crops_path = vision_model_obj.run_image_preprocessing(img_path)

    if crops_path is None or not os.path.exists(crops_path):
        print("[ERROR] run_image_preprocessing() did not return a valid folder.")
        return None


    # ---------------------------------------------------
    # 2. Prepare save dirs (TEMP STORAGE for test)
    # ---------------------------------------------------
    base = os.path.dirname(os.path.dirname(img_path))

    deep_dir = os.path.join(base, "test_deep_features")
    hand_dir = os.path.join(base, "test_hand_features")
    trans_dir = os.path.join(base, "test_transformer_features")

    os.makedirs(deep_dir, exist_ok=True)
    os.makedirs(hand_dir, exist_ok=True)
    os.makedirs(trans_dir, exist_ok=True)

    img_id = os.path.splitext(os.path.basename(img_path))[0]

    # ---------------------------------------------------
    # 3. Extract features for each quadrant
    # ---------------------------------------------------
    quadrant_list = ["Q1", "Q2", "Q3", "Q4"]
    quadrant_vectors = []
    handcrafted_all_quadrants = {}

    for q in quadrant_list:

        crop_path = os.path.join(crops_path, f"{img_id}_{q}.png")

        if not os.path.exists(crop_path):
            print(f"[WARN] Missing crop: {crop_path}")
            continue


        # LOAD CROP
        crop = cv2.imread(crop_path)
        crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)

        # ---------------------------------------------------
        # DEEP FEATURES
        # ---------------------------------------------------
        cnn_ready = vision_model_obj.preprocess_for_cnn(crop)
        deep_vec = vision_model_obj.extract_deep_features(
            tensor=cnn_ready,
            save_dir=deep_dir,
            img_name=img_id,
            quadrant=q
        )

        # ---------------------------------------------------
        # HANDCRAFT FEATURES
        # ---------------------------------------------------
        hand_vec = vision_model_obj.handcrafted_features(
            cropped_img_pth=crop_path,
            save_dir=hand_dir,
            img_name=img_id,
            quadrant=q
        )
        handcrafted_all_quadrants[q] = hand_vec

        hand_vec = np.array(list(hand_vec.values()), dtype=np.float32)

        # ---------------------------------------------------
        # TRANSFORMER FEATURES
        # ---------------------------------------------------
        trans_vec = text_model_obj.extract_transformer_features(
            img_path=crop_path,
            save_dir=trans_dir,
            img_name=img_id,
            quadrant=q
        )

        # ---------------------------------------------------
        # CONCATENATE 3 MODALITIES FOR THIS QUADRANT
        # ---------------------------------------------------
        fused_vec = np.concatenate([deep_vec, hand_vec, trans_vec], axis=0)
        quadrant_vectors.append(fused_vec)

    if len(quadrant_vectors) == 0:
        print("[ERROR] No valid quadrant vectors extracted.")
        return None

    # ---------------------------------------------------
    # 4. FINAL IMAGE FEATURE = AVERAGE OF QUADRANTS
    # ---------------------------------------------------
    final_feature_vec = np.mean(quadrant_vectors, axis=0)


    # ---------------------------------------------------
    # 5. LOAD TRAINED XGBOOST MODEL
    # ---------------------------------------------------
    model_path = r"saved_models/kc_classifier.pkl"
    if not os.path.exists(model_path):
        print("[ERROR] XGBoost model file not found:", model_path)
        return None

    model = joblib.load(model_path)

    # ---------------------------------------------------
    # 6. PREDICTION
    # ---------------------------------------------------
    pred_prob = model.predict_proba(final_feature_vec.reshape(1, -1))[0]
    pred_class = model.predict(final_feature_vec.reshape(1, -1))[0]

    label_map = {0: "normal", 1: "Keratoconus"}
    pred_label = label_map[int(pred_class)]

    print("\n==================== RESULT ====================")
    print(f"Prediction: {pred_label}")
    print(f"Probability: normal={pred_prob[0]:.4f}, KC={pred_prob[1]:.4f}")
    print("=================================================\n")

    # ---------------------------------------------------
    # 7. OUTPUTS FOR LLM REPORT
    # ---------------------------------------------------
    final_prediction = {
        "image_id": img_id,
        "predicted_label": pred_label,
        "probabilities": {
            "normal": float(pred_prob[0]),
            "keratoconus": float(pred_prob[1])
        },
        "handcrafted_summary": handcrafted_all_quadrants
    }
    
    plain_prompt = f"""
        Image ID: {img_id}
        
        Prediction: {pred_label}
        Probability Normal: {final_prediction['probabilities']['normal']:.4f}
        Probability Keratoconus: {final_prediction['probabilities']['keratoconus']:.4f}
        
        Handcrafted features per quadrant:
        {handcrafted_all_quadrants}
        
        Please summarize this result clinically.
        """
    
    return final_prediction, plain_prompt
    

'(ProtocolError('Connection aborted.', RemoteDisconnected('Remote end closed connection without response')), '(Request ID: 58707953-f3d2-4bbb-913c-2d0299f7a6bb)')' thrown while requesting HEAD https://huggingface.co/google/vit-base-patch16-224-in21k/resolve/main/preprocessor_config.json
Retrying in 1s [Retry 1/5].


In [164]:
from IPython.display import Markdown, display
import ipywidgets as widgets
from classes import GEMINIAGENT

def run_with_chat(img_path):
    # 1. Run your vision + XGBoost pipeline
    final_prediction, prompt_text = test_architecture_pipeline(img_path)

    # 2. Create Gemini agent
    agent = GEMINIAGENT(
        api_key="AIzaSyC1r4egWoskczgSYmXkxCQMhJ6szqOw6XM",
        system_message="""
        You are a medical assistant and OCULUS - PENTACAM 4 Maps analysis.
        """,
        model="gemini-2.5-flash"
    )

    # 3. Auto clinical report
    report = agent.ask(prompt_text)
    display(Markdown("## Clinical Report\n\n" + report))

    # -------- CHAT MODE --------
    print("\n===== CHAT MODE =====")
    
    # Create output area for chat history
    chat_output = widgets.Output()
    
    input_box = widgets.Text(
        placeholder="Type your question...",
        description="You:",
        layout=widgets.Layout(width="500px")
    )

    send_button = widgets.Button(
        description="Send",
        button_style="success"
    )
    
    # Container for input controls
    input_container = widgets.HBox([input_box, send_button])

    def send_message(_):
        user_input = input_box.value.strip()
        if not user_input:
            return

        input_box.value = ""  # Clear box

        if user_input.lower() in ["exit", "quit"]:
            with chat_output:
                print("Chat ended.")
            return

        # Display user message
        with chat_output:
            display(Markdown(f"**You:** {user_input}"))
        
        # Get and display assistant reply
        try:
            reply = agent.ask(user_input)
            with chat_output:
                display(Markdown(f"**Assistant:** {reply}\n\n---\n"))
        except Exception as e:
            with chat_output:
                display(Markdown(f"**Error:** {str(e)}\n\n---\n"))

    # Allow Enter key to send message
    def handle_submit(sender):
        send_message(None)
    
    input_box.on_submit(handle_submit)
    send_button.on_click(send_message)
    
    # Display the interface
    display(input_container)
    display(chat_output)

In [165]:
run_with_chat("dataset/test/images/2.jpg")



Prediction: normal
Probability: normal=0.6984, KC=0.3016

Initializing GEMINIAGENT chat for model gemini-2.5-flash.
GEMINIAGENT initialized successfully for model gemini-2.5-flash.
GEMINIAGENT.ask() completed successfully with reply length 2206 chars.


## Clinical Report

**Clinical Summary of Oculus Pentacam 4 Maps Analysis (Image ID: 2)**

The automated analysis predicts a **normal cornea** with a probability of 69.84%. This prediction is well-supported by the detailed evaluation of handcrafted features across all four quadrants.

**Key Findings:**

1.  **Overall Symmetry:** The various symmetry metrics consistently indicate a healthy, symmetrical corneal shape.
    *   **Inferior-Superior Ratios** are all very close to 1 (ranging from 0.999 to 1.014), indicating excellent vertical symmetry without significant inferior steepening, which is a hallmark of keratoconus.
    *   **Left-Right Ratios** are also close to 1 (ranging from 0.980 to 0.995), suggesting good horizontal symmetry.
    *   **Radial Symmetry** values are low across all quadrants (0.006 to 0.032), confirming a high degree of radial regularity and an absence of localized, asymmetric bulging.
    *   **Diagonal Differences (diag1_difference, diag2_difference)** are very small, further supporting overall corneal symmetry.

2.  **Corneal Shape and Smoothness:**
    *   The **average intensity** is consistently high across all quadrants (0.867 to 0.894), and the **standard deviation of intensity** is low (0.147 to 0.180). This suggests a generally uniform and smooth corneal surface without significant localized irregularities or abrupt changes in elevation/curvature.
    *   **Dominant Value (dom_v)** is similar across all quadrants (0.46 to 0.50), implying a consistent distribution of the underlying topographical feature represented by intensity.

3.  **Center-Periphery Relationship:** While there are some variations in the center-periphery ratio per quadrant (e.g., Q3 at 0.981 indicating a slightly less intense center compared to the periphery, while Q1 and Q4 are >1), these variations do not present a consistent pattern indicative of a pathological cone.

**Conclusion:**

Based on the quantitative analysis of the provided handcrafted features, there are **no clinical signs suggestive of keratoconus or other significant corneal ectasia**. The cornea demonstrates excellent symmetry and regularity, strongly supporting the automated prediction of a normal corneal topography.


===== CHAT MODE =====


  input_box.on_submit(handle_submit)


HBox(children=(Text(value='', description='You:', layout=Layout(width='500px'), placeholder='Type your questio…

Output()