In [1]:
!pip install python-multipart

Collecting python-multipart
  Using cached python_multipart-0.0.12-py3-none-any.whl.metadata (1.9 kB)
Using cached python_multipart-0.0.12-py3-none-any.whl (23 kB)
Installing collected packages: python-multipart
Successfully installed python-multipart-0.0.12


In [None]:
from fastapi import FastAPI, File, UploadFile
from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import io
import nest_asyncio
import uvicorn
import requests

app = FastAPI()

# Define your model architecture (use the same architecture as during training)
class LogisticRegression(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_layers, dropout_rate):
        super(LogisticRegression, self).__init__()
        layers = []
        layers.append(nn.Linear(input_dim, hidden_layers[0]))
        layers.append(nn.ReLU())
        layers.append(nn.Dropout(dropout_rate))

        for i in range(1, len(hidden_layers)):
            layers.append(nn.Linear(hidden_layers[i - 1], hidden_layers[i]))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout_rate))

        layers.append(nn.Linear(hidden_layers[-1], output_dim))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input
        return self.model(x)

# Initialize model and load saved weights
# Update the hidden layers to match the saved model
model = LogisticRegression(input_dim=28*28, output_dim=10, hidden_layers=[512, 256, 128], dropout_rate=0.3)
model.load_state_dict(torch.load("model.pth", map_location=torch.device('cpu')))
model.eval()  # Set the model to evaluation mode

@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
    image = Image.open(io.BytesIO(await file.read())).convert('L')  # Convert to grayscale
    
    # Preprocess the image
    transform = transforms.Compose([
        transforms.Resize((28, 28)),  # Resize the image to match MNIST input size
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    image_tensor = transform(image).unsqueeze(0)  # Add batch dimension

    # Perform model inference
    with torch.no_grad():
        output = model(image_tensor)
        prediction = torch.argmax(output, dim=1).item()  # Get the predicted class
    
    return {"prediction": prediction}

nest_asyncio.apply()

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)



INFO:     Started server process [473]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
