In [2]:
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import time

model_path = "./mistral-combined-finetuned-weights"

tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.float16,
    device_map="auto",
    local_files_only=True
)
model.eval()

# Initialize FastAPI
app = FastAPI(title="Financial Risk Analysis API")

# Define request schema
class AnalysisRequest(BaseModel):
    input_text: str
    max_new_tokens: int = 512
    temperature: float = 0.7
    top_p: float = 0.95

# Define your inference function
def generate_analysis(input_data: str, max_new_tokens=512, temperature=0.7, top_p=0.95):
    prompt = (
        "You are an expert financial risk analyst. Analyze the provided text for financial risks, "
        "and output a structured assessment in JSON format including risk detection, specific risk flags, "
        "financial exposure details, and analysis notes. "
        f"{input_data}"
    )

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        start = time.time()
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=temperature,
            top_p=top_p,
            pad_token_id=tokenizer.eos_token_id
        )
        end = time.time()

    result = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"\n⏱️ Generation time: {end - start:.2f} seconds")
    return result

# Define the endpoint
@app.post("/analyze")
def analyze_risk(request: AnalysisRequest):
    output = generate_analysis(
        input_data=request.input_text,
        max_new_tokens=request.max_new_tokens,
        temperature=request.temperature,
        top_p=request.top_p
    )
    return {"analysis": output}




Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Some parameters are on the meta device because they were offloaded to the cpu and disk.


uvicorn main:app --reload --host 0.0.0.0 --port 8000