In [None]:
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from keras.models import load_model as keras_load_model
from keras.src.engine.sequential import Sequential
from PIL import Image
import numpy as np
import io
import sys

app = FastAPI()

# Load the model from the given path
def load_model(path: str) -> Sequential:
    return keras_load_model(path)

# Predict the digit from the image data
def predict_digit(model: Sequential, data_point: list) -> str:
    data = np.array(data_point).reshape(1, 28, 28, 1)
    prediction = model.predict(data)
    digit = np.argmax(prediction)
    return str(digit)

# Format the image to a 28x28 grayscale array
def format_image(image_bytes: bytes) -> list:
    image = Image.open(io.BytesIO(image_bytes)).convert('L')
    image = image.resize((28, 28))
    image_array = np.array(image).astype('float32') / 255
    serialized_array = image_array.flatten().tolist()
    return serialized_array

# API endpoint to predict the digit
@app.post('/predict')
async def predict(file: UploadFile = File(...)):
    image_bytes = await file.read()
    data_point = format_image(image_bytes)
    digit = predict_digit(model, data_point)
    return JSONResponse(content={"digit": digit})

if __name__ == "__main__":
    model_path = sys.argv[1]  # Path to the model passed as command line argument
    model = load_model(model_path)
    uvicorn.run(app, host="0.0.0.0", port=8000)
