In [None]:
import os
import random
import shutil
import numpy as np
import base64
from io import BytesIO
import datetime

from flask import Flask, request, render_template, redirect, url_for, session, send_file

# --------- PyTorch / ViT imports ----------
import torch
import torch.nn as nn
import timm
from torchvision import transforms, models
from PIL import Image

# --------- TensorFlow / ResNet imports (for Grad-CAM) ----------
import tensorflow as tf
from tensorflow.keras.models import load_model, Model
from tensorflow.keras.preprocessing import image as kimage

import cv2

# --------- PDF (Report) imports ----------
from reportlab.pdfgen import canvas
from reportlab.lib.pagesizes import A4
from reportlab.lib.utils import ImageReader

# ======================================================
# Flask Setup
# ======================================================
app = Flask(__name__)
app.secret_key = "supersecretkey"  # Required for session management

# Directories
STATIC_FOLDER = "static"
UPLOAD_FOLDER = os.path.join(STATIC_FOLDER, "uploads")
GRADCAM_FOLDER = os.path.join(STATIC_FOLDER, "gradcam")
REPORT_FOLDER = os.path.join(STATIC_FOLDER, "reports")

os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(GRADCAM_FOLDER, exist_ok=True)
os.makedirs(REPORT_FOLDER, exist_ok=True)

# Path to test images for random prediction
DATASET_PATH_PNEU = r"C://archive (11)//chest_xray//test"
DATASET_PATH_TB   = r"C://tb_chest_xray//test"
DATASET_PATH_DR   = r"C://retino//test"   # folders: DR, No_DR

# ======================================================
# 1️⃣ Common transforms and device
# ======================================================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

IMG_SIZE = 224
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std  = [0.229, 0.224, 0.225]

base_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=imagenet_mean, std=imagenet_std),
])

vit_transform = base_transform
imgtype_transform = base_transform

# ======================================================
# 2️⃣ ViT + BiLSTM Models (PyTorch)  → PREDICTION
# ======================================================
VIT_PNEU_MODEL_PATH = "vit_bilstm_pneumonia_best.pth"
VIT_TB_MODEL_PATH   = "vit_bilstm_tb_best.pth"
VIT_DR_MODEL_PATH   = "vit_bilstm_retino_best.pth"   # DR model


class ViT_BiLSTM(nn.Module):
    def __init__(self, num_classes=1, pretrained=False, lstm_hidden=256):
        super(ViT_BiLSTM, self).__init__()
        self.vit = timm.create_model(
            "vit_base_patch16_224",
            pretrained=pretrained
        )
        self.vit.reset_classifier(0)
        embed_dim = self.vit.num_features

        self.bilstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=lstm_hidden,
            num_layers=1,
            batch_first=True,
            bidirectional=True
        )

        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(lstm_hidden * 2, num_classes)

    def forward_features_tokens(self, x):
        tokens = self.vit.forward_features(x)
        if isinstance(tokens, dict):
            tokens = tokens["x"]
        patch_tokens = tokens[:, 1:, :]  # drop CLS token
        return patch_tokens

    def forward(self, x):
        patch_tokens = self.forward_features_tokens(x)
        lstm_out, (h_n, c_n) = self.bilstm(patch_tokens)
        forward_h = h_n[-2, :, :]
        backward_h = h_n[-1, :, :]
        h_cat = torch.cat([forward_h, backward_h], dim=1)
        h_cat = self.dropout(h_cat)
        logits = self.fc(h_cat)
        return logits


# ---------- Pneumonia model ----------
vit_pneu_model = ViT_BiLSTM(num_classes=1, pretrained=False).to(DEVICE)
state_dict_pneu = torch.load(VIT_PNEU_MODEL_PATH, map_location=DEVICE)
vit_pneu_model.load_state_dict(state_dict_pneu)
vit_pneu_model.eval()
print("✅ ViT+BiLSTM Pneumonia model loaded")

