In [7]:
import os
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import BertTokenizer, BertForQuestionAnswering
import torch

# Disable wandb logging if not needed
os.environ["WANDB_DISABLED"] = "true"

# Load the model and tokenizer (make sure to load the fine-tuned model)
model = BertForQuestionAnswering.from_pretrained('fine_tuned_model')
tokenizer = BertTokenizer.from_pretrained('fine_tuned_model')

# Create FastAPI app
app = FastAPI()

# Define request body using Pydantic
class QuestionRequest(BaseModel):
    question: str

# Context for all questions (you can modify or dynamically generate if needed)
context = "This document discusses various political issues and positions in the 2024 elections, including economic policies, healthcare, and reform initiatives."

# Answering function
def answer_question(question: str, context: str) -> str:
    inputs = tokenizer(question, context, return_tensors='pt', truncation=True)

    with torch.no_grad():
        outputs = model(**inputs)

    answer_start = torch.argmax(outputs.start_logits)
    answer_end = torch.argmax(outputs.end_logits) + 1

    # Handle case where answer may not be found
    if answer_start >= answer_end:
        return "No answer found"

    # Convert token IDs to string
    answer_tokens = inputs['input_ids'][0][answer_start:answer_end]
    answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(answer_tokens))
    return answer.strip()

# Define the endpoint to handle question-answering requests
@app.post("/answer/")
async def get_answer(request: QuestionRequest):
    question = request.question
    answer = answer_question(question, context)
    return {"question": question, "answer": answer}

# Run the app (you will run this in your terminal, not in the script)
# Use `uvicorn main:app --reload` to start the server
