# Part 1: Chunk Text

In [1]:
from langchain.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
from langchain.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.document_loaders import DirectoryLoader, TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.prompts import PromptTemplate

# NOTE: assume manually separated by THREE newlines already
loader = DirectoryLoader('data/rag', glob='**/*text.txt', loader_cls=TextLoader)
documents = loader.load()

# with open("data/rag/gene_query_docs.txt", "r") as doc_fd:
#     ref_text = doc_fd.read().split("\n\n\n")
#     ref_text = list(map(lambda s: s.strip(), ref_text))

In [None]:
text_splitter = CharacterTextSplitter(separator="\n\n\n", chunk_size=1, chunk_overlap=0)
docs = text_splitter.split_documents(documents)

In [None]:
docs[2].page_content

# Part 2: Vectorstore

In [None]:
from langchain.embeddings.base import Embeddings

class CustomHuggingFaceEmbeddings(Embeddings):
    def __init__(self, model, tokenizer, device=None):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device if device else torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)
    
    def embed_documents(self, texts):
        return [self._embed(text) for text in texts]
    
    def embed_query(self, text):
        return self._embed(text)
    
    def _embed(self, text):
        inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding='max_length', max_length=512)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        with torch.no_grad():
            outputs = self.model(**inputs)
            # Assuming you want to use the mean pooling of the last hidden state
            embeddings = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
        return embeddings

In [3]:
from transformers import AutoTokenizer, AutoModel
import torch

# using a reasonable model from https://huggingface.co/spaces/mteb/leaderboard
embedding_model_name = 'Snowflake/snowflake-arctic-embed-l'
model_kwargs = {"device": "cuda:1"}

# Load custom tokenizer and model if needed
# add pad token - https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/
# tokenizer = AutoTokenizer.from_pretrained(embedding_model_name)
# tokenizer.add_special_tokens({'pad_token': '<|finetune_right_pad_id|>'})
# model = AutoModel.from_pretrained(embedding_model_name, torch_dtype=torch.float32).to("cuda:1")

embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name, model_kwargs=model_kwargs)
# embeddings = CustomHuggingFaceEmbeddings(model=model, tokenizer=tokenizer, device="cuda:1")

In [None]:
# Step 4: Set up the vector store
vectorstore = FAISS.from_documents(docs, embeddings)
vectorstore.save_local(folder_path="data/rag", index_name="faiss_index")

In [4]:
vectorstore = FAISS.load_local(folder_path="data/rag", index_name="faiss_index", embeddings=embeddings, allow_dangerous_deserialization=True)

In [5]:
retriever = vectorstore.as_retriever(search_kwargs={"k": 16})
out = retriever.batch(["give me uniprot id of the gene with the entrez gene id of 1017", "What are the symbol and Ensembl gene ID for genes in species 9669 with a symbol starting with 'LOC123388108'?"])

In [6]:
out[0]

[Document(metadata={'source': 'data/rag/compact_desc_with_context.txt'}, page_content="entrezgene - This field contains a string representing the unique identifier assigned by NCBI's Entrez Gene database. It is used to uniquely identify genes in this centralized resource."),
 Document(metadata={'source': 'data/rag/compact_desc_with_context.txt'}, page_content='ensembl.gene - The Ensembl identifier for the gene, which is a unique reference used to describe the gene in the Ensembl database.'),
 Document(metadata={'source': 'data/rag/compact_desc_with_context.txt'}, page_content='reagent.GNF_mm-kinase_lenti-shRNA.id - Unique identifier for a lentiviral shRNA reagent targeting mouse kinase genes in the GNF library.'),
 Document(metadata={'source': 'data/rag/compact_desc_with_context.txt'}, page_content='genomic_pos_mm9.chr - The chromosome number for the gene using the mouse genome assembly version mm9.'),
 Document(metadata={'source': 'data/rag/compact_desc_with_context.txt'}, page_conten

# Part 3: Inference

In [None]:
import outlines

@outlines.prompt
def rag_prompt(instruction, relevant_docs):
    """
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Use the documentation to complete the user-given task.
Docs: {{ relevant_docs }}\n<|eot_id|><|start_header_id|>user<|end_header_id|>
{{ instruction }}. Write an API call.
<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""

In [None]:
def process_retrieved_docs(doc_batches):
    return ["\n\n".join([doc.page_content for doc in doc_batch]) for doc_batch in doc_batches]

In [None]:
list(starmap(lambda x, y, z: (x, y, z), [1,2,3], [4], [5]))

In [None]:
from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline

llm = HuggingFacePipeline.from_model_id(
    model_id="HuggingFaceH4/zephyr-7b-beta",
    task="text-generation",
    pipeline_kwargs=dict(
        max_new_tokens=512,
        do_sample=False,
        repetition_penalty=1.03,
        return_full_text=False,
    ),
    model_kwargs={"quantization_config": quantization_config},
)

chat_model = ChatHuggingFace(llm=llm)

In [None]:
from langchain_core.messages import (
    HumanMessage,
    SystemMessage,
)

messages = [
    SystemMessage(content="You're a helpful assistant"),
    HumanMessage(
        content="What happens when an unstoppable force meets an immovable object?"
    ),
]

ai_msg = chat_model.invoke(messages)