In [20]:
import os
import json
import torch
import numpy as np
from sentence_transformers import SentenceTransformer, util
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.utils.quantization_config import BitsAndBytesConfig


DATA_PROCESSED_PATH = "../data/processed"
FINETUNED_RETRIEVER_PATH = "../models/retriever_finetuned_bge_base" 
LOCAL_GENERATOR_PATH = "../models/generator_qwen"

WIKIPEDIA_CHUNKS_FILE = "wikipedia_chunks_bge_base.jsonl"
CORPUS_EMBEDDINGS_FILE = "corpus_embeddings_finetuned_bge_base.npy"

USE_GENERATOR_QUANTIZATION_ON_LOAD = True
GENERATOR_QUANTIZATION_TYPE_ON_LOAD = "int8" 

TOP_K_RETRIEVER = 3 
MAX_CONTEXT_TOKENS_FOR_GENERATOR = 2048 

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

wikipedia_chunks_path = os.path.join(DATA_PROCESSED_PATH, WIKIPEDIA_CHUNKS_FILE)
corpus_embeddings_path = os.path.join(DATA_PROCESSED_PATH, CORPUS_EMBEDDINGS_FILE)

if not os.path.exists(FINETUNED_RETRIEVER_PATH):
    raise FileNotFoundError(f"Fine-tuned retriver not found: {FINETUNED_RETRIEVER_PATH}")
if not os.path.exists(LOCAL_GENERATOR_PATH):
    raise FileNotFoundError(f"Generator not found locally: {LOCAL_GENERATOR_PATH}")
if not os.path.exists(wikipedia_chunks_path):
    raise FileNotFoundError(f"Wikipedia corpus file not found: {wikipedia_chunks_path}")
if not os.path.exists(corpus_embeddings_path):
    raise FileNotFoundError(f"Corpus embeddings file not found: {corpus_embeddings_path}")

Using device: cuda


In [21]:
retriever_model = SentenceTransformer(FINETUNED_RETRIEVER_PATH, device=DEVICE)
print("Retriever loaded.")

corpus_passages_data = [] 
with open(wikipedia_chunks_path, 'r', encoding='utf-8') as f:
    for line in f:
        corpus_passages_data.append(json.loads(line))
print(f"Loaded {len(corpus_passages_data)} passages.")

corpus_embeddings = np.load(corpus_embeddings_path)
print(f"Corpus embeddings loaded. Shape: {corpus_embeddings.shape}")

Retriever loaded.
Loaded 36508 passages.
Corpus embeddings loaded. Shape: (36508, 768)


In [22]:
generator_tokenizer = AutoTokenizer.from_pretrained(
    LOCAL_GENERATOR_PATH,
    use_fast=True,
    trust_remote_code=True
)
if generator_tokenizer.pad_token_id is None:
    generator_tokenizer.pad_token_id = generator_tokenizer.eos_token_id
print("Generator's tokenizer loaded from local path.")

quantization_config_gen = BitsAndBytesConfig(load_in_8bit=True)
model_gen_kwargs = {
    "device_map": "auto",
    "trust_remote_code": True
}
model_gen_kwargs["quantization_config"] = quantization_config_gen

generator_model = AutoModelForCausalLM.from_pretrained(
    LOCAL_GENERATOR_PATH,
    **model_gen_kwargs
)
print("Generator loaded from local path.")


Generator's tokenizer loaded from local path.




Generator loaded from local path.


In [23]:
def retrieve_contexts(query_text, top_k=TOP_K_RETRIEVER):
    query_embedding = retriever_model.encode(query_text, convert_to_numpy=True)
    if query_embedding.ndim == 1:
        query_embedding = query_embedding.reshape(1, -1)
        
    cosine_scores = util.cos_sim(query_embedding, corpus_embeddings)[0].numpy()
    
    if top_k >= len(cosine_scores):
        top_k_indices = np.argsort(cosine_scores)[::-1][:top_k]
    else:
        top_k_indices_unsorted = np.argpartition(-cosine_scores, range(top_k))[:top_k]
        top_k_indices = top_k_indices_unsorted[np.argsort(-cosine_scores[top_k_indices_unsorted])]
        
    retrieved_passages = [corpus_passages_data[idx] for idx in top_k_indices]
    return retrieved_passages, [float(cosine_scores[idx]) for idx in top_k_indices]


