In [7]:
import nest_asyncio
import uvicorn
from fastapi import FastAPI, UploadFile, File, Form
from PIL import Image
import torch
import numpy as np
import io
import base64
import cv2

In [8]:
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image

# Load your pre-trained Gastric Sentinel Model
# Replace with your actual model loading logic
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
model.eval()

def generate_explanations(img_tensor, img_np, clinical_data):
    # 1. Grad-CAM (Visual)
    target_layers = [model.layer4[-1]]
    cam = GradCAM(model=model, target_layers=target_layers)
    grayscale_cam = cam(input_tensor=img_tensor)[0, :]
    visualization = show_cam_on_image(img_np, grayscale_cam, use_rgb=True)
    
    # 2. SHAP (Feature significance)
    # Placeholder: In production, use shap.Explainer on your clinical branch
    shap_data = {"Age": 0.45, "Genomic_Marker_A": 0.30, "BMI": 0.15} 
    
    return visualization, shap_data

Downloading: "https://github.com/pytorch/vision/zipball/v0.10.0" to C:\Users\KIIT0001/.cache\torch\hub\v0.10.0.zip




Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to C:\Users\KIIT0001/.cache\torch\hub\checkpoints\resnet50-0676ba61.pth


100%|██████████| 97.8M/97.8M [00:21<00:00, 4.73MB/s]


In [9]:
app = FastAPI()

@app.post("/predict")
async def predict(file: UploadFile = File(...), age: str = Form(...), markers: str = Form(...)):
    # Read and preprocess image
    contents = await file.read()
    image = Image.open(io.BytesIO(contents)).convert('RGB').resize((224, 224))
    img_np = np.array(image).astype(np.float32) / 255.0
    img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0)

    # Generate explanations
    heatmap, clinical_shap = generate_explanations(img_tensor, img_np, [age, markers])

    # Convert heatmap to base64 for frontend
    _, buffer = cv2.imencode('.jpg', cv2.cvtColor(heatmap * 255, cv2.COLOR_RGB2BGR))
    heatmap_base64 = base64.b64encode(buffer).decode('utf-8')

    return {
        "prediction": "High Risk" if clinical_shap["Age"] > 0.4 else "Low Risk",
        "heatmap": heatmap_base64,
        "shap_values": clinical_shap
    }

In [None]:
from fastapi.middleware.cors import CORSMiddleware

origins = [
    "http://127.0.0.1:5500",
    "http://localhost:5500",
]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)