# ---------- TB model ----------
vit_tb_model = ViT_BiLSTM(num_classes=1, pretrained=False).to(DEVICE)
state_dict_tb = torch.load(VIT_TB_MODEL_PATH, map_location=DEVICE)
vit_tb_model.load_state_dict(state_dict_tb)
vit_tb_model.eval()
print("✅ ViT+BiLSTM TB model loaded")

# ---------- DR model (2-class: DR vs No_DR) ----------
vit_dr_model = ViT_BiLSTM(num_classes=2, pretrained=False).to(DEVICE)
state_dict_dr = torch.load(VIT_DR_MODEL_PATH, map_location=DEVICE)
vit_dr_model.load_state_dict(state_dict_dr)
vit_dr_model.eval()
print("✅ ViT+BiLSTM DR model loaded")


class_labels_pneu = {0: "Normal", 1: "Pneumonia"}
class_labels_tb   = {0: "Normal", 1: "Tuberculosis"}
# DR mapping handled manually

# ======================================================
# 2.5️⃣ Image-type classifier (ResNet18) → chest_xray / fundus / other
# ======================================================
IMGTYPE_MODEL_PATH = "imgtype_resnet18.pth"   # your trained model

imgtype_model = models.resnet18(weights=None)
in_features = imgtype_model.fc.in_features
imgtype_model.fc = nn.Linear(in_features, 3)  # 3 classes: chest_xray, fundus, other
state_dict_type = torch.load(IMGTYPE_MODEL_PATH, map_location=DEVICE)
imgtype_model.load_state_dict(state_dict_type)
imgtype_model = imgtype_model.to(DEVICE)
imgtype_model.eval()
print("✅ Image-type classifier (chest_xray / fundus / other) loaded")

idx_to_modality = {0: "xray", 1: "fundus", 2: "other"}


def classify_modality(pil_img):
    """
    Use imgtype_resnet18 to classify image type.
    Returns: "xray", "fundus", or "other"
    """
    tensor = imgtype_transform(pil_img).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        logits = imgtype_model(tensor)
        pred_idx = torch.argmax(logits, dim=1).item()
    modality = idx_to_modality.get(pred_idx, "other")
    return modality

# ======================================================
# 3️⃣ ResNet50 Models (Keras)  → GRAD-CAM for all 3 diseases
# ======================================================

RESNET_PNEU_MODEL_PATH = "ResNet50_pneumonia_model.h5"
RESNET_TB_MODEL_PATH   = "ResNet50_tb_model_best.h5"
RESNET_DR_MODEL_PATH   = "ResNet50_dr_model_best.h5"

resnet_pneu_model = load_model(RESNET_PNEU_MODEL_PATH, compile=False)
resnet_pneu_model.trainable = False
print("✅ ResNet50 Pneumonia model loaded for Grad-CAM")

resnet_tb_model = load_model(RESNET_TB_MODEL_PATH, compile=False)
resnet_tb_model.trainable = False
print("✅ ResNet50 TB model loaded for Grad-CAM")

resnet_dr_model = load_model(RESNET_DR_MODEL_PATH, compile=False)
resnet_dr_model.trainable = False
print("✅ ResNet50 DR model loaded for Grad-CAM")

patients = {
    "devraj": "1234",
    "jagadeesh": "5678"
}

LAST_CONV_LAYER_NAME = "conv5_block3_out"


def make_gradcam_heatmap(img_array, model, last_conv_layer_name):
    last_conv_layer = model.get_layer(last_conv_layer_name)
    grad_model = Model([model.inputs],
                       [last_conv_layer.output, model.output])

    with tf.GradientTape() as tape:
        conv_outputs, predictions = grad_model(img_array)
        class_idx = tf.argmax(predictions[0])
        loss = predictions[:, class_idx]

    grads = tape.gradient(loss, conv_outputs)
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
    conv_outputs = conv_outputs[0]
    heatmap = conv_outputs @ pooled_grads[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)
    heatmap = tf.maximum(heatmap, 0) / (tf.reduce_max(heatmap) + 1e-10)
    return heatmap.numpy()


