In [28]:
%%writefile bones_api.py
# bones_api.py
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision import transforms
from PIL import Image, ImageDraw, ImageFont
import io, base64

app = FastAPI(title="Bone Fracture Detection API")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# ---- Load Model ----
def load_model():
    num_classes = 7
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=None)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    model.load_state_dict(torch.load("model.pt", map_location="cpu"))
    model.eval()
    return model

model = load_model()

# ---- Class Names (same as your training dataset) ----
CLASS_NAMES = [
    'elbow positive',
    'fingers positive',
    'forearm fracture',
    'humerus fracture',
    'humerus',
    'shoulder fracture',
    'wrist positive'
]

# ---- Assign colors per class ----
CLASS_COLORS = {
    'elbow positive': "red",
    'fingers positive': "orange",
    'forearm fracture': "green",
    'humerus fracture': "blue",
    'humerus': "purple",
    'shoulder fracture': "yellow",
    'wrist positive': "cyan"
}

# ---- Prediction Endpoint ----
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    try:
        contents = await file.read()
        image = Image.open(io.BytesIO(contents)).convert("RGB")

        transform = transforms.Compose([transforms.ToTensor()])
        img_tensor = transform(image).unsqueeze(0)

        with torch.no_grad():
            outputs = model(img_tensor)

        boxes = outputs[0]['boxes']
        scores = outputs[0]['scores']
        labels = outputs[0]['labels']

        threshold = 0.5
        filtered_boxes = boxes[scores > threshold]
        filtered_scores = scores[scores > threshold]
        filtered_labels = labels[scores > threshold]

        draw = ImageDraw.Draw(image)
        FONT_SIZE = 30
        font = ImageFont.load_default(FONT_SIZE)
        

        for i, box in enumerate(filtered_boxes):
            x1, y1, x2, y2 = box
            label_idx = int(filtered_labels[i].item())
            label_name = CLASS_NAMES[label_idx] if label_idx < len(CLASS_NAMES) else f"cls_{label_idx}"
            color = CLASS_COLORS.get(label_name, "red")
            score = filtered_scores[i].item()

            # Draw bounding box
            draw.rectangle([(x1, y1), (x2, y2)], outline=color, width=3)
            draw.text((x1 + 5, y1 + 5), f"{label_name}: {score:.2f}", fill=color, font=font)

        # Convert image to base64 for React Native
        buffered = io.BytesIO()
        image.save(buffered, format="JPEG")
        img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")

        return {
            "detections": len(filtered_boxes),
            "image_base64": img_str
        }

    except Exception as e:
        return {"error": str(e)}


if __name__ == "__main__":
    uvicorn.run("bones_api:app", host="0.0.0.0", port=8000)


Overwriting bones_api.py


In [2]:
%%writefile bones.py
import streamlit as st
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import cv2
import numpy as np
from PIL import Image
import tempfile

# ---- Setup ----
st.title("ðŸ©» Bone Fracture Detection (Faster R-CNN)")
st.write("Upload an X-ray image to detect fractures using the trained model.")

# ---- Load model ----
@st.cache_resource
def load_model():
    num_classes = 7  # same as training
    classes = ['elbow positive', 'fingers positive', 'forearm fracture', 
               'humerus fracture', 'humerus', 'shoulder fracture', 'wrist positive']
    
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=None)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    model.load_state_dict(torch.load("model.pt", map_location="cpu"))
    model.eval()
    return model, classes

model, classes = load_model()

# ---- File upload ----
uploaded_file = st.file_uploader("Upload an image (jpg/png)", type=["jpg", "jpeg", "png"])

if uploaded_file is not None:
    # Save temporarily
    tfile = tempfile.NamedTemporaryFile(delete=False)
    tfile.write(uploaded_file.read())
    
    # Read and preprocess
    image = Image.open(tfile.name).convert("RGB")
    img_np = np.array(image)
    img_tensor = torchvision.transforms.functional.to_tensor(img_np).unsqueeze(0)

    # Inference
    with torch.no_grad():
        preds = model(img_tensor)

    pred = preds[0]
    boxes = pred['boxes'].cpu().numpy()
    scores = pred['scores'].cpu().numpy()
    labels = pred['labels'].cpu().numpy()

    threshold = st.slider("Confidence threshold", 0.0, 1.0, 0.5, 0.05)
    img_draw = img_np.copy()

    for i, box in enumerate(boxes):
        if scores[i] > threshold:
            (x1, y1, x2, y2) = box.astype(int)
            label = classes[labels[i]]
            cv2.rectangle(img_draw, (x1, y1), (x2, y2), (0,255,0), 2)
            cv2.putText(img_draw, f"{label}: {scores[i]:.2f}", (x1, y1-10), 
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,0,0), 1)

    st.image(img_draw, caption="Detection Results", use_column_width=True)


Writing bones.py


In [4]:
!streamlit run bones.py

^C
