## Regulatory Information Retrieval and Answer Generation

This notebook runs the original approach from "A Hybrid Approach To Information Retrieval And Answer Generation For Regulatory Texts" with minimal adjustments to run on Irish S.I dataset

In [1]:
import os
import json
import numpy as np
import pandas as pd
from typing import Dict
from tqdm import tqdm
from re import compile
from nltk.corpus import stopwords
from nltk.stem import SnowballStemmer
from nltk.tokenize import word_tokenize
from contractions import fix as fix_contractions
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer
from sklearn.feature_extraction.text import TfidfVectorizer
from functools import partial

from sklearn.feature_extraction import text as sk_text
from trectools import TrecRun, TrecQrel, TrecEval # type: ignore
import matplotlib.pyplot as plt



## Dataset preparation

Kept original from "A Hybrid Approach To Information Retrieval And Answer Generation For Regulatory Texts"

In [2]:
def load_qrels(docs_dir: str, fqrels: str) -> Dict[str, Dict[str, int]]:
    with open(fqrels, encoding='utf-8') as f:
        data = json.load(f)

    qrels = {}
    
    for e in data:
        qid = e["QuestionID"]
        for psg in e["Passages"]:
            qrels.setdefault(qid, {})
            pid = f"{psg['DocumentID']}-{psg['PassageID']}"
            qrels[qid][pid] = 1

    return qrels

file_type = 'test'
qrels = load_qrels("", "./QnA_complete.json")

with open("./data/qrels", "w") as f:
    for qid, rels in qrels.items():
        for pid, rel in rels.items():
            line = f"{qid} Q0 {pid} {rel}"
            f.write(line + "\n")

with open('../all_data.json', 'r', encoding='utf-8') as f:
    all_data = json.load(f)

collection = []
seen = set()

for doc in all_data:
    for psg in doc['Passages']:
        psg_id = f"{psg['DocumentID']}-{psg['PassageID']}"
        if psg_id not in seen:
            passage_text = psg['PassageID'] + " " + psg['Passage']
            if len(passage_text) > 100:
                collection.append(
                    dict(
                        text=passage_text,
                        ID=psg_id,
                        DocumentId=psg['DocumentID'],
                        PassageId=psg['PassageID'],
                    )
                )
                seen.add(psg_id)

print(f"Loaded {len(collection)} passages into collection")

Loaded 32810 passages into collection


In [3]:
stop_words = set(stopwords.words('english'))
stop_words = sk_text.ENGLISH_STOP_WORDS.union(stop_words)
stemmer = SnowballStemmer(language='english')

pattern_newline = compile(r'[\n\t\u200e]')
pattern_multiple_spaces = compile(r' +')
pattern_non_alphanumeric = compile(r'[^a-z0-9]')

def clean_text(text: str) -> str:
    cln_text = fix_contractions(text)
    
    cln_text = cln_text.lower()
    
    cln_text = pattern_newline.sub(' ', cln_text)
    
    cln_text = pattern_non_alphanumeric.sub(' ', cln_text)
    
    tokens = [stemmer.stem(word) for word in word_tokenize(cln_text) if word not in stop_words]
    
    cln_text = ' '.join(tokens)
    
    cln_text = pattern_multiple_spaces.sub(' ', cln_text).strip()
    
    return cln_text

def simple_cleaning(query: str) -> str:
    
    cln_query = pattern_newline.sub(' ', query)
    cln_query = pattern_multiple_spaces.sub(' ', cln_query).strip()
    return cln_query

def tokenizer(text:str)-> list:

    tokens = text.split()
    
    unigrams = tokens
    
    bigrams = [f"{tokens[i]} {tokens[i + 1]}" for i in range(len(tokens) - 1)]
    
    return unigrams + bigrams

In [4]:
tokenized_corpus = [tokenizer(clean_text(doc['text'])) for doc in collection]

In [5]:
collection_array = np.array(collection)

len(tokenized_corpus) # 10592 (originalmente 13732)

32810

## Lexical Retriever: BM25

Let us evaluate the lexical retriever using BM25 (Baseline)

In [6]:
bm25 = BM25Okapi(tokenized_corpus, k1=1.5, b=0.75)

In [7]:
def sintactic_query_bm5(query: str, bm5_instance: BM25Okapi) -> np.array:

    tokenized_query = tokenizer(clean_text(query))
    
    scores = bm5_instance.get_scores(tokenized_query)
    
    return scores

