<a href="https://colab.research.google.com/github/Ayesharani96/heart_detect/blob/main/Flask.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pyngrok

import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image, UnidentifiedImageError
from flask import Flask, request, jsonify
import io
import joblib
import pandas as pd
from pyngrok import ngrok

# -----------------------
# Device
# -----------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -----------------------
# Model Classes
# -----------------------
class HeartRiskResNet(nn.Module):
    def __init__(self, pretrained=False, dropout_p=0.5):
        super().__init__()
        self.base_model = models.resnet18(pretrained=pretrained)
        self.base_model.fc = nn.Identity()
        self.feature_dim = 512
        self.dropout = nn.Dropout(dropout_p)
        self.disease_head = nn.Linear(self.feature_dim, 1)
        self.risk_head = nn.Linear(self.feature_dim, 3)

    def forward(self, x):
        features = self.base_model(x)
        features = self.dropout(features)
        disease_out = torch.sigmoid(self.disease_head(features))
        risk_out = self.risk_head(features)
        return disease_out, risk_out

    def extract_features(self, x):
        return self.base_model(x)


class MobileNetRiskModel(nn.Module):
    def __init__(self, pretrained=False, dropout_p=0.5):
        super().__init__()
        base = models.mobilenet_v2(pretrained=pretrained)
        self.backbone = base.features
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.feature_dim = 1280
        self.dropout = nn.Dropout(dropout_p)
        self.disease_head = nn.Linear(self.feature_dim, 1)
        self.risk_head = nn.Linear(self.feature_dim, 3)

    def forward(self, x):
        x = self.backbone(x)
        x = self.pool(x).view(x.size(0), -1)
        x = self.dropout(x)
        disease = torch.sigmoid(self.disease_head(x))
        risk = self.risk_head(x)
        return disease, risk

    def extract_features(self, x):
        x = self.backbone(x)
        return self.pool(x).view(x.size(0), -1)


class FusionModel(nn.Module):
    def __init__(self, ecg_mod, echo_mod, xray_mod):
        super().__init__()
        self.ecg = ecg_mod
        self.echo = echo_mod
        self.xray = xray_mod
        total = self.ecg.feature_dim + self.echo.feature_dim + self.xray.feature_dim
        self.fusion_layers = nn.Sequential(
            nn.Linear(total, 512), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(512, 128), nn.ReLU(), nn.Dropout(0.4),
        )
        self.disease_head = nn.Linear(128, 1)
        self.risk_head = nn.Linear(128, 3)

    def forward(self, ecg_x, echo_x, xray_x):
        with torch.no_grad():
            f1 = self.ecg.extract_features(ecg_x)
            f2 = self.echo.extract_features(echo_x)
            f3 = self.xray.extract_features(xray_x)
        fused = torch.cat([f1, f2, f3], dim=1)
        h = self.fusion_layers(fused)
        disease_logits = torch.sigmoid(self.disease_head(h))
        risk_logits = self.risk_head(h)
        return disease_logits, risk_logits


# -----------------------
# Load checkpoints
# -----------------------
ckpt_path = "/content/drive/MyDrive/fusion_finetuned.pkl"
ckpt = torch.load(ckpt_path, map_location=device)

ecg_model = HeartRiskResNet(pretrained=False).to(device)
echo_model = MobileNetRiskModel(pretrained=False).to(device)
xray_model = HeartRiskResNet(pretrained=False).to(device)
fusion_model = FusionModel(ecg_model, echo_model, xray_model).to(device)

fusion_model.load_state_dict(ckpt.get("fusion_state_dict", {}), strict=False)
ecg_model.load_state_dict(ckpt.get("ecg_state_dict", {}), strict=False)
echo_model.load_state_dict(ckpt.get("echo_state_dict", {}), strict=False)
xray_model.load_state_dict(ckpt.get("xray_state_dict", {}), strict=False)

fusion_model.eval()
print("✅ Fusion Model loaded successfully!")

# -----------------------
# Load Logistic Regression (Tabular)
# -----------------------
tabular_package = joblib.load("/content/drive/MyDrive/heart_model_final.pkl")
tabular_model = tabular_package["model"]
tabular_scaler = tabular_package["scaler"]
tabular_features = tabular_package["features"]

