#CafChem Teaching - Simple RAG-based chat with Gemma

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/MauricioCafiero/CafChemTeach/blob/main/notebooks/Simple_Rag-Chat__CafChem.ipynb)

## This notebook allows you to:
- Upload a document and interact with it via RAG and the Gemma LLM.

## Requirements:
- will install all needed libraries
- Needs a GPU. L4 or higher recommended.

## Set-up

### Install libraries

In [1]:
!pip -q install pypdf
! pip -q install langchain_community
!pip -q install langchain_huggingface
!pip -q install langchain_chroma
!pip install -U bitsandbytes

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m310.5/310.5 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m43.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m45.2/45.2 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.9/50.9 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m74.4/74.4 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m443.5/443.5 kB[0m [31m17.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m786.8/786.8 kB[0m [31m41.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.3/67.3 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25

### Import libraries and define functions

In [1]:
from transformers import AutoTokenizer, BitsAndBytesConfig, AutoModelForCausalLM
import torch
import os
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
import gradio as gr

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
chat_history = []

def clear_history():
  '''
    Clear the chat history.

      Args:
        None
      Returns:
        None
  '''
  global chat_history
  chat_history = []

class use_model():
  '''
    Class to use a Gemma model to perform retrieval augmented generation (RAG).
  '''
  def __init__(self, device: str, model_id = "google/gemma-3-1b-it"):
    '''
    Initialize the class.

      Args:
        device: The device to use for the model.
        model_id: The model to use.
      Also defines:
        store_names: A list of the names of the stores created.
        store_flag: A flag to indicate if a store has been created.
    '''
    self.model_id = model_id
    self.device = device
    self.store_names = []
    self.store_flag = False

  def start_model_tokenizer(self):
    '''
      Downloads and loads the model and tokenizer.

      Args:
        None
      Returns:
        None
      Also defines:
        model: The model to use.
        tokenizer: The tokenizer to use.
    '''
    quantization_config = BitsAndBytesConfig(load_in_8bit=True)

    self.model = AutoModelForCausalLM.from_pretrained(
        self.model_id, quantization_config=quantization_config
    ).eval()

    self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)

    print(f"Model loaded on {self.device}")

  def create_store(self, doc_path: str, store_name: str, num_docs: int):
    '''
      creates a vector store from a pdf document.

      Args:
        doc_path: The path to the pdf document.
        store_name, directory_name: The name of the store.
        num_docs, k: The number of documents to retrive for a query.
      Returns:
        None
      Also defines:
        store: The vector store.
        embeddings: The embeddings to use.
    '''
    self.k = num_docs

    if store_name in self.store_names:
      print(f"Store already exists, calling load_store()")
      self.load_store(store_name)
    else:
      self.store_names.append(store_name)
      self.directory_name = f"/content/{store_name}"


      os.mkdir(self.directory_name)
      print(f"Directory '{self.directory_name}' created successfully.")

      loader = PyPDFLoader(doc_path)
      pages = loader.load()

      text_splitter = RecursiveCharacterTextSplitter(
          chunk_size=2000,
          chunk_overlap=200,
          length_function=len,
          add_start_index=True,
      )
      texts = text_splitter.split_documents(pages)

      model_name = "BAAI/bge-base-en-v1.5"
      model_kwargs = {"device":self.device}
      encode_kwargs = {'normalize_embeddings':True}
      self.embeddings = HuggingFaceEmbeddings(model_name = model_name, model_kwargs = model_kwargs,
                                        encode_kwargs = encode_kwargs)

      self.store = Chroma.from_documents(documents = texts,
                                    embedding = self.embeddings,
                                    persist_directory = self.directory_name)
      print(f"Store {store_name} created")
      self.store_flag = True

  def load_store(self, store_name):
    '''
      Loads a vector store from a directory.

      Args:
        store_name: The name of the store.
      Returns:
        None
    '''
    if store_name in self.store_names:
      self.directory_name = f"/content/{store_name}"
      self.store = Chroma(persist_directory = self.directory_name, embedding_function = self.embeddings)
      print("Store loaded")
      self.store_flag = True
    else:
      print("Store not found, call create store with a document path.")

  def search_store(self, query: str):
    '''
      Searches the vector store for a query.

      Args:
        query: The query to search for.
      Returns:
        results: The k results of the search.
    '''
    results = self.store.similarity_search(query, k=self.k)
    return results


  def chat(self, raw_prompt: str):
    '''
      Chats with the model.

      Args:
        raw_prompt: The prompt to send to the model.
      Returns:
        chat_history: The chat history.
    '''
    global chat_history

    role_text = "You are a helpful assistant. "
    prompt = ""
    context = ""

    if self.store_flag:
      relevant_docs = self.search_store(raw_prompt)
      context += "\n\n".join([doc.page_content for doc in relevant_docs])
      prompt += f"RELEVANT INFORMATION: {context}\n\nQUERY: {raw_prompt}"
      role_text += " Use the RELEVANT INFORMATION to answer the QUERY."

      messages = [[{
                "role": "system",
                "content": [{"type": "text", "text": role_text},]
            },{
                "role": "user",
                "content": [{"type": "text", "text": prompt},]
            }]]

    else:
      messages = [[{
                  "role": "system",
                  "content": [{"type": "text", "text": role_text},]
              },{
                  "role": "user",
                  "content": [{"type": "text", "text": raw_prompt},]
              }]]

    chat_history.append({"role": "user", "content": raw_prompt})

    inputs = self.tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    ).to(self.model.device) #.to(torch.bfloat16)

    with torch.inference_mode():
        outputs = self.model.generate(**inputs, max_new_tokens=1000)

    outputs = self.tokenizer.batch_decode(outputs)

    parts = outputs[0].split('<start_of_turn>model')
    response = parts[1].strip('<end_of_turn>')

    chat_history.append(
              {"role": "assistant", "content": response}
  )

    return '', chat_history #response, context

def chatbot():
  '''
    A simple chatbot to interact with the model. Includes a clear button to clear the chat history, and
    a submit button to send a message to the model.
    Displays the entire chat history in a gradio chat interface unless the chat history is cleared.

      Args:
        None
      Returns:
        None
  '''
  with gr.Blocks() as forest:
    gr.Markdown(
        """
        # Chat with Gemma using RAG.
        ### Enter your messages below.
        """)


    chatbot = gr.Chatbot(type="messages")
    msg = gr.Textbox(label="Type your messages here and hit enter.")
    chat_btn = gr.Button(value = "Send")

    clear = gr.ClearButton([msg, chatbot])
    clear.click(clear_history)

    chat_btn.click(chat_with_rag.chat, [msg], [msg,chatbot])
    msg.submit(chat_with_rag.chat, [msg], [msg,chatbot])

  forest.launch(share=True)

## Set up chat with
- Upload your document for RAG.
- add the path to your document to the create_store function call below.
- run these 3 cells before starting the chatbot

In [3]:
model_name = "google/gemma-3-1b-it"
#model_name = "google/gemma-3-4b-it"

chat_with_rag = use_model(device, model_name)

In [4]:
chat_with_rag.start_model_tokenizer()

Model loaded on cuda


In [5]:
chat_with_rag.create_store('/content/Variable_temperature_inference_Cafiero_2025_advance.pdf','anneal', 5)

Directory '/content/anneal' created successfully.
Store anneal created


## Test RAG
- Just some cells for testing

In [12]:
#chat_with_rag.load_store('anneal')

Store not found, call create store with a document path.


In [6]:
_, history = chat_with_rag.chat("What is the best temperature ramp to use for molecule generation and why?")

In [7]:
print(history)




In [13]:
_, history = chat_with_rag.chat("What is a challenging token?")

In [14]:
print(history)


According to the text, a challenging token is one with high Shannon entropy, meaning that there are multiple possibilities for that token at that point in the generation process that all have similar probabilities.


In [24]:
chat_with_rag.create_store('/content/harle-cafiero-2025-benchmark-ccsd(t)-and-density-functional-theory-calculations-of-biologically-relevant-catecholic.pdf','catechols',5)

Directory '/content/catechols' created successfully.
Store catechols created


In [26]:
_, history = chat_with_rag.chat("What is the best DFT method for catecholic molecules?")

In [None]:
print(history)

## Chatbot
- Click the link below to open the chat full-screen.
- You may also share the link with others. It will work as long as this notebook is running.

In [6]:
chatbot()

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://a992c51ccd8879ec06.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
