In [22]:
import os
import numpy as np
from PIL import Image
from sklearn.metrics import classification_report, confusion_matrix

import torch
from utils.predict import load_model
from utils.preprocess import transform
from utils.patchify import split_into_patches
from utils.ocr import extract_text_from_image

from language_model.language_risk import predict_language_risk
from language_model.text_features import compute_raw_language_features

import config


In [25]:
#settings
SCREENING_THRESHOLD = 0.4
MIN_WORDS_LANGUAGE = 20

MAX_PATCHES = 25                 
RESIZE_IMAGE_TO = (1024, 1024)   
HANDWRITING_SKIP_THRESHOLD = 0.15
MAX_IMAGES_PER_CLASS = 50      
TEST_DIR = "data/dataset/test"
OCR_CACHE = {}

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_model()  
model.to(device)
model.eval()

print(f"Handwriting model loaded on {device}")

‚úÖ Handwriting model loaded on cuda


In [28]:
#HANDWRITTING RISK
def handwriting_risk_from_image(image: Image.Image) -> float:
    image = image.resize(RESIZE_IMAGE_TO)
    image_np = np.array(image)

    patches = split_into_patches(
        image_np,
        patch_size=config.IMAGE_SIZE,
        stride=config.IMAGE_SIZE
    )

    if len(patches) == 0:
        return 0.0

    dyslexic_count = 0
    patches = patches[:MAX_PATCHES]

    for patch in patches:
        patch_img = Image.fromarray(patch).convert("RGB")
        input_tensor = transform(patch_img).unsqueeze(0).to(device)

        with torch.no_grad():
            prob = model(input_tensor).item()
            if prob > 0.5:
                dyslexic_count += 1

    return dyslexic_count / len(patches)

In [29]:
#Multimodal fusion
def multimodal_fusion(handwriting_risk, language_risk):
    if language_risk is None:
        return handwriting_risk

    # Language gate (only allow strong language evidence)
    language_gate = max(
        0.0,
        min((language_risk - 0.5) / 0.5, 1.0)
    )

    final_risk = (
        0.65 * handwriting_risk +
        0.35 * language_gate * language_risk
    )

    # Conservative cap for neat handwriting
    if handwriting_risk < 0.2:
        final_risk = min(final_risk, 0.45)

    return final_risk

In [None]:
#Evaluation fusion
y_true = []
y_pred_handwriting = []
y_pred_multimodal = []

for label_name in ["normal", "dyslexic"]:
    label = 0 if label_name == "normal" else 1
    folder = os.path.join(TEST_DIR, label_name)

    files = os.listdir(folder)[:MAX_IMAGES_PER_CLASS]

    print(f"\nEvaluating {label_name.upper()} ({len(files)} samples)")

    for fname in files:
        img_path = os.path.join(folder, fname)

        try:
            image = Image.open(img_path).convert("RGB")

            # ----------------------
            # Handwriting risk
            # ----------------------
            h_risk = handwriting_risk_from_image(image)

            # ----------------------
            # Language risk (GATED + CACHED)
            # ----------------------
            if h_risk < HANDWRITING_SKIP_THRESHOLD:
                l_risk = None
            else:
                if fname in OCR_CACHE:
                    text = OCR_CACHE[fname]
                else:
                    text = extract_text_from_image(image)
                    OCR_CACHE[fname] = text

                if len(text.split()) < MIN_WORDS_LANGUAGE:
                    l_risk = None
                else:
                    raw_features = compute_raw_language_features(text)
                    l_risk = (
                        predict_language_risk(raw_features, debug=False)
                        if raw_features is not None
                        else None
                    )

            # ----------------------
            # Fusion
            # ----------------------
            final_risk = multimodal_fusion(h_risk, l_risk)

            # ----------------------
            # Screening decisions
            # ----------------------
            y_true.append(label)
            y_pred_handwriting.append(int(h_risk >= SCREENING_THRESHOLD))
            y_pred_multimodal.append(int(final_risk >= SCREENING_THRESHOLD))

            print(f"Processed: {fname}")

        except Exception as e:
            print(f"‚ùå Error processing {fname}: {e}")


üìÇ Evaluating NORMAL (50 samples)
Processed: A-100.png
Processed: A-102.png
Processed: A-104.png
Processed: A-105.png
Processed: A-107.png
Processed: A-108.png
Processed: A-109.png
Processed: A-110.png
Processed: A-112.png
Processed: A-113.png
Processed: A-114.png
Processed: A-115.png
Processed: A-116.png
Processed: A-117.png
Processed: A-118.png
Processed: A-119.png
Processed: A-120.png
Processed: A-122.png
Processed: A-123.png
Processed: A-124.png
Processed: A-126.png
Processed: A-127.png
Processed: A-128.png
Processed: A-129.png
Processed: A-130.png
Processed: A-133.png
Processed: A-135.png
Processed: A-136.png
Processed: A-137.png
Processed: A-138.png
Processed: A-139.png
Processed: A-140.png
Processed: A-141.png
Processed: A-142.png
Processed: A-144.png
Processed: A-145.png
Processed: A-146.png
Processed: A-147.png
Processed: A-148.png
Processed: A-149.png
Processed: A-150.png
Processed: A-151.png
Processed: A-152.png
Processed: A-153.png
Processed: A-154.png
Processed: A-155.p

In [31]:
#Results
print(classification_report(y_true, y_pred_handwriting))
print(confusion_matrix(y_true, y_pred_handwriting))

print(classification_report(y_true, y_pred_multimodal))
print(confusion_matrix(y_true, y_pred_multimodal))


              precision    recall  f1-score   support

           0       0.51      0.58      0.54        50
           1       0.51      0.44      0.47        50

    accuracy                           0.51       100
   macro avg       0.51      0.51      0.51       100
weighted avg       0.51      0.51      0.51       100

[[29 21]
 [28 22]]
              precision    recall  f1-score   support

           0       0.51      0.58      0.54        50
           1       0.51      0.44      0.47        50

    accuracy                           0.51       100
   macro avg       0.51      0.51      0.51       100
weighted avg       0.51      0.51      0.51       100

[[29 21]
 [28 22]]