def generate_resnet_gradcam(img_path, disease_norm):
    disease_norm = disease_norm.lower()

    if disease_norm == "tb":
        model = resnet_tb_model
    elif disease_norm == "dr":
        model = resnet_dr_model
    else:  # default to pneumonia
        model = resnet_pneu_model

    img = kimage.load_img(img_path, target_size=(224, 224))
    img_array = kimage.img_to_array(img) / 255.0
    img_array = np.expand_dims(img_array, axis=0)

    heatmap = make_gradcam_heatmap(img_array, model, LAST_CONV_LAYER_NAME)
    heatmap_resized = cv2.resize(heatmap, (224, 224))
    heatmap_uint8 = np.uint8(255 * heatmap_resized)
    heatmap_color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)

    orig = kimage.img_to_array(img).astype("uint8")
    orig_bgr = cv2.cvtColor(orig, cv2.COLOR_RGB2BGR)

    alpha = 0.4
    superimposed = cv2.addWeighted(heatmap_color, alpha, orig_bgr, 1 - alpha, 0)
    superimposed_rgb = cv2.cvtColor(superimposed, cv2.COLOR_BGR2RGB)

    overlay_pil = Image.fromarray(superimposed_rgb)

    # ---- Save to disk for report ----
    base_name = os.path.basename(img_path)
    gradcam_filename = f"gradcam_{disease_norm}_{base_name}"
    gradcam_disk_path = os.path.join(GRADCAM_FOLDER, gradcam_filename)
    overlay_pil.save(gradcam_disk_path, format="JPEG")

    # ---- Also create data URL for UI ----
    buffer = BytesIO()
    overlay_pil.save(buffer, format="JPEG")
    img_bytes = buffer.getvalue()
    b64 = base64.b64encode(img_bytes).decode("utf-8")
    gradcam_dataurl = f"data:image/jpeg;base64,{b64}"

    # Web-style path for templates / PDF
    gradcam_disk_path_web = gradcam_disk_path.replace("\\", "/")
    return gradcam_dataurl, gradcam_disk_path_web

# ======================================================
# 4️⃣ Precautions / Next Steps helper
# ======================================================
def get_recommendations(disease_norm, result):
    disease_norm = disease_norm.lower()
    precautions = ""
    next_steps = ""

    if result is None:
        return precautions, next_steps

    if disease_norm == "pneumonia":
        if result == "Pneumonia":
            precautions = (
                "Avoid strenuous activity, drink plenty of fluids, "
                "avoid smoking and second-hand smoke, and monitor fever or breathing difficulty."
            )
            next_steps = (
                "Consult a qualified doctor or pulmonologist as soon as possible. "
                "Bring this report and the chest X-ray for clinical evaluation and additional tests if needed."
            )
        else:  # Normal
            precautions = (
                "Maintain good respiratory hygiene, stay vaccinated (as advised by a doctor), "
                "and avoid exposure to smoke and pollutants."
            )
            next_steps = (
                "If symptoms like cough, fever or breathing difficulty persist, "
                "still meet a doctor for a physical examination."
            )

    elif disease_norm == "tb":
        if result == "Tuberculosis":
            precautions = (
                "Limit close contact with others until evaluated, cover mouth while coughing, "
                "and ensure good ventilation in rooms."
            )
            next_steps = (
                "Visit a government TB centre or chest specialist urgently for confirmation tests "
                "(sputum, GeneXpert, etc.) and to start proper anti-TB treatment."
            )
        else:  # Normal
            precautions = (
                "Maintain good nutrition and hygiene, and avoid close contact with known TB patients "
                "without a mask."
            )
            next_steps = (
                "If you have long-lasting cough, weight loss or night sweats, "
                "consult a doctor even if this screening appears normal."
            )

    elif disease_norm == "dr":
        if result == "Diabetic Retinopathy":
            precautions = (
                "Keep blood sugar, blood pressure and cholesterol under control, "
                "avoid smoking, and follow a diabetes-friendly lifestyle."
            )
            next_steps = (
                "Take this report to an eye specialist (ophthalmologist) for a detailed retinal examination. "
                "They may advise OCT, retinal laser or injections depending on the severity."
            )
        else:  # Normal
            precautions = (
                "Continue regular diabetes check-ups, maintain good sugar and blood pressure control, "
                "and schedule routine eye exams."
            )
            next_steps = (
                "Even with a normal screening result, diabetic patients should visit an eye specialist "
                "at least once a year or as advised."
            )

    return precautions, next_steps

