# From Black-Box to Glass-Box: Neuro-Symbolic Pharyngitis Triage

This notebook demonstrates the **Neuro-Symbolic AI** system presented in the paper. 
It combines **Deep Learning (YOLOv8)** for perception and **Causal Bayesian Networks** for reasoning, ensuring transparency and safety.

### Key Features:
1.  **Visual Perception**: See how the AI segments the throat (Tonsils, Pus, Petechiae).
2.  **Glass-Box Reasoning**: Interactively modify symptoms to see how the diagnosis changes in real-time (<10ms).
3.  **Transparency**: Visualize the **Cognitive Conflict** between Subjective symptoms and Objective findings.
4.  **Verification**: Run a statistical benchmark to validate performance.

## 1. Environment Setup

In [None]:
# @title 1. Environment Setup (Run this first)
import os

# ------------------------------------------------------------------
# [CONFIGURATION] Please set your repository URL here before running
REPO_URL = "https://github.com/Lug2/LM-Pharyngitis-Autonomous-Triage.git" # @param {type:"string"}
# ------------------------------------------------------------------

if not os.path.exists('src'):
    print("üîÑ Cloning repository...")
    !git clone $REPO_URL repo
    %cd repo
    print("‚úÖ Repository Cloned.")
else:
    print("‚úÖ Src directory found (Already cloned or Local mode).")

# Install dependencies
print("üîÑ Installing dependencies...")
!pip install -r requirements.txt -q

# Add src to path
import sys
sys.path.append(os.path.join(os.getcwd(), 'src'))

print("‚úÖ Environment Ready.")

## 2. Visual Perception (YOLOv8)

In [None]:
import ipywidgets as widgets
from IPython.display import display, clear_output
from ultralytics import YOLO
from PIL import Image
import glob
import matplotlib.pyplot as plt
import io
import numpy as np

# Load Model
# Note: In a real repo, ensure 'models/yolov8s-seg.pt' exists or download it.
# For this demo, we assume the model is present.
model_path = 'models/yolov8s-seg.pt'
if not os.path.exists(model_path):
    print("‚ö†Ô∏è Model not found. downloading standard YOLOv8s-seg...")
    model = YOLO('yolov8s-seg.pt') # Fallback
else:
    model = YOLO(model_path)

# UI Components
style = {'description_width': 'initial'}
dataset_images = glob.glob('datasets/**/*.jpg', recursive=True) + glob.glob('datasets/**/*.png', recursive=True)
dataset_images = sorted(dataset_images)[:10] # Show first 10 for demo

dropdown = widgets.Dropdown(
    options=[('Select Image...', None)] + [(os.path.basename(p), p) for p in dataset_images],
    value=None,
    description='Sample Images:',
    style=style
)

uploader = widgets.FileUpload(
    accept='image/*',
    multiple=False,
    description='Upload Your Own'
)

out = widgets.Output()

def process_image(img_path=None, upload_data=None):
    with out:
        clear_output()
        img = None
        if upload_data:
            # Process Upload
            content = upload_data[0]['content']
            img = Image.open(io.BytesIO(content)).convert('RGB')
            print("‚úÖ Processing Uploaded Image...")
        elif img_path:
            # Process File
            img = Image.open(img_path).convert('RGB')
            print(f"‚úÖ Processing {os.path.basename(img_path)}...")
        
        if img is None: return

        # Inference
        results = model(img, verbose=False)
        res = results[0]
        
        # Visualization
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
        ax1.imshow(img)
        ax1.set_title("Original Image")
        ax1.axis('off')
        
        # Plot Predictions
        res_plotted = res.plot()
        ax2.imshow(res_plotted)
        ax2.set_title("AI Segmentation (YOLOv8)")
        ax2.axis('off')
        plt.show()

def on_dropdown_change(change):
    if change['new']:
        process_image(img_path=change['new'])

def on_upload_change(change):
    if change['new']:
        process_image(upload_data=change['new'])

dropdown.observe(on_dropdown_change, names='value')
uploader.observe(on_upload_change, names='value')

display(widgets.HBox([dropdown, uploader]))
display(out)

