In [None]:
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
import pandas as pd
import torch
import os
import re


In [None]:
# Function to generate embeddings for documents and queries
def generate_embeddings(embedding_model, texts, device):
    
    return embedding_model.encode(texts, show_progress_bar=True, batch_size=160, device=device)

In [None]:
# Function to retrieve top k most similar documents
def retrieve_documents(query, doc_embeddings, documents, embedding_model, device, k=3):
    
    # Generate embeddings for the query and documents
    query_embedding = generate_embeddings(embedding_model, [query], device)
    
    # Compute cosine similarity between query and documents
    similarities = cosine_similarity(query_embedding, doc_embeddings)
    
    # Get top k most similar documents
    most_similar_indices = similarities.argsort()[0][-k:][::-1]
    return [documents[i] for i in most_similar_indices]

In [None]:
# Function to generate an answer based on retrieved documents
def generate_answer(generator_model, tokenizer, query, retrieved_docs):

    # Construct the prompt using the retrieved documents
    prompt = "Given the following documents:\n"
    prompt += "\n".join(f"{i}. {doc}" for i, doc in enumerate(retrieved_docs, 1))
    prompt += f"\n\nUser query: {query}\n\n"
    prompt += "Based on the above documents, provide a concise, clear and short answer to the user's query.\n"
    prompt += "Don't need too much explanation and keep the anwser in 50 words.\n"

    messages = [
        {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
        {"role": "user", "content": prompt}
    ]

    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    model_inputs = tokenizer([text], return_tensors="pt").to(generator_model.device)

    generated_ids = generator_model.generate(
        **model_inputs,
        max_new_tokens=512
    )

    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]

    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return response



In [None]:
# check if there is a GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
# Step 1: Load pre-trained models
embedding_model = SentenceTransformer('BAAI/bge-small-en').to(device)

model_name = "Qwen/Qwen2.5-0.5B"  # "Qwen/Qwen2.5-1.5B"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
generator_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16, 
).to(device)



In [None]:
# Step 2: Index dataset
df = pd.read_csv('../data/1K_news.csv', encoding='utf-8')
data_list = df.values.tolist()
documents = ['title: {}.  text: {}'.format(d[1], d[2])  for d in data_list]

torch.cuda.empty_cache()

if not os.path.exists('../data/document_embeddings.npy'):
    documents_embedding =  generate_embeddings(embedding_model, documents, device)
    
    # Save embeddings to a numpy file
    np.save('../data/document_embeddings.npy', documents_embedding)

# Load embeddings from the numpy file
documents_embedding = np.load('../data/document_embeddings.npy')

In [None]:
# Step 3: Perform Retrieval
query = "What option do civil servants in Malaysia have for their working hours during Ramadan, according to Communications Minister Fahmi Fadzil?"
retrieved_docs = retrieve_documents(query, documents_embedding, documents, embedding_model, device, k=2)

# view retrieved documents
for doc in retrieved_docs:
    print(re.search(r'^(.*?)text:', doc).group(1).strip())

In [None]:
# Step 4: Perform Generation
answer = generate_answer(generator_model, tokenizer, query, retrieved_docs)
print("Generated Answer:")
print(answer)

In [None]:
print(device)

In [None]:
torch.cuda.empty_cache()