# PMC RAG Demo and Deep Dive
## Why RAG
By embracing RAG, you can unlock a range of benefits for your organization:

 * Improved decision-making: Accessing rights and trustworthy information empowers better choices and strategies.
 * Enhanced customer experience: Delivering reliable answers and insights builds trust and satisfaction.
 * Reduced risk and compliance: Curated data sources minimize the risk of misinformation and ensure compliance with regulations.
 * Increased efficiency: Streamlining access to information saves time and resources.
 
The beauty of RAG lies in its focus on data quality, not just data quantity. We're moving beyond the “bigger is better” mentality of massive models trained on internet data that often include misinformation and biases. RAG puts the emphasis on smaller, more valuable models that use curated, trustworthy data sources.


## The Standard Rag Workflow Empowerd by TileDB

1. **User:** Uploads documents, and the system converts them into vectors (numeric representations) using sentence embeddings. In our case, the user submits a text file that the system then uses to scrape Pub Med Central.
2. **User:** Stores these document vectors, along with the documents and metadata, into TileDB (a smart database for storing vectors).
3. **User:** Asks a question.
4. **Embedding Model:** Processes the user's question by embedding it into a vector and sends this vector to TileDB.
5. **VectorDB:** Searches through the stored document vectors and retrieves the most relevant documents.
6. **Retriever:** Takes these relevant documents and constructs a new query for the LLM, instructing it to use these documents as context.
7. **LLM:** Uses the relevant documents as context to generate and deliver the final answer back to the user.

## PHASE 1 The Notebook and Ingestion Pipelines
Below is code that you can run in order. The ingestion process will take a bit, but we will discuss further in this document how you can improve and iterate on this example depending on your objectives.

### Imports
The below imports are for our local file push to our remote repo as well as the `initialize_step()` that will build our end to end pipeline for us and submit a dag directly. 

In [None]:
import requests
import os
import pandas as pd
import tarfile
import urllib.request
import xml.etree.ElementTree as ET
import hashlib
import subprocess
import shutil
import math
import warnings
import time
from datetime import datetime
from typing import Optional, List, Union, Dict, Iterator

# Langchain Imports
from langchain._api import LangChainDeprecationWarning
warnings.filterwarnings("ignore")
warnings.filterwarnings("ignore", category=DeprecationWarning) 
warnings.filterwarnings("ignore", category=LangChainDeprecationWarning)
from langchain.document_loaders import PyPDFLoader
from langchain.document_loaders.parsers.pdf import PyPDFParser
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.document_loaders.pdf import BasePDFLoader
from langchain.document_loaders.blob_loaders import Blob
from langchain.schema import Document
from langchain.memory import ConversationSummaryMemory
from langchain.chains import ConversationalRetrievalChain, ConversationChain
from langchain.vectorstores.tiledb import TileDB

# TileDB Imports 
import tiledb
import tiledb.cloud
from tiledb.cloud.dag import dag
from tiledb.vector_search.object_api import object_index
from tiledb.vector_search.object_readers import DirectoryTextReader
from tiledb.vector_search.embeddings import SentenceTransformersEmbedding, LangChainEmbedding
from tiledb.cloud import groups
from langchain_community.embeddings import HuggingFaceEmbeddings
from sentence_transformers import SentenceTransformer  # Ensure this is correctly imported
# Environment Variables 
os.environ['TOKENIZERS_PARALLELISM'] = "true"

### Credentials
The below cell is so we can cache our credentials during an initial push for our steps. After you push the local file, you may need to manually enter creds and push the file from the terminal. Afterward, the credentials should be cached and you can run without issues. 

In [None]:
#before running below please run this. The first push to the repo (if necessary) may fail and you will need to manually push thie file.
#after, the credentials should be temporarily cached. 
#!git config --global user.email "you@example.com"
#!git config --global user.name "Your Name"
#!git config --global credential.helper cache
#!git config --global credential.helper 'cache --timeout=3600'
#!git config --global credential.helper store

### Upload "pipeline" Helpers
For now, this is a quick simulation of a pipeline for updating the local file. The first cell is full ofhelper functions for our notebook pipeline (ran in our notebook). 

In [None]:
def add_and_commit_files(message: str):
    """
    Adds all files to the Git staging area, commits them with a provided message, and pushes the changes to the remote repository.
    
    :param message: Commit message to be used in the Git commit.
    """
    
    # Print the current working directory
    print(f"Working directory is {os.getcwd()}")

    # Stage all changes (new, modified, deleted) in the current Git repository
    subprocess.run(["git", "add", "-A"])

    # Commit the staged changes with the provided commit message
    subprocess.run(["git", "commit", "-m", f"{message}"])

    # Push the committed changes to the remote repository (origin/master by default)
    subprocess.run(["git", "push"])

def hash_file(file_path: str) -> str:
    """
    Computes the SHA256 hash of the contents of a file.

    :param file_path: Path to the file to be hashed.
    :return: The SHA256 hash of the file contents as a hex string.
    """
    
    # Create a new SHA256 hash object
    hasher = hashlib.sha256()
    
    # Open the file in binary read mode
    with open(file_path, 'rb') as f:
        # Read the file contents and update the hash object
        buffer = f.read()
        hasher.update(buffer)
    
    # Return the hexadecimal digest of the hash
    return hasher.hexdigest()

def estimate_resources(job_df: pd.DataFrame):
    """
    Estimates the CPU and memory usage required for processing a given DataFrame, with an additional 2 GB overhead, 
    and rounds the total memory usage to the nearest GB.

    :param job_df: A pandas DataFrame representing job data.
    :return: A tuple containing the estimated number of CPUs per job and the total memory usage in GB (rounded).
    """
    
    # Print the data types of each column in the DataFrame for reference
    print("Data types of the DataFrame:")
    print(job_df.dtypes)

    # Calculate the memory usage of each column in the DataFrame in MB (deep=True considers the actual memory usage)
    memory_usage_per_column = job_df.memory_usage(deep=True) / (1024 ** 2)  # Convert bytes to MB
    total_memory_usage = memory_usage_per_column.sum()  # Sum of the memory usage of all columns
    
    # Convert total memory usage to GB and add an additional 2 GB overhead for processing
    total_memory_usage_gb = total_memory_usage / 1024  # Convert MB to GB
    total_memory_usage_gb += 2  # Add 2 GB overhead for processing
    
    # Round up the total memory usage to the nearest GB
    rounded_memory_usage_gb = math.ceil(total_memory_usage_gb)
    
    # Print the memory usage for each column and the total estimated memory usage
    print("Memory usage per column (MB):")
    print(memory_usage_per_column)
    
    print(f"Total memory usage for job (GB, rounded to the nearest GB, with overhead): {rounded_memory_usage_gb}")

    # Estimate the CPU usage per job (adjustable based on job complexity)
    cpu_per_job = 1  # Example: assuming 1 CPU per job
    
    # Return the estimated CPU count and the rounded memory usage in GB
    return cpu_per_job, rounded_memory_usage_gb


### The "Upload Pipeline" 
This next cell is similar to a DevOps runner pipeline where a user would commit an updated file and the run would create a container with a run hash as the tag. The goal here is to determine if there is a change in the local file and update the hash. The next stage of the pipeline will/would use the hash to determine if a run is necessary. The frequency of checking a run really depends on the frequency of the file update and how we want to tune it to adjust the quality of our LLMs outputs. Fututre state this could be a webhook upon git update, or s3 storage. The benefit of git for this is the "GitOps" like workflow of tracking changes to our ingestion documents file for RAG and then using that knowledge to understand the impact on our model outputs. 