# ======================================================
# 5️⃣ Prediction helper
# ======================================================
def predict_and_explain(img_path, disease="pneumonia", threshold=0.5):
    disease_norm = (disease or "pneumonia").lower()
    pil_img = Image.open(img_path).convert("RGB")

    # ---------- Strong modality validation via classifier ----------
    modality = classify_modality(pil_img)   # "xray", "fundus", "other"

    if disease_norm in ("pneumonia", "tb") and modality != "xray":
        error_msg = ("Invalid image for chest X-ray model. "
                     "Please upload a proper chest X-ray image.")
        return None, None, None, None, None, None, None, error_msg

    if disease_norm == "dr" and modality != "fundus":
        error_msg = ("Invalid image for DR model. Please upload a retina/fundus "
                     "photograph instead of other images.")
        return None, None, None, None, None, None, None, error_msg

    input_tensor = vit_transform(pil_img).unsqueeze(0).to(DEVICE)

    # --------- Select disease model ----------
    if disease_norm == "tb":
        model = vit_tb_model
        labels_map = class_labels_tb
    elif disease_norm == "dr":
        model = vit_dr_model
        labels_map = None   # handled manually
    else:
        model = vit_pneu_model
        labels_map = class_labels_pneu

    model.eval()
    with torch.no_grad():
        logits = model(input_tensor)

        if disease_norm == "dr":
            probs = torch.softmax(logits, dim=1)
            prob_dr = probs[0, 0].item()
            prob_positive = prob_dr
        else:
            prob_positive = torch.sigmoid(logits)[0, 0].item()

    # ---------- Convert prob -> label & confidence ----------
    if disease_norm == "dr":
        if prob_positive >= threshold:
            result = "Diabetic Retinopathy"
            confidence = prob_positive
        else:
            result = "Normal"
            confidence = 1 - prob_positive
        raw_positive_prob = prob_positive * 100.0
    else:
        if prob_positive >= threshold:
            result = labels_map[1]            # Pneumonia or TB
            confidence = prob_positive
        else:
            result = labels_map[0]            # Normal
            confidence = 1 - prob_positive
        raw_positive_prob = prob_positive * 100.0

    # --------- Precautions & next steps ----------
    precautions, next_steps = get_recommendations(disease_norm, result)

    # Grad-CAM for all three diseases
    gradcam_dataurl, gradcam_path = generate_resnet_gradcam(img_path, disease_norm)

    return (
        result,
        confidence * 100.0,
        raw_positive_prob,
        gradcam_dataurl,
        gradcam_path,
        precautions,
        next_steps,
        None
    )

# ======================================================
# 6️⃣ Random image helper
# ======================================================
def get_random_test_image(disease="pneumonia"):
    disease_norm = (disease or "pneumonia").lower()

    if disease_norm == "tb":
        categories = ["NORMAL", "TUBERCULOSIS"]
        dataset_path = DATASET_PATH_TB
    elif disease_norm == "dr":
        categories = ["DR", "No_DR"]
        dataset_path = DATASET_PATH_DR
    else:
        categories = ["NORMAL", "PNEUMONIA"]
        dataset_path = DATASET_PATH_PNEU

    category = random.choice(categories)
    folder_path = os.path.join(dataset_path, category)
    if not os.path.exists(folder_path) or not os.listdir(folder_path):
        return None, None

    img_name = random.choice(os.listdir(folder_path))
    src_path = os.path.join(folder_path, img_name)
    dst_path = os.path.join(UPLOAD_FOLDER, img_name)
    shutil.copy(src_path, dst_path)
    return dst_path.replace("\\", "/"), category

