In [1]:
# from langchain_ollama import OllamaLLM
from langchain_ollama import ChatOllama

# llm = OllamaLLM(model='deepseekmini')
llm = ChatOllama(model='dsq4km')

In [2]:
from langchain.embeddings.base import Embeddings
from transformers import AutoTokenizer, AutoModel
import torch

class MiniLM(Embeddings):
    def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2", device="cpu"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(device)
        self.device = device

    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0]
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    def embed_documents(self, texts):
        return [self._embed(text) for text in texts]

    def embed_query(self, text):
        return self._embed(text)

    def _embed(self, text):
        inputs = self.tokenizer(
            text,
            return_tensors="pt",
            truncation=True,
            padding=True
        ).to(self.device)

        with torch.no_grad():
            model_output = self.model(**inputs)

        embedding = self.mean_pooling(model_output, inputs["attention_mask"])
        return embedding[0].cpu().numpy().tolist()

  from scipy.sparse import csr_matrix, issparse


In [3]:
embedding_fn = MiniLM(
    model_name="sentence-transformers/all-MiniLM-L6-v2",
    device="cuda"  # or "cpu"
)

In [4]:
from langchain_qdrant import QdrantVectorStore, RetrievalMode
from qdrant_client import QdrantClient

client = QdrantClient(url="http://localhost:6333")

qdrant = QdrantVectorStore(
    client=client,
    collection_name="wikipedia",
    embedding=embedding_fn,
    retrieval_mode=RetrievalMode.DENSE,
)

In [59]:
from pydantic import BaseModel, Field
from langchain_core.prompts import PromptTemplate

class RetrievalResponse(BaseModel):
    response: str = Field(..., title="Retrieval Necessity Judgment", 
                         description="Whether retrieval is necessary for the query. Answer only 'Yes' or 'No'.")

retrieval_prompt = PromptTemplate(
    input_variables=["query"],
    template="""Determine if the following user query requires factual, real-world knowledge to answer. This includes questions about events, people, places, companies, historical facts, scientific concepts, or recent news.

Query: {query}

Answer only 'Yes' or 'No'. Do not output anything else."""
)

class RelevanceResponse(BaseModel):
    response: str = Field(..., title="Relevance Judgment", 
                         description="Whether the retrieved context is relevant to the query. Answer only 'Relevant' or 'Irrelevant'.")

relevance_prompt = PromptTemplate(
    input_variables=["query", "context"],
    template="Given the query '{query}' and the retrieved context '{context}', determine if the context is relevant to the query and provides useful information to complete the task. Only answer 'Relevant' or 'Irrelevant'. Do not output anything else."
)

class GenerationResponse(BaseModel):
    response: str = Field(..., title="Generated Response", 
                         description="The response generated based on the query and context.")

generation_prompt = PromptTemplate(
    input_variables=["query", "context"],
    template="""You are a helpful AI assistant. Generate a response to the query based on the provided context. If the context is relevant, use it to answer the query accurately. If the context is irrelevant, rely on your own knowledge but indicate any uncertainties or lack of information. Be concise and informative. If you are presented a multiple choice, only answer with the letter, do not include anything else.

Query: {query}
Context: {context}

Response:"""
)

class SupportResponse(BaseModel):
    response: str = Field(..., title="Support Judgment", 
                         description="Whether the response is supported by the context. Answer only 'Fully supported', 'Partially supported', or 'No support'.")

support_prompt = PromptTemplate(
    input_variables=["response", "context"],
    template="""Given the response '{response}' to a query, and the information provided '{context}' for response generation, determine if the response is supported by the information. 

Use the following entailment scale to generate a score:
- Fully supported: All information in output is supported by the evidence, or extraction from the evidence. This is only applicable when the output and part of the evidence are almost identical.
- Partially supported: The output is supported by the evidence to some extent, but there is major information in the output that is not discussed in the evidence.
- No support: The output completely ignores evidence, is unrelated to the evidence, or contradicts the evidence.

Make sure to not use any external information/knowledge to judge whether the output is true or not. Only check whether the output is supported by the evidence.

Only answer 'Fully supported', 'Partially supported', or 'No support'. Do not output anything else."""
)