#### The Code Below in Plain English

1. **Change to Home Directory:** It starts by navigating to the user's home directory.

2. **Clone or Pull Repository:** It checks if a given repository (based on the URL) already exists locally. If it exists, it updates the repository by pulling the latest changes. If it doesn't exist, it clones the repository from the given URL.

3. **Hash a Local File:** It calculates a hash (unique identifier) of a local file's contents to check if it has been modified.

4. **Compare the Hash:** It checks whether a previously saved hash (in a hash.txt file) exists. If it does, the new hash is compared to the saved one to determine if the file has changed.

5. **Skip or Proceed:** If the file hasn't changed (i.e., the current hash matches the previous one), it skips any further action. If the file has changed, it proceeds.

6. **Update the Hash File:** It writes the new hash into a hash.txt file in the repository.

7. **Commit Changes to Git:** Finally, it navigates to the repository directory, adds, commits, and pushes the new or modified file and updated hash to the Git repository

In [None]:
def handle_local_file(repo_url: str, local_file_name: str, hash_file_name: str):
    """
    Clones or pulls a Git repository, checks if a local file has changed by comparing its hash with a stored hash,
    and if the file has been modified, pushes the updated file and hash to the repository.
    
    :param repo_url: The URL of the Git repository.
    :param local_file_name: The local file whose hash will be checked for modifications.
    :param hash_file_name: The file in the repo that stores the previous hash of the local file.
    """
    
    # Step 1: Navigate to the home directory to clone or pull the repo
    home_directory = os.path.expanduser("~")  # Get the user's home directory
    os.chdir(home_directory)  # Change the working directory to the home directory
    
    # Extract the repository name from the repo URL (assumes .git format at the end)
    repo_name = repo_url.split('/')[-1].replace('.git', '')
    
    # Check if the repository already exists locally
    if os.path.exists(repo_name): 
        # If the repo exists, navigate into it and pull the latest changes
        os.chdir(repo_name)
        subprocess.run(["git", "pull"])  # Pull latest changes from the remote repo
        os.chdir("..")  # Go back to the home directory after pulling changes
    else:
        # If the repo doesn't exist, clone it from the provided URL
        subprocess.run(["git", "clone", repo_url])  # Clone the repository
    
    # Get the full path to the local repository
    local_repo_path = os.path.join(os.getcwd(), repo_name)
    print(f"Created {local_repo_path}")  # Print the path to the repo
    
    # Step 2: Hash the contents of the local file
    current_hash = hash_file(local_file_name)  # Hash the local file
    
    # Step 3: Check if the hash file exists in the repository and compare hashes
    hash_file_path = os.path.join(local_repo_path, hash_file_name)  # Path to the hash file in the repo
    
    if os.path.exists(hash_file_path):
        # If the hash file exists, read the previous hash
        with open(hash_file_path, 'r') as f:
            previous_hash = f.read().strip()  # Strip any extra whitespace
    else:
        previous_hash = ""  # If the hash file doesn't exist, assume no previous hash
    
    # Step 4: Compare the current hash with the previous hash
    if current_hash == previous_hash:
        # If the hashes match, the file hasn't changed
        print("File has not changed, skipping submission.")
        return
    else:
        # If the hashes differ, the file has changed, so proceed with updating the repo
        print("File has changed, proceeding with submission.")
    
    # Step 5: Update the hash file in the repository with the new hash
    with open(hash_file_path, 'w') as f:
        f.write(current_hash)  # Write the new hash to the file
    
    # Step 6: Commit the changes (the updated file and the new hash) to the Git repository
    os.chdir(local_repo_path)  # Change directory to the repo
    add_and_commit_files("added local articles file to the repository")  # Stage, commit, and push the changes


In [None]:
# Example usage
#handle_local_file(
#    repo_url="https://github.com/TileDB-Inc/pmc-llm.git", 
#    local_file_name="rag-article-list.txt", 
#    hash_file_name="hash.txt"
#)

### Ingestion Tasks for the TileDB DAG

In [None]:
def consolidate_chunks(total_jobs, bucket_region, object_path, prefix):
    """
    Consolidates a list of chunked files into a single file in S3 and deletes the chunked files afterward.

    :param total_jobs: Number of chunked files to process
    :param bucket_region: The S3 bucket region
    :param object_path: The S3 path to the files
    :param prefix: The prefix of the chunked files
    """
    import tiledb
    
    # Create a TileDB context with the specified S3 region
    ctx = tiledb.Ctx({"vfs.s3.region": bucket_region})
    # Initialize the TileDB VFS (Virtual File System) to interact with S3
    vfs = tiledb.VFS(ctx=ctx)
    
    # Initialize tmp_data as an empty bytes object to store the consolidated data
    tmp_data = b""  # Use bytes since the files are being read in binary mode
    
    # Loop through all the chunked files and concatenate them
    for i in range(total_jobs):
        # Construct the file path for each chunk
        file_path = f"{object_path}/history/{prefix}_{i}"
        
        # Check if the file exists in the S3 bucket
        if vfs.is_file(file_path):
            # If the file exists, open it in binary read mode and append its contents
            with vfs.open(file_path, 'rb') as f:
                tmp_data += f.read()  # Read and concatenate the binary content
            
            # Delete the chunked file after reading
            vfs.remove_file(file_path)
            print(f"Deleted chunk file: {file_path}")
        else:
            # If the file doesn't exist, print a message and continue to the next file
            print(f"No such file path {file_path}. Moving onto the next file.")
            continue
    
    # Now `tmp_data` contains the concatenated data from all chunked files
    # Define the path for the consolidated file in the S3 bucket
    consolidated_file_path = f"{object_path}/history/consolidated_{prefix}"
    
    # Open the consolidated file in binary write mode and write the consolidated data
    with vfs.open(consolidated_file_path, 'wb') as f:
        f.write(tmp_data)  # Write the concatenated binary data to the new file
    
    # Print a success message with the path of the consolidated file
    print(f"Consolidation complete. Consolidated file uploaded to: {consolidated_file_path}")


In [None]:
def delete_unwanted_files(bucket_region, bucket_path):
    """
    Deletes files in an S3 bucket that are not listed in the 'consolidated_file_history' file.
    
    :param bucket_region: The S3 bucket region.
    :param bucket_path: The S3 path to the bucket.
    """
    import tiledb
    import os
    
    # Initialize TileDB context and VFS with the specified S3 region
    ctx = tiledb.Ctx({"vfs.s3.region": bucket_region})
    vfs = tiledb.VFS(ctx=ctx)
    
    # Path to the 'consolidated_file_history' file in S3
    consolidated_history_path = os.path.join(bucket_path, "history/consolidated_file_history")
    
    # Step 1: Load the list of valid files from 'consolidated_file_history'
    valid_files = set()  # Initialize an empty set to store valid file names
    if vfs.is_file(consolidated_history_path):
        # If 'consolidated_file_history' exists, open and read it
        with vfs.open(consolidated_history_path, 'rb') as f:
            # Read each line, decode from bytes to string, and strip newline characters
            valid_files = set(line.decode('utf-8').strip() for line in f.read().splitlines())
    else:
        # Raise an error if the 'consolidated_file_history' file is not found
        raise FileNotFoundError(f"{consolidated_history_path} not found.")
    
    # Step 2: List all files in the S3 bucket
    all_files = []  # Initialize a list to store all file paths
    if vfs.is_dir(bucket_path):
        # If the bucket is a directory, list all files
        for file in vfs.ls(bucket_path):
            if not file.endswith('/'):  # Skip directories
                all_files.append(file)
    
    print("Begin the search!")
    
    # Step 3: Delete unwanted files
    for file_path in all_files:
        # Extract the file name from the full path
        file_name = os.path.basename(file_path)
        
        # Skip the 'consolidated_file_history' file to avoid deleting it
        if file_name == "consolidated_file_history":
            print(f"Skipping deletion of {file_name}, as it is the consolidated file.")
            continue  # Skip further checks for this file
        
        # Check if the file name is NOT in the consolidated list, delete if not
        if file_name not in valid_files:
            print(f"Deleting {file_path} because it's not in consolidated_file_history.")
            vfs.remove_file(file_path)  # Remove the file from S3
        else:
            print(f"Keeping {file_path}, it's in consolidated_file_history.")


