In [4]:
%pip install python-multipart

Collecting python-multipart
  Downloading python_multipart-0.0.20-py3-none-any.whl.metadata (1.8 kB)
Downloading python_multipart-0.0.20-py3-none-any.whl (24 kB)
Installing collected packages: python-multipart
Successfully installed python-multipart-0.0.20
Note: you may need to restart the kernel to use updated packages.


In [2]:
# === FastAPI micro-API inside Jupyter ===
import io
from pathlib import Path

import torch
from torch import nn
from torchvision import transforms as T
import torchvision
from PIL import Image

from fastapi import FastAPI, UploadFile, File
from pydantic import BaseModel
import nest_asyncio
import uvicorn

# --- Setup ---
CKPT_DIR = Path("checkpoints_multiclass_strong")
ckpt = torch.load(CKPT_DIR/"best.pt", map_location="cpu")
CLASSES = ckpt["classes"]; IMG_SIZE = ckpt["img_size"]; BACKBONE = ckpt["backbone"]

def build_model(backbone: str, n_classes: int):
    if backbone == "resnet18":
        m = torchvision.models.resnet18(weights=None); m.fc = nn.Linear(m.fc.in_features, n_classes)
    elif backbone == "resnet50":
        m = torchvision.models.resnet50(weights=None); m.fc = nn.Linear(m.fc.in_features, n_classes)
    elif backbone == "efficientnet_b0":
        m = torchvision.models.efficientnet_b0(weights=None); m.classifier[1] = nn.Linear(m.classifier[1].in_features, n_classes)
    elif backbone == "vit_b_16":
        m = torchvision.models.vit_b_16(weights=None); m.heads.head = nn.Linear(m.heads.head.in_features, n_classes)
    else:
        raise ValueError(f"Unknown backbone: {backbone}")
    return m

MODEL = build_model(BACKBONE, len(CLASSES))
MODEL.load_state_dict(ckpt["model_state"]); MODEL.eval()

TFORM = T.Compose([
    T.Resize(int(IMG_SIZE*1.2)), T.CenterCrop(IMG_SIZE),
    T.ToTensor(), T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])

# --- FastAPI app ---
app = FastAPI(title="Damage Multiclass API")

class PredictResponse(BaseModel):
    pred: str
    conf: float

@app.post("/predict", response_model=PredictResponse)
async def predict(file: UploadFile = File(...)):
    img_bytes = await file.read()
    img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
    x = TFORM(img).unsqueeze(0)
    with torch.no_grad():
        logits = MODEL(x)
        probs = torch.softmax(logits, dim=1).squeeze(0)
        top = int(probs.argmax().item())
    return {"pred": CLASSES[top], "conf": float(probs[top])}

In [4]:
import asyncio
import uvicorn
import nest_asyncio

nest_asyncio.apply()

config = uvicorn.Config(app, host="127.0.0.1", port=8000, log_level="info")
server = uvicorn.Server(config)

loop = asyncio.get_event_loop()
task = loop.create_task(server.serve())

print("🚀 FastAPI is running at http://127.0.0.1:8000")

🚀 FastAPI is running at http://127.0.0.1:8000


INFO:     Started server process [36409]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)


In [None]:
import requests

test_img = "disaster-ai/data/xbd/tier1/images/socal-fire_00001323_post_disaster.png"
with open(test_img, "rb") as f:
    r = requests.post("http://127.0.0.1:8000/predict", files={"file": f})
print(r.json())