class UtilityResponse(BaseModel):
    response: int = Field(..., title="Utility Score", 
                         description="The utility score of the response from 1 to 5, where 5 is highest utility.")

utility_prompt = PromptTemplate(
    input_variables=["query", "response"],
    template="""Given the query '{query}' and the response '{response}', rate the perceived utility score of the response from 1 (lowest) to 5 (highest). 

The detailed criterion is as follows:
5: The response provides a complete, highly detailed, and informative response to the query, fully satisfying the information needs.
4: The response mostly fulfills the need in the query, while there can be some minor improvements such as discussing more detailed information, having better structure of the response, or improving coherence.
3: The response is acceptable, but some major additions or improvements are needed to satisfy users' needs.
2: The response still addresses the main request, but it is not complete or not relevant to the query.
1: The response is barely on-topic or completely irrelevant.

Only answer a single number between 1 and 5. Do not output anything else."""
)

In [60]:
retrieval_chain = retrieval_prompt | llm.with_structured_output(RetrievalResponse)
relevance_chain = relevance_prompt | llm.with_structured_output(RelevanceResponse)
generation_chain = generation_prompt | llm.with_structured_output(GenerationResponse)
support_chain = support_prompt | llm.with_structured_output(SupportResponse)
utility_chain = utility_prompt | llm.with_structured_output(UtilityResponse)

In [7]:
def score_response(response_tuple):
    response, support, utility = response_tuple
    support_score = 3 if support == 'fully supported' else (2 if support == 'partially supported' else 1)
    return support_score * 10 + utility

In [8]:
def self_rag(query, vectorstore, top_k=5):
    log = []  # collect logs here
    
    def log_msg(message):
        log.append(message)
    
    # Step 1: Determine if retrieval is necessary
    input_data = {"query": query}
    retrieval_decision = retrieval_chain.invoke(input_data).response.strip().lower()
    log_msg(f"Retrieval decision: {retrieval_decision}")
    
    if "yes" in retrieval_decision:
        # Step 2: Retrieve relevant documents
        docs = vectorstore.similarity_search(query, k=top_k)
        contexts = [doc.page_content for doc in docs]
        
        # Step 3: Evaluate relevance of retrieved documents
        relevant_contexts = []
        for i, context in enumerate(contexts):
            input_data = {"query": query, "context": context}
            relevance = relevance_chain.invoke(input_data).response.strip().lower()
            log_msg(f"Document {i+1} relevance: {relevance}")
            if relevance == 'relevant':
                relevant_contexts.append(context)
        
        log_msg(f"No of relevant context: {len(relevant_contexts)}")
        
        # If no relevant contexts found, generate without retrieval
        if not relevant_contexts:
            log_msg("No relevant contexts found. Generating without retrieval...")
            input_data = {"query": query, "context": "No relevant context found. Answer the query anyway."}
            return generation_chain.invoke(input_data).response, "\n".join(log)
        
        # Step 4: Generate response using relevant contexts
        responses = []
        for i, context in enumerate(relevant_contexts):
            input_data = {"query": query, "context": context}
            response = generation_chain.invoke(input_data).response
            
            # Step 5: Assess support
            input_data = {"response": response, "context": context}
            support = support_chain.invoke(input_data).response.strip().lower()
            log_msg(f"Support assessment for response {i+1}: {support}")
            
            # Step 6: Evaluate utility
            input_data = {"query": query, "response": response}
            utility = int(utility_chain.invoke(input_data).response)
            log_msg(f"Utility score for response {i+1}: {utility}")
            
            responses.append((response, support, utility))
        
        # Select the best response based on support and utility
        best_response = max(responses, key=score_response)
        response, support, utility = best_response

        log_msg(f"Best response support: {support}, utility: {utility}")

        if len(context) > 100:
            log_msg(f"Used context: {context[:100]}...")
        else:
            log_msg(f"Used context: {context}")
        
        return response, "\n".join(log)
    else:
        # Generate without retrieval
        log_msg("Generating without retrieval...")
        input_data = {"query": query, "context": "No retrieval necessary as the query can be answered with general knowledge."}
        return generation_chain.invoke(input_data).response, "\n".join(log)