In [None]:
def pmid_ingestion(job_df, object_directory: str, bucket_region: str, job_id: int, index_uri: str, embedding_model_name: str):
    """
    Main function to handle the ingestion of articles using PubMed API and TileDB VFS.
    
    :param job_df: DataFrame containing PubMed PMIDs and gene-disease data.
    :param object_directory: Path to the S3 bucket for file storage.
    :param bucket_region: Region of the S3 bucket.
    :param job_id: ID of the job to track file history.
    :param index_uri: TileDB index URI where the articles will be stored.
    :param embedding_model_name: Name of the embedding model to store as metadata in TileDB.
    """
    import logging
    import time
    from datetime import datetime
    import os
    import urllib.request
    import xml.etree.ElementTree as ET
    import requests
    import tiledb
    from langchain_community.vectorstores import TileDB
    from langchain_community.embeddings import HuggingFaceEmbeddings
    from langchain.schema import Document
    from typing import List, Iterator, Optional, Union
    import tarfile
    import glob
    from typing import Optional, Union, Dict, List, Iterator
    import tiledb

    

    now = datetime.now()
    ctx = tiledb.Ctx({"vfs.s3.region": bucket_region})
    vfs = tiledb.VFS(ctx=ctx)

    try:
        embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name)
    except Exception as e:
        raise RuntimeError(f"Error initializing HuggingFace embeddings with model {embedding_model_name}: {e}")

    # Initialize the file history log for the job
    new_file_path = f"{object_directory}/history/file_history_{job_id}"
    missed_file_path = f"{object_directory}/history/missed_paper_history_{job_id}"

    with vfs.open(new_file_path, 'wb') as f:
        log_entry = f"Run launched at {now.strftime('%Y-%m-%d %H:%M:%S')}\n"
        f.write(log_entry.encode('utf-8'))
        
    with vfs.open(missed_file_path, 'wb') as f:
        log_entry = f"Run launched at {now.strftime('%Y-%m-%d %H:%M:%S')}\n"
        f.write(log_entry.encode('utf-8'))        

    start_time = time.time()

    def convert_pmid_to_pmcid_and_ingest(pmid):
        """Convert PMID to PMCID and attempt to download the article or abstract."""
        url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esummary.fcgi"
        params = {"db": "pubmed", "id": pmid, "retmode": "json", "tool": "your_tool_name", "email": "your_email@example.com"}
        max_retries = 5
        retry_delay = 5

        for attempt in range(max_retries):
            try:
                response = requests.get(url, params=params)
                if response.status_code == 200:
                    data = response.json()
                    pmid_str = str(pmid)

                    # Check if the result contains the expected data
                    if "result" in data and pmid_str in data["result"]:
                        result = data["result"][pmid_str]
                        pmcid = None

                        # Check if a PMCID is available in the article IDs
                        for article_id in result.get('articleids', []):
                            if article_id['idtype'] == 'pmc':
                                pmcid = article_id['value']
                                if download_pmc_article(pmcid):
                                    add_to_new_file(f"{pmcid}.pdf", "file_history")
                                    return  # Exit the function once successfully processed
                                else:
                                    add_to_new_file(f"{pmcid}.pdf", "missed_paper_history")
                                    return  # Exit after processing (even if unsuccessful)

                        # If no PMCID, try fetching the abstract
                        if fetch_abstract(pmid):
                            add_to_new_file(f"{pmid}_abstract.txt", "file_history")
                            return  # Exit after successfully fetching the abstract
                        else:
                            add_to_new_file(f"{pmid}_abstract.txt", "missed_paper_history")
                            return  # Exit after failing to fetch the abstract

                    # If no valid result is found in the response, break the loop and exit
                    print(f"No valid result found for PMID {pmid}. Exiting.")
                    return

                elif response.status_code == 429:
                    print(f"Rate limit hit for PMID {pmid}, retrying in {retry_delay} seconds...")
                    time.sleep(retry_delay)
                    retry_delay *= 2

            except requests.exceptions.ConnectionError as e:
                print(f"Connection error for PMID {pmid} on attempt {attempt + 1}. Retrying...")
                time.sleep(retry_delay)
                retry_delay *= 2

        print(f"Failed to retrieve data for PMID {pmid} after {max_retries} attempts.")

    def download_pmc_article(pmcid):
        """
        Download the article package by PMCID and upload the PDF using TileDB VFS.
        """
        file_name = os.path.join(object_directory, f"{pmcid}.pdf")
        
        # Attempt ingestion regardless of whether the file exists
        if vfs.is_file(file_name):
            print(f"File{pmcid}.pdf already exists.")
            return True
        else:
            print(f"File {pmcid}.pdf does not exist, downloading...")

            url = f"https://www.ncbi.nlm.nih.gov/pmc/utils/oa/oa.fcgi?id={pmcid}"
            max_retries = 5
            retry_delay = 5

            for attempt in range(max_retries):
                try:
                    response = urllib.request.urlopen(url)
                    response_content = response.read()

                    if response.getcode() == 200:
                        root = ET.fromstring(response_content)
                        error_element = root.find('.//error')

                        if error_element is not None:
                            return False
                        else:
                            # Download and extract article package
                            return download_and_extract_article(root, pmcid)
                    elif response.status_code == 429:
                        time.sleep(retry_delay)
                        retry_delay *= 2

                except requests.exceptions.ConnectionError as e:
                    time.sleep(retry_delay)
                    retry_delay *= 2

        return True
    
    def download_and_extract_article(root, pmcid):
        """
        Download and extract the article tarball.
        
        :param root: XML root element.
        :param pmcid: PubMed Central ID.
        """
        import shutil
        records = root.find('records')
        if records is not None:
            record = records.find(f'record[@id="{pmcid}"]')
            link = record.find('link[@format="tgz"]')
            if link is not None:
                tar_url = link.get('href')
                tar_file_name = f"{pmcid}.tar.gz"

                print(f"Downloading {tar_url}...")
                urllib.request.urlretrieve(tar_url, tar_file_name)

                print(f"Extracting {tar_file_name}...")
                with tarfile.open(tar_file_name, 'r:gz') as tar:
                    tar.extractall(pmcid)
                os.remove(tar_file_name)

                os.chdir(pmcid)
                pdf_files = glob.glob("**/*.pdf", recursive=True)
                if not pdf_files:
                    print(f"No PDF found in {os.getcwd()}")
                    os.chdir("..")
                    os.rmdir(pmcid)
                    return False

                pdf_file = pdf_files[0]
                print(f"Found PDF: {pdf_file}")

                with open(pdf_file, 'rb') as local_pdf_file:
                    pdf_content = local_pdf_file.read()

                with vfs.open(os.path.join(object_directory, f"{pmcid}.pdf"), 'wb') as vfs_pdf_file:
                    vfs_pdf_file.write(pdf_content)

                print(f"PDF {pdf_file} successfully written to TileDB VFS.")
                os.chdir("..")
                shutil.rmtree(pmcid)
                return True
        print(f"No records found for PMCID {pmcid}.")
        return False
    
    def fetch_abstract(pmid):
        """Fetch the abstract from PubMed by PMID."""
        url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
        params = {'db': 'pubmed', 'id': pmid, 'retmode': 'xml'}
        file_name = os.path.join(object_directory, f"{pmid}_abstract.txt")

        if vfs.is_file(file_name):
            print(f"Abstract {pmid}_abstract.txt already exists.")
            return True
        print(f"Fetching abstract for PMID {pmid}...")
        max_retries = 5
        retry_delay = 5
        for attempt in range(max_retries):
            try:
                response = requests.get(url, params=params)
                if response.status_code == 200:
                    root = ET.fromstring(response.content)
                    abstract_text = "\n".join([abstract.text for abstract in root.findall(".//AbstractText")])
                    if not abstract_text.strip():
                        print(f"No abstract available for PMID {pmid}.")
                        return False
                    with vfs.open(file_name, 'wb') as file:
                        file.write(abstract_text.encode('UTF-8'))
                    print(f"Abstract saved to {file_name}.")
                    return True

                elif response.status_code == 429:
                    print(f"Rate limit hit for PMID {pmid}, retrying in {retry_delay} seconds...")
                    time.sleep(retry_delay)
                    retry_delay *= 2
            except requests.exceptions.ConnectionError as e:
                print(f"Network error on attempt {attempt + 1}: {e}")
                time.sleep(retry_delay)
                retry_delay *= 2
        print(f"Failed to retrieve abstract for PMID {pmid} after {max_retries} attempts.")
        return False

    # Add to file log function
    def add_to_new_file(file_name, prefix):
        """Append the name of the processed file to the file history log in S3."""
        new_file_path = f"{object_directory}/history/{prefix}_{job_id}"
        tmp_data = b""
        try:
            if vfs.is_file(new_file_path):
                with vfs.open(new_file_path, 'rb') as f:
                    tmp_data = f.read()

            file_name_as_bytes = (file_name + '\n').encode('utf-8')
            tmp_data += file_name_as_bytes

            with vfs.open(new_file_path, 'wb') as f:
                f.write(tmp_data)
        except Exception as e:
            print(f"Error appending to file history: {e}")

    # Processing each entry in the DataFrame
    job_df.dropna(subset=['PMID Gene-disease'], inplace=True)
    total_entries = len(job_df)  
    for index, row in job_df.iterrows():
        try:
            pmid = str(int(row["PMID Gene-disease"]))
            convert_pmid_to_pmcid_and_ingest(pmid)
        except ValueError as e:
            continue

    total_time = time.time() - start_time
    print(f"Processed {total_entries} entries in {total_time:.2f} seconds.")


