In [1]:
import os
import shutil
import tempfile

import requests
from bs4 import BeautifulSoup
from langchain.chains import RetrievalQA
from langchain.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain_openai import OpenAI, OpenAIEmbeddings
from langchain_ollama import OllamaLLM
from termcolor import colored
import re
from langchain.prompts import PromptTemplate

import mlflow

In [2]:
def create_faiss_database(document_path, database_save_directory, chunk_size=500, chunk_overlap=10):
    """
    Creates and saves a FAISS database using documents from the specified file.

    Args:

retrievalQA = RetrievalQA.from_llm(llm=OpenAI(), retriever=vector_db.as_retriever())


# Log the retrievalQA chain
def load_retriever(persist_directory):
    embeddings = OpenAIEmbeddings()
    vectorstore = FAISS.load_local(
        document_path (str): Path to the file containing documents.
        database_save_directory (str): Directory where the FAISS database will be saved.
        chunk_size (int, optional): Size of each document chunk. Default is 500.
        chunk_overlap (int, optional): Overlap between consecutive chunks. Default is 10.

    Returns:
        FAISS database instance.
    """
    # Load documents from the specified file
    document_loader = TextLoader(document_path)
    raw_documents = document_loader.load()

    # Split documents into smaller chunks with specified size and overlap
    document_splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
    document_chunks = document_splitter.split_documents(raw_documents)

    # Generate embeddings for each document chunk
    embedding_generator = OpenAIEmbeddings()
    faiss_database = FAISS.from_documents(document_chunks, embedding_generator)

    # Save the FAISS database to the specified directory
    faiss_database.save_local(database_save_directory)

    return faiss_database

def print_answer_formatted(answer, max_line_length=100):
    """
    Prints the answer with the following requirements:
    1. Max length of each line is 160.
    2. <think> ... </think> content is printed in a light color.
    3. After <think> content, print 2 empty lines.
    """
    # Extract <think> ... </think> content
    think_match = re.search(r"<think>(.*?)</think>", answer, re.DOTALL)
    if think_match:
        think_content = think_match.group(1).strip()
        rest_content = answer.replace(think_match.group(0), "").strip()
    else:
        think_content = ""
        rest_content = answer

    # Helper to print with max line length
    def print_wrapped(text, color=None):
        words = text.split()
        line = ""
        for word in words:
            if len(line) + len(word) + 1 <= max_line_length:
                line += word + " "
            else:
                if color:
                    print(colored(line.rstrip(), color))
                else:
                    print(line.rstrip())
                line = word + " "
        if line:
            if color:
                print(colored(line.rstrip(), color))
            else:
                print(line.rstrip())

    # Print <think> content in light color (e.g., 'cyan')
    if think_content:
        print_wrapped(think_content, color="cyan")
        print("\n")

    # Print the rest
    if rest_content:
        print_wrapped(rest_content)

In [3]:
temporary_directory = tempfile.mkdtemp()

# doc_path = os.path.join(temporary_directory, "docs.txt")
doc_path = "local_text/paper.txt"
persist_dir = os.path.join(temporary_directory, "faiss_index")

# fetch_and_save_documents(url_listings, doc_path)

vector_db = create_faiss_database(doc_path, persist_dir)

Created a chunk of size 1701, which is longer than the specified 500
Created a chunk of size 865, which is longer than the specified 500
Created a chunk of size 1327, which is longer than the specified 500
Created a chunk of size 1495, which is longer than the specified 500
Created a chunk of size 2024, which is longer than the specified 500
Created a chunk of size 971, which is longer than the specified 500
Created a chunk of size 680, which is longer than the specified 500
Created a chunk of size 855, which is longer than the specified 500
Created a chunk of size 737, which is longer than the specified 500
Created a chunk of size 1842, which is longer than the specified 500
Created a chunk of size 1267, which is longer than the specified 500
Created a chunk of size 517, which is longer than the specified 500
Created a chunk of size 590, which is longer than the specified 500
Created a chunk of size 665, which is longer than the specified 500
Created a chunk of size 1169, which is lon

In [4]:
mlflow.set_experiment("Ollama RAG")
mlflow.openai.autolog()


code_path = "ollama_pyfunction.py"
with mlflow.start_run() as run:
    model_info = mlflow.pyfunc.log_model(
        name="gemma3-12b",
        python_model=code_path,
        artifacts={
            "persist_directory": persist_dir,
        },
    )




Downloading artifacts:   0%|          | 0/2 [00:00<?, ?it/s]



In [None]:
model_info.model_uri

In [5]:
loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)



In [6]:
query1 = {"query": "What is GCPATr?"}
answer1 = loaded_model.predict(query1)
print_answer_formatted (answer1['result'])

GCPATr is a Graph Convolution Position Aware Transformer architecture that enhances the
non-Euclidean interdependency modeling power of PATr by incorporating graph convolution operations.


In [None]:
vector_db.as_retriever().get_relevant_documents("What is GCPATr?")

AttributeError: 'PyFuncModel' object has no attribute 'vectorstore'

In [None]:
answer2 = loaded_model.predict({"query": "Can you summerize this paper?"})
print_answer_formatted (answer2['result'])

In [None]:
answer3 = loaded_model.predict({"query": "Can you repeat my previous question?"})
print_answer_formatted (answer3['result'])

In [None]:
loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)

In [None]:
model_info.model_uri

In [None]:
loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)