In [2]:
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi import UploadFile, File, Form, HTTPException
import transformers
from typing import Annotated
from pydantic import BaseModel
import torch
import numpy as np
import uvicorn
import nest_asyncio
import mlflow
import time
import psutil

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=['*'],
    allow_credentials=['*'],
    allow_methods=['*'],
    allow_headers=['*'],
)

def logging_run(text):
    log = f"INFO:     {text}"
    print(log)

class InputJson(BaseModel):
    texts: list

round_number = 4

classifier = transformers.pipeline(model='Movasaghi/finetuning-sentiment-rottentomatoes', 
                                   return_all_scores=True)
# classifier.eval()

def monitor_hardware():
    cpu_percent = psutil.cpu_percent() 
    memory_percent = psutil.virtual_memory().percent
    return cpu_percent, memory_percent


def logging_mlflow(start_time):
    latency = time.time() - start_time
    logging_run(f"Latency: {latency:.5f}s")
    cpu_percent, memory_percent = monitor_hardware()
    mlflow.log_metric("Latency", latency)
    mlflow.log_metric("CPU Usage", cpu_percent) 
    mlflow.log_metric("Memory Usage", memory_percent) 


async def batch_prediction(texts):
    logging_run(f"Batch Inference (Number: {len(texts)})")
    output= classifier(texts)
    results = {
        "sentiment": None,
        "score": None,
        "detail": []
        }
    positive_count = 0
    negative_count = 0
    for i in range(len(output)):
        result = {
            "text": texts[i], 
            "sentiment": None,
            "probability": None
            }
        if output[i][0]['score'] > output[i][1]['score']:
            result['sentiment'] = "Negative"
            result['probability'] = round(output[i][0]['score'], round_number)
            negative_count += 1
        else:
            result['sentiment'] = "Positive"
            result['probability'] = round(output[i][1]['score'], round_number)
            positive_count += 1
        results["detail"].append(result)
    results['sentiment'] = "Positive" if positive_count > negative_count else "Negative"
    results['score'] = round(positive_count / (negative_count + positive_count), round_number)
    return results



async def single_prediction(text):
    logging_run(f"Single Inference")
    output= classifier(text)
    result = {
        "text": text, 
        "sentiment": None,
        "probability": None
        }
    if output[0][0]['score'] > output[0][1]['score']:
        result['sentiment'] = "Negative"
        result['probability'] = round(output[0][0]['score'], round_number)
    else:
        result['sentiment'] = "Positive"
        result['probability'] = round(output[0][1]['score'], round_number)
    return result



@app.post("/predict")
async def predict(data: InputJson):
    start_time = time.time()

    result = None
    if len(data.texts) > 1:
        result = await batch_prediction(data.texts)
    elif len(data.texts) == 1:
        result = await single_prediction(data.texts[0])
    else:
        raise HTTPException(status_code=400, detail="You must send at least one text.")

    logging_mlflow(start_time)

    return result

In [8]:
if __name__ == "__main__":
    nest_asyncio.apply()
    uvicorn.run(app, host="0.0.0.0", port=8001)

INFO:     Started server process [5105]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8001 (Press CTRL+C to quit)


INFO:     Batch Inference (Number: 6)
INFO:     Latency: 0.13576s
INFO:     127.0.0.1:55000 - "POST /predict HTTP/1.1" 200 OK
INFO:     Batch Inference (Number: 6)
INFO:     Latency: 0.20246s
INFO:     127.0.0.1:55045 - "POST /predict HTTP/1.1" 200 OK
INFO:     Batch Inference (Number: 6)
INFO:     Latency: 0.19129s
INFO:     127.0.0.1:55242 - "POST /predict HTTP/1.1" 200 OK


INFO:     Shutting down
INFO:     Waiting for application shutdown.
INFO:     Application shutdown complete.
INFO:     Finished server process [5105]
