In [3]:
from fastapi import FastAPI, UploadFile, File

In [21]:
import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt

In [22]:
from PIL import Image
from torchvision import models, transforms

In [23]:
app = FastAPI()

In [24]:
model = models.resnet50(pretrained=True)




In [25]:
model.fc = torch.nn.Linear(2048,2)


In [26]:
model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [27]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])


In [31]:
def generate_heatmap(image_path):

    image = cv2.imread(image_path)
    heatmap = cv2.applyColorMap(image, cv2.COLORMAP_JET)
    blended = cv2.addWeighted(image, 0.5, heatmap, 0.5, 0)

    heatmap_path = f"static/heatmap.jpg"
    cv2.imwrite(heatmap_path, blended)

    return f"https://your-api-url.com/{heatmap_path}"


In [32]:
def predict_xray(image_path):

    image = Image.open(image_path).convert("RGB")
    image_tensor = transform(image).unsqueeze(0)

    with torch.no_grad():
        output = model(image_tensor)
        prediction = torch.argmax(output, dim=1).item()

    diagnosis = "Disease Detected" if prediction == 1 else "Normal"
    heatmap_path = generate_heatmap(image_path)

    return diagnosis, heatmap_path


In [35]:
@app.post("/analyze-xray/")
async def analyze_xray(file: UploadFile = File(...)):

    image_path = f"temp/{file.filename}"
    with open(image_path, "wb") as buffer:
        buffer.write(await file.read())

    diagnosis, heatmap_url = predict_xray(image_path)

    return {"diagnosis": diagnosis, "heatmap": heatmap_url}