In [8]:
sintactic_bm25_retriever = partial(sintactic_query_bm5, bm5_instance=bm25)

In [9]:
retrieved = {}
top_n = 20

with open("./QnA_complete.json", encoding='utf-8') as f:
    data = json.load(f)
    
    for e in tqdm(data):
        query = e['Question']
        
        scores = sintactic_bm25_retriever(query)
        
        top_k = np.argpartition(-scores, top_n)[:top_n]
        
        top_k = top_k[np.argsort(-scores[top_k])]

        top_docs = collection_array[top_k]

        top_scores = scores[top_k]

        top_results = [{**doc, 'score': score} for doc, score in zip(top_docs, top_scores)]

        retrieved[e["QuestionID"]] = top_results

100%|██████████| 240/240 [00:44<00:00,  5.35it/s]


In [10]:
with open("./data/rankings_sintactic.trec", "w") as f:
    for qid, hits in retrieved.items():
        for i, hit in enumerate(hits):
            line = f"{qid} 0 {hit['ID']} {i+1} {hit['score']} bm25"
            f.write(line + "\n")

In [11]:
qrels = TrecQrel("./data/qrels")
run = TrecRun("./data/rankings_sintactic.trec")
te = TrecEval(run, qrels)

bm25_recall_10 = te.get_recall(depth=10)
bm25_map_cut_10 = te.get_map(depth=10)

print(f"bm25_recall_10             \tall\t{bm25_recall_10:.4f}")
print(f"bm25_map_cut_10            \tall\t{bm25_map_cut_10:.4f}")

bm25_recall_10             	all	0.4875
bm25_map_cut_10            	all	0.2991


In [12]:
qrels = TrecQrel("./data/qrels")
run = TrecRun("./data/rankings_sintactic.trec")
te = TrecEval(run, qrels)

bm25_recall_20 = te.get_recall(depth=20)
bm25_map_cut_20 = te.get_map(depth=20)

print(f"bm25_recall_20             \tall\t{bm25_recall_20:.4f}")
print(f"bm25_map_cut_20            \tall\t{bm25_map_cut_20:.4f}")

bm25_recall_20             	all	0.5500
bm25_map_cut_20            	all	0.3034


In [13]:
metrics = {
    "bm25_recall_10": bm25_recall_10,
    "bm25_map_cut_10": bm25_map_cut_10,
    "bm25_recall_20": bm25_recall_20,
    "bm25_map_cut_20": bm25_map_cut_20
}

retriever_name = "BM25-Baseline"  # Hybrid / Vector

names = list(metrics.keys())
values = list(metrics.values())

plt.figure(figsize=(8,4))
plt.bar(names, values)
plt.ylabel("score")
plt.xlabel("metric")
plt.title(f"{retriever_name} Metrics")
plt.tight_layout()
plt.savefig(f"{retriever_name}_metrics.png")
plt.close()


## Vector Retriever: Fine Tunned BAAI/bge-small-en-v1.5

Vector retriever using a fine-tuned model based on `BAAI/bge-small-en-v1.5`

In [14]:
sentence_transformer_model = SentenceTransformer(
    'raul-delarosa99/bge-small-en-v1.5-RIRAG_ObliQA',
    device='cuda'
)

In [15]:
def semantic_query(query: str, corpus_embeddings_matrix: np.array, 
                   sentence_transformer_model: SentenceTransformer) -> np.array:
    query_emb = sentence_transformer_model.encode([simple_cleaning(query)], 
                                                  device='cuda',
                                                  normalize_embeddings=True)
    scores = (query_emb @ corpus_embeddings_matrix.T)[0]
    
    return scores

In [16]:
corpus_embeddings_matrix = sentence_transformer_model.encode([simple_cleaning(doc['text']) for doc in collection_array],
                          normalize_embeddings=True,
                          show_progress_bar=True,
                          max_length=512,
                          )

Batches:   0%|          | 0/1026 [00:00<?, ?it/s]

In [17]:
semantic_retriever = partial(semantic_query, corpus_embeddings_matrix=corpus_embeddings_matrix,
                             sentence_transformer_model=sentence_transformer_model)

In [18]:
retrieved = {}
top_n = 20

