In [None]:
# This notebook demonstrates the API logic

import nest_asyncio
nest_asyncio.apply()

from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
import uvicorn
import io, torch
from PIL import Image
from torchvision import transforms, models

app = FastAPI()

model_path = "../models/saved_model.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor()
])

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    img = Image.open(io.BytesIO(await file.read())).convert("RGB")
    x = transform(img).unsqueeze(0).to(device)
    out = model(x).softmax(1)
    return JSONResponse({"prediction": out.tolist()})

uvicorn.run(app, host="0.0.0.0", port=8000)
