In [1]:
# References
# https://huggingface.co/blog/ngxson/make-your-own-rag

In [2]:
import json
from datasets import Dataset
with open('CombinedLaws.json', 'r', encoding='utf-8') as file:
  dataset: list = json.load(file)

In [3]:
def PrepareDatasetWithArticles(dataset):
    data = []

    for law in dataset:
        try:
            for article in law["Articles"]:
                entry = {
                    'id': "{}-{}".format(law["LawId"], article["Number"]),
                    'text': "{}\n{}".format(law['LawTitle'], article['Point']),
                    'metadata': {
                        'lawTitle':law['LawTitle'],
                        'lawId':law['LawId'],
                        'lawUrl':law['LawSource'],
                        'articleNumber':article['Number'],
                        'articleText':article['Point']
                    }
                }
                data.append(entry)
        except  Exception as error:
            print(error, law['LawSource'])

    return Dataset.from_list(data)

In [4]:
def PrepareDatasetWithLaws(dataset, ArticlePoints = 0):
    data = []

    for law in dataset:

        try:
            articleGroups = []
            if ArticlePoints == 0:
                text = law['LawTitle']
                numbers = []
                for article in law["Articles"]:
                    text += '\n' + article['Point']
                    numbers.append(article['Number'])
                articleGroups.append((text,numbers))
            else:
                text = law['LawTitle']
                numbers = []
                for article in law["Articles"]:
                    text += '\n' + article['Point']
                    numbers.append(article['Number'])
                    if len(numbers) == ArticlePoints:
                        articleGroups.append((text,numbers))
                        text = law['LawTitle']
                        numbers = []
                if len(numbers) > 0:
                    articleGroups.append((text,numbers))
            for group in articleGroups:
                if len(articleGroups) == 1: id = law["LawId"]
                else: id = law["LawId"] + '-' + str(group[1][0]) + '-' + str(group[1][-1])
                entry = {
                    'id': id,
                    'text': group[0],
                    'metadata': {
                        'lawTitle':law['LawTitle'],
                        'lawUrl':law['LawSource'],
                        'Articles': group[1]
                    }
                }
                data.append(entry)
        except  Exception as error:
            print(error, law['LawSource'])

    return Dataset.from_list(data)

In [5]:
articleData = PrepareDatasetWithArticles(dataset)
lawData = PrepareDatasetWithLaws(dataset)
twoArticleGroupData = PrepareDatasetWithLaws(dataset,2)

In [None]:
import ollama

EMBEDDING_MODEL = 'bge-m3'
LANGUAGE_MODEL = 'alibayram/erurollm-9b-instruct'
RERANKING_MODEL = 'BAAI/bge-reranker-v2-m3'

In [7]:
def add_chunk_to_database(id, text, article_list):
  embedding = ollama.embed(model=EMBEDDING_MODEL, input=text)['embeddings'][0]
  return [id, text, article_list, embedding]

In [8]:
import os
import pickle
from tqdm.auto import tqdm

batch_size = 64

def EmbedData(filename, data):
    vector_db = []
    if not(os.path.exists(filename)):
        for i in tqdm(range(0, len(data), batch_size)):
            i_end = min(len(data), i+batch_size)
            for x in range(i,i_end):
                chunk = data[x]
                article_list = data[x]['metadata']['Articles'] if 'Articles' in data[x]['metadata'].keys() else [data[x]['metadata']['articleNumber']]
                vector_db.append(add_chunk_to_database(chunk['id'], chunk['text'], article_list))
        with open(filename, "wb") as f:
            pickle.dump(vector_db, f)
    else:
        with open(filename, "rb") as f:
            vector_db = pickle.load(f)
    return vector_db


In [9]:
filename = 'VectorDatabases/PilniLikumi.pkl'
law_db = EmbedData(filename, lawData)

In [10]:
filename = 'VectorDatabases/LikumiPaPantiem.pkl'
article_db = EmbedData(filename, articleData)

In [11]:
filename = 'VectorDatabases/LikumiPaDivuPantuGrupām.pkl'
articleTwoGroup_db = EmbedData(filename, twoArticleGroupData)

In [None]:
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import heapq