with open("./QnA_complete.json", encoding='utf-8') as f:
    data = json.load(f)
    
    for e in tqdm(data):
        query = e['Question']
        
        scores = semantic_retriever(query)
        
        top_k = np.argpartition(-scores, top_n)[:top_n]
        
        top_k = top_k[np.argsort(-scores[top_k])]

        top_docs = collection_array[top_k]

        top_scores = scores[top_k]

        top_results = [{**doc, 'score': score} for doc, score in zip(top_docs, top_scores)]
        
        retrieved[e["QuestionID"]] = top_results

100%|██████████| 240/240 [00:03<00:00, 66.10it/s]


In [19]:
with open("./data/rankings_semantic.trec", "w") as f:
    for qid, hits in retrieved.items():
        for i, hit in enumerate(hits): 
            line = f"{qid} 0 {hit['ID']} {i+1} {hit['score']} dense"
            f.write(line + "\n") 

In [20]:
qrels = TrecQrel("./data/qrels")
run = TrecRun("./data/rankings_semantic.trec")
te = TrecEval(run, qrels)

vector_recall_10 = te.get_recall(depth=10)
vector_map_cut_10 = te.get_map(depth=10)

print(f"vector_recall_10             \tall\t{vector_recall_10:.4f}")
print(f"vector_map_cut_10            \tall\t{vector_map_cut_10:.4f}")

vector_recall_10             	all	0.4458
vector_map_cut_10            	all	0.2529


In [21]:
qrels = TrecQrel("./data/qrels")
run = TrecRun("./data/rankings_semantic.trec")
te = TrecEval(run, qrels)

vector_recall_20 = te.get_recall(depth=20)
vector_map_cut_20 = te.get_map(depth=20)

print(f"vector_recall_20             \tall\t{vector_recall_20:.4f}")
print(f"vector_map_cut_20            \tall\t{vector_map_cut_20:.4f}")

vector_recall_20             	all	0.5000
vector_map_cut_20            	all	0.2564


In [22]:
metrics = {
    "vector_recall_10": vector_recall_10,
    "vector_map_cut_10": vector_map_cut_10,
    "vector_recall_20": vector_recall_20,
    "vector_map_cut_20": vector_map_cut_20
}

retriever_name = "Vector-Baseline"  # Hybrid / BM25

names = list(metrics.keys())
values = list(metrics.values())

plt.figure(figsize=(8,4))
plt.bar(names, values)
plt.ylabel("score")
plt.xlabel("metric")
plt.title(f"{retriever_name} Metrics")
plt.tight_layout()
plt.savefig(f"{retriever_name}_metrics.png")
plt.close()


## Hybrid Retriever (BM25 + Fine Tunned BAAI/bge-small-en-v1.5)

Hybrid retriever using a fine-tuned model based on `BAAI/bge-small-en-v1.5`

In [23]:
def hybrid_query_avg(query: str, sintactic_retriever: partial, semantic_retriever: partial, 
                     alpha: float = 0.5) -> np.array:    
    
    sintactic_scores = sintactic_retriever(query)
    sintactic_scores = (sintactic_scores - sintactic_scores.min()) / (sintactic_scores.max() - sintactic_scores.min())
    
    semantic_scores = semantic_retriever(query)
    semantic_scores = (semantic_scores - semantic_scores.min()) / (semantic_scores.max() - semantic_scores.min())
    
    scores = alpha * semantic_scores + (1 - alpha) * sintactic_scores

    return scores

In [24]:
retrieved = {}
top_n = 20

with open("./QnA_complete.json", encoding='utf-8') as f:
    data = json.load(f)
    
    for e in tqdm(data):
        query = e['Question']
        
        scores = hybrid_query_avg(
                                query,
                                sintactic_retriever=sintactic_bm25_retriever,
                                semantic_retriever=semantic_retriever,
                                alpha=0.65
                                )
        
        top_k = np.argpartition(-scores, top_n)[:top_n]

        top_k = top_k[np.argsort(-scores[top_k])]

        top_docs = collection_array[top_k]

        top_scores = scores[top_k]

        top_results = [{**doc, 'score': score} for doc, score in zip(top_docs, top_scores)]
        
        retrieved[e["QuestionID"]] = top_results

100%|██████████| 240/240 [00:55<00:00,  4.33it/s]


In [25]:
with open("./data/rankings_hybrid.trec", "w") as f:
    for qid, hits in retrieved.items():
        for i, hit in enumerate(hits):
            line = f"{qid} 0 {hit['ID']} {i+1} {hit['score']} hybrid"
            f.write(line + "\n") 

