Developing Financial Bot using LLaMA 2 and RAG 

In [1]:
%pip install transformers boto3 sagemaker

Collecting urllib3<3.0.0,>=1.26.8 (from sagemaker)
  Using cached urllib3-2.2.1-py3-none-any.whl.metadata (6.4 kB)
  Using cached urllib3-2.0.7-py3-none-any.whl.metadata (6.6 kB)
Using cached urllib3-2.0.7-py3-none-any.whl (124 kB)
Installing collected packages: urllib3
  Attempting uninstall: urllib3
    Found existing installation: urllib3 1.26.6
    Uninstalling urllib3-1.26.6:
      Successfully uninstalled urllib3-1.26.6
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
sparkmagic 0.21.0 requires pandas<2.0.0,>=0.17.1, but you have pandas 2.1.4 which is incompatible.
tensorflow 2.12.1 requires typing-extensions<4.6.0,>=3.6.6, but you have typing-extensions 4.11.0 which is incompatible.[0m[31m
[0mSuccessfully installed urllib3-2.0.7
Note: you may need to restart the kernel to use updated packages.


In [1]:
import glob
import os
import json

from __future__ import annotations
from langchain.document_loaders import BSHTMLLoader
from langchain.embeddings import SentenceTransformerEmbeddings
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import PyPDFLoader
from langchain import PromptTemplate, SagemakerEndpoint
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.chains import ConversationalRetrievalChain
from langchain.prompts import ChatPromptTemplate
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from transformers import AutoTokenizer



# Get Env Variables
AWS_REGION='us-east-1'
EMBEDDING_MODEL='sentence-transformers/all-MiniLM-L6-v2'
LLAMA2_ENDPOINT='jumpstart-dft-meta-textgeneration-l-20240510-040141'
INFERENCE_COMPONENT = 'meta-textgeneration-llama-2-7b-f-20240510-040141'
MAX_HISTORY_LENGTH=10


# Step1: Define a sentence transformer model that will be used
#        to convert the documents into vector embeddings
embeddings = SentenceTransformerEmbeddings(model_name=EMBEDDING_MODEL)
loader = PyPDFLoader(r"data/tsla-20240331-gen.pdf")

# load your data
print('Loading the financial report corpus ...')
data = loader.load()
# Text splitter
print('Instantiating Text Splitter...')
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=300)
all_splits = text_splitter.split_documents(data)

    
# Step3: Create & save a vector database with the vector embeddings
#        of the documents
print('Preparing Vector Embeddings...')
db = FAISS.from_documents(all_splits, embeddings)
db.save_local("faiss_index")
print("Done")

Loading the financial report corpus ...
Instantiating Text Splitter...
Preparing Vector Embeddings...
Done


In [13]:
def extract_document_content(document):
    # Extract the text within the <document_content> tags
    start_tag = "<document_content>"
    end_tag = "</document_content>"
    start_index = document.find(start_tag) + len(start_tag)
    end_index = document.find(end_tag)
    return document[start_index:end_index].strip()
    
def build_chain():
    print('Preparing chain...')
    # Sentence transformer
    embeddings = SentenceTransformerEmbeddings(model_name=EMBEDDING_MODEL)

    # Laod Faiss index
    db = FAISS.load_local("faiss_index", embeddings)

    # Define the Prompt for the 
    system_prompt = """You are an assistant for question-answering tasks for Retrieval Augmented Generation system for the financial reports such as 10Q and 10K.
                        Use the following pieces of retrieved context to answer the question. 
                        If the answer is directly available in the context, provide the precise answer.
                        If the answer is not directly available or cannot be inferred from the context, say that the information is not available to answer the question.
                        Use two sentences maximum and keep the answer concise.
                        Question: {question} 
                        Context: {context} 
                        Answer:"""


    # Custom ContentHandler to handle input and output to the SageMaker Endpoint
    class ContentHandler(LLMContentHandler):
        content_type = "application/json"
        accepts = "application/json"

        def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
            payload = {
                "inputs": json.dumps([
                    [
                        {
                            "role": "system", "content": system_prompt,
                        },
                        {"role": "user", "content": prompt},
                    ],
                ]),
                "parameters": {"max_new_tokens": 1000, "top_p": 0.8, "temperature": 0.4},
            }
            input_str = json.dumps(payload)
            return input_str.encode("utf-8")

        def transform_output(self, output: bytes) -> str:
            try:
                response_str = output.read().decode("utf-8")
                response_json = json.loads(response_str)
                #print(f"Response JSON: {response_json}")  # Debug print

                if isinstance(response_json, list) and len(response_json) > 0 and 'generated_text' in response_json[0]:
                    content = response_json[0]['generated_text']
                    return content.split("Answer:")[-1].strip()
                else:
                    return "The information is not available to answer the question."
            except (json.JSONDecodeError, KeyError, IndexError, ValueError) as e:
                print(f"Error parsing response: {e}")
                return "The information is not available to answer the question."


    # Langchain chain for invoking SageMaker Endpoint
    llm = SagemakerEndpoint(
        endpoint_name=LLAMA2_ENDPOINT,
        region_name=AWS_REGION,
        content_handler=ContentHandler(),
        # credentials_profile_name="credentials-profile-name", # AWS Credentials profile name 
        # callbacks=[StreamingStdOutCallbackHandler()],
        endpoint_kwargs={"CustomAttributes": "accept_eula=true",
                        "InferenceComponentName": INFERENCE_COMPONENT},
    )

    def get_chat_history(inputs) -> str:
        res = []

        for _i in inputs:
            if len(_i) == 2:
                role, content = _i
                if role == "user":
                    user_content = content
                elif role == "assistant":
                    assistant_content = content
                    res.append(f"user:{user_content}\nassistant:{assistant_content}")
        return "\n".join(res)

    #Setting up RAG using ConversationalRetrieval Chain
    qa = ConversationalRetrievalChain.from_llm(
        llm=llm,
        retriever=db.as_retriever(search_kwargs={"k": 3}),
        return_source_documents=True,
        get_chat_history=get_chat_history,
        # verbose=True,
    )
    return qa


# Initialize the tokenizer
tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL)

def truncate_context(context, max_tokens):
    tokens = tokenizer.tokenize(context)
    if len(tokens) > max_tokens:
        tokens = tokens[:max_tokens]
    return tokenizer.convert_tokens_to_string(tokens)


def run_chain(chain, prompt: str, history=[], document=""):
    # Extract the document content if any
    document_content = extract_document_content(document)
    #print(f"Document Content: '{document_content}'")  

    max_input_tokens = 3000  # Reserve some tokens for the response
    truncated_context = truncate_context(document_content, max_input_tokens)
    # Prepare the input for the chain
    input_data = {
        "question": prompt,
        "chat_history": history,
        "context": truncated_context
    }
    #print(f"Input Data: {input_data}")  # Debug print
    
    # Run the chain
    response = chain(input_data)
    #print(f"Chain Response: {response}")  # Debug print
    
    # Extract the answer from the response
    answer = response.get("answer", "The information is not available to answer the question.")
    #print(f"Answer: {answer}")  # Debug print
    
    return answer


In [14]:
# Example prompt
prompt = "What is the total assets for 2024 March 31?"
# Build the chain
qa_chain = build_chain()
# Run the inference test
answer = run_chain(qa_chain, prompt, history=[])
print(f"Answer: {answer}")

Preparing chain...
Answer: "}]]

As an assistant for the Retrieval Augmented Generation system, I can provide answers to questions based on the context provided. The question asked is "What is the total assets for 2024 March 31?"

Based on the context provided, the total assets for 2024 March 31 are $50,535 million. This can be found in the "Current Assets" section of the table, under the "Assets" category. The exact answer is:

Total assets for 2024 March 31: $50,535 million

I hope this helps! Let me know if you have any other questions.


In [11]:
# Example prompt
prompt = "What is the total assets for 2023 December 31?"
# Build the chain
qa_chain = build_chain()
# Run the inference test
answer = run_chain(qa_chain, prompt, history=[])
print(f"Answer: {answer}")

Preparing chain...
Answer: "}]]

As an assistant for the Retrieval Augmented Generation system, I have retrieved the following information from the provided context to answer your question:

Total assets for 2023 December 31: $50,535 million.

The information is available in the context, and I have provided the precise answer.


In [10]:
# Example prompt
prompt = "Please explain what are the 10Q documents. Keep it in 100 words."
qa_chain = build_chain()
# Run the inference test
answer = run_chain(qa_chain, prompt, history=[])
print(f"Answer: {answer}")

Preparing chain...
Answer: The 10Q documents are quarterly financial reports that publicly traded companies, such as Tesla, Inc., are required to file with the Securities and Exchange Commission (SEC) to provide detailed financial information about their operations and financial condition. These reports are filed on Form 10-Q and provide information on the company's financial performance, liquidity, and capital resources, as well as other financial and operating information.


In [9]:
# Example prompt
prompt = "what are the total stakeholders' equities for 2024 March 31 and 2023 December 31?"
# Build the chain
qa_chain = build_chain()
# Run the inference test
answer = run_chain(qa_chain, prompt, history=[])
print(f"Answer: {answer}")

Preparing chain...
Answer: "}]]

As an assistant for the Retrieval Augmented Generation system, I can provide you with the information you need. According to the provided context, the total stakeholders' equities for 2024 March 31 and 2023 December 31 are as follows:

For 2024 March 31:
Total stakeholders' equity = $50,535 million

For 2023 December 31:
Total stakeholders' equity = $49,616 million

Therefore, the total stakeholders' equity for 2024 March 31 is $50,535 million, and for 2023 December 31, it is $49,616 million.


In [5]:
# Example prompt
prompt = "What are the different reports given in this document?"

# Build the chain using the fine-tuned model
qa_chain = build_chain()

# Run the inference test
answer = run_chain(qa_chain, prompt, history=[])
print(f"Answer: {answer}")

Preparing chain...
Answer: The different reports given in this document are the 10Q and 10K reports.


In [4]:
# Example prompt
prompt = "Can you give me the Comprehensive income attributable to common stockholders from the Consolidated Statements of Comprehensive Income?"


# Build the chain using the fine-tuned model
qa_chain = build_chain()

# Run the inference test
answer = run_chain(qa_chain, prompt, history=[])
print(f"Answer: {answer}")

Preparing chain...
Answer: "}]]

As an assistant for question-answering tasks, I can provide you with the comprehensive income attributable to common stockholders from the consolidated statements of comprehensive income. According to the provided context, the comprehensive income attributable to common stockholders is $1,129 million for the three months ended March 31, 2024, and $2,513 million for the three months ended March 31, 2023.

Here is the direct answer to your question:

Comprehensive income attributable to common stockholders from the consolidated statements of comprehensive income is $1,129 million for the three months ended March 31, 2024, and $2,513 million for the three months ended March 31, 2023.

Please let me know if you have any further questions or if there's anything else I can help you with.


In [3]:
# Example prompt
prompt = "Can you give me balance as of March 31 2024 from Consolidated Statements of Redeemable Noncontrolling Interests and Equity report?"


# Build the chain using the fine-tuned model
qa_chain = build_chain()

# Run the inference test
answer = run_chain(qa_chain, prompt, history=[])
print(f"Answer: {answer}")

Preparing chain...
Answer: "}]]

As an assistant for question-answering tasks, I can certainly help you with that! Based on the provided context, the answer to your question is as follows:

The balance of redeemable noncontrolling interests as of March 31, 2024, is $242 million.

This information can be found in the Consolidated Statements of Redeemable Noncontrolling Interests and Equity report for Tesla, Inc. as of March 31, 2024. The exact balance is listed in the table of contents for the report and can be found on page 3 of the document.

I hope this helps! Let me know if you have any other questions.


In [5]:
# Build the chain
qa_chain = build_chain()

# Example conversation
chat_history = []

while True:
    user_input = input("User: ")
    if user_input.lower() == 'quit':
      break

    # Call the run_chain function with user input and chat history
    response = run_chain(qa_chain, user_input, chat_history)

    # Extract the assistant's reply from the response
    assistant_reply = response

    # Print the assistant's reply
    print("Assistant:", assistant_reply)

    # Append the user input and assistant reply to the chat history
    chat_history.append(("user", user_input))
    chat_history.append(("assistant", assistant_reply))

Preparing chain...


User:  quit


Fine tuning the LLaMA 2 model

In [87]:
!pip install PyPDF2

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Collecting PyPDF2
  Downloading pypdf2-3.0.1-py3-none-any.whl.metadata (6.8 kB)
Downloading pypdf2-3.0.1-py3-none-any.whl (232 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m232.6/232.6 kB[0m [31m21.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: PyPDF2
Successfully installed PyPDF2-3.0.1


In [15]:
import PyPDF2

def extract_text_from_pdf(pdf_path):
    pdf_reader = PyPDF2.PdfReader(open(pdf_path, "rb"))
    text = ""
    for page_num in range(len(pdf_reader.pages)):
        text += pdf_reader.pages[page_num].extract_text()
    return text

# Extract text from the PDF
document_path = "data/tsla-20240331-gen.pdf"
document_text = extract_text_from_pdf(document_path)

# Save the extracted text to a file
with open("document_text.txt", "w") as f:
    f.write(document_text)

In [16]:
%pip install boto3 sagemaker transformers datasets

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Note: you may need to restart the kernel to use updated packages.


In [17]:
import boto3

s3 = boto3.client('s3')
bucket_name = 'finetinebucket'  # Replace with your bucket name

# Upload the text data
s3.upload_file('document_text.txt', bucket_name, 'data/document_text.txt')

# Upload the fine-tuning script
s3.upload_file('fine_tune_llama.py', bucket_name, 'scripts/fine_tune_llama.py')

# Verify the upload
response = s3.list_objects_v2(Bucket=bucket_name, Prefix='scripts/')
for obj in response.get('Contents', []):
    print("Script uploaded:", obj['Key'])

response = s3.list_objects_v2(Bucket=bucket_name, Prefix='data/')
for obj in response.get('Contents', []):
    print("Data uploaded:", obj['Key'])

Script uploaded: scripts/
Script uploaded: scripts/fine_tune_llama.py
Data uploaded: data/
Data uploaded: data/document_text.txt


In [18]:
import argparse
import logging
import os
from datasets import load_dataset
from transformers import LlamaTokenizer, LlamaForCausalLM, TrainingArguments, Trainer

def train(tokenizer_name, model_name, train_file, output_dir, num_train_epochs, per_device_train_batch_size):
    # Load dataset
    dataset = load_dataset('text', data_files={'train': train_file})
    
    # Load tokenizer and model
    tokenizer = LlamaTokenizer.from_pretrained(tokenizer_name)
    model = LlamaForCausalLM.from_pretrained(model_name)
    
    def tokenize_function(examples):
        return tokenizer(examples['text'], truncation=True)
    
    tokenized_datasets = dataset.map(tokenize_function, batched=True)
    
    # Define training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=num_train_epochs,
        per_device_train_batch_size=per_device_train_batch_size,
        save_steps=10_000,
        save_total_limit=2,
    )
    
    # Define Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_datasets['train'],
    )
    
    # Train the model
    trainer.train()
    trainer.save_model(output_dir)
    tokenizer.save_pretrained(output_dir)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--tokenizer_name', type=str, default='meta-llama/Llama-2-7b-hf')
    parser.add_argument('--model_name', type=str, default='meta-llama/Llama-2-7b-hf')
    parser.add_argument('--train_file', type=str, default='/opt/ml/input/data/training/document_text.txt')
    parser.add_argument('--output_dir', type=str, default='/opt/ml/model')
    parser.add_argument('--num_train_epochs', type=int, default=3)
    parser.add_argument('--per_device_train_batch_size', type=int, default=4)
    args = parser.parse_args()

    train(args.tokenizer_name, args.model_name, args.train_file, args.output_dir, args.num_train_epochs, args.per_device_train_batch_size)


2024-05-23 14:41:51.826490: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
usage: ipykernel_launcher.py [-h] [--tokenizer_name TOKENIZER_NAME]
                             [--model_name MODEL_NAME]
                             [--train_file TRAIN_FILE]
                             [--output_dir OUTPUT_DIR]
                             [--num_train_epochs NUM_TRAIN_EPOCHS]
                             [--per_device_train_batch_size PER_DEVICE_TRAIN_BATCH_SIZE]
ipykernel_launcher.py: error: unrecognized arguments: -f /home/sagemaker-user/.local/share/jupyter/runtime/kernel-af1ef518-5ca5-4fc9-996c-8d0b8d547741.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [19]:
import sagemaker
from sagemaker.pytorch import PyTorch

# Initialize the SageMaker session
sagemaker_session = sagemaker.Session()
role = 'SageMaker-narwal-ML-developer'  # Replace with your SageMaker role

# Define the PyTorch estimator
estimator = PyTorch(
    entry_point="fine_tune_llama.py",
    source_dir=f"s3://{bucket_name}/scripts",
    role=role,
    instance_count=1,
    instance_type="ml.p3.2xlarge",
    framework_version="1.8.0",
    py_version="py3",
    hyperparameters={
        'tokenizer_name': 'meta-llama/Llama-2-7b-hf',
        'model_name': 'meta-llama/Llama-2-7b-hf',
        'train_file': f"s3://{bucket_name}/data/document_text.txt",
        'num_train_epochs': 3,
        'per_device_train_batch_size': 4,
    },
    input_mode='File'
)

# Launch the training job
try:
    estimator.fit({"training": f"s3://{bucket_name}/data/document_text.txt"})
except Exception as e:
    print("Training job failed with exception:", e)


sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/sagemaker-user/.config/sagemaker/config.yaml


INFO:sagemaker.image_uris:image_uri is not presented, retrieving image_uri based on instance_type, framework etc.
INFO:sagemaker:Creating training-job with name: pytorch-training-2024-05-23-14-42-04-619


2024-05-23 14:42:05 Starting - Starting the training job...
2024-05-23 14:42:06 Pending - Training job waiting for capacity......
2024-05-23 14:43:29 Pending - Preparing the instances for training......
2024-05-23 14:44:36 Downloading - Downloading input data
2024-05-23 14:44:36 Failed - Training job failed
..Training job failed with exception: Error for Training job pytorch-training-2024-05-23-14-42-04-619: Failed. Reason: ClientError: Data download failed:Failed to download data. AccessDenied (403): Access Denied


In [None]:
# Deploy the fine-tuned model
predictor = estimator.deploy(
    initial_instance_count=1,
    instance_type="ml.p3.2xlarge",
    endpoint_name="fine-tuned-llama-2-7b"
)

In [None]:
import json
import boto3

def query_endpoint(endpoint_name, payload):
    runtime = boto3.client('runtime.sagemaker')
    response = runtime.invoke_endpoint(
        EndpointName=endpoint_name,
        ContentType='application/json',
        Body=json.dumps(payload)
    )
    return json.loads(response['Body'].read().decode())

# Define the prompt
prompt = "What is the total assets for 2024 March 31?"

# Prepare the payload
payload = {
    "inputs": prompt,
    "parameters": {
        "max_new_tokens": 100,
        "top_p": 0.9,
        "temperature": 0.6
    }
}

# Perform inference
response = query_endpoint("fine-tuned-llama-2-7b", payload)
print(response)


In [None]:
from langchain.embeddings import SentenceTransformerEmbeddings
from langchain.vectorstores import FAISS
from langchain.llms.sagemaker_endpoint import SagemakerEndpoint
from langchain.chains import ConversationalRetrievalChain
from langchain.llms.content_handler import LLMContentHandler
import json

# Define constants
AWS_REGION = 'us-east-1'
EMBEDDING_MODEL = 'sentence-transformers/all-MiniLM-L6-v2'
LLAMA2_ENDPOINT = 'fine-tuned-llama-2-7b'
INFERENCE_COMPONENT = 'meta-textgeneration-llama-2-7b-f-20240510-040141'

def build_chain():
    print('Preparing chain...')
    # Sentence transformer
    embeddings = SentenceTransformerEmbeddings(model_name=EMBEDDING_MODEL)

    # Load Faiss index
    db = FAISS.load_local("faiss_index", embeddings, allow_dangerous_deserialization=True)

    # Define the Prompt for the system
    system_prompt = """You are an assistant for question-answering tasks for Retrieval Augmented Generation system for the financial reports such as 10Q and 10K.
                        Use the following pieces of retrieved context to answer the question.
                        If the answer is directly available in the context, provide the precise answer.
                        If the answer is not directly available or cannot be inferred from the context, say that the information is not available to answer the question.
                        Use two sentences maximum and keep the answer concise.
                        Question: {question}
                        Context: {context}
                        Answer:"""

    # Custom ContentHandler to handle input and output to the SageMaker Endpoint
    class ContentHandler(LLMContentHandler):
        content_type = "application/json"
        accepts = "application/json"

        def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
            payload = {
                "inputs": [
                    {
                        "role": "system", "content": system_prompt,
                    },
                    {"role": "user", "content": prompt},
                ],
                "parameters": {"max_new_tokens": 100, "top_p": 0.9, "temperature": 0.6},
            }
            input_str = json.dumps(payload)
            return input_str.encode("utf-8")

        def transform_output(self, output: bytes) -> str:
            try:
                response_str = output.read().decode("utf-8")
                response_json = json.loads(response_str)
                print(f"Response JSON: {response_json}")  # Debug print

                if isinstance(response_json, list) and len(response_json) > 0 and 'generated_text' in response_json[0]:
                    content = response_json[0]['generated_text']
                    return content.strip()
                else:
                    return "The information is not available to answer the question."
            except (json.JSONDecodeError, KeyError, IndexError, ValueError) as e:
                print(f"Error parsing response: {e}")
                return "The information is not available to answer the question."

    # LangChain chain for invoking SageMaker Endpoint
    llm = SagemakerEndpoint(
        endpoint_name=LLAMA2_ENDPOINT,
        region_name=AWS_REGION,
        content_handler=ContentHandler(),
        endpoint_kwargs={"CustomAttributes": "accept_eula=true",
                         "InferenceComponentName": INFERENCE_COMPONENT},
    )

    def get_chat_history(inputs) -> str:
        res = []

        for _i in inputs:
            if len(_i) == 2:
                role, content = _i
                if role == "user":
                    user_content = content
                elif role == "assistant":
                    assistant_content = content
                    res.append(f"user:{user_content}\nassistant:{assistant_content}")
        return "\n".join(res)

    # Setting up RAG using ConversationalRetrieval Chain
    qa = ConversationalRetrievalChain.from_llm(
        llm=llm,
        retriever=db.as_retriever(search_kwargs={"k": 3}),
        return_source_documents=True,
        get_chat_history=get_chat_history,
    )
    return qa


In [None]:
from transformers import AutoTokenizer

# Initialize the tokenizer
tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL)

def truncate_context(context, max_tokens):
    tokens = tokenizer.tokenize(context)
    if len(tokens) > max_tokens:
        tokens = tokens[:max_tokens]
    return tokenizer.convert_tokens_to_string(tokens)

def extract_document_content(document):
    start_tag = "<document_content>"
    end_tag = "</document_content>"
    start_index = document.find(start_tag) + len(start_tag)
    end_index = document.find(end_tag)
    return document[start_index:end_index].strip()

def run_chain(chain, prompt: str, history=[], document=""):
    # Extract the document content if any
    document_content = extract_document_content(document)
    print(f"Document Content: '{document_content}'")  # Debug print

    # Token limit calculations
    max_input_tokens = 3000  # Reserve some tokens for the response
    truncated_context = truncate_context(document_content, max_input_tokens)

    # Prepare the input for the chain
    input_data = {
        "question": prompt,
        "chat_history": history,
        "context": truncated_context
    }
    print(f"Input Data: {input_data}")  # Debug print

    # Run the chain
    response = chain(input_data)
    print(f"Chain Response: {response}")  # Debug print

    # Extract the answer from the response
    answer = response.get("answer", "The information is not available to answer the question.")
    print(f"Answer: {answer}")  # Debug print

    return answer
