In [61]:
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi import UploadFile, File, Form, HTTPException
import transformers
from typing import Annotated
from pydantic import BaseModel
import torch
import numpy as np
import uvicorn
import nest_asyncio
import mlflow
import asyncio
import time
import psutil
import requests

In [62]:
app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=['*'],
    allow_credentials=['*'],
    allow_methods=['*'],
    allow_headers=['*'],
)

def logging_run(text):
    log = f"INFO:     {text}"
    print(log)

class InputJson(BaseModel):
    texts: list


# -- must set with database
model_A_url = 'http://0.0.0.0:8000/predict'
model_B_url = 'http://0.0.0.0:8001/predict'
output_other = []


def monitor_hardware():
    cpu_percent = psutil.cpu_percent() 
    memory_percent = psutil.virtual_memory().percent
    return cpu_percent, memory_percent


def logging_mlflow(start_time, item=0):
    latency = time.time() - start_time
    logging_run(f"Latency_{item}: {latency:.5f}s")
    cpu_percent, memory_percent = monitor_hardware()
    mlflow.log_metric(f"Latency_{item}", latency)
    mlflow.log_metric(f"CPU_Usage_{item}", cpu_percent) 
    mlflow.log_metric(f"Memory_Usage_{item}", memory_percent) 


async def model_inference(data, model_url, model_main=True):
    try:
        response = requests.post(model_url, json=data)
        if model_main:
            if response.status_code != 200:
                raise HTTPException(status_code=response.status_code)
        return response.json()
    except:
        raise HTTPException(status_code=500, detail="The model not working")
    


@app.post("/predict")
async def predict(data: InputJson):
    start_time0 = time.time()
    input_data = {"texts": data.texts}
    # -- base server inference
    start_time = time.time()
    response = await model_inference(input_data, model_A_url, True)
    logging_mlflow(start_time, "Base")
    
    # -- other servers inference
    start_time = time.time()
    asyncio.create_task(model_inference(input_data, model_B_url, False))
    logging_mlflow(start_time, "test")

    logging_mlflow(start_time0, "total")

    return response

In [63]:
if __name__ == "__main__":
    nest_asyncio.apply()
    uvicorn.run(app, host="0.0.0.0", port=8002)

INFO:     Started server process [4013]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8002 (Press CTRL+C to quit)


INFO:     Latency_Base: 0.39160s
INFO:     Latency_test: 0.00002s
INFO:     Latency_total: 0.41317s
INFO:     127.0.0.1:55240 - "POST /predict HTTP/1.1" 200 OK


INFO:     Shutting down
INFO:     Waiting for application shutdown.
INFO:     Application shutdown complete.
INFO:     Finished server process [4013]