def format_rag_prompt(query_text, retrieved_contexts_data):
    context_str = ""
    for i, context_data in enumerate(retrieved_contexts_data):
        context_str += f"Kontekst [{i+1}]: {context_data['passage_text']}\n\n"
    
    instruction = (
        "Answer the given question USING ONLY information in provided contexts. "
        "If the exact answer doesn't apper in provided context, then EXACTLY THIS: "
        "'Sorry, I don't know the answer based on the articles provided.'"
        "and don't say anything after that\n\n"
    )
    
    full_prompt_content = f"{instruction}Dostarczone konteksty:\n{context_str}Pytanie: {query_text}"

    messages = [{"role": "user", "content": full_prompt_content}]
    formatted_prompt = generator_tokenizer.apply_chat_template(
        messages, 
        tokenize=False, 
        add_generation_prompt=True
    )
    return formatted_prompt, full_prompt_content


def generate_answer_with_rag(query_text, max_new_tokens=250, temperature=0.1, top_p=0.9, do_sample=False):
    retrieved_contexts, scores = retrieve_contexts(query_text, top_k=TOP_K_RETRIEVER)
    if not retrieved_contexts:
        return "Sorry, I failed to find any articles to answer this question."

    print("\nLoaded passages:")
    for i, (ctx, score) in enumerate(zip(retrieved_contexts, scores)):
        print(f"  Rank {i+1} ID: {ctx['passage_id']} (Score: {score:.4f}):\"{ctx['passage_text']}...\" (Dokument: {ctx['document_title']})")


    rag_prompt_formatted, rag_prompt_content_for_log = format_rag_prompt(query_text, retrieved_contexts)
    prompt_token_ids = generator_tokenizer(rag_prompt_formatted, return_tensors="pt").input_ids
    inputs = generator_tokenizer(rag_prompt_formatted, return_tensors="pt", padding=False, truncation=False).to(DEVICE) 

    with torch.no_grad():
        outputs = generator_model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature if do_sample else 1.0,
            top_p=top_p if do_sample else 1.0,
            do_sample=do_sample,
            pad_token_id=generator_tokenizer.pad_token_id,
            eos_token_id=generator_tokenizer.eos_token_id
        )
    
    response_text = generator_tokenizer.decode(outputs[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True)
    
    print(response_text)
    return response_text

In [None]:
test_query_rag1 = "When did the French Revolution start?"
print(test_query_rag1 + "\n")
generate_answer_with_rag(test_query_rag1)

print("\n" + "#"*70 + "\n")

test_query_rag2 = "What is the main export product of Japan according to the provided texts?" # Pytanie wymagające kontekstu
print(test_query_rag2 + "\n")
generate_answer_with_rag(test_query_rag2)

print("\n" + "#"*70 + "\n")

test_query_rag3 = "Who is the current president of Mars based on the documents?" # Pytanie, na które prawdopodobnie nie ma odpowiedzi
print(test_query_rag3 + "\n")
generate_answer_with_rag(test_query_rag3)



--- Testing full RAG System ---
When did the French Revolution start?


Loaded passages:
  Rank 1 ID: wiki_153_chunk_0 (Score: 0.9542):"The French Revolution (French: Révolution française [ʁevɔlysjɔ̃ fʁɑ̃sɛːz]) was a period of political and societal change in France which began with the Estates General of 1789 and ended with the Coup of 18 Brumaire on 9 November 1799. Many of the revolution's ideas are considered fundamental principles of liberal democracy, and its values remain central to modern French political discourse. The causes of the revolution were a combination of social, political, and economic factors which the ancien régime ("old regime") proved unable to manage. A financial crisis and widespread social distress led to the convocation of the Estates General in May 1789, its first meeting since 1614. The representatives of the Third Estate broke away and re-constituted themselves as a National Assembly in June. The Storming of the Bastille in Paris on 14 July was followed 

"Sorry, I don't know the answer based on the articles provided."