# ======================================================
# 7️⃣ PDF Download helper
# ======================================================
@app.route('/download-report')
def download_report():
    if 'last_prediction' not in session:
        return redirect(url_for('upload_and_predict'))

    details = session['last_prediction']

    # Create PDF file
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"report_{timestamp}.pdf"
    pdf_path = os.path.join(REPORT_FOLDER, filename)

    c = canvas.Canvas(pdf_path, pagesize=A4)
    width, height = A4
    margin = 40
    y = height - margin

    c.setFont("Helvetica-Bold", 18)
    c.drawString(margin, y, "AI Screening Report")
    y -= 30

    c.setFont("Helvetica", 11)
    c.drawString(margin, y, f"Disease Type : {details['disease_label']}")
    y -= 16
    c.drawString(margin, y, f"Result       : {details['result']}")
    y -= 16
    c.drawString(margin, y, f"Confidence   : {details['confidence']:.2f}%")
    y -= 16
    c.drawString(margin, y, f"Probability  : {details['prob']:.2f}%")
    y -= 24

    c.setFont("Helvetica-Bold", 12)
    c.drawString(margin, y, "Precautions:")
    y -= 16
    c.setFont("Helvetica", 10)
    for line in details['precautions'].split(". "):
        if not line.strip():
            continue
        c.drawString(margin, y, line.strip() + ".")
        y -= 14

    y -= 10

    c.setFont("Helvetica-Bold", 12)
    c.drawString(margin, y, "What should the patient do next?")
    y -= 16
    c.setFont("Helvetica", 10)
    for line in details['next_steps'].split(". "):
        if not line.strip():
            continue
        c.drawString(margin, y, line.strip() + ".")
        y -= 14

    y -= 20
    img_height = 220
    img_width = 200

    try:
        input_path_fs = details['input_image']
        c.setFont("Helvetica-Bold", 11)
        c.drawString(margin, y, "Input Image")
        y_img = y - img_height - 5
        c.drawImage(ImageReader(input_path_fs), margin, y_img,
                    width=img_width, height=img_height)
    except Exception as e:
        print("Error adding input image to PDF:", e)
        y_img = y - img_height - 5

    try:
        if details.get('gradcam_image'):
            grad_path_fs = details['gradcam_image']
            c.setFont("Helvetica-Bold", 11)
            c.drawString(margin + img_width + 40, y, "Grad-CAM")
            c.drawImage(
                ImageReader(grad_path_fs),
                margin + img_width + 40,
                y_img,
                width=img_width,
                height=img_height
            )
    except Exception as e:
        print("Error adding Grad-CAM image to PDF:", e)

    c.showPage()
    c.save()

    return send_file(pdf_path, as_attachment=True,
                     download_name="AI_Screening_Report.pdf")

# ======================================================
# 8️⃣ Routes (MODIFIED FOR 3-PAGE FLOW)
# ======================================================

@app.route('/')
def root():
    if 'username' in session:
        return redirect(url_for('home_page'))
    return redirect(url_for('login'))

@app.route('/home')
def home_page():
    if 'username' not in session:
        return redirect(url_for('login'))
    return render_template('index.html', active_page='home')

@app.route('/login', methods=['GET', 'POST'])
def login():
    if request.method == "POST":
        username = request.form['username']
        password = request.form['password']
        if username in patients and patients[username] == password:
            session['username'] = username
            return redirect(url_for('home_page'))
        else:
            return render_template('login.html', error="Invalid credentials!")
    return render_template('login.html')

@app.route('/logout')
def logout():
    session.pop('username', None)
    return redirect(url_for('login'))

@app.route('/doctors')
def doctors_page():
    if 'username' not in session:
        return redirect(url_for('login'))
    return render_template('doctors.html', active_page='doctors')