### Pipeline Factory Function
The `initialize_step` will build out our DAG starting with the ingestion steps. The below code: (write when it works)

In [None]:
def pipeline_step(access_credentials: str, repo_url: str, out_directory: str, bucket_path: str,index_uri: str, in_file: str = "rag-article-list.txt", 
                  num_jobs: int = 4, bucket_region: str = "us-west-2", hash_check: bool = True, model_name: str = "pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb",
                  index_type: str = "ivf_flat_index",dimensions: int = 768,ingestion_mode: str ="BATCH"):
    """
    Pipeline step to pull the repo, check for changes using a hash, and divide the input file into jobs for batch processing.

    :param access_credentials: Credentials to access TileDB Cloud for DAG execution.
    :param repo_url: Git repository URL to clone or pull the latest version from.
    :param out_directory: Output directory where the files will be stored.
    :param bucket_path: S3 bucket path for file storage.
    :param in_file: Input file name, defaulting to 'rag-article-list.txt'.
    :param num_jobs: Number of jobs to divide the input file into, default is 4.
    :param bucket_region: S3 bucket region, default is 'us-west-2'.
    :param hash_check: If True, the function checks if the file has changed using a hash before proceeding.
    :param index_uri: TileDB index URI where the embeddings are stored.
    :param model_name: Name of the embedding model to store in TileDB metadata.
    :param index_type: Type of index used for TileDB. Must be one of the supported types.
    :param input_files_dir: Directory from where the input files will be ingested.
    :param dimensions: The dimensionality of the embeddings.
    """

    # Supported index types
    supported_index_types = ["FLAT", "IVF_FLAT", "IVF_PQ", "VAMANA"]    
    # Validate index type
    if index_type not in supported_index_types:
        raise ValueError(f"Unsupported index type '{index_type}'. Supported types are: {supported_index_types}")

    # Supported index ingestion types
    supported_ingestion_modes = ["BATCH","REALTIME"]
    if ingestion_mode not in supported_ingestion_modes:
        raise ValueError(f"Unsupported index_ingestion_mode type '{index_ingestion_mode}'. Supported types are: {ingestion_modes}")
    
    # Step 1: Prepare paths
    full_bucket_path = f"{bucket_path}/{out_directory}"  # Path to the S3 bucket

    # Step 2: Pull the latest changes from the repository
    home_directory = os.path.expanduser("~")
    os.chdir(home_directory)
    
    repo_name = repo_url.split('/')[-1].replace('.git', '')  # Extract repo name from URL

    if os.path.exists(repo_name):
        os.chdir(repo_name)
        subprocess.run(["git", "pull"])  # Pull the latest changes
    else:
        subprocess.run(["git", "clone", repo_url])  # Clone the repo if it doesn't exist
        os.chdir(repo_name)

    # Step 3: Check and handle file hash
    previous_hash_path = "previous_hash.txt"
    current_hash_path = "hash.txt"
    current_hash = ""

    if os.path.exists(current_hash_path):
        with open(current_hash_path, 'r') as f:
            current_hash = f.read().strip()
            print(f"Current hash: {current_hash}")
    else:
        if os.path.exists(in_file):
            current_hash = hash_file(in_file)  # Compute the hash of the input file
            with open(current_hash_path, 'w') as fh:
                fh.write(current_hash)
        else:
            print(f"{in_file} does NOT exist. Please submit a valid file and rerun this function.")
            return

    # Step 4: Compare hash with previous hash
    if os.path.exists(previous_hash_path):
        with open(previous_hash_path, 'r') as f:
            previous_hash = f.read().strip()
            if current_hash == previous_hash:
                print("File did not change.")
                if hash_check:
                    return  # Stop if hash_check is True and no changes
            else:
                print("File has changed, proceeding with processing.")
                with open(previous_hash_path, 'w') as fh:
                    fh.write(current_hash)
    else:
        with open(previous_hash_path, 'w') as fh:
            fh.write(current_hash)

    # Step 5: Commit the updated hash and file to the Git repository
    add_and_commit_files("Updating the hash and file to keep the repo up to date")  # Commit changes

    # Step 6: Load the DataFrame from the input file
    df = pd.read_csv(in_file, sep='\t', engine='python')  # Load input file into a DataFrame
    total_rows = len(df)
    print(f"Total rows in the DataFrame: {total_rows}")

    # Step 7: Divide the DataFrame into jobs based on num_jobs
    if num_jobs > total_rows:
        num_jobs = total_rows
        print(f"Number of jobs reduced to {num_jobs} to match total rows.")

    job_size = math.ceil(total_rows / num_jobs)
    jobs = [df.iloc[i:i + job_size] for i in range(0, total_rows, job_size)]
    print(f"Divided into {len(jobs)} jobs")

    # Step 8: Set up TileDB Cloud DAG for batch processing
    dag = tiledb.cloud.dag.DAG(name="document_batch", mode=tiledb.cloud.dag.Mode.BATCH)

    # Consolidation and deletion tasks
    consolidate_node = dag.submit(consolidate_chunks, num_jobs, bucket_region, full_bucket_path, "file_history", access_credentials_name=access_credentials)
    consolidate_node_2 = dag.submit(consolidate_chunks, num_jobs, bucket_region, full_bucket_path, "missed_paper_history", access_credentials_name=access_credentials)
    delete_unwanted_node = dag.submit(delete_unwanted_files, bucket_region, full_bucket_path, access_credentials_name=access_credentials)
    delete_unwanted_node.depends_on(consolidate_node)

    # Step 9: Submit each job for ingestion
    for i, job_df in enumerate(jobs):
        if job_df.empty:
            print("Skipping empty job.")
            continue
        cpu, mem = estimate_resources(job_df)
        print(f"Processing {len(job_df)} rows, estimated CPU: {cpu}, estimated memory: {mem:.2f} GB.")

        ingest_node = dag.submit(
            pmid_ingestion,
            job_df,
            full_bucket_path,
            bucket_region,
            i,
            index_uri,
            model_name,
            access_credentials_name=access_credentials,
            image_name="vectorsearch",
            resources={"cpu": f"{cpu}", "memory": f"{mem}Gi"}
        )
        consolidate_node.depends_on(ingest_node)
        consolidate_node_2.depends_on(ingest_node)

    dag.compute()
    dag.wait()
    print(f"Finished processing all jobs, consolidating history, and cleaning up undesired S3 files at {full_bucket_path}")

    # Step 10: Index creation and embedding ingestion
    embedding = SentenceTransformersEmbedding(model_name_or_path=model_name, dimensions=dimensions)
    reader = DirectoryTextReader(
        search_uri=full_bucket_path,
        include="*",  
        suffixes=[".pdf", ".txt"],
        exclude=["[.]*", "*/[.]*"],  
        text_splitter="RecursiveCharacterTextSplitter",
        text_splitter_kwargs={"chunk_size": 1000, "chunk_overlap": 100}
    )

    # Clean the existing index
    print("Ensuring we have a clean index")
    if tiledb.object_type(index_uri) == "group":
        print("Deleting existing index...")
        group = tiledb.Group(index_uri, "m")
        group.delete(recursive=True)
        print("Deleted old index")

    print("Creating a new index")
    index = object_index.create(
        uri=index_uri,
        index_type=index_type,
        object_reader=reader,
        embedding=embedding,
    )
    if ingestion_mode == "BATCH":
        embeddings_generation_driver_mode = tiledb.cloud.dag.Mode.BATCH
    elif ingestion_mode == "REAL":
        embeddings_generation_driver_mode = tiledb.cloud.dag.Mode.REAL

    print("Kicking off the ingestion")
    index.update_index(
        embeddings_generation_driver_mode=embeddings_generation_driver_mode,
        embeddings_generation_mode=embeddings_generation_driver_mode, #let's add onto this. 
        files_per_partition=50, 
        extra_worker_modules=["transformers==4.37.1", "PyMuPDF", "beautifulsoup4"],
        embedding=embedding,
        worker_access_credentials_name=access_credentials,
        workers=num_jobs, # I believe we need to remove this if we choose REALTIME. I think this is maxed at 15. 
        worker_resources={"cpu": f"{cpu}", "memory": f"{mem}Gi"}
    )
    
    return f"Pipeline completed with file: {in_file}"