In [26]:
qrels = TrecQrel("./data/qrels")
run = TrecRun("./data/rankings_hybrid.trec")
te = TrecEval(run, qrels)

hybrid_recall_10 = te.get_recall(depth=10)
hybrid_map_cut_10 = te.get_map(depth=10)

print(f"hybrid_recall_10             \tall\t{hybrid_recall_10:.4f}")
print(f"hybrid_map_cut_10            \tall\t{hybrid_map_cut_10:.4f}")

hybrid_recall_10             	all	0.4958
hybrid_map_cut_10            	all	0.2989


In [27]:
qrels = TrecQrel("./data/qrels")
run = TrecRun("./data/rankings_hybrid.trec")
te = TrecEval(run, qrels)

hybrid_recall_20 = te.get_recall(depth=20)
hybrid_map_cut_20 = te.get_map(depth=20)

print(f"hybrid_recall_20             \tall\t{hybrid_recall_20:.4f}")
print(f"hybrid_map_cut_20            \tall\t{hybrid_map_cut_20:.4f}")

hybrid_recall_20             	all	0.5625
hybrid_map_cut_20            	all	0.3038


In [28]:
metrics = {
    "hybrid_recall_10": hybrid_recall_10,
    "hybrid_map_cut_10": hybrid_map_cut_10,
    "hybrid_recall_20": hybrid_recall_20,
    "hybrid_map_cut_20": hybrid_map_cut_20
}

retriever_name = "Hybrid-Baseline"  # Vector / BM25

names = list(metrics.keys())
values = list(metrics.values())

plt.figure(figsize=(8,4))
plt.bar(names, values)
plt.ylabel("score")
plt.xlabel("metric")
plt.title(f"{retriever_name} Metrics")
plt.tight_layout()
plt.savefig(f"{retriever_name}_metrics.png")
plt.close()


In [30]:
import matplotlib.pyplot as plt
import numpy as np

retrievers = ["bm25", "vector", "hybrid"]
x = np.arange(len(retrievers))
width = 0.35

recall_10_vals = [bm25_recall_10, vector_recall_10, hybrid_recall_10]
recall_20_vals = [bm25_recall_20, vector_recall_20, hybrid_recall_20]

plt.figure(figsize=(12,6))
plt.bar(x - width/2, recall_10_vals, width, label="Recall@10")
plt.bar(x + width/2, recall_20_vals, width, label="Recall@20")
plt.xticks(x, retrievers)
plt.ylabel("Recall")
plt.xlabel("retriever")
plt.title("Recall Comparison")
plt.ylim(0, 0.7)
plt.grid(axis="y", linestyle="--", alpha=0.5)

for i, v in enumerate(recall_10_vals):
    plt.text(x[i] - width/2, v + 0.01, f"{v:.3f}", ha="center", fontsize=10)

for i, v in enumerate(recall_20_vals):
    plt.text(x[i] + width/2, v + 0.01, f"{v:.3f}", ha="center", fontsize=10)

plt.legend()
plt.tight_layout()
plt.savefig("comparison_recall.png")
plt.close()


In [31]:
import matplotlib.pyplot as plt
import numpy as np

retrievers = ["bm25", "vector", "hybrid"]
x = np.arange(len(retrievers))
width = 0.35

map_10_vals = [bm25_map_cut_10, vector_map_cut_10, hybrid_map_cut_10]
map_20_vals = [bm25_map_cut_20, vector_map_cut_20, hybrid_map_cut_20]

plt.figure(figsize=(12,6))
plt.bar(x - width/2, map_10_vals, width, label="MAP@10")
plt.bar(x + width/2, map_20_vals, width, label="MAP@20")
plt.xticks(x, retrievers)
plt.ylabel("MAP")
plt.xlabel("retriever")
plt.title("MAP Comparison")
plt.ylim(0, 0.5)
plt.grid(axis="y", linestyle="--", alpha=0.5)

for i, v in enumerate(map_10_vals):
    plt.text(x[i] - width/2, v + 0.01, f"{v:.3f}", ha="center", fontsize=10)

for i, v in enumerate(map_20_vals):
    plt.text(x[i] + width/2, v + 0.01, f"{v:.3f}", ha="center", fontsize=10)

plt.legend()
plt.tight_layout()
plt.savefig("comparison_map.png")
plt.close()