@app.route('/contact')
def contact_page():
    if 'username' not in session:
        return redirect(url_for('login'))
    return render_template('contact.html', active_page='contact')


# ======================================================
# SERVICES & PREDICTION FLOW (UPDATED)
# ======================================================

@app.route('/services')
def services_page():
    """
    If 'type' is present in query string (e.g. ?type=pneumonia),
    RENDER the separate 'screening.html' page.
    Otherwise, RENDER the 'services.html' list page.
    """
    if 'username' not in session:
        return redirect(url_for('login'))

    service_type = request.args.get('type')  # 'pneumonia','tb','dr'
    
    # IF A SPECIFIC SERVICE IS SELECTED -> Go to dedicated Screening Page
    if service_type in ('pneumonia', 'tb', 'dr'):
        return render_template(
            'screening.html',  # <--- New separate template
            active_page='services',
            selected_disease=service_type
        )
    
    # ELSE -> Show the list of services
    return render_template(
        'services.html',
        active_page='services'
    )


@app.route('/predict', methods=['GET', 'POST'])
def upload_and_predict():
    """
    Handles file upload.
    On Success: Renders 'result.html'
    On Error: Renders 'screening.html' with error message
    """
    if 'username' not in session:
        return redirect(url_for('login'))

    if request.method == 'GET':
        return redirect(url_for('services_page'))

    disease = (request.form.get('disease') or 'pneumonia').strip().lower()

    # If no file uploaded
    if 'file' not in request.files or request.files['file'].filename == '':
        return render_template(
            'screening.html',
            error="No file selected.",
            selected_disease=disease,
            active_page='services'
        )

    file = request.files['file']
    if file and file.filename.lower().endswith(('.png', '.jpg', '.jpeg')):
        file_path = os.path.join(UPLOAD_FOLDER, file.filename)
        file.save(file_path)

        (result,
         confidence,
         raw_positive_prob,
         gradcam_dataurl,
         gradcam_path,
         precautions,
         next_steps,
         error_msg) = predict_and_explain(file_path, disease=disease)

        # ---------------- ERROR CASE ----------------
        if error_msg is not None:
            return render_template(
                'screening.html',
                error=error_msg,
                selected_disease=disease,
                active_page='services'
            )

        # ---------------- SUCCESS CASE ----------------
        img_web_path = file_path.replace("\\", "/")

        # Save for PDF
        disease_label = "Pneumonia"
        if disease == "tb": disease_label = "Tuberculosis"
        elif disease == "dr": disease_label = "Diabetic Retinopathy"

        session['last_prediction'] = {
            "disease_label": disease_label,
            "result": result,
            "confidence": confidence,
            "prob": raw_positive_prob,
            "precautions": precautions,
            "next_steps": next_steps,
            "input_image": img_web_path,
            "gradcam_image": gradcam_path
        }

        # Render the specific result page layout
        return render_template(
            'result.html',  # <--- New separate template
            result=result,
            confidence=confidence,
            raw_pneu_prob=raw_positive_prob,
            img_path=img_web_path,
            gradcam_dataurl=gradcam_dataurl,
            selected_disease=disease,
            precautions=precautions,
            next_steps=next_steps,
            active_page='services'
        )

    # Invalid file extension
    return render_template(
        'screening.html',
        error="Invalid file format.",
        selected_disease=disease,
        active_page='services'
    )


