#### Imports


In [1]:
import os
import sys
from flask import Flask, request, jsonify, render_template_string
from threading import Thread
from langchain_community.llms import Ollama
from langchain_ollama import OllamaEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import DocArrayInMemorySearch
from langchain.chains import ConversationalRetrievalChain
from langchain.document_loaders import PyPDFLoader

sys.path.append('../..')

#### Model import 

In [2]:
model = Ollama(
        model="llama3",
        base_url="http://localhost:11434",
        verbose=True,
        temperature=0.0,
    )

  model = Ollama(


#### RAG setup

In [3]:
def load_db(file, chain_type, k):
    loader = PyPDFLoader(file)
    documents = loader.load()

    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
    docs = text_splitter.split_documents(documents)

    embeddings = OllamaEmbeddings(model="llama3")
    db = DocArrayInMemorySearch.from_documents(docs, embeddings)
    retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": k})

    qa = ConversationalRetrievalChain.from_llm(
        llm=model, 
        chain_type=chain_type, 
        retriever=retriever, 
        return_source_documents=True,
        return_generated_question=True,
    )
    return qa 

#### Chatbot

In [4]:
app = Flask(__name__)

@app.route('/ask', methods=['POST'])
def ask():
    global qa_chain  
    if qa_chain is None:
        return jsonify({'error': 'No QA chain initialized'}), 400
    
    try:
        user_input = request.json.get('question')
        if not user_input:
            return jsonify({'error': 'Question is missing in the request'}), 400

        else:
            response = qa_chain.invoke({
                "question": user_input,
                "chat_history": []
            })
        
        serializable_response = {
            'answer': response.get('answer', ''),
            'source_documents': [
                {
                    'source': doc.metadata.get('source', ''),
                    'page': doc.metadata.get('page', ''),
                    'content': doc.page_content
                }
                for doc in response.get('source_documents', [])
            ],
            'generated_question': response.get('generated_question', '')
        }

        return jsonify(serializable_response), 200
    
    except Exception as e:
        return jsonify({'error': str(e)}), 500

@app.route('/load_pdf', methods=['POST'])
def load_pdf():
    global qa_chain 
    file = request.files['file']
    if file:
        file_path = os.path.join('docs', file.filename)
        file.save(file_path)
        
        qa_chain = load_db(file=file_path, chain_type='stuff', k=3)
        
        return jsonify({'message': 'PDF loaded successfully'}), 200
    return jsonify({'error': 'No file provided'}), 400



@app.route('/')
def home():
    return render_template_string("""
        <h1>Welcome to the Chat API</h1>
        <p>Send a POST request to /ask to ask questions.</p>
        <form id="ask-form">
            <label for="question">Write Your Question:</label>
            <input type="text" id="question" name="question" required>
            <button type="submit">Ask</button>
        </form>
        <form action="/load_pdf" method="POST" enctype="multipart/form-data">
            <label for="file">Choose a PDF to load:</label>
            <input type="file" name="file" accept=".pdf" required>
            <button type="submit">Upload PDF</button>
        </form>

        <script>
            document.getElementById('ask-form').onsubmit = async function(event) {
                event.preventDefault();
                
                const question = document.getElementById('question').value;

                const response = await fetch('/ask', {
                    method: 'POST',
                    headers: {
                        'Content-Type': 'application/json',
                    },
                    body: JSON.stringify({ question: question })
                });

                const data = await response.json();
                if (data.error) {
                    alert("Error: " + data.error);
                } else {
                    alert("Response: " + data.response);
                }
            };
        </script>
    """)


def run_flask():
    app.run(debug=True, use_reloader=False, host='localhost', port=5000)

def start_flask_thread():
    flask_thread = Thread(target=run_flask)
    flask_thread.start()
    return flask_thread

In [5]:
flask_thread = start_flask_thread()

 * Serving Flask app '__main__'
 * Debug mode: on


 * Running on http://localhost:5000
[33mPress CTRL+C to quit[0m
Ignoring wrong pointing object 110 0 (offset 0)
127.0.0.1 - - [16/Dec/2024 13:37:25] "POST /load_pdf HTTP/1.1" 200 -
127.0.0.1 - - [16/Dec/2024 13:38:02] "POST /ask HTTP/1.1" 200 -
127.0.0.1 - - [16/Dec/2024 13:38:52] "POST /ask HTTP/1.1" 200 -
127.0.0.1 - - [16/Dec/2024 13:41:36] "POST /load_pdf HTTP/1.1" 200 -
127.0.0.1 - - [16/Dec/2024 13:41:58] "GET / HTTP/1.1" 200 -
