In [1]:
from fastapi import FastAPI
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import re
import threading
import uvicorn
from fastapi.middleware.cors import CORSMiddleware


In [2]:
# Initialize the FastAPI app
app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["http://localhost:5173"],  # Your frontend URL
    allow_credentials=True,
    allow_methods=["*"],  # Allow all methods (GET, POST, etc.)
    allow_headers=["*"],  # Allow all headers
)

In [3]:
# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load model and tokenizer
model_path = "./opt_collegebot"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path).to(device)

# Set pad token if missing
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
model.eval()

# Define chat function with improved answer extraction
def chat(question, max_new_tokens=150):
    prompt = f"Question: {question}\nAnswer:"
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(device)
    
    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            pad_token_id=tokenizer.pad_token_id,
            do_sample=True,
            top_p=0.9,  # Adjust for diversity
            temperature=0.7,  # Adjust for randomness
            repetition_penalty=2.0,  # Prevent repetition
            eos_token_id=tokenizer.eos_token_id
        )
    
    # Decode and post-process
    decoded = tokenizer.decode(output[0], skip_special_tokens=True)
    
    # Use regex to find the first answer block
    answer_match = re.search(r'Answer:\s*(.*?)(\n\n|Question:|$)', decoded, re.DOTALL)
    if answer_match:
        answer = answer_match.group(1).strip()
    else:
        answer = "Sorry, I couldn't generate a proper response. Please try again."
    
    return answer

# Create Pydantic model for the request body
class ChatRequest(BaseModel):
    question: str

class ChatResponse(BaseModel):
    answer: str

# Define the chat API endpoint
@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(request: ChatRequest):
    user_question = request.question
    try:
        response = chat(user_question)
    except Exception as e:
        return {"error": str(e)}
    return ChatResponse(answer=response)

# Function to run the API using uvicorn in a separate thread
def run_uvicorn():
    uvicorn.run(app, host="127.0.0.1", port=8000)

# Run the server in a separate thread
if __name__ == "__main__":
    threading.Thread(target=run_uvicorn, daemon=True).start()
    print("Server is running in the background.")


Server is running in the background.


INFO:     Started server process [154846]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)


INFO:     127.0.0.1:45938 - "POST /chat HTTP/1.1" 200 OK