In [9]:
from langchain.chains import RetrievalQA

# Create a prompt template that includes context
prompt_template = """Use the following pieces of context to answer the question. If you don't know the answer based on the context, or context is irrelevant, say so and answer with general knowledge.

Context: {context}

Question: {question}

Answer:"""

PROMPT = PromptTemplate(
    template=prompt_template, 
    input_variables=["context", "question"]
)

qa_chain = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type="stuff",
    retriever=qdrant.as_retriever(search_kwargs={"k": 5}),
    return_source_documents=True,
    chain_type_kwargs={"prompt": PROMPT}
)

In [10]:
# query = '''
# "Which naturalised American electrical/mechanical engineer and inventor (1856-1943) has given his name to the SI unit of magnetic flux density?"
# '''

In [11]:
# # Dokumen yang ditemukan
# found_docs = qdrant.similarity_search(query, k=5)
# for doc in found_docs:
#     print(doc)

In [12]:
# response, log = self_rag(query, qdrant)

In [13]:
# # Jawaban dengan self-rag
# print(response)

In [14]:
# # Jawaban tanpa apa-apa
# llm.invoke(query).content

In [15]:
# # Jawaban dengan rag
# result = qa_chain.invoke({"query": query})

# print(f"{result['result']}")

In [16]:
import os
from pathlib import Path

data_dir = Path.cwd().parent / "data"

triviaqa_dir = data_dir / "triviaqa" / "triviaqa-unfiltered" / "test.json"
popqa_dir = data_dir / "popqa" / "test.json"
arc_dir = data_dir / "arc" / "arc.json"
result_dir = data_dir / "results"

In [29]:
import json
import random

def evaluate(benchmark, mode="selfrag", name="default", limit=None, shuffle=False):
    """
    Evaluate benchmark questions with different modes.
    
    mode options:
        - "selfrag": uses self_rag(query, qdrant)
        - "norag":   uses llm.invoke(query)
        - "rag":     uses qa_chain.invoke({"query": query})
    """

    response_file = result_dir / (name + "-r.jsonl")
    log_file = result_dir / (name + "-l.jsonl")
    # load benchmark data
    with open(benchmark, 'r', encoding='utf-8') as f:
        data = json.load(f)

    if shuffle:
        random.shuffle(data)

    if limit > len(data) or limit is None or limit < 0:
        limit = len(data)
    
    with open(response_file, 'w', encoding='utf-8') as rf, open(log_file, 'w', encoding='utf-8') as lf:
        for idx, item in enumerate(data[:limit]):
            query = item['Question']
            answer = item['Answer']

            # --- Run according to mode ---
            if mode == "selfrag":
                response, log = self_rag(query, qdrant)

            elif mode == "norag":
                response = llm.invoke(query).content
                log = None  # no log

            elif mode == "rag":
                result = qa_chain.invoke({"query": query})
                response = result['result']
                docs = result.get('source_documents', [])
                log = [doc.page_content for doc in docs]  # store docs content

            else:
                raise ValueError(f"Unknown mode: {mode}")

            # --- Check correctness ---
            responselower = response.lower()
            anslower = [ans.lower() for ans in answer]
            correct = 1 if any(ans in responselower for ans in anslower) else 0

            # --- Write response JSONL ---
            json.dump({
                'id': idx,
                'query': query,
                'response': response,
                'answer': answer,
                'correct': correct
            }, rf, ensure_ascii=False)
            rf.write('\n')

            # --- Write log JSONL ---
            json.dump({
                'id': idx,
                'query': query,
                'log': log
            }, lf, ensure_ascii=False)
            lf.write('\n')

            if (idx + 1) % 10 == 0:
                print(f"Processed {idx + 1} / {limit} questions.")