## 3. Glass-Box Reasoning (Interactive)
This widget allows you to simulate the reasoning engine. 
The system separates **Subjective Symptoms (Patient Story)** from **Objective Findings (Doctor's View)** to calculate a **Cognitive Conflict Score**.
Try creating a conflict (e.g., "Severe Pain" but "No Inflammation") to see the **Safety Net** in action.

In [None]:
from src.causal_brain_v6 import CausalBrainV6
import time

# Initialize Logic Engine
ai = CausalBrainV6('src/model_config.yaml')

# Define Widgets
style = {'description_width': '120px'}
layout = widgets.Layout(width='auto')

# --- Context ---
w_age = widgets.Dropdown(options=['Child', 'YoungAdult', 'Adult', 'Senior'], value='Adult', description='Age Group:', style=style)
w_epi = widgets.Dropdown(options=['None', 'Flu_Warning', 'GAS_Warning'], value='None', description='Epidemic:', style=style)

# --- Subjective (Psub) ---
w_fatigue = widgets.Dropdown(options=[('Absent', 'Absent'), ('Present', 'Present')], value='Absent', description='Fatigue:', style=style)
w_joint = widgets.Dropdown(options=[('None', 'None'), ('Severe', 'Severe')], value='None', description='Joint Pain:', style=style)
w_pain_sev = widgets.Dropdown(options=[('Mild', 'Mild'), ('Severe', 'Severe')], value='Mild', description='Pain Severity:', style=style)
w_pain_lat = widgets.Dropdown(options=[('Bilateral', 'Bilateral'), ('Unilateral', 'Unilateral')], value='Bilateral', description='Pain Laterality:', style=style)
w_onset = widgets.Dropdown(options=[('Sudden', 'Sudden'), ('Gradual', 'Gradual')], value='Gradual', description='Onset:', style=style)
w_duration = widgets.Dropdown(options=[('<3 Days', 'Acute'), ('>=4 Days', 'Subacute')], value='Acute', description='Duration:', style=style)

# --- Objective (Pobj) ---
w_temp = widgets.Dropdown(options=[('Normal', 'Normal'), ('Mild', 'Mild'), ('High (>38C)', 'High')], value='Normal', description='Fever:', style=style)
w_cough = widgets.Dropdown(options=[('Absent', 'Absent'), ('Present', 'Present')], value='Present', description='Cough:', style=style)
w_lymph = widgets.Dropdown(options=[('Normal', 'Normal'), ('Anterior', 'Anterior'), ('Posterior', 'Posterior'), ('Bilateral', 'Both')], value='Normal', description='Lymph Nodes:', style=style)
w_rash = widgets.Dropdown(options=[('Absent', 'Absent'), ('Present', 'Present')], value='Absent', description='Skin Rash:', style=style)
w_pet = widgets.Dropdown(options=[('Normal', 'Normal'), ('Prominent', 'Prominent')], value='Normal', description='Petechiae:', style=style)
w_exudate = widgets.Dropdown(options=[('None', 'None'), ('Low', 'Low'), ('High', 'High')], value='None', description='Tonsil Exudate:', style=style)
w_color = widgets.Dropdown(options=[('Normal', 'Normal'), ('Red', 'Red'), ('Dark Red', 'DarkRed')], value='Normal', description='Redness:', style=style)
w_eye = widgets.Dropdown(options=[('Normal', 'Normal'), ('Conjunctivitis', 'Conjunctivitis')], value='Normal', description='Eye:', style=style)

def run_inference(age, epi, fatigue, joint, pain_sev, pain_lat, onset, duration, 
                  temp, cough, lymph, rash, pet, exudate, color, eye):
    
    start_time = time.time()
    
    evidence = {
        'Age_Group': age, 'C_epidemic': epi,
        'C_fatigue': fatigue, 'C_joint': joint, 'C_pain_sev': pain_sev, 'C_pain_lat': pain_lat, 'C_onset': onset, 'C_duration': duration,
        'C_temp': temp, 'C_cough': cough, 'C_lymph': lymph, 'C_rash': rash, 'V_vessel': pet, 'Exudate_Gen': (exudate != 'None'), 'V_white': exudate, 'V_color': color, 'C_eye': eye
        # Note: Exudate_Gen is boolean parent of V_white. 
    }

    # Run Cognitive Inference
    result = ai.diagnose(evidence, enable_safety_net=True)
    
    elapsed = (time.time() - start_time) * 1000
    
    # --- Visualization ---
    print(f"‚è±Ô∏è Inference Time: {elapsed:.2f} ms (Edge-Device Ready)")
    print("="*60)
    
    # Diagnosis
    diag = result['diagnosis']
    prob = result['probability']
    print(f"üè• Diagnosis: {diag} ({prob:.1%})")
    
    # AADT Decision
    c_dec = result.get('clinical_decision', {})
    print(f"üíä Action:    {c_dec.get('decision', 'N/A')} (Pediatric Mode: {c_dec.get('is_pediatric_mode', False)})")
    print("-"*60)
    
    # Cognitive Transparency
    cog = result.get('cognitive', {})
    print("üß† Cognitive Transparency")
    print(f"   - Subjective View (Psub): {cog.get('prob_subjective',0):.1%}")
    print(f"   - Objective View  (Pobj): {cog.get('prob_objective',0):.1%}")
    print(f"   - Conflict Score:         {cog.get('conflict_score',0):.4f}")
    print(f"   - Triage Pattern:         {cog.get('triage_type', 'N/A')}")
    
    # Explainability
    print("="*60)
    print("üìù Explanation:")
    print(f"   {result['explanation']['summary']}")
    print("\n   [Supporting Evidence]")
    for p in result['explanation']['positive']: print(f"   + {p}")
    print("\n   [Conflicting Evidence]")
    for n in result['explanation']['negative']: print(f"   - {n}")
    
    if result['explanation']['alerts']:
        print("\nüîî ALERTS:")
        for a in result['explanation']['alerts']: print(f"   {a}")


ui = widgets.Tab()
v_sub = widgets.VBox([widgets.Label("Subjective Symptoms (Patient Story)"), w_fatigue, w_joint, w_pain_sev, w_pain_lat, w_onset, w_duration])
v_obj = widgets.VBox([widgets.Label("Objective Findings (Examination)"), w_temp, w_cough, w_lymph, w_rash, w_pet, w_exudate, w_color, w_eye])
v_ctx = widgets.VBox([widgets.Label("Context"), w_age, w_epi])

ui.children = [v_ctx, v_sub, v_obj]
ui.set_title(0, 'Context')
ui.set_title(1, 'Subjective')
ui.set_title(2, 'Objective')

out_infer = widgets.interactive_output(run_inference, {
    'age': w_age, 'epi': w_epi,
    'fatigue': w_fatigue, 'joint': w_joint, 'pain_sev': w_pain_sev, 'pain_lat': w_pain_lat, 'onset': w_onset, 'duration': w_duration,
    'temp': w_temp, 'cough': w_cough, 'lymph': w_lymph, 'rash': w_rash, 'pet': w_pet, 'exudate': w_exudate, 'color': w_color, 'eye': w_eye
})

display(ui, out_infer)

## 4. Benchmark Validation
Finally, we run a statistical verification (N=50 samples) to ensure the system performs as expected.

In [None]:
!python experiments/Benchmark/runner.py --task standard --n_samples 50