In [3]:
# from transformers import AutoTokenizer, AutoModel
import torch

import pandas as pd

from transformers import XLMRobertaConfig
from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaModel as XLMRobertaModelHF

from transformers import AutoTokenizer, AutoModelForCausalLM

from sentence_transformers import SentenceTransformer

import faiss


In [4]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

Using device: cuda


### Part 1: Embed the Documents

In [5]:
# Create the encoder model
encoder_model_name = "FacebookAI/xlm-roberta-base"
config = hf_config = XLMRobertaConfig.from_pretrained(encoder_model_name)
encoder_model = XLMRobertaModelHF(config, add_pooling_layer=True).to(device)
encoder_tokenizer = AutoTokenizer.from_pretrained(encoder_model_name)

In [None]:
# Create the document loader
def embed_inputs(texts, tokenizer, model, max_length=64, pool_method="mean_pool"):

    assert pool_method in ["mean_pool", "cle"]

    # Tokenize
    inputs = tokenizer(
        texts, 
        padding=True, 
        truncation=True,
        max_length=max_length, 
        return_tensors="pt"
    ).to(model.device)

    # Forward pass
    with torch.no_grad():
        outputs = model(**inputs)

    # Compress to vector form:
    hidden_states = outputs.last_hidden_state
    if pool_method == "mean_pool":
        mask = inputs['attention_mask'].unsqueeze(-1).expand(hidden_states.size()).float()
        embeds = (hidden_states * mask).sum(1) / mask.sum(1)
    else:
        embeds = hidden_states[:,0,:]
    
    return embeds

In [None]:
# Load the documents
df = pd.read_csv("./sample_data/rag_documents.csv")
texts = df["content"].tolist()

# Embed the documents
embeds = embed_inputs(texts, encoder_tokenizer, encoder_model, max_length=64, method="mean_pooling").cpu().numpy()

# Add the documents to the VDB
index = faiss.IndexFlatIP(embeds.shape[1])
faiss.normalize_L2(embeds)
index.add(embeds)

### Part 2 - Answer Questions with RAG

In [None]:
# Load in the decoder model
decoder_model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # Replace with Qwen 3 - 0.6B ( or maybe 1.7B)
decoder_tokenizer = AutoTokenizer.from_pretrained(decoder_model_name)
decoder_model = AutoModelForCausalLM.from_pretrained(
    decoder_model_name,
    dtype=torch.bfloat16
).to(device)


In [10]:
# Load in the questions
df = pd.read_csv("./sample_data/rag_queries.csv")
queries = df["query"].tolist()

query = queries[0]

# TODO: Add clarification of the query with the decoder model

# Embed the query with the encoder model
embedding = embed_inputs([query], encoder_tokenizer, encoder_model).cpu().numpy()

# Find the top-k documents
k = 3
dist, idx = index.search(embedding, k)

# Generate an answer given the context
context = ""
for i in range(k):
    context += "\n" + texts[idx[0][i]]
answer_prompt = f"Answer the following question using only the context provided.\nContext: {context}\n\nQuestion: {query}\n\nAnswer:"

inputs = decoder_tokenizer(answer_prompt, return_tensors="pt").to(decoder_model.device)

outputs = decoder_model.generate(
    **inputs,
    max_new_tokens=150,
    temperature=0.7,
    top_p=0.9,
    do_sample=True
)

response = decoder_tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)

prompt = f"Rephrase the following query and identify what information is useful for answering it. \nUser Query: {query}\nClarified Query:"


Answer the following question using only the context provided.
Context: 
The solar system consists of the Sun and all celestial objects bound by its gravity, including eight planets.
Climate change refers to long-term alterations in temperature and weather patterns, largely caused by human activity and greenhouse gases.
The French Revolution was an uprising from 1789 to 1799 that overthrew the monarchy and led to the rise of Napoleon Bonaparte.

Question: What is the role of chlorophyll in plants?

Answer:
Chlorophyll is a pigment found in plants that absorbs sunlight and uses it to convert carbon dioxide and water into glucose. It is responsible for the green color of leaves and the production of food for the plant.

Context: 
Plants use chlorophyll to absorb light and convert it into energy for growth and reproduction. Without chlorophyll, plants would not be able to photosynthesize.

Climate change is caused by the greenhouse effect, which is the greenhouse effect refers to the fact