In [41]:
def compute_accuracy(response_file):
    with open(response_file, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    
    total = len(lines)
    correct = sum(1 for line in lines if json.loads(line)['correct'] == 1)
    
    accuracy = correct / total if total > 0 else 0
    return accuracy

In [49]:
def analyze(log_file, response_file):
    # Track IDs by retrieval decision
    no_retrieval_ids = set()
    yes_retrieval_ids = set()
    
    with open(log_file, 'r', encoding='utf-8') as f:
        for line in f:
            log_entry = json.loads(line)
            if "Retrieval decision: no" in log_entry['log']:
                no_retrieval_ids.add(log_entry['id'])
            elif "Retrieval decision: yes" in log_entry['log']:
                yes_retrieval_ids.add(log_entry['id'])
    
    # Initialize counters
    no_correct = no_wrong = 0
    yes_correct = yes_wrong = 0
    total_correct = total_wrong = 0
    
    with open(response_file, 'r', encoding='utf-8') as f:
        for line in f:
            response_entry = json.loads(line)
            rid = response_entry['id']
            correct = response_entry['correct'] == 1
            
            # Update totals
            if correct:
                total_correct += 1
            else:
                total_wrong += 1
            
            # Update retrieval-specific counts
            if rid in no_retrieval_ids:
                if correct:
                    no_correct += 1
                else:
                    no_wrong += 1
            elif rid in yes_retrieval_ids:
                if correct:
                    yes_correct += 1
                else:
                    yes_wrong += 1
    
    total_answers = total_correct + total_wrong
    
    return {
        "total_answers": total_answers,
        "total_correct": total_correct,
        "total_wrong": total_wrong,
        "no_retrieval_correct": no_correct,
        "no_retrieval_wrong": no_wrong,
        "yes_retrieval_correct": yes_correct,
        "yes_retrieval_wrong": yes_wrong,
    }

In [63]:
def count_inconsistent_retrieval(log_file):
    inconsistent_count = 0
    with open(log_file, 'r', encoding='utf-8', errors='ignore') as f:
        for line in f:
            log_entry = json.loads(line)
            if "Retrieval decision:" in log_entry['log']:
                if ("Retrieval decision: no" not in log_entry['log'] and
                    "Retrieval decision: yes" not in log_entry['log']):
                    inconsistent_count += 1
                    # print what comes after "Retrieval decision:"
                    print(log_entry['log'].split("Retrieval")[1].strip())
    return inconsistent_count


In [73]:
def accuracy_analysis(name):
    response_file = result_dir / f"{name}-r.jsonl"
    log_file = result_dir / f"{name}-l.jsonl"

    accuracy = compute_accuracy(response_file)
    try:
        analysis = analyze(log_file, response_file)
        inconsistent_count = count_inconsistent_retrieval(log_file)
    except Exception:
        analysis = {
            "total_answers": 0,
            "total_correct": 0,
            "total_wrong": 0,
            "no_retrieval_correct": 0,
            "no_retrieval_wrong": 0,
            "yes_retrieval_correct": 0,
            "yes_retrieval_wrong": 0,
        }
        inconsistent_count = 0

    print(f"Overall Accuracy: {accuracy:.2%}")
    print(f"Total Answers: {analysis['total_answers']}")
    print(f"Total Correct: {analysis['total_correct']}")
    print(f"Total Wrong: {analysis['total_wrong']}")
    print(f"No Retrieval - Correct: {analysis['no_retrieval_correct']}, Wrong: {analysis['no_retrieval_wrong']}")
    print(f"Yes Retrieval - Correct: {analysis['yes_retrieval_correct']}, Wrong: {analysis['yes_retrieval_wrong']}")
    print(f"Inconsistent Retrieval Decisions: {inconsistent_count}")


In [68]:
name = 'test'

In [69]:
evaluate(triviaqa_dir, mode="selfrag", name=name, limit=100, shuffle=True)

Processed 10 / 100 questions.
Processed 20 / 100 questions.
Processed 30 / 100 questions.
Processed 40 / 100 questions.
Processed 50 / 100 questions.
Processed 60 / 100 questions.
Processed 70 / 100 questions.
Processed 80 / 100 questions.
Processed 90 / 100 questions.
Processed 100 / 100 questions.


In [70]:
accuracy_analysis(name)

decision: the question requires factual real-world knowledge to answer.
Generating without retrieval...
decision: partially correct information
Generating without retrieval...
decision: partially correct answer
Generating without retrieval...
Overall Accuracy: 47.00%
Total Answers: 100
Total Correct: 47
Total Wrong: 53
No Retrieval - Correct: 16, Wrong: 21
Yes Retrieval - Correct: 29, Wrong: 31
Inconsistent Retrieval Decisions: 3


In [56]:
name = 'poptest'
evaluate(popqa_dir, mode="selfrag", name=name, limit=100, shuffle=True)

Processed 10 / 100 questions.
Processed 20 / 100 questions.
Processed 30 / 100 questions.
Processed 40 / 100 questions.
Processed 50 / 100 questions.
Processed 60 / 100 questions.
Processed 70 / 100 questions.
Processed 80 / 100 questions.
Processed 90 / 100 questions.
Processed 100 / 100 questions.


In [None]:
accuracy_analysis(name)

Overall Accuracy: 18.00%
Total Answers: 100
Total Correct: 18
Total Wrong: 82
No Retrieval - Correct: 10, Wrong: 34
Yes Retrieval - Correct: 8, Wrong: 48
Inconsistent Retrieval Decisions: 0


In [None]:
name = 'arctest'
evaluate(arc_dir, mode="selfrag", name=name, limit=100, shuffle=True)
accuracy_analysis(name)

# name = 'arctest2'
# evaluate(arc_dir, mode="selfrag", name=name, limit=100, shuffle=True)
# accuracy_analysis(name)

Processed 10 / 100 questions.
Processed 20 / 100 questions.
Processed 30 / 100 questions.
Processed 40 / 100 questions.
Processed 50 / 100 questions.
Processed 60 / 100 questions.
Processed 70 / 100 questions.
Processed 80 / 100 questions.
Processed 90 / 100 questions.
Processed 100 / 100 questions.
true
Generating without retrieval...
Overall Accuracy: 70.00%
Total Answers: 100
Total Correct: 70
Total Wrong: 30
No Retrieval - Correct: 20, Wrong: 12
Yes Retrieval - Correct: 49, Wrong: 18
Inconsistent Retrieval Decisions: 1


In [75]:
# normal mode
name = 'defaulttriviaqa'
evaluate(triviaqa_dir, mode="norag", name=name, limit=100, shuffle=True)
accuracy_analysis(name)

name = 'defaultpopqa'
evaluate(popqa_dir, mode="norag", name=name, limit=100, shuffle=True)
accuracy_analysis(name)

name = 'defaultarc'
evaluate(arc_dir, mode="norag", name=name, limit=100, shuffle=True)
accuracy_analysis(name)

Processed 10 / 100 questions.
Processed 20 / 100 questions.
Processed 30 / 100 questions.
Processed 40 / 100 questions.
Processed 50 / 100 questions.
Processed 60 / 100 questions.
Processed 70 / 100 questions.
Processed 80 / 100 questions.
Processed 90 / 100 questions.
Processed 100 / 100 questions.
Overall Accuracy: 58.00%
Total Answers: 0
Total Correct: 0
Total Wrong: 0
No Retrieval - Correct: 0, Wrong: 0
Yes Retrieval - Correct: 0, Wrong: 0
Inconsistent Retrieval Decisions: 0
Processed 10 / 100 questions.
Processed 20 / 100 questions.
Processed 30 / 100 questions.
Processed 40 / 100 questions.
Processed 50 / 100 questions.
Processed 60 / 100 questions.
Processed 70 / 100 questions.
Processed 80 / 100 questions.
Processed 90 / 100 questions.
Processed 100 / 100 questions.
Overall Accuracy: 23.00%
Total Answers: 0
Total Correct: 0
Total Wrong: 0
No Retrieval - Correct: 0, Wrong: 0
Yes Retrieval - Correct: 0, Wrong: 0
Inconsistent Retrieval Decisions: 0
Processed 10 / 100 questions.
Pr