print("✅ Tabular Model loaded successfully!")

# -----------------------
# Preprocessing
# -----------------------
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]),
])

def preprocess_image(image_bytes):
    try:
        img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        return transform(img).unsqueeze(0).to(device)
    except UnidentifiedImageError:
        return None


# -----------------------
# Flask App
# -----------------------
app = Flask(__name__)
risk_labels = ["Low", "Medium", "High"]

@app.route("/predict", methods=["POST"])
def predict():
    try:
        # ---------- Image Inputs ----------
        ecg_img = preprocess_image(request.files["ecg"].read()) if "ecg" in request.files else None
        echo_img = preprocess_image(request.files["echo"].read()) if "echo" in request.files else None
        xray_img = preprocess_image(request.files["xray"].read()) if "xray" in request.files else None

        images_valid = ecg_img is not None and echo_img is not None and xray_img is not None

        fusion_results = {}
        if images_valid:
            with torch.no_grad():
                fusion_disease_pred, fusion_risk_pred = fusion_model(ecg_img, echo_img, xray_img)
                risk_probs = torch.softmax(fusion_risk_pred, dim=1).cpu().numpy()[0]
                disease_img = float(fusion_disease_pred.item())
                risk_index = risk_probs.argmax()
                risk_level = risk_labels[risk_index]

            fusion_results = {
                "fusion_disease_prob": disease_img,
                "fusion_risk_probs": risk_probs.tolist(),
                "fusion_risk_level": risk_level,
                "images_status": "valid"
            }
        else:
            fusion_results = {
                "fusion_disease_prob": None,
                "fusion_risk_probs": None,
                "fusion_risk_level": None,
                "images_status": "not_match"
            }

        # ---------- Tabular Inputs ----------
        data = request.form.to_dict()
        mapped_data = {
            "Age": int(data.get("age", 0)),
            "Gender": 1 if data.get("gender", "").lower() == "male" else 0,
            "Height": float(data.get("height", 0)),
            "Weight": float(data.get("weight", 0)),
        }

        if "bloodPressure" in data and "/" in data["bloodPressure"]:
            systolic, diastolic = data["bloodPressure"].split("/")
            mapped_data["SystolicBP"] = float(systolic)
            mapped_data["DiastolicBP"] = float(diastolic)
        else:
            mapped_data["SystolicBP"] = 0
            mapped_data["DiastolicBP"] = 0

        mapped_data.update({
            "CholesterolLevel": float(data.get("cholesterolLevel", 0)),
            "SmokingStatus": 1 if data.get("smokingStatus", "").lower() == "yes" else 0,
            "AlcoholConsumption": 1 if data.get("alcoholConsumption", "").lower() == "yes" else 0,
            "FastingBloodSugar": 1 if data.get("fastingBloodSugar", "").lower() == "yes" else 0,
            "ChestPainType": 1 if data.get("chestPainType", "").lower() == "yes" else 0,
        })

        df = pd.DataFrame([mapped_data])
        df = df[tabular_features]  # ensure correct column order

        X_scaled = tabular_scaler.transform(df)
        y_pred = int(tabular_model.predict(X_scaled)[0])
        y_prob = tabular_model.predict_proba(X_scaled)[0].tolist()

        # ---------- Final Decision (based on probability thresholds) ----------
        prob_percent = int(y_prob[1] * 100)  # probability of disease in %

        if prob_percent <= 30:
            final_decision = "Low Risk"
        elif prob_percent <= 55:
            final_decision = "Moderate Risk"
        else:
            final_decision = "High Risk"

        # If images are valid, you can also combine fusion + tabular
        if fusion_results["images_status"] == "valid":
            fusion_results["fusion_prob_percent"] = int(
                fusion_results["fusion_disease_prob"] * 100
            )

        return jsonify({
            "status": "success",
            **fusion_results,
            "tabular_pred": y_pred,
            "tabular_probs": y_prob,
            "final_decision": final_decision
        })


    except Exception as e:
        return jsonify({"error": str(e)}), 400



# -----------------------
# Run with ngrok
# -----------------------
port = 5000
public_url = ngrok.connect(port).public_url
print("🚀 Public URL:", public_url)

app.run(host="0.0.0.0", port=port)