In [3]:
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from datasets import load_dataset
from fastapi import FastAPI
from pydantic import BaseModel
import nest_asyncio
import uvicorn

# Load trained model and tokenizer
model_path = "./intent_classifier"
model = AutoModelForSequenceClassification.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# # Load dataset to get label names
# dataset = load_dataset("clinc_oos", "plus")
# label_names = dataset["train"].features["intent"].names

# Define FastAPI app
app = FastAPI()

# Define request model
class TextRequest(BaseModel):
    text: str

# Inference function
def predict_intent(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
    inputs = {key: val.to(device) for key, val in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    predicted_class = torch.argmax(logits, dim=1).item()
    return predicted_class

# API endpoint
@app.post("/predict")
def predict(request: TextRequest):
    intent = predict_intent(request.text)
    return {"intent": intent}

nest_asyncio.apply()
uvicorn.run(app, host="0.0.0.0", port=8000)


INFO:     Started server process [22368]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)


INFO:     127.0.0.1:57267 - "GET / HTTP/1.1" 404 Not Found
INFO:     127.0.0.1:57267 - "GET /favicon.ico HTTP/1.1" 404 Not Found
INFO:     127.0.0.1:57269 - "GET / HTTP/1.1" 404 Not Found
INFO:     127.0.0.1:57269 - "GET /favicon.ico HTTP/1.1" 404 Not Found
INFO:     127.0.0.1:57269 - "GET /docs/ HTTP/1.1" 307 Temporary Redirect
INFO:     127.0.0.1:57269 - "GET /docs HTTP/1.1" 200 OK
INFO:     127.0.0.1:57269 - "GET /openapi.json HTTP/1.1" 200 OK
INFO:     127.0.0.1:57274 - "POST /predict HTTP/1.1" 200 OK
INFO:     127.0.0.1:57274 - "POST /predict HTTP/1.1" 200 OK
INFO:     127.0.0.1:57275 - "POST /predict HTTP/1.1" 200 OK


INFO:     Shutting down
INFO:     Waiting for application shutdown.
INFO:     Application shutdown complete.
INFO:     Finished server process [22368]