In [None]:
!pip install sentence-transformers

### TileDB Ingestion Pipeline
The Pipeline can take a bit to run (depending on the size of the ingestion) so let's take a few moments to explore what is going on and you can read the code to learn more. 

#### Key Features and Value Highlights of the Ingestion Pipeline Demo

The ingestion pipeline has two major parts:
1. **PMC/PMCID Ingestion**: Scraping PubMed for documents using PMIDs to gather relevant medical articles and their abstracts.
2. **Multi-Modal TileDB Ingestion**: Ingesting the processed documents and data into TileDB, creating an index that allows for fast and scalable searches across multiple modalities (e.g., PDFs, text).

This function highlights several key elements:

#### Credential Management
One of the core benefits of the pipeline is how credentials and distributed tasks are abstracted across the entire process, simplifying the implementation and making it scalable.

The `access_credentials` parameter allows for secure, centralized management of credentials, removing the need for hard-coded keys in scripts. This enables TileDB Cloud to handle requests automatically and securely.

```python
def pipeline_step(access_credentials: str, repo_url: str, out_directory: str, bucket_path: str, index_uri: str, in_file: str = "rag-article-list.txt", 
                  num_jobs: int = 4, bucket_region: str = "us-west-2", hash_check: bool = True, model_name: str = "pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb",
                  index_type: str = "IFV_FLAT", dimensions: int = 768, ingestion_mode: str ="BATCH"):

```

#### Scalable Vector Search and Indexing

TileDB’s vector search capabilities allow for efficient search and indexing across datasets. The pipeline supports several index types, such as **IVF_FLAT**, **VAMANA**, and **IVF_PQ**, providing flexibility to optimize for either speed or memory efficiency.

The ability to toggle between **batch (BATCH)** and **real-time (REALTIME)** ingestion modes offers flexibility to handle different scenarios. For batch processing, you can maximize throughput, while real-time ingestion enables rapid updates to the index.

#### Unified Ingestion for Multi-Modal Data

The pipeline supports **multi-modal data ingestion**, meaning it can handle PDFs, text files, and abstracts using the same loader. This eliminates the need for specialized logic for different data types. The unified loader handles different formats like PDFs and text seamlessly.

In cases where documents are stored in compressed tarballs, the pipeline is capable of **recursive extraction** of PDFs, making it ideal for complex medical or academic data ingestion scenarios.

#### Seamless Integration with TileDB Cloud DAG

TileDB’s **Cloud DAG** allows you to manage complex workflows with distributed jobs like consolidating chunks and data cleanup. The DAG framework ensures tasks are executed in the correct order while supporting parallel processing and efficient scaling.

This feature simplifies managing dependencies between tasks, such as ensuring that data cleanup only occurs after ingestion has been completed.

#### Efficient Use of TileDB's Virtual File System (VFS)

The pipeline leverages **TileDB's Virtual File System (VFS)** to interact directly with object stores like S3, without the need for local storage. This results in faster data transfers and reduces the storage footprint during the ingestion process.

Additionally, the VFS allows tracking of file history and logs in S3, enabling easier debugging, auditing, and progress monitoring without needing to rely on local files.

#### Support for Embedding Models and Metadata Storage

The pipeline integrates seamlessly with embedding models like **BioBERT** or **HuggingFace**, allowing for the storage of embeddings alongside original documents. This is crucial for **retrieval-augmented generation (RAG)** systems or any pipeline that requires similarity-based search.

The ability to store model metadata alongside the embeddings ensures **traceability**, allowing for reproducibility when running searches or processing documents in future versions of the pipeline.

##### Why Experimenting with an Embedding Model is Important

Experimenting with medical-specific embedding models like BioBERT is vital because it ensures the embeddings capture domain-specific context, terminology, and nuances. This precision enhances the quality of responses from LLMs when generating or retrieving information from scientific papers. By tying embeddings to the LLM's response, the system delivers more accurate, contextually relevant answers, crucial for applications in medical research or clinical insights where even slight misinterpretations can lead to incorrect conclusions.  Another way to think about this is in the context of "human in the loop" scenarios where humans label complex datasets. Just as medical embedding models require domain-specific training to understand intricate concepts, human labelers also need specialized knowledge to correctly annotate data. For example, labeling medical datasets is far more complex than identifying basic objects like "this is a bird." If a non-expert incorrectly labels a medical condition, it could misinform the model, leading to inaccurate predictions or insights. Embedding models, like humans in these tasks, need to be properly trained with the right context to ensure meaningful and accurate representations.