@app.route('/predict-random')
def predict_random():
    """
    Handles random image selection.
    On Success: Renders 'result.html'
    On Error: Renders 'screening.html'
    """
    if 'username' not in session:
        return redirect(url_for('login'))

    disease = (request.args.get('disease') or 'pneumonia').strip().lower()

    img_web_path, actual_label = get_random_test_image(disease=disease)
    
    # Error: No test images found
    if img_web_path is None:
        return render_template(
            'screening.html',
            error="No test images found in dataset folder.",
            selected_disease=disease,
            active_page='services'
        )

    (result,
     confidence,
     raw_positive_prob,
     gradcam_dataurl,
     gradcam_path,
     precautions,
     next_steps,
     error_msg) = predict_and_explain(img_web_path, disease=disease)

    # ---------------- ERROR CASE ----------------
    if error_msg is not None:
        return render_template(
            'screening.html',
            error=error_msg,
            selected_disease=disease,
            active_page='services'
        )

    # ---------------- SUCCESS CASE ----------------
    disease_label = "Pneumonia"
    if disease == "tb": disease_label = "Tuberculosis"
    elif disease == "dr": disease_label = "Diabetic Retinopathy"

    session['last_prediction'] = {
        "disease_label": disease_label,
        "result": result,
        "confidence": confidence,
        "prob": raw_positive_prob,
        "precautions": precautions,
        "next_steps": next_steps,
        "input_image": img_web_path,
        "gradcam_image": gradcam_path
    }

    return render_template(
        'result.html',  # <--- New separate template
        result=result,
        confidence=confidence,
        raw_pneu_prob=raw_positive_prob,
        img_path=img_web_path,
        gradcam_dataurl=gradcam_dataurl,
        actual_label=actual_label,
        selected_disease=disease,
        precautions=precautions,
        next_steps=next_steps,
        active_page='services'
    )


if __name__ == '__main__':
    app.run(debug=True, use_reloader=False)

Using device: cuda


  state_dict_pneu = torch.load(VIT_PNEU_MODEL_PATH, map_location=DEVICE)


✅ ViT+BiLSTM Pneumonia model loaded


  state_dict_tb = torch.load(VIT_TB_MODEL_PATH, map_location=DEVICE)


✅ ViT+BiLSTM TB model loaded


  state_dict_dr = torch.load(VIT_DR_MODEL_PATH, map_location=DEVICE)


✅ ViT+BiLSTM DR model loaded


  state_dict_type = torch.load(IMGTYPE_MODEL_PATH, map_location=DEVICE)


✅ Image-type classifier (chest_xray / fundus / other) loaded
✅ ResNet50 Pneumonia model loaded for Grad-CAM
✅ ResNet50 TB model loaded for Grad-CAM
✅ ResNet50 DR model loaded for Grad-CAM
 * Serving Flask app '__main__'
 * Debug mode: on


 * Running on http://127.0.0.1:5000
Press CTRL+C to quit
127.0.0.1 - - [16/Dec/2025 08:13:05] "GET / HTTP/1.1" 302 -
127.0.0.1 - - [16/Dec/2025 08:13:05] "GET /login HTTP/1.1" 200 -
127.0.0.1 - - [16/Dec/2025 08:13:05] "GET /favicon.ico HTTP/1.1" 404 -
127.0.0.1 - - [16/Dec/2025 08:13:15] "POST /login HTTP/1.1" 302 -
127.0.0.1 - - [16/Dec/2025 08:13:15] "GET /home HTTP/1.1" 200 -
127.0.0.1 - - [16/Dec/2025 08:13:15] "GET /static/img/gemini-3.png HTTP/1.1" 200 -
127.0.0.1 - - [16/Dec/2025 08:13:15] "GET /static/css/style.css HTTP/1.1" 200 -
127.0.0.1 - - [16/Dec/2025 08:13:16] "GET /static/js/main.js HTTP/1.1" 200 -
127.0.0.1 - - [16/Dec/2025 08:13:21] "GET /services HTTP/1.1" 200 -
127.0.0.1 - - [16/Dec/2025 08:13:21] "GET /static/css/style.css HTTP/1.1" 304 -
127.0.0.1 - - [16/Dec/2025 08:13:21] "GET /static/js/main.js HTTP/1.1" 304 -
127.0.0.1 - - [16/Dec/2025 08:13:21] "GET /static/img/pneumonia.png HTTP/1.1" 200 -
127.0.0.1 - - [16/Dec/2025 08:13:21] "GET /static/img/retina.png HTT