Follow these: 
https://github.com/langchain-ai/langgraph/blob/main/examples/rag/langgraph_rag_agent_llama3_local.ipynb
https://github.com/langchain-ai/langgraph/blob/main/examples/rag/langgraph_self_rag_local.ipynb

TODO: change embeddings; adapt json output for groq -> json_mode: https://api.python.langchain.com/en/latest/chat_models/langchain_groq.chat_models.ChatGroq.html#langchain_groq.chat_models.ChatGroq.with_structured_output


In [None]:
!source ~/.zshrc
# specify your working directory
working_dir = "/Users/pietro/open-modular-rag"

In [None]:
from dotenv import load_dotenv
import os
from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq
from torch import cuda
from typing import Callable
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
import chromadb

import pandas as pd
import re
import string

In [None]:
load_dotenv()
GROQ_API_KEY = os.getenv("GROQ_API_KEY")

### Utils

In [None]:
def parse_metadata(metadata_str: str):
    """ transforms relevant data from a data frame column into a dict format
    Args:
        metadata_str (_type_): column of a dataframe

    Returns:
        _type_: column in a dict format needed for the metadata chroma function
    """
    metadata_dict = {}
    if pd.notna(metadata_str):
        # Assuming metadata is a string formatted as "key: value, key: value"
        for part in metadata_str.split(", "):
            if ": " in part:
                key, value = part.split(": ", 1)
                metadata_dict[key.strip()] = value.strip()
    return metadata_dict

## Load the docs

In [None]:
# store preprocessed and chunked data to a defined directory
combined_df = pd.read_parquet(working_dir + '/moreAgentsPaper.parquet', engine='fastparquet')
combined_df.head()


In [None]:
combined_df["Metadata"] = combined_df["Metadata"].apply(parse_metadata)
combined_df.Metadata.to_list()[:2]

In [None]:
# extract elements from dataframe and put them in a format suitable for chromadb
metadatas = combined_df['Metadata'].tolist()
ids = combined_df[['Chunk_ID']].apply(lambda x: ' '.join(x.dropna().values.tolist()), axis=1).tolist()
documents_all = combined_df[['Content']].apply(lambda x: ' '.join(x.dropna().values.tolist()), axis=1).tolist() 

### Initialize embedding model and embed chunks

In [None]:
embed_model_id = 'sentence-transformers/all-mpnet-base-v2'

device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'

# Initialize embedding model
embedding_model = HuggingFaceEmbeddings(
    model_name=embed_model_id,
    model_kwargs={'device': device},
    encode_kwargs={'device': device, 'batch_size': 32},
    cache_folder=working_dir + '/emb_model'
)

In [None]:
# Perform Embedding
embeddings = embedding_model.embed_documents(documents_all)

print(f"We have {len(embeddings)} doc embeddings, each with "
      f"a dimensionality of {len(embeddings[0])}.")

In [None]:
# ChromaDB setup to initilize collection including indeces of all documents
# (in case of errors, perform pip uninstall chromadb and pip install chromadb)
chroma_client = chromadb.PersistentClient(path=working_dir + "/vectordb")

In [None]:
# provide a name to setup and reference the vector index
collection_name = "more_agents_paper_self_rag"
# initialize the vector index with the respective similarity search metric
vectorstore = chroma_client.get_or_create_collection(collection_name, metadata={"hnsw:space": "cosine"})

In [None]:
# update the vector index with the preparred data
vectorstore.upsert(
    embeddings=embeddings,
    documents=documents_all,
    metadatas=metadatas,
    ids=ids
)

In [None]:
# Save the vectordb as a langchain object

- I need to initialize a persistent chromadb client.
- Then, I need to do `get_or_create_collection` to initialize a new collection
- Then, I need to update the vector store.

In a new notebook then I call again `get_or_create_collection` and initialize the langchain retriever from the chromadb collection as done in [here](https://python.langchain.com/docs/integrations/vectorstores/chroma/#passing-a-chroma-client-into-langchain).