#### Real-Time Feedback with Logging and File History

A key strength of the pipeline is its ability to provide **real-time feedback** on the progress of document ingestion. By logging successfully ingested files and noting failed attempts (through `file_history` and `missed_paper_history`), the pipeline allows you to quickly diagnose and address issues.

This feedback loop ensures that failed files can be reprocessed without having to re-run the entire ingestion pipeline. Logs are stored directly in the object store for easy retrieval and auditing.

#### Flexible Configuration and Extensibility

The pipeline is highly configurable, with parameters such as:

- **Number of jobs**: To control workload distribution.
- **Bucket path**: To specify object store locations.
- **Index type**: To optimize for different search strategies (e.g., IVF_FLAT for faster searches or VAMANA for memory efficiency).
- **Embedding dimensions**: To configure the model’s vector size (e.g., 768 for BioBERT).

This flexibility allows the pipeline to scale based on dataset size and processing needs. Furthermore, the ability to switch between **batch** and **real-time** modes allows for dynamic processing depending on the workload.

#### Summary

This ingestion pipeline leverages **distributed abstraction**, **secure credential handling**, and **scalable vector search capabilities**, while remaining flexible for multi-modal data ingestion. The built-in features such as **change detection**, **retry logic**, and **embedding model integration** make this pipeline robust and scalable, handling even the most complex medical document ingestion and processing tasks.


## Idexing and Ingestion Exploration and Guidelines 
### Driver and Embedding Generation
For a deeper exploration, during the indexing/ingestion of the PDFs and text we see two DAGs being created. One of them is the ** *ingest_embeddings_with_driver task.***

The ingest_embeddings_with_driver function, specifically through the driver, is responsible for orchestrating and scheduling tasks to the workers (pods) by setting up and managing the task execution flow. 

#### Workers and Ingestion Tasks

* **Task Distribution:** The total set of files (or partitions) is divided into smaller groups based on object_partitions_per_worker. These groups are then assigned to workers, with each worker processing a specific number of partitions in parallel.

* **Workers Allocation:** The number of workers (pods) is determined by the workers variable. Tasks are distributed across the workers, and if there are fewer tasks than workers, some workers will remain idle.

*  **Dynamic Worker Usage:**  If there are fewer partitions than workers, only the necessary workers will be used, and any excess workers will not be spun up. If there are more partitions than workers, the system will continue assigning tasks to available workers until all partitions are processed.

### Task Scheduling Guidelines 
When optimizing ingestion tasks in a distributed environment (e.g., Kubernetes), consider the following guidelines to maximize parallelism and resource efficiency:

* **Task Size:** Keep ingestion tasks small enough to fit comfortably within a pod’s memory and CPU limits. Ideally, each task should process one file or partition at a time to reduce memory footprint and avoid resource contention.

* **Memory Footprint:** Assign memory just sufficient for loading and processing a single file. This prevents over-allocation of resources while ensuring that each pod can handle its workload efficiently. Adjust memory requests/limits based on file size and processing complexity.

* **Parallelism:** Use smaller pods with fewer tasks to maximize parallelism. More, smaller pods allow Kubernetes to distribute tasks across multiple CPU cores or nodes, improving throughput.

* **Resource Allocation:** Ensure that each pod’s resource requests (CPU, memory) are tuned to its task size. Smaller resource requests lead to better scheduling flexibility and faster execution.

By adhering to these guidelines, you can achieve better resource utilization, faster task execution, and improved scalability across your infrastructure.

## Optional Phase: Download the consolidated file from S3 and write it to your local Git repository
let's also do this for our missed page. 

In [None]:
# Optional: Download the consolidated file from S3 and write it to your local Git repository
# Set up TileDB context with S3 credentials (replace with your actual credentials)
#ctx = tiledb.Ctx({
 #   "vfs.s3.region": "",  # S3 region
  #  "vfs.s3.aws_access_key_id": "",  # AWS Access Key (add your key here)
# "vfs.s3.aws_secret_access_key": ""  # AWS Secret Access Key (add your key here)
#})

# S3 path to the consolidated file
#consolidated_file_path = ""

# Initialize TileDB VFS to interact with S3
#vfs = tiledb.VFS(ctx=ctx)

# Step 1: Read the contents of the consolidated file from S3
#file_contents = ""
#with vfs.open(consolidated_file_path, 'rb') as f:  # Open the file in binary read mode
    #file_contents = f.read()  # Read the entire file contents

# Step 2: Write the contents to the local file system in the current directory
# You may want to navigate to your Git repository directory before writing
# Uncomment and modify the following line if needed:
# os.chdir("pmc-llm")  # Change to your Git repository directory

# Create the local path for the consolidated file
#path = os.path.join(os.getcwd(), "consolidated_file_history")  # Store in the current working directory

# Write the contents from S3 to the local file in binary write mode
#with open(path, 'wb') as f:
    #f.write(file_contents)  # Write the downloaded file contents to the local file

# Step 3: Assuming `add_and_commit_files` is a valid function, commit the changes to Git
# This function should add the file, commit it with a message, and push the changes
#add_and_commit_files("updating the consolidated file history for tracking purposes")