def retrieve_rerankers(query, vector_db, top_k=25, top_n=3, reranker = None, batch_size=10):
    # Generate query embedding
    query_embedding = np.array(ollama.embed(model=EMBEDDING_MODEL, input=query)['embeddings'][0]).reshape(1, -1)

    # Extract ids, texts, and embeddings
    ids, chunks, article_numbers, embeddings = zip(*vector_db)  # Unpack into separate lists
    embeddings = np.vstack(embeddings)  # Convert list of arrays into a 2D NumPy array

    # Compute cosine similarities for all embeddings at once
    similarities = cosine_similarity(query_embedding, embeddings).flatten()  # Get 1D array of similarities

    # Find top-K results efficiently (instead of min_n)
    top_k_indices = np.argpartition(similarities, -top_k)[-top_k:]  # Get indices of top-K elements
    sorted_top_k_indices = top_k_indices[np.argsort(similarities[top_k_indices])[::-1]]  # Sort them

    top_results = [(ids[i], chunks[i], article_numbers[i], similarities[i]) for i in sorted_top_k_indices]

    if reranker:
        reranked_results = []
        doc_to_article = {doc_id: article_no for doc_id, article_no in zip(ids, article_numbers)}

        # print(sum([len(doc[1]) for doc in top_results]),end=' ')
        # Process reranker inputs in batches
        for i in range(0, len(top_results), batch_size):
            batch = top_results[i:i + batch_size]
            doc_ids, docs, _, _ = zip(*batch)

            # print(sum([len(doc) for doc in docs]),end=' ')

            reranker_scores = reranker.rank(query=query, docs=docs, doc_ids=doc_ids)

            # Fetch reranked results efficiently
            reranked_results.extend(
                (example.doc_id, example.document.text, doc_to_article[example.doc_id], example.score)
                for example in reranker_scores.results
            )
        # print()
        # Sort reranked results by reranker score
        reranked_results.sort(key=lambda x: x[3], reverse=True)

        return reranked_results[:top_n]
    
    return top_results[:top_n]

In [None]:
import os
from os import listdir
from os.path import isfile, join

def GetCustomRagInstructions(instructions, RagFile, VectorDb, ranker, top_k=500, top_n=8, batch_size=10):
    prompts = {}
    for i in range(1,top_n,2):
        prompts[i] = []

    for i in tqdm(range(0, len(instructions))):
        instruction = instructions[i]
        input_query = instruction['question']
        
        retrieved_knowledge = retrieve_rerankers(input_query ,VectorDb, top_k=top_k, top_n=top_n, reranker=ranker, batch_size=batch_size)

        _, chunks, _, _ = zip(*retrieved_knowledge)

        # Prepare non rereanked prompt
        prompt = 'Izmantojot dotās atsauces, sniedz juridiski pareizu atbildi uz doto jautājumu:\n'
        prompt += 'Atsauces: ' + '\n'.join(chunks) + '\n'
        prompt += 'Jautājums: ' + input_query

        prompts[i].append({
            "id": instruction['id'],
            "rag_prompt": prompt,
            "question": input_query,
            "gold": instruction['gold']
        })

        saveFile = f'RAG/{top_k}_{i}_'+RagFile
        
        if isfile(saveFile):
            with open(saveFile, 'r', encoding='utf-8') as f:
                temp = json.load(f)
                temp.extend(prompts[i])
                prompts[i] = temp

        with open(saveFile, 'wt', encoding='utf-8') as f:
            json.dump(prompts[i], f, ensure_ascii=False, indent=4)

In [None]:
with open('ModelInstructions.json', encoding='utf-8') as f:
    instructions = [item for item in json.load(f) if 'rag_prompt' != '']

In [15]:
from rerankers import Reranker

ranker = Reranker(RERANKING_MODEL, model_type='cross-encoder')

Loading TransformerRanker model BAAI/bge-reranker-v2-m3 (this message can be suppressed by setting verbose=0)
No device set
Using device cuda
No dtype set
Using dtype torch.float32
Loaded model BAAI/bge-reranker-v2-m3
Using device cuda.
Using dtype torch.float32.


In [None]:
RankFile = 'RagWithRerankInstructionsArticles.json'
GetCustomRagInstructions(instructions, RankFile, article_db, ranker, top_k=100, top_n=3, batch_size=5)

  0%|          | 0/90 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]