In [4]:
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, pipeline
import torch
import os


# Connection
def start_text_generation_api(api_params: dict, model_name: str, model_params: dict):
    app = FastAPI(**api_params)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_name, torch_dtype=torch.float16, trust_remote_code=True, device_map=device
    )
    generation_config = GenerationConfig.from_pretrained(model_name, **model_params)
    text_pipeline = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        generation_config=generation_config,
    )

    class InputText(BaseModel):
        text: str

    @app.post("/generate-text")
    def generate_text(input_text: InputText):
        try:
            generated_text = text_pipeline(input_text.text, max_length=250)[0][
                "generated_text"
            ]
            return {"result": generated_text}
        except Exception as e:
            raise HTTPException(
                status_code=500, detail=f"Error generating text: {str(e)}"
            )

    uvicorn_cmd = f"uvicorn {__name__}:app --host {api_params['host']} --port {api_params['port']} --reload"
    print(
        f"FastAPI is ongoing, please use the following address ：http://{api_params['host']}:{api_params['port']}"
    )
    os.system(uvicorn_cmd)


api_params = {
    "host": "127.0.1.1",
    "port": 8889,
}

model_name = "meta-llama/Llama-2-7b-hf"

model_params = {
    "max_new_tokens": 1024,
    "temperature": 0.0001,
    "top_p": 0.95,
    "do_sample": True,
    "repetition_penalty": 1.15,
}

start_text_generation_api(api_params, model_name, model_params)

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/188 [00:00<?, ?B/s]

FastAPI is ongoing, please use the following address ：http://127.0.1.1:8889


INFO:     Will watch for changes in these directories: ['/home/lujun/local/causalllm']
INFO:     Uvicorn running on http://127.0.1.1:8889 (Press CTRL+C to quit)
INFO:     Started reloader process [749066] using WatchFiles
ERROR:    Error loading ASGI app. Attribute "app" not found in module "__main__".
ERROR:    Error loading ASGI app. Attribute "app" not found in module "__main__".
ERROR:    Error loading ASGI app. Attribute "app" not found in module "__main__".
ERROR:    Error loading ASGI app. Attribute "app" not found in module "__main__".
ERROR:    Error loading ASGI app. Attribute "app" not found in module "__main__".
ERROR:    Error loading ASGI app. Attribute "app" not found in module "__main__".


In [None]:
import requests

input_text = {"text": "Your input is here for the verification."}
response = requests.post("http://127.0.1.1:8889/generate-text", json=input_text)

if response.status_code == 200:
    result = response.json()["result"]
    print(f"Generated Text: {result}")
else:
    print(f"Error: {response.status_code}, {response.text}")