## Optional Phase: Exploring our Ingested Files (WIP)
Every array registered with TileDB Cloud must be accessed using a URI of the form  `tiledb://<namespace>/<array-name>`, where `<namespace>` is the user or organization who owns the array and <array-name> is the array name set by the owner upon array registration. This URI is [displayed on the console when viewing the array details](https://docs.tiledb.com/cloud/how-to/arrays/view-array-details).
Set the TileDB configuration parameters rest.username and rest.passwordwith your TileDB Cloud username and password, or alternatively [rest.token with the API token you created.](https://docs.tiledb.com/cloud/how-to/account/create-api-tokens)  **Accessing arrays by setting an API token is typically faster than using your username and password.** More details [here](https://docs.tiledb.com/cloud/api-reference/array-access) First, let's set up som variables and explore some outputs


We can then view the schema of the array itself to better understand our vectors. 

TileDB stores the data in an efficient, compressed format, and when you query it, the data is returned as a 1D array to minimize storage complexity and retrieval overhead. 

We can also then figure out the total expected size of our array of vectors based on the non empty domains representing partitions. 

| Algorithm         | Description                               | Accuracy | Speed  | Memory Efficiency | When to Use                                            |
|-------------------|-------------------------------------------|----------|--------|-------------------|--------------------------------------------------------|
| **flat_index**     | FlatIndex (Brute-force search)            | Highest  | Slow   | Low               | Small datasets where accuracy is the top priority       |
| **ivf_flat_index** | IVFFlat Index (Inverted File with Flat)   | Medium   | Fast   | Medium            | Large-scale datasets, good balance between speed and accuracy |
| **vamana_index**   | Vamana Index (Graph-based search)         | High     | Fast   | Low               | When high accuracy is required with faster search times compared to flat_index |
| **ivf_pq_index**   | IVFPQ Index (Inverted File with PQ)       | Low      | Very Fast | High            | Very large datasets, where memory efficiency is a priority over accuracy |


Each partition ID in your TileDB array points to a batch of 258 arrays, and each of those arrays (or vectors) has 768 float32 values. The external ID serves as a unique identifier for each vector within that batch.

So, to summarize:

* Partition ID: This is the key that points to a "cell" in the TileDB array.
* Batch of 258 Vectors: Each cell contains 258 vectors, where each vector represents one embedding or feature array of 768 float32 values.
* External ID: For every vector in the batch, there's an external ID that uniquely maps to that specific vector. The external ID helps link or identify which specific vector (or array) the data corresponds to within the batch.

You have 7 partition IDs (from 0 to 6).
Each partition contains 258 vectors.
So, the total number of vectors in your TileDB array is 1,806 vectors. 

Example Scenario:
    * Partition contains: 258 vectors (batch).
    * Query: You want to find the top 10 nearest neighbors (k=10).
    * Result: You load one partition (258 vectors), compute the similarity between your query vector and all 258 vectors, and extract the top 10 most similar vectors. Since k=10 is less than 258, you don't need to fetch another partition, making this operation highly efficient.
    
When You Might Need More Tiles:
If k > 258, you may need to fetch additional partitions to get enough vectors for the query, but with good locality, you can still minimize the number of partitions you need to pull.


Summary:
* Locality advantage: Similar vectors are likely stored close together in the same partition, so once you find one similar vector, you can search the entire batch efficiently.
* Efficient for small k: If k < 258, you can find all your nearest neighbors within a single partition, reducing the need for additional I/O.
* Improved performance: Fewer disk reads, fewer memory loads, and better cache utilization make the search faster and more efficient.


## Hugging Face Pipelines
a Hugging Face pipeline provides a high-level API that simplifies the entire process of loading models and performing tasks such as text generation, sentiment analysis, translation, and more. It abstracts away the complexities of:

    Loading the model: Automatically retrieves and initializes the model from Hugging Face’s model hub.
    Tokenization: Tokenizes the input text before passing it to the model.
    Running inference: Submits the tokenized input to the model for processing.
    Post-processing: Converts model outputs back into human-readable format.

This makes it easier to work with NLP models without needing to manually handle every step.

## LLM Response
Now that we've explored our embeddings, its time to get an LLM response. This section of the code can and should be executed on a GPU notebook. Luckily, with TileDB you saved code can always access the data via TileDB URL and our authentication methods, so you can choose what type of notebooks to execute or iterate on various task. All the above task are using task graphs and standard Python, so do not benefit from GPUs. The Tasks below will use a GPU to improve response times. 

In [None]:
import transformers
from transformers import AutoTokenizer, pipeline
from langchain.llms import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.vectorstores.tiledb import TileDB
import tiledb.cloud
from tiledb.vector_search.object_api import object_index

# Function to initialize the LLM chain
def initialize_llmchain(llm_model, temperature, max_tokens):
    tokenizer = AutoTokenizer.from_pretrained(llm_model)
    from transformers import BitsAndBytesConfig
    quantization_config = BitsAndBytesConfig(load_in_4bit=True,
                                         llm_int8_threshold=200.0)
    hf_pipeline = transformers.pipeline(
        "text-generation",
        model=llm_model,
        model_kwargs={
            'quantization_config': quantization_config,
        },
        tokenizer=tokenizer,
        trust_remote_code=True,
        device_map="auto",
        torch_dtype="bfloat16",
        max_new_tokens=max_tokens,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
    )
    llm = HuggingFacePipeline(pipeline=hf_pipeline, model_kwargs={'temperature': temperature})

    # Define the prompt template
    template = """You are a helpful assistant. 
Context:
{context}

Conversation:
{chat_history}
User: {question}  
Use the context information only if related to the question.
Generate only one assistant response.
Output:"""
    prompt = PromptTemplate.from_template(template)
    llm_chain = LLMChain(prompt=prompt, llm=llm)

    return llm_chain

# Initialize the LLM and Vector Search
def initialize_LLM(index_uri, llm_name, llm_temperature, max_tokens):
    index = object_index.ObjectIndex(index_uri, memory_budget=-1)
    qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens)
    return index, qa_chain

# Format chat history for display
def format_chat_history(message, chat_history):
    recent_history = chat_history[-10:] if len(chat_history) > 3 else chat_history
    formatted_chat_history = ""
    for user_message, bot_message in recent_history:
        formatted_chat_history += f"User: {user_message}\n"
        formatted_chat_history += f"Assistant: {bot_message}\n"
    return formatted_chat_history

# Perform conversation with LLM and VectorDB
def conversation(vector_db, topk, qa_chain, message, history):
    formatted_chat_history = format_chat_history(message, history)
    question = message
    _, _, results = vector_db.query(
        {"text": [question]}, 
        k=topk,
        nprobe=vector_db.index.partitions,
        return_objects=False,
        return_metadata=True,
    )

    # Build context from the retrieved documents
    context = ""
    for i, text in enumerate(results["text"][0]):
        context += f"Context for file with path: {results['file_path'][0][i]} at page: {results['page'][0][i]}\n{text.strip()}\n"
    
    # Print the retrieved documents for debugging
  #  print("\nRetrieved Documents:")
   # for doc in results["text"][0]:
    #    print(doc)
    #print path from the retrieved document. 
    print("\nFile Paths:")
    for path in results["file_path"][0]:
        print(path)
    # Generate response from LLM
    response_answer = qa_chain.run(question=question, context=context, chat_history=formatted_chat_history)
    
    # Append to chat history
    new_history = history + [(message, response_answer)]
    
    return vector_db, qa_chain, new_history


In [None]:
!pip install gradio
!pip install transformers --upgrade
!pip install torch --upgrade

In [None]:
import gradio as gr  # Ensure Gradio is imported
def llm_call(query: str, index_directory: str, max_tokens: int, top_k: int, llm_temperature: float):
    import transformers
    from transformers import AutoTokenizer, pipeline
    from langchain.llms import HuggingFacePipeline
    from langchain.prompts import PromptTemplate
    from langchain.chains import LLMChain
    from langchain.vectorstores.tiledb import TileDB
    import tiledb.cloud
    from tiledb.vector_search.object_api import object_index
    
    # Get the vector DB URI
    user_profile = tiledb.cloud.user_profile()
    index_uri = f"tiledb://{user_profile.username}/{user_profile.default_s3_path.rstrip('/')}{index_directory}"
    llm_name = "BioMistral/BioMistral-7B"
    
    # Initialize LLM Chain
    def initialize_llmchain(llm_model, temperature, max_tokens, top_k):
        tokenizer = AutoTokenizer.from_pretrained(llm_model)
        from transformers import BitsAndBytesConfig
        quantization_config = BitsAndBytesConfig(load_in_4bit=True, llm_int8_threshold=200.0)
        
        hf_pipeline = transformers.pipeline(
            "text-generation",
            model=llm_model,
            model_kwargs={'quantization_config': quantization_config},
            tokenizer=tokenizer,
            trust_remote_code=True,
            device_map="auto",
            torch_dtype="bfloat16",
            max_new_tokens=max_tokens,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.eos_token_id,
        )
        llm = HuggingFacePipeline(pipeline=hf_pipeline, model_kwargs={'temperature': temperature})

        # Define the prompt template
        template = """You are a helpful assistant. 
        Context:
        {context}
        
        Conversation:
        {chat_history}
        User: {question}  
        
        Use the context information only if related to the question.
        Generate only one assistant response.
        Output:"""
        
        prompt = PromptTemplate.from_template(template)
        llm_chain = LLMChain(prompt=prompt, llm=llm)
        return llm_chain

    # Initialize LLM and Vector Search
    def initialize_LLM(index_uri, llm_name, llm_temperature, max_tokens, top_k):
        index = object_index.ObjectIndex(index_uri, memory_budget=-1)
        qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k)
        return index, qa_chain

    # Format chat history
    def format_chat_history(message, chat_history):
        recent_history = chat_history[-10:] if len(chat_history) > 3 else chat_history
        formatted_chat_history = ""
        for user_message, bot_message in recent_history:
            formatted_chat_history += f"User: {user_message}\n"
            formatted_chat_history += f"Assistant: {bot_message}\n"
        return formatted_chat_history

    # Perform conversation
    def conversation(vector_db, topk, qa_chain, message, history):
        formatted_chat_history = format_chat_history(message, history)
        question = message
        _, _, results = vector_db.query(
            {"text": [question]}, 
            k=topk,
            nprobe=vector_db.index.partitions,
            return_objects=False,
            return_metadata=True,
        )

        # Build context from retrieved documents
        context = ""
        for i, text in enumerate(results["text"][0]):
            context += f"Context for file with path: {results['file_path'][0][i]} at page: {results['page'][0][i]}\n{text.strip()}\n"

        # Collect unique file paths into a set
        file_paths_set = set(results["file_path"][0])

        # Combine the file paths into a single string
        combined_file_paths = "\n".join(file_paths_set)

        # Generate response from LLM
        response_answer = qa_chain.run(question=question, context=context, chat_history=formatted_chat_history)

        # Append to chat history
        new_history = history + [(message, response_answer)]
        return vector_db, qa_chain, new_history, combined_file_paths

    # Initialize LLM and Vector Search
    vector_db, qa_chain = initialize_LLM(index_uri, llm_name, llm_temperature, max_tokens, top_k)

    # Example conversation
    history = []
    vector_db, qa_chain, history, combined_file_paths = conversation(vector_db, top_k, qa_chain, query, history)

    # Format final response for display
    final_response = f"Retrieved Files:\n{combined_file_paths}\n\n"
    for user_msg, assistant_msg in history:
        final_response += f"User: {user_msg}\n\nAssistant: {assistant_msg}\n\n"
    
    return final_response


