<a href="https://colab.research.google.com/github/Sidhtang/bert-project/blob/main/fast_api_for_the_mode_l.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [9]:
import torch
from transformers import AutoTokenizer
import os

# Check if MPS is available and set the device accordingly
device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')
print(f"Using device: {device}")

# Load the model
model_name = "/content/distilbert_general_router_model.pth"

# Check if the model file exists and is not empty
if not os.path.exists(model_name) or os.stat(model_name).st_size == 0:
    print(f"Error: Model file not found or empty: {model_name}")
else:
    # Load the tokenizer
    tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

    try:
        model = torch.load(model_name, map_location=device)
    except RuntimeError as e:
        print(f"Error loading model: {e}")
    else:
        model.eval()  # Put the model in evaluation mode
        model.to(device)

        # Function for inference
        def classify_query(query):
            inputs = tokenizer(query, return_tensors='pt', max_length=256, truncation=True, padding=True)
            inputs = {k: v.to(device) for k, v in inputs.items()}  # Move inputs to the correct device
            with torch.no_grad():
                outputs = model(**inputs)

            logits = outputs.logits
            prediction = torch.argmax(logits, dim=1).item()
            return "Personalization" if prediction == 1 else "Customer_support"

        # Test the model with queries
        queries = [
            "Hey there, you guys got some nice hoodies for me?"
        ]
        for query in queries:
            result = classify_query(query)
            print(f"Query: {query}")
            print(f"Prediction: {result}")
            print()

Using device: cpu


  model = torch.load(model_name, map_location=device)


Query: Hey there, you guys got some nice hoodies for me?
Prediction: Personalization



In [None]:
!pip install fastapi

Collecting fastapi
  Downloading fastapi-0.115.0-py3-none-any.whl.metadata (27 kB)
Collecting starlette<0.39.0,>=0.37.2 (from fastapi)
  Downloading starlette-0.38.6-py3-none-any.whl.metadata (6.0 kB)
Downloading fastapi-0.115.0-py3-none-any.whl (94 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m94.6/94.6 kB[0m [31m8.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading starlette-0.38.6-py3-none-any.whl (71 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m71.5/71.5 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: starlette, fastapi
Successfully installed fastapi-0.115.0 starlette-0.38.6


In [None]:
!pip install uvicorn

Collecting uvicorn
  Downloading uvicorn-0.31.0-py3-none-any.whl.metadata (6.6 kB)
Collecting h11>=0.8 (from uvicorn)
  Downloading h11-0.14.0-py3-none-any.whl.metadata (8.2 kB)
Downloading uvicorn-0.31.0-py3-none-any.whl (63 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading h11-0.14.0-py3-none-any.whl (58 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/58.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.3/58.3 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: h11, uvicorn
Successfully installed h11-0.14.0 uvicorn-0.31.0


In [None]:
import torch
from transformers import AutoTokenizer
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import os
import asyncio
import uvicorn

# Initialize FastAPI app
app = FastAPI(title="Text Classification API")

# Check if MPS is available and set the device accordingly
device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')
print(f"Using device: {device}")

# Load the model
model_name = "/content/distilbert_general_router_model.pth"

# Check if the model file exists and is not empty
if not os.path.exists(model_name) or os.stat(model_name).st_size == 0:
    raise FileNotFoundError(f"Error: Model file not found or empty: {model_name}")

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

try:
    model = torch.load(model_name, map_location=device)
    model.eval()  # Put the model in evaluation mode
    model.to(device)
except RuntimeError as e:
    raise RuntimeError(f"Error loading model: {e}")

class Query(BaseModel):
    text: str

def classify_query(query: str):
    inputs = tokenizer(query, return_tensors='pt', max_length=256, truncation=True, padding=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}  # Move inputs to the correct device
    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits
    prediction = torch.argmax(logits, dim=1).item()
    return "Personalization" if prediction == 1 else "Customer_support"

@app.post("/classify")
async def classify_text(query: Query):
    try:
        result = classify_query(query.text)
        return {"query": query.text, "classification": result}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/")
async def root():
    return {"message": "Welcome to the Text Classification API"}

if __name__ == "__main__":
    # Use uvicorn.Config to create a server instance
    config = uvicorn.Config(app, host="0.0.0.0", port=8000)
    server = uvicorn.Server(config)

    # Use asyncio to run the server in the current event loop
    loop = asyncio.get_event_loop()
    loop.run_until_complete(server.serve())

  model = torch.load(model_name, map_location=device)


Using device: cpu


RuntimeError: This event loop is already running