In [None]:
def fetch_and_display_pdf(aws_access_key_id, aws_secret_access_key, s3_region, s3_url):
    # Set up the TileDB VFS context with the provided credentials
    ctx = tiledb.Ctx({
        "vfs.s3.region": s3_region,
        "vfs.s3.aws_access_key_id": aws_access_key_id,
        "vfs.s3.aws_secret_access_key": aws_secret_access_key
    })
    
    # Extract the bucket and path from the s3_url
    bucket_name = s3_url.split('/')[2]
    object_path = '/'.join(s3_url.split('/')[3:])
    
    # Set up TileDB VFS to read the PDF
    vfs = tiledb.VFS(ctx=ctx)
    
    # Temporary file to store the PDF locally
    local_pdf_path = "/tmp/fetched_pdf.pdf"
    
    # Fetch the PDF from S3
    with vfs.open(f"s3://{bucket_name}/{object_path}", "rb") as remote_pdf:
        with open(local_pdf_path, "wb") as local_pdf:
            local_pdf.write(remote_pdf.read())
    
    # Now return the PDF file path for visualization in Gradio
    return local_pdf_path

In [None]:
jupyterhub_user = os.getenv("JUPYTERHUB_USER")
jupyterhub_api_url = "https://us-west-2.aws.jupyterhub.cloud.tiledb.com"
gradio_port = 7830
gradio_path = f"/user/{jupyterhub_user}/proxy/{gradio_port}/"
gradio_frontend_url=f"{jupyterhub_api_url}{gradio_path}"
# Define the Gradio interface
def demo():

    print(f"JupyterHub User = {jupyterhub_user} jupyterhub_api_url = {jupyterhub_api_url} gradio port = {gradio_port} gradio path = {gradio_path} gradio_frontend_url = {gradio_frontend_url}")
    with gr.Blocks() as interface:
        with gr.Tab("Ingestion Pipeline"):
            gr.Markdown("### Enter the details for the ingestion pipeline.")

            in_file = gr.Textbox(label="Enter File Name (e.g., 'data.txt')", value="rag-article-list.txt")
            access_credentials = gr.Textbox(label="Access Credentials", value="")
            repo_url = gr.Textbox(label="Repo URL", value="https://github.com/TileDB-Inc/pmc-llm.git")
            out_directory = gr.Textbox(label="Out Directory", value="pmc/rag/ingestion")
            bucket_path = gr.Textbox(label="Bucket Path", value="")
            model_name = gr.Textbox(label="Model Name", value="")
            index_uri = gr.Textbox(label="Index URI", value="")
            num_jobs = gr.Number(label="Number of Jobs", value=4)
            bucket_region = gr.Textbox(label="Bucket Region", value="")
            hash_check = gr.Checkbox(label="Enable Hash Check", value=False)
            index_type = gr.Textbox(label="Index Type", value="IVF_FLAT")
            dimensions = gr.Number(label="Embedding Dimensions", value=768)
            ingestion_output = gr.Textbox(label="Ingestion Output")

            ingest_button = gr.Button("Start Ingestion")
            ingest_button.click(
                fn=pipeline_step,
                inputs=[access_credentials, repo_url, out_directory, bucket_path, index_uri, in_file, num_jobs, bucket_region, hash_check, model_name, index_type, dimensions],
                outputs=ingestion_output
            )

        with gr.Tab("LLM Query"):
            gr.Markdown("### Query the LLM with a custom question.")
            
            query_input = gr.Textbox(label="Query", value="What are the effects of halothane on the fetus?")
            index_directory = gr.Textbox(label="Index Directory", value="/pmc/rag/index")
            max_tokens = gr.Slider(label="Max Tokens", minimum=100, maximum=1000, step=100, value=500)
            top_k = gr.Slider(label="Top K", minimum=1, maximum=10, step=1, value=3)
            llm_temperature = gr.Slider(label="LLM Temperature", minimum=0.1, maximum=1.0, step=0.1, value=0.7)
            llm_output = gr.Textbox(label="LLM Output")

            query_button = gr.Button("Query LLM")
            query_button.click(
                fn=llm_call,
                inputs=[query_input, index_directory, max_tokens, top_k, llm_temperature],
                outputs=llm_output
            )
        with gr.Tab("PDF Viewer"):
            gr.Markdown("### Enter AWS credentials and S3 URL to fetch the PDF.")

            aws_access_key_id = gr.Textbox(label="AWS Access Key ID", value="")
            aws_secret_access_key = gr.Textbox(label="AWS Secret Access Key", value="", type="password")
            s3_region = gr.Textbox(label="S3 Region", value="us-west-2")
            s3_url = gr.Textbox(label="S3 URL", value="s3://your-bucket/path-to-pdf.pdf")

            pdf_output = gr.File(label="Fetched PDF")

            fetch_button = gr.Button("Fetch PDF")
            fetch_button.click(
                fn=fetch_and_display_pdf,
                inputs=[aws_access_key_id, aws_secret_access_key, s3_region, s3_url],
                outputs=pdf_output
            )            
        interface.launch(
            debug=False,
            quiet=True,
            show_error=True,
            height=1100,
            root_path="/user/9a6ff6a4-7b46-490c-93eb-bd067509ac74/proxy/7830/",  # Adjust for your proxy path
            server_name="0.0.0.0",  # Bind to all interfaces, not just localhost
            server_port=7830  # Ensure you're using the correct port
        )



In [None]:
# Run the Gradio interface
if __name__ == "__main__":
    demo()

In [None]:
from IPython.display import display, IFrame
print(f"gradio_frontend_url= {gradio_frontend_url}")
IFrame(gradio_frontend_url, width="100%", height="1100")