In [2]:
import json
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
from nltk.tokenize import sent_tokenize as sent_tokenize_uncached
import nltk
from functools import cache
import tqdm
tqdm.tqdm.pandas()
import os
import csv
import re

import pandas as pd
import itertools

nltk.download('punkt')
import traceback

In [2]:
from dotenv import load_dotenv
"""
This can also be done directly through os.environ or shell command
env file format -
HF_TOKEN=<your-key-here>
OPENAI_API_KEY=<your-key-here>
"""
load_dotenv(dotenv_path="regnlp/docQAgent/.env.local")

True

# Repass provided code

In [1]:
# Set up device for torch operations
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Correct path to the trained model and tokenizer
# model_name = './models/obligation-classifier-legalbert'
model_name = "yashmalviya/legal-bert-base-uncased-regnlp-obligation-classifier"

# Load the tokenizer and model for obligation detection
obligation_tokenizer = AutoTokenizer.from_pretrained(model_name)
obligation_model = AutoModelForSequenceClassification.from_pretrained(model_name)
obligation_model.to(device)
obligation_model.eval()

# Load NLI model and tokenizer for obligation coverage using Microsoft's model
coverage_nli_model = pipeline("text-classification", model="microsoft/deberta-large-mnli", device=device)

# Load NLI model and tokenizer for entailment and contradiction checks
nli_tokenizer = AutoTokenizer.from_pretrained('cross-encoder/nli-deberta-v3-xsmall')
nli_model = AutoModelForSequenceClassification.from_pretrained('cross-encoder/nli-deberta-v3-xsmall')
nli_model.to(device)
nli_model.eval()

In [45]:
# Define a cached version of sentence tokenization
@cache
def sent_tokenize(passage: str):
    return sent_tokenize_uncached(passage)

def softmax(logits):
    e_logits = np.exp(logits - np.max(logits, axis=1, keepdims=True))
    return e_logits / np.sum(e_logits, axis=1, keepdims=True)

def get_nli_probabilities(premises, hypotheses):
    features = nli_tokenizer(premises, hypotheses, padding=True, truncation=True, return_tensors="pt").to(device)
    nli_model.eval()
    with torch.no_grad():
        logits = nli_model(**features).logits.detach().cpu().numpy()
    probabilities = softmax(logits)
    features.input_ids.detach()
    features.token_type_ids.detach()
    features.attention_mask.detach()
    return probabilities

def get_nli_matrix(passages, answers):
    entailment_matrix = np.zeros((len(passages), len(answers)))
    contradiction_matrix = np.zeros((len(passages), len(answers)))

    batch_size = 16
    # for i, pas in enumerate(tqdm.tqdm(passages)):
    for i, pas in enumerate(passages):
        for b in range(0, len(answers), batch_size):
            e = b + batch_size
            probs = get_nli_probabilities([pas] * len(answers[b:e]), answers[b:e])  # Get NLI probabilities
            entailment_matrix[i, b:e] = probs[:, 1]
            contradiction_matrix[i, b:e] = probs[:, 0]
    return entailment_matrix, contradiction_matrix

def calculate_scores_from_matrix(nli_matrix, score_type='entailment'):
    if nli_matrix.size == 0:
        return 0.0
    return np.round(np.mean(np.max(nli_matrix, axis=0)), 5)

def classify_obligations(sentences):
    inputs = obligation_tokenizer(sentences, padding=True, truncation=True, return_tensors='pt').to(device)
    with torch.no_grad():
        logits = obligation_model(**inputs).logits
    predictions = torch.argmax(logits, dim=1).detach().cpu().numpy()
    inputs.input_ids.detach()
    inputs.token_type_ids.detach()
    inputs.attention_mask.detach()
    return predictions

def calculate_obligation_coverage_score(passages, answers):
    # Filter obligation sentences from passages
    obligation_sentences_source = []
    for passage in passages:
        sentences = sent_tokenize(passage)
        is_obligation = classify_obligations(sentences)
        obligation_sentences_source.extend([sent for sent, label in zip(sentences, is_obligation) if label == 1])

    # Filter obligation sentences from answers
    obligation_sentences_answer = []
    for answer in answers:
        sentences = sent_tokenize(answer)
        is_obligation = classify_obligations(sentences)
        obligation_sentences_answer.extend([sent for sent, label in zip(sentences, is_obligation) if label == 1])

    # Calculate coverage based on NLI entailment
    covered_count = 0
    for obligation in obligation_sentences_source:
        for answer_sentence in obligation_sentences_answer:
            nli_result = coverage_nli_model(f"{answer_sentence} [SEP] {obligation}")
            if nli_result[0]['label'].lower() == 'entailment' and nli_result[0]['score'] > 0.7:
                covered_count += 1
                break

    # TODO: Should be else 1, if obligation_sentences_source is empty give max obligation score
    return covered_count / len(obligation_sentences_source) if obligation_sentences_source else 0


def calculate_final_composite_score(passages, answers):
    passage_sentences = [sent for passage in passages for sent in sent_tokenize(passage)]
    answer_sentences = [sent for answer in answers for sent in sent_tokenize(answer)]
    entailment_matrix, contradiction_matrix = get_nli_matrix(passage_sentences, answer_sentences)
    entailment_score = calculate_scores_from_matrix(entailment_matrix, 'entailment')
    contradiction_score = calculate_scores_from_matrix(contradiction_matrix, 'contradiction')
    obligation_coverage_score = calculate_obligation_coverage_score(passages, answers)

    composite_score = (obligation_coverage_score + entailment_score - contradiction_score + 1) / 3
    return np.round(composite_score, 5), entailment_score, contradiction_score, obligation_coverage_score

In [None]:
def find_ngrams(input_list, n):
    if len(input_list) <= n:
        return input_list
    return [" ".join(e) for e in (zip(*[input_list[i:] for i in range(n)]))]

def get_nli_matrix_ngram(passages, answers, ngram_size):
    passages = find_ngrams(passages, ngram_size)
    answers = find_ngrams(answers, ngram_size)

    entailment_matrix = np.zeros((len(passages), len(answers)))
    contradiction_matrix = np.zeros((len(passages), len(answers)))

    batch_size = 16
    # for i, pas in enumerate(tqdm.tqdm(passages)):
    for i, pas in enumerate(passages):
        for b in range(0, len(answers), batch_size):
            e = b + batch_size
            probs = get_nli_probabilities([pas] * len(answers[b:e]), answers[b:e])  # Get NLI probabilities
            entailment_matrix[i, b:e] = probs[:, 1]
            contradiction_matrix[i, b:e] = probs[:, 0]
    return entailment_matrix, contradiction_matrix


def calculate_final_composite_score_ngram(passages, answers, ngram_size=3):
    passage_sentences = [sent for passage in passages for sent in sent_tokenize(passage)]
    answer_sentences = [sent for answer in answers for sent in sent_tokenize(answer)]
    entailment_matrix, contradiction_matrix = get_nli_matrix_ngram(passage_sentences, answer_sentences, ngram_size)
    entailment_score = calculate_scores_from_matrix(entailment_matrix, 'entailment')
    contradiction_score = calculate_scores_from_matrix(contradiction_matrix, 'contradiction')
    obligation_coverage_score = calculate_obligation_coverage_score(passages, answers)

    composite_score = (obligation_coverage_score + entailment_score - contradiction_score + 1) / 3
    return np.round(composite_score, 5), entailment_score, contradiction_score, obligation_coverage_score

In [40]:
output_dir = f'regnlp/repass-eval'
os.makedirs(output_dir, exist_ok=True)

def run_eval(test_data, result_file_name, ngram_size=0):

    # Define the paths for result files
    output_file_csv = os.path.join(output_dir, f'{result_file_name}.csv')
    output_file_txt = os.path.join(output_dir, f'{result_file_name}.txt')

    processed_question_ids = set()
    composite_scores = []
    entailment_scores = []
    contradiction_scores = []
    obligation_coverage_scores = []

    # Check if the output CSV file already exists and read processed QuestionIDs
    if os.path.exists(output_file_csv):
        with open(output_file_csv, 'r') as csvfile:
            reader = csv.DictReader(csvfile)
            for row in reader:
                processed_question_ids.add(row['QuestionID'])

    # with open(input_file_path, 'r') as file:
    #     test_data = json.load(file)

    # Open the CSV file for appending results
    with open(output_file_csv, 'a', newline='') as csvfile:
        writer = csv.writer(csvfile)
        if not processed_question_ids:
            # Write the header if the file is empty or new
            writer.writerow(['QuestionID', 'entailment_score', 'contradiction_score', 'obligation_coverage_score', 'composite_score'])

        for item in tqdm.tqdm(test_data):
            question_id = item['QuestionID']

            # Skip if the QuestionID has already been processed
            if question_id in processed_question_ids:
                print(f"Skipping QuestionID {question_id}, already processed.")
                continue

            # Skip if the "Answer" is null or empty
            if not item.get('Answer') or not item['Answer'].strip():
                print(f"Skipping QuestionID {question_id}, no answer.")
                continue

            # Merge "RetrievedPassages" if it's a list
            if isinstance(item['RetrievedPassages'], list):
                item['RetrievedPassages'] = " ".join(item['RetrievedPassages'])

            passages = [item['RetrievedPassages']]
            answers = [item['Answer']]
            composite_score, entailment_score, contradiction_score, obligation_coverage_score = 0.0, 0.0, 1.0, 0.0
            try:
                with torch.no_grad():
                    if ngram_size > 0:
                        composite_score, entailment_score, contradiction_score, obligation_coverage_score = calculate_final_composite_score(passages, answers, ngram_size)
                    else:
                        composite_score, entailment_score, contradiction_score, obligation_coverage_score = calculate_final_composite_score(passages, answers)
            except Exception as e:
                print(f"Error processing QuestionID {question_id}: {e}")
                traceback.print_exc()
                raise e

            # Append the scores to the lists
            composite_scores.append(composite_score)
            entailment_scores.append(entailment_score)
            contradiction_scores.append(contradiction_score)
            obligation_coverage_scores.append(obligation_coverage_score)

            # Write the result to the CSV file
            writer.writerow([question_id, entailment_score, contradiction_score, obligation_coverage_score, composite_score])

    # Calculate averages
    avg_entailment = np.mean(entailment_scores)
    avg_contradiction = np.mean(contradiction_scores)
    avg_obligation_coverage = np.mean(obligation_coverage_scores)
    avg_composite = np.mean(composite_scores)

    # Print and save results to a text file
    results = (
        f"Average Entailment Score: {avg_entailment}\n"
        f"Average Contradiction Score: {avg_contradiction}\n"
        f"Average Obligation Coverage Score: {avg_obligation_coverage}\n"
        f"Average Final Composite Score: {avg_composite}\n"
    )

    print(results)

    with open(output_file_txt, 'w') as txtfile:
        txtfile.write(results)

    print(f"Processing complete. Results saved to {output_dir}")

# Experimentation

In [22]:
# ObliQADataset test data path
df = pd.read_json("regnlp/ObliQADataset/ObliQA_test.json")
df

Unnamed: 0,QuestionID,Question,Passages,Group
0,777e7a14-fea3-4c37-a0e6-9ffb50024d5c,Can the ADGM provide clarity on the level of d...,"[{'DocumentID': 1, 'PassageID': '14.2.3.Guidan...",2
1,0eb99ea8-3810-492c-9986-7739006b5708,Are there any exceptions or specific circumsta...,"[{'DocumentID': 19, 'PassageID': '100)', 'Pass...",2
2,d34e3516-f053-4652-a0ac-ede703144b9a,What type of procedures must a Third Party Pro...,"[{'DocumentID': 3, 'PassageID': '20.14.1.(2)',...",1
3,6d876d32-7557-4149-875e-8c66f13f3485,Are there any exceptions or specific circumsta...,"[{'DocumentID': 13, 'PassageID': '4.15.16', 'P...",3
4,2efd28f4-8677-4f05-82cd-d9989fb72409,What specific areas of inventory and delivery ...,"[{'DocumentID': 34, 'PassageID': '35)', 'Passa...",3
...,...,...,...,...
2781,b7f623dc-3b61-447f-b481-d47af4115d07,Your petroleum reporting entity is finalizing ...,"[{'DocumentID': 31, 'PassageID': '10)', 'Passa...",1
2782,cd8b132c-30b0-4322-bd65-e685a5237ad9,Can the regulatory authority provide guidance ...,"[{'DocumentID': 33, 'PassageID': '113)', 'Pass...",3
2783,eb183134-d0be-4a3c-a935-7a39e9831fef,What should the detailed limit structure for a...,"[{'DocumentID': 13, 'PassageID': 'APP4.A4.1.Gu...",1
2784,c82f4441-bd41-4587-b59c-de77576370ad,In cases where an Authorised Person's internal...,"[{'DocumentID': 13, 'PassageID': 'APP6.A6.9.2....",10


In [23]:
test_df_index = df.set_index("QuestionID")

In [24]:
import json
import glob

def load_json_files_from_directory(directory_path):
  """Loads all JSON files from a given directory into a list of JSON objects."""
  json_files = glob.glob(directory_path + "/*.json")
  json_data_list = []
  for json_file in json_files:
    with open(json_file, 'r') as f:
      try:
        json_data = json.load(f) 
        json_data_list.append((json_file, json_data))
      except json.JSONDecodeError as e:
        print(f"Error decoding JSON in file {json_file}: {e}")
  return json_data_list

# Path of structured json files provided with ObliQADataset
directory_path = "regnlp/ObliQADataset/StructuredRegulatoryDocuments" 
json_data_list = load_json_files_from_directory(directory_path)

In [25]:
flattened_json_data_list = []

for json_file, json_data in json_data_list:
        flattened_json_data_list.extend(json_data)
len(flattened_json_data_list)

13732

In [26]:
flattened_json_data_dict = {}

for item in flattened_json_data_list:
    flattened_json_data_dict[item["ID"]] = item

In [27]:
flattened_json_data_dict_dp_id = {f'{passage["DocumentID"]}:{passage["PassageID"]}'.replace(' ', '_'): passage for passage in flattened_json_data_list}

## Concatatenation

### Ground Truth
If concated ground truth passages are passed as the answer

At the time we were not aware RePASs takes context passages as input not ground truth passages

In [11]:
df.iloc[0]["Passages"]

[{'DocumentID': 1,
  'PassageID': '14.2.3.Guidance.10.',
  'Passage': 'Relevant Persons should comply with guidance issued by the EOCN with regard to identifying and reporting suspicious activity and Transactions relating to money laundering, terrorist financing and proliferation financing.'}]

In [54]:
test_data = []

for index, row in df.iterrows():
    test_data.append({
        "QuestionID": row['QuestionID'],
        "RetrievedPassages": [element["Passage"] for element in row['Passages']],
        "Answer": "\n".join([element["Passage"] for element in row['Passages']])
    })

test_data[:10]

[{'QuestionID': '777e7a14-fea3-4c37-a0e6-9ffb50024d5c',
  'RetrievedPassages': ['Relevant Persons should comply with guidance issued by the EOCN with regard to identifying and reporting suspicious activity and Transactions relating to money laundering, terrorist financing and proliferation financing.'],
  'Answer': 'Relevant Persons should comply with guidance issued by the EOCN with regard to identifying and reporting suspicious activity and Transactions relating to money laundering, terrorist financing and proliferation financing.'},
 {'QuestionID': '0eb99ea8-3810-492c-9986-7739006b5708',
  'RetrievedPassages': ['REGULATORY REQUIREMENTS FOR AUTHORISED PERSONS ENGAGED IN REGULATED ACTIVITIES IN RELATION TO VIRTUAL ASSETS\nMarket Abuse, Transaction Reporting and Misleading Impressions (FSMR)\nSimilar to the reporting requirements imposed on Recognised Investment Exchanges and MTFs in relation to Financial Instruments, MTFs (pursuant to FSMR Section 149) are required to report details o

In [None]:
result_file_name = "results-concat"

run_eval(test_data, result_file_name)

### BM25

In [11]:
# trec file (see retriever for more detail) for BM25
ret_df = pd.read_csv("regnlp/ObliQADataset/bm25-rankings-topk-100.trec", sep=" ", names=["QuestionID", "IterId", "PassageId", "RetrievalPos", "Score", "ModelType"])
ret_df

Unnamed: 0,QuestionID,IterId,PassageId,RetrievalPos,Score,ModelType
0,777e7a14-fea3-4c37-a0e6-9ffb50024d5c,0,91cd8922-2b83-43f1-b258-40ea02eecce8,1,11.3694,bm25
1,777e7a14-fea3-4c37-a0e6-9ffb50024d5c,0,a068d4e0-2329-407f-8fa3-06bf38c0a3f5,2,10.6749,bm25
2,777e7a14-fea3-4c37-a0e6-9ffb50024d5c,0,fe6b58fc-14fb-46e4-a790-902c6dae6498,3,10.4845,bm25
3,777e7a14-fea3-4c37-a0e6-9ffb50024d5c,0,8f2d6ed9-f3a0-4c87-9abc-93c720355393,4,10.4292,bm25
4,777e7a14-fea3-4c37-a0e6-9ffb50024d5c,0,e246e133-be84-43b6-9018-08d6dd75dd9e,5,9.7701,bm25
...,...,...,...,...,...,...
278595,235c1a96-e7b2-4812-bd48-4fcc4d4f4202,0,6f1e9b26-ccce-40ff-9384-6b5ebfcc4ed9,96,8.6269,bm25
278596,235c1a96-e7b2-4812-bd48-4fcc4d4f4202,0,594d1d7c-01ab-4505-9df6-ae31218e6469,97,8.6233,bm25
278597,235c1a96-e7b2-4812-bd48-4fcc4d4f4202,0,715e2534-e356-4225-a409-dd70f18bffa6,98,8.5898,bm25
278598,235c1a96-e7b2-4812-bd48-4fcc4d4f4202,0,ababc144-e3e9-4d05-89c3-c882544fec2e,99,8.5797,bm25


In [12]:
ret_df = ret_df[ret_df["RetrievalPos"] <= 2]
ret_df["Passage"] = ret_df["PassageId"].apply(lambda x: flattened_json_data_dict[x])
ret_df["Question"] = ret_df["QuestionID"].apply(lambda x: test_df_index.loc[x]["Question"])

ret_df_grouped = ret_df.groupby(["QuestionID", "Question"])["Passage"].apply(list).reset_index()
ret_df_grouped["GroundTruth"] = ret_df_grouped["QuestionID"].apply(lambda x: test_df_index.loc[x]["Passages"])

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  ret_df["Passage"] = ret_df["PassageId"].apply(lambda x: flattened_json_data_dict[x])
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  ret_df["Question"] = ret_df["QuestionID"].apply(lambda x: test_df_index.loc[x]["Question"])


In [13]:
test_data = []

for index, row in ret_df_grouped.iterrows():
    test_data.append({
        "QuestionID": row['QuestionID'],
        "RetrievedPassages": [element["Passage"] for element in row['GroundTruth']],
        "Answer": "\n\n".join([element["Passage"] for element in row['Passage']])
    })

test_data[:10]

[{'QuestionID': '0048b1b6-e739-49dc-91cb-3ec9061276d0',
  'RetrievedPassages': ['A Person, referred to in this chapter as an applicant, who intends to carry on the Regulated Activity of Operating a Representative Office must apply to the Regulator for a Financial Services Permission in such form as the Regulator shall prescribe.'],
  'Answer': 'A Representative Office seeking to have its Financial Services Permission withdrawn must submit a request in writing stating:\n(1)\tthe reasons for the request;\n(2)\tthat it has ceased or will cease to carry on the Regulated Activity of Operating a Representative Office in or from the ADGM; and\n(3)\tthe date on which it ceased or will cease to carry on the Regulated Activity of Operating a Representative Office in or from the ADGM.\n\nBecause of the limited nature of the Regulated Activity of Operating a Representative Office, much of the ADGM Rulebook has been disapplied for Representative Offices. While most of the key provisions applying to

In [None]:
result_file_name = "results-concat-bm25-top2-2newline"

run_eval(test_data, result_file_name)


Average Entailment Score: 0.3921376288659794
Average Contradiction Score: 0.1243821207658321
Average Obligation Coverage Score: 0.6000736377025038
Average Final Composite Score: 0.6226098085419736


: 

### BGE-EN-ICL 5 shot single passages

In [44]:
"""
Zero shot examples used - 

[{'instruct': 'Given a web search query, retrieve relevant passages that answer the query.',
  'query': 'What are the expectations of the ADGM regarding the level of access to information about numbered accounts and their holders that must be granted to staff involved in Anti-Money Laundering (AML) and Counter-Terrorist Financing (CTF) functions?',
  'response': 'If a Relevant Person uses a numbered account with an abbreviated name, it must ensure that:\n(a)\tsuch an account is used only for internal purposes;\n(b)\tit has undertaken the same CDD procedures in relation to the account holder as are required for other account holders;\n(c)\tit maintains the same information in relation to the account and account holder as is required for other accounts and account holders; and\n(d)\tstaff performing AML/TFS functions, including staff responsible for identifying and monitoring transactions for suspicious activity, and staff performing compliance and audit functions, have full access to information about the account and the account holder.'},
 {'instruct': 'Given a web search query, retrieve relevant passages that answer the query.',
  'query': "What are the ADGM's expectations for stress and scenario testing in liquidity risk management, and how should these tests be documented and reviewed?",
  'response': 'Liquidity risk. A Recognised Clearing House must:\n(a)\tdetermine the amount of its minimum liquid resources;\n(b)\tmaintain sufficient liquid resources to be able to effect same-day, intra-day or multi-day settlement, as applicable, of its payment obligations with a high degree of confidence under a wide range of potential stress scenarios;\n(c)\tensure that all resources held for the purposes of meeting its minimum liquid resource requirement are available when needed;\n(d)\thave a well-documented rationale to support the amount and form of total liquid resources it maintains for the purposes of \u200e(b) and \u200e(c); and\n(e)\thave appropriate arrangements in order to be able to maintain, on an on-going basis, such amount and form of its total liquid resources.'},
 {'instruct': 'Given a web search query, retrieve relevant passages that answer the query.',
  'query': 'Could the FSRA outline the criteria that determine whether an Authorised Person\'s data processing activities qualify as "adequate protection" under the Data Protection Regulations?',
  'response': 'SPECIFIC FSRA GUIDANCE ON THE VIRTUAL ASSET FRAMEWORK\nData protection obligations for Authorised Persons\nADGM’s data protection regime protects individuals’ right to privacy by controlling how personal information is used by organisations and businesses registered in ADGM.  All entities registered in ADGM that hold or process the personal data of an individual must protect personal data in compliance with the ADGM Data Protection Regulations 2015 (the “Data Protection Regulations”).  Specifically, an Authorised Person, as a data controller, will be responsible for determining the purposes for which, and the manner in which, personal data is processed in compliance with the Data Protection Regulations. Failure to do so risks enforcement action and compensation claims from individuals, each of which are considered data subjects under the Data Protection Regulations.\n'},
 {'instruct': 'Given a web search query, retrieve relevant passages that answer the query.',
  'query': "When evaluating a Non Abu Dhabi Global Market Firm's adherence to Threshold Conditions, what types of opinions may the Regulator consider from Non Abu Dhabi Global Market Regulators?",
  'response': 'In determining whether a Non Abu Dhabi Global Market Firm is satisfying or will satisfy, and continue to satisfy, any one or more of the Threshold Conditions, the Regulator may have regard to any opinion notified to it by a Non Abu Dhabi Global Market Regulator which relates to the Non Abu Dhabi Global Market Firm and appears to the Regulator to be relevant to compliance with those conditions.'},
 {'instruct': 'Given a web search query, retrieve relevant passages that answer the query.',
  'query': 'What constitutes "readily accessible by the market" for disclosures made outside of the annual report or periodic financial statements?',
  'response': 'An Authorised Person may disclose the items marked as quantitative in App12 in a medium or location other than its annual report or periodic financial statements, provided that:\n(a)\tit has prior written approval of the Regulator to do so;\n(b)\tthe annual report or periodic financial statements contain clear references to the location of such disclosures; and\n(c)\tsuch disclosures are readily accessible by the market.'}]
"""
ret_df = pd.read_csv("regnlp/ObliQADataset/bge-en-icl-5-shot-single-5-rankings-exact-topk-100.trec", sep=" ", names=["QuestionID", "IterId", "PassageId", "RetrievalPos", "Score", "ModelType"])
ret_df

Unnamed: 0,QuestionID,IterId,PassageId,RetrievalPos,Score,ModelType
0,777e7a14-fea3-4c37-a0e6-9ffb50024d5c,0,b626aec6-0610-4ad3-91e3-19f401aa3295,1,10,bge-en-icl
1,777e7a14-fea3-4c37-a0e6-9ffb50024d5c,0,91cd8922-2b83-43f1-b258-40ea02eecce8,2,10,bge-en-icl
2,777e7a14-fea3-4c37-a0e6-9ffb50024d5c,0,4d0bf908-0c5b-44b5-b7c6-29de0659158f,3,10,bge-en-icl
3,777e7a14-fea3-4c37-a0e6-9ffb50024d5c,0,cb2b9e90-74f0-4f59-8727-0596db814a53,4,10,bge-en-icl
4,777e7a14-fea3-4c37-a0e6-9ffb50024d5c,0,ca75fba7-00e1-4256-84b4-2d78c322102f,5,10,bge-en-icl
...,...,...,...,...,...,...
278595,235c1a96-e7b2-4812-bd48-4fcc4d4f4202,0,c8afc13a-f574-42ed-be48-f919bcd31e12,96,10,bge-en-icl
278596,235c1a96-e7b2-4812-bd48-4fcc4d4f4202,0,9eba26da-ae52-47b9-a5fa-fa3aab05752e,97,10,bge-en-icl
278597,235c1a96-e7b2-4812-bd48-4fcc4d4f4202,0,96e3ea38-01f8-4ddb-9ce7-d89177c97e38,98,10,bge-en-icl
278598,235c1a96-e7b2-4812-bd48-4fcc4d4f4202,0,441d4e0f-4330-48d0-9488-0775dc0ad3a2,99,10,bge-en-icl


In [45]:
ret_df = ret_df[ret_df["RetrievalPos"] <= 2]
ret_df["Passage"] = ret_df["PassageId"].apply(lambda x: flattened_json_data_dict[x])
ret_df["Question"] = ret_df["QuestionID"].apply(lambda x: test_df_index.loc[x]["Question"])

ret_df_grouped = ret_df.groupby(["QuestionID", "Question"])["Passage"].apply(list).reset_index()
ret_df_grouped["GroundTruth"] = ret_df_grouped["QuestionID"].apply(lambda x: test_df_index.loc[x]["Passages"])

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  ret_df["Passage"] = ret_df["PassageId"].apply(lambda x: flattened_json_data_dict[x])
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  ret_df["Question"] = ret_df["QuestionID"].apply(lambda x: test_df_index.loc[x]["Question"])


In [46]:
test_data = []

for index, row in ret_df_grouped.iterrows():
    test_data.append({
        "QuestionID": row['QuestionID'],
        "RetrievedPassages": [element["Passage"] for element in row['GroundTruth']],
        "Answer": "\n\n".join([element["Passage"] for element in row['Passage']])
    })

test_data[:10]

[{'QuestionID': '0048b1b6-e739-49dc-91cb-3ec9061276d0',
  'RetrievedPassages': ['A Person, referred to in this chapter as an applicant, who intends to carry on the Regulated Activity of Operating a Representative Office must apply to the Regulator for a Financial Services Permission in such form as the Regulator shall prescribe.'],
  'Answer': 'An applicant will only be authorised to carry on the Regulated Activity of Operating a Representative Office if the Regulator is satisfied that the applicant is fit and proper to hold a Financial Services Permission. In making this assessment the Regulator may consider:\n(1)\twhether the applicant is subject to supervision by a Non-ADGM Financial Services Regulator;\n(2)\twhether the applicant’s Non-ADGM Financial Services Regulator in its home state has been made aware of the proposed application and has expressed itself as having no objection to the establishment by the applicant of a Representative Office in the ADGM; and\n(3)\tany other rele

In [None]:
result_file_name = "results-concat-bge-en-icl-5-shot-single-5-top2-2newline"

run_eval(test_data, result_file_name)

## Same sentence

In [29]:
sent_tokenize_uncached(df.iloc[0]["Passages"][0]["Passage"])

['Relevant Persons should comply with guidance issued by the EOCN with regard to identifying and reporting suspicious activity and Transactions relating to money laundering, terrorist financing and proliferation financing.']

In [31]:
sent_set = set()
for element in df.iloc[10]["Passages"]:
    sent_set.update(sent_tokenize_uncached(element["Passage"]))
sent_set

{'A Listed Entity must, on the occurrence of an event specified in column 1, undertake the requirements detailed in column 2, within the time specified in column 3, in respect of the Securities identified with a "\uf0fc" in column 4, of this Table.',
 'A Recognised Body must ensure that appropriate procedures are adopted for it to make rules, for keeping its rules under review and for amending them.',
 'APP 1.A.2.2.1\t/Table Start\nEVENT \tREQUIREMENTS\tTIME\tStructured Products \tShares\tWarrants over Shares\tWarrants over Debentures\tDebentures\tCertificates\tUnits\nShares\tDebentures\nREGISTRATION\n1.',
 "Any proposed decision with regard to any change in its board of directors or Shari'a Supervisory Board\tConsult with the Regulator\tIn advance\t\uf0fc\t\uf0fc\t\uf0fc\t\uf0fc\t\uf0fc\t\uf0fc\t\uf0fc\t\uf0fc\n/Table End",
 "Appointment of an independent Shari'a Supervisory Board to evaluate the Shari'a compliance of the Islamic equity Securities on an annual basis\tNotify the Regula

In [None]:
tokenizer = nltk.load(f"tokenizers/punkt/english.pickle")

In [53]:
sent_tokenize_uncached(" ".join([re.sub(tokenizer._lang_vars._re_sent_end_chars, "", sent).lower() for sent in sent_set]))

['any proposed decision with regard to any change in its board of directors or shari\'a supervisory board\tconsult with the regulator\tin advance\t\uf0fc\t\uf0fc\t\uf0fc\t\uf0fc\t\uf0fc\t\uf0fc\t\uf0fc\t\uf0fc\n/table end a recognised body must ensure that appropriate procedures are adopted for it to make rules, for keeping its rules under review and for amending them appointment of an independent shari\'a supervisory board to evaluate the shari\'a compliance of the islamic equity securities on an annual basis\tnotify the regulator\tannually\t\t\uf0fc\t\uf0fc\t\t\t\uf0fc\n2 a listed entity must, on the occurrence of an event specified in column 1, undertake the requirements detailed in column 2, within the time specified in column 3, in respect of the securities identified with a "\uf0fc" in column 4, of this table the procedures must include the arrangements for:\n(a)\ttaking decisions about making and amending business rules, clearing rules and default rules, including the level at w

In [55]:
test_data = []

for index, row in df.iterrows():
    sent_set = set()
    for element in row["Passages"]:
        sent_set.update(sent_tokenize_uncached(element["Passage"]))
    test_data.append({
        "QuestionID": row['QuestionID'],
        "RetrievedPassages": [element["Passage"] for element in row['Passages']],
        "Answer": " ".join([re.sub(tokenizer._lang_vars._re_sent_end_chars, "", sent).lower() for sent in sent_set])
    })

test_data[:10]

[{'QuestionID': '777e7a14-fea3-4c37-a0e6-9ffb50024d5c',
  'RetrievedPassages': ['Relevant Persons should comply with guidance issued by the EOCN with regard to identifying and reporting suspicious activity and Transactions relating to money laundering, terrorist financing and proliferation financing.'],
  'Answer': 'relevant persons should comply with guidance issued by the eocn with regard to identifying and reporting suspicious activity and transactions relating to money laundering, terrorist financing and proliferation financing'},
 {'QuestionID': '0eb99ea8-3810-492c-9986-7739006b5708',
  'RetrievedPassages': ['REGULATORY REQUIREMENTS FOR AUTHORISED PERSONS ENGAGED IN REGULATED ACTIVITIES IN RELATION TO VIRTUAL ASSETS\nMarket Abuse, Transaction Reporting and Misleading Impressions (FSMR)\nSimilar to the reporting requirements imposed on Recognised Investment Exchanges and MTFs in relation to Financial Instruments, MTFs (pursuant to FSMR Section 149) are required to report details of

In [None]:
result_file_name = "results-same-line"

run_eval(test_data, result_file_name)

# LLM for answer generations

In [16]:
# You will have to replace localhost:11434 with your ollama hosted instance
!curl http://localhost:11434/api/generate -d '{"model": "llama3.1:8b-instruct-q4_K_M", "prompt": "Why is the sky blue?", "stream": false}'

{"model":"llama3.1:8b-instruct-q4_K_M","created_at":"2024-11-30T20:34:39.117375049Z","response":"The sky appears blue because of a phenomenon called Rayleigh scattering, named after the British physicist Lord Rayleigh, who first described it in the late 19th century. Here's what happens:\n\n1. **Sunlight enters Earth's atmosphere**: When sunlight enters our atmosphere, it encounters tiny molecules of gases such as nitrogen (N2) and oxygen (O2).\n\n2. **Scattering occurs**: These gas molecules are much smaller than the wavelength of light. As a result, when sunlight hits them, they scatter the light in all directions.\n\n3. **Shorter wavelengths scattered more**: However, shorter wavelengths like blue and violet are scattered more than longer wavelengths like red and orange. This is because the smaller molecules can only interact with the shorter wavelengths effectively due to their size relative to the wavelength of light.\n\n4. **Blue light dominates**: Due to this scattering effect, 

In [32]:
from langchain_ollama.llms import OllamaLLM

model = OllamaLLM(base_url="http://localhost:11434", model="llama3.1:8b-instruct-q4_K_M", num_ctx=2048, num_predict=512)

## rlm/rag-promp

In [None]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_ollama.llms import OllamaLLM
from langchain import hub

rag_prompt = hub.pull("rlm/rag-prompt")

print("*" * 20 + "Prompt[rlm/rag-prompt]" + "*" * 20)
rag_prompt.pretty_print()

In [None]:
def crawl(row):
    answer = chain.invoke({"question": row["Question"], "context": "\n---\n".join([element["Passage"] for element in row["Passages"]])})
    return {
        "QuestionID": row['QuestionID'],
        "RetrievedPassages": [element["Passage"] for element in row['Passages']],
        "Answer": answer
    }

In [32]:
crawl(df.iloc[10])

{'QuestionID': '9bd38f26-b6ac-4062-b076-500c38063b1f',
 'RetrievedPassages': ['A Listed Entity must, on the occurrence of an event specified in column 1, undertake the requirements detailed in column 2, within the time specified in column 3, in respect of the Securities identified with a "\uf0fc" in column 4, of this Table.\n\nAPP 1.A.2.2.1\t/Table Start\nEVENT \tREQUIREMENTS\tTIME\tStructured Products \tShares\tWarrants over Shares\tWarrants over Debentures\tDebentures\tCertificates\tUnits\nShares\tDebentures\nREGISTRATION\n1.\tAppointment of an independent Shari\'a Supervisory Board to evaluate the Shari\'a compliance of the Islamic equity Securities on an annual basis\tNotify the Regulator\tAnnually\t\t\uf0fc\t\uf0fc\t\t\t\uf0fc\n2.\tAny proposed decision with regard to any change in its board of directors or Shari\'a Supervisory Board\tConsult with the Regulator\tIn advance\t\uf0fc\t\uf0fc\t\uf0fc\t\uf0fc\t\uf0fc\t\uf0fc\t\uf0fc\t\uf0fc\n/Table End',
  'A Recognised Body must ensur

In [33]:
from multiprocessing.pool import ThreadPool

test_data = []
with ThreadPool(2) as p:
    l = [row for _, row in df.iterrows()]
    test_data = list(tqdm.tqdm(p.imap(crawl, l), total=len(l)))

test_data[:10]

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2786/2786 [1:17:39<00:00,  1.67s/it]


[{'QuestionID': '777e7a14-fea3-4c37-a0e6-9ffb50024d5c',
  'RetrievedPassages': ['Relevant Persons should comply with guidance issued by the EOCN with regard to identifying and reporting suspicious activity and Transactions relating to money laundering, terrorist financing and proliferation financing.'],
  'Answer': "The ADGM doesn't provide explicit details on the level of documentation required for a report of suspicious activity. However, Relevant Persons are expected to comply with guidance issued by the EOCN regarding identifying and reporting suspicious activities. This implies that the EOCN's guidance should be consulted for clarity on regulatory standards."},
 {'QuestionID': '0eb99ea8-3810-492c-9986-7739006b5708',
  'RetrievedPassages': ['REGULATORY REQUIREMENTS FOR AUTHORISED PERSONS ENGAGED IN REGULATED ACTIVITIES IN RELATION TO VIRTUAL ASSETS\nMarket Abuse, Transaction Reporting and Misleading Impressions (FSMR)\nSimilar to the reporting requirements imposed on Recognised Inv

In [None]:
# save the test_data
with open(f"{output_dir}/rag_llama3_1_8b_test_data.jsonl", "w") as f:
    for item in test_data:
        f.write(json.dumps(item) + "\n")

In [None]:
result_file_name = "results-rag-llama3_1_8b"

run_eval(test_data, result_file_name)

## rlm/rag-promp without concise and don't answer part

In [9]:
from langchain_core.prompts import PromptTemplate

prompt = PromptTemplate(
        template="""You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question.
Question: {question} 
Context: {context} 
Answer:""",
        input_variables=["context", "question"],
    )

In [13]:
chain = prompt | model

In [14]:
def crawl(row):
    answer = chain.invoke({"question": row["Question"], "context": "\n---\n".join([element["Passage"] for element in row["Passages"]])})
    return {
        "QuestionID": row['QuestionID'],
        "RetrievedPassages": [element["Passage"] for element in row['Passages']],
        "Answer": answer
    }

In [15]:
crawl(df.iloc[10])

{'QuestionID': '9bd38f26-b6ac-4062-b076-500c38063b1f',
 'RetrievedPassages': ['A Listed Entity must, on the occurrence of an event specified in column 1, undertake the requirements detailed in column 2, within the time specified in column 3, in respect of the Securities identified with a "\uf0fc" in column 4, of this Table.\n\nAPP 1.A.2.2.1\t/Table Start\nEVENT \tREQUIREMENTS\tTIME\tStructured Products \tShares\tWarrants over Shares\tWarrants over Debentures\tDebentures\tCertificates\tUnits\nShares\tDebentures\nREGISTRATION\n1.\tAppointment of an independent Shari\'a Supervisory Board to evaluate the Shari\'a compliance of the Islamic equity Securities on an annual basis\tNotify the Regulator\tAnnually\t\t\uf0fc\t\uf0fc\t\t\t\uf0fc\n2.\tAny proposed decision with regard to any change in its board of directors or Shari\'a Supervisory Board\tConsult with the Regulator\tIn advance\t\uf0fc\t\uf0fc\t\uf0fc\t\uf0fc\t\uf0fc\t\uf0fc\t\uf0fc\t\uf0fc\n/Table End',
  'A Recognised Body must ensur

In [16]:
from multiprocessing.pool import ThreadPool

test_data = []
with ThreadPool(2) as p:
    l = [row for _, row in df.iterrows()]
    test_data = list(tqdm.tqdm(p.imap(crawl, l), total=len(l)))

test_data[:10]

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2786/2786 [2:33:46<00:00,  3.31s/it]


[{'QuestionID': '777e7a14-fea3-4c37-a0e6-9ffb50024d5c',
  'RetrievedPassages': ['Relevant Persons should comply with guidance issued by the EOCN with regard to identifying and reporting suspicious activity and Transactions relating to money laundering, terrorist financing and proliferation financing.'],
  'Answer': "According to the provided context, the Abu Dhabi Global Market (ADGM) can likely provide clarity on the level of detail and documentation required for a report of suspicious activity. This is because Relevant Persons are expected to comply with guidance issued by the Emirate of Dubai's Economic Office Committee for Non-Resident Affairs (EOCN), which suggests that clear guidelines exist for reporting suspicious activities. However, it is advisable to consult directly with the ADGM or the relevant regulatory bodies within the Emirate of Dubai for the most accurate and up-to-date information on regulatory standards."},
 {'QuestionID': '0eb99ea8-3810-492c-9986-7739006b5708',
  

In [None]:
result_file_name = "results-rag-llama3_1_8b-prompt-without"

run_eval(test_data, result_file_name)

## RegNLP paper prompt

In [33]:
from langchain_core.prompts import PromptTemplate

prompt = PromptTemplate(
        template="""You are a regulatory compliance assistant. Provide a detailed answer for the question that fully integrates all the obligations and best practices from the given passages. Ensure your response is cohesive and directly addresses the question. Synthesize the information from all passages into a single, unified answer.
question: {question} 
passages: {context} 
answer:""",
        input_variables=["context", "question"],
    )

In [34]:
chain = prompt | model

In [21]:
# Llama generated answers
ret_df = pd.read_json("regnlp/ObliQADataset/RIRAGSharedTask/yay-2.json")
ret_df

Unnamed: 0,QuestionID,RetrievedPassages,Answer
0,0048b1b6-e739-49dc-91cb-3ec9061276d0,[An application to Operate a Representative Of...,An application to Operate a Representative Off...
1,005042ec-464a-426d-8c61-066360ea73d2,[For the purposes of determining the net prese...,For the purposes of determining the net presen...
2,0076d200-5221-48d7-ad69-84766de8827a,[Principle 2 – High Standards for Authorisatio...,Principle 2 – High Standards for Authorisation...
3,0083176b-36bd-4a0b-b891-ba348b814227,[Compliance risk . The Senior Executive Office...,Compliance risk . The Senior Executive Officer...
4,0090be4b-b92a-4c48-a8d2-4d06861f9844,[Non-submission of ESG disclosures by in-scope...,Non-submission of ESG disclosures by in-scope ...
...,...,...,...
2781,ff96547f-63b3-4140-9aed-4abecbf601e0,[An Authorised Person with a Financial Service...,An Authorised Person with a Financial Services...
2782,ffb8ab74-8b2e-408e-82fc-e68a532f0067,[\nWhen determining whether it has satisfactor...,\nWhen determining whether it has satisfactory...
2783,ffc7adc3-a684-46a2-bfc1-fd93420dcc49,[REGULATORY REQUIREMENTS FOR AUTHORISED PERSON...,REGULATORY REQUIREMENTS FOR AUTHORISED PERSONS...
2784,ffca564d-b579-4af5-b2b8-4367004a0957,[Guidance on risks to be covered as part of th...,Guidance on risks to be covered as part of the...


In [28]:
ret_df["Question"] = ret_df["QuestionID"].apply(lambda x: test_df_index.loc[x]["Question"])
ret_df

Unnamed: 0,QuestionID,RetrievedPassages,Answer,Question
0,0048b1b6-e739-49dc-91cb-3ec9061276d0,[An application to Operate a Representative Of...,An application to Operate a Representative Off...,To whom should an applicant submit their appli...
1,005042ec-464a-426d-8c61-066360ea73d2,[For the purposes of determining the net prese...,For the purposes of determining the net presen...,Could you please provide detailed guidance on ...
2,0076d200-5221-48d7-ad69-84766de8827a,[Principle 2 – High Standards for Authorisatio...,Principle 2 – High Standards for Authorisation...,In the context of Principle 2 – High Standards...
3,0083176b-36bd-4a0b-b891-ba348b814227,[Compliance risk . The Senior Executive Office...,Compliance risk . The Senior Executive Officer...,Can the ADGM provide a detailed outline of the...
4,0090be4b-b92a-4c48-a8d2-4d06861f9844,[Non-submission of ESG disclosures by in-scope...,Non-submission of ESG disclosures by in-scope ...,Can the ADGM provide guidance on best practice...
...,...,...,...,...
2781,ff96547f-63b3-4140-9aed-4abecbf601e0,[An Authorised Person with a Financial Service...,An Authorised Person with a Financial Services...,Can an Authorised Person expect to pay more th...
2782,ffb8ab74-8b2e-408e-82fc-e68a532f0067,[\nWhen determining whether it has satisfactor...,\nWhen determining whether it has satisfactory...,Could you provide clarification on the process...
2783,ffc7adc3-a684-46a2-bfc1-fd93420dcc49,[REGULATORY REQUIREMENTS FOR AUTHORISED PERSON...,REGULATORY REQUIREMENTS FOR AUTHORISED PERSONS...,In the case of a conflict between the inherent...
2784,ffca564d-b579-4af5-b2b8-4367004a0957,[Guidance on risks to be covered as part of th...,Guidance on risks to be covered as part of the...,Can you elaborate on what constitutes adequate...


In [35]:
def crawl(row):
    # print(row)
    answer = chain.invoke({"question": row["Question"], "context": "\n---\n".join(row["RetrievedPassages"])})
    return {
        "QuestionID": row['QuestionID'],
        "RetrievedPassages": row['RetrievedPassages'],
        "Answer": answer
    }

In [36]:
crawl(ret_df.iloc[10])

{'QuestionID': '0175a778-59d5-467a-a74a-899fde4f6b2f',
 'RetrievedPassages': ['PROSPECTUS DISCLOSURE. Importantly, Rule 11.3.1(1) requires that in addition to complying with Chapter 4 of the Rules, a Prospectus that includes a statement about Exploration Targets, Exploration Results, Mineral Resources, Ore Reserves or Production Targets must comply with Rule 11.2.1 (in that such Prospectus must also (in terms of the disclosures made within such Prospectus) comply with a Mining Reporting Standard and Chapter 11 of the Rules.\n',
  '\nRule 11.2.1 applies to all disclosures made or required to be made under the Rules which include a statement about Exploration Targets, Exploration Results, Mineral Resources, Ore Reserves or Production Targets, including within a Prospectus, Exempt Offer document, bidder’s and target’s statements, annual reports, financial statements, technical papers, presentations, and website content and disclosures.\nIn order to ensure consistency of its disclosures, I

In [37]:
from multiprocessing.pool import ThreadPool

test_data = []
with ThreadPool(2) as p:
    l = [row for _, row in ret_df.iterrows()]
    test_data = list(tqdm.tqdm(p.imap(crawl, l), total=len(l)))

test_data[:10]

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2786/2786 [8:49:18<00:00, 11.40s/it]


[{'QuestionID': '0048b1b6-e739-49dc-91cb-3ec9061276d0',
  'RetrievedPassages': ['An application to Operate a Representative Office may only be made by a Person who is:\n(a)\tincorporated; and\n(b)\tregulated by a Non-ADGM Financial Services Regulator\nin a jurisdiction other than the ADGM.',
   'A Person, referred to in this chapter as an applicant, who intends to carry on the Regulated Activity of Operating a Representative Office must apply to the Regulator for a Financial Services Permission in such form as the Regulator shall prescribe.'],
  'Answer': 'To whom should an applicant submit their application for permission to operate a Representative Office?\n\nBased on the provided passages, it is clear that only a Person who meets specific criteria can submit an application to operate a Representative Office. The Person must be:\n\n1. Incorporated\n2. Regulated by a Non-ADGM Financial Services Regulator in a jurisdiction other than the ADGM.\n\nGiven these requirements, an applicant 

In [41]:
# dump jsonl
with open(f"{output_dir}/rag_llama3_1_8b_reglp_paper_prompt_test_data_score_filter.jsonl", "w") as f:
    for item in test_data:
        f.write(json.dumps(item) + "\n")

In [None]:
result_file_name = "results-rag-llama3_1_8b-regnlp-paper-prompt-retrieved-score-filter"

run_eval(test_data, result_file_name)

# Bruteforce Programtic Optimisation of RePASs

## Ground Truth

In [None]:
test_data = []

for index, row in tqdm.tqdm(list(df.iterrows())):
    passages = [element["Passage"] for element in row['Passages']]

    sent_set = set()
    for element in passages:
        sent_set.update(sent_tokenize_uncached(element))

    temp_answers = " ".join([sent for sent in sent_set])

    passage_boundaries = []

    i = 0
    for passage in passages:
        for sent in sent_tokenize(passage):
            i += 1
        passage_boundaries.append(i)

    passage_sentences = [sent for passage in passages for sent in sent_tokenize(passage)]
    answer_sentences = [sent for sent in sent_tokenize(temp_answers)]
    entailment_matrix, contradiction_matrix = get_nli_matrix(passage_sentences, answer_sentences)
    
    obligation_sentences_source = []
    source_obligation_count = 0
    for passage in passage_sentences:
        sentences = sent_tokenize(passage)
        is_obligation = classify_obligations(sentences)
        sentence_labels = [(sent, label) for sent, label in zip(sentences, is_obligation)]
        obligation_sentences_source.extend(sentence_labels)
        source_obligation_count += sum(is_obligation)

    obligation_sentences_indexes = [i for i, (sent, label) in enumerate(obligation_sentences_source) if label == 1]

    # Filter obligation sentences from answers
    obligation_sentences_answer = []
    for answer in answer_sentences:
        sentences = sent_tokenize(answer)
        is_obligation = classify_obligations(sentences)
        obligation_sentences_answer.extend([(sent, label) for sent, label in zip(sentences, is_obligation)])
    
    obligation_matrix = np.zeros((len(passage_sentences), len(answer_sentences)))
    obligation_mask = np.zeros((len(passage_sentences), len(answer_sentences)))

    for idx_source, (obligation, source_label) in enumerate(obligation_sentences_source):
        for idx_answer, (answer_sentence, answer_label) in enumerate(obligation_sentences_answer):
            if source_label == 1:
                obligation_mask[idx_source, idx_answer] = 1
                nli_result = coverage_nli_model(f"{answer_sentence} [SEP] {obligation}")
                if nli_result[0]['label'].lower() == 'entailment' and nli_result[0]['score'] > 0.7:
                    obligation_matrix[idx_source, idx_answer] = 1
    
    e_score = np.max(entailment_matrix, axis=0)
    c_score = np.max(contradiction_matrix, axis=0)

    obligation_matrix_w_ignore = obligation_matrix + (1 - obligation_mask)

    # Try all subsets of answer sentences

    best_score = 0
    best_subset = list(range(len(answer_sentences)))
    for i in range(1, len(answer_sentences) + 1):
        for subset in itertools.combinations(range(len(answer_sentences)), i):
            e = np.mean(e_score[list(subset)])
            c = np.mean(c_score[list(subset)]) + 1
            obl = np.mean(np.max(obligation_matrix[np.ix_(obligation_sentences_indexes, list(subset))], axis=1))
            # if no obligation sentence nan is returned, falling back to 0 in this case as done in metric's implementation
            if np.isnan(obl) is np.nan:
                obl = 0.0
            subset_score =  (e - c + obl + 1) / 3
            if subset_score > best_score:
                best_score = subset_score
                best_subset = list(subset)

    test_data.append({
        "QuestionID": row['QuestionID'],
        "RetrievedPassages": [element["Passage"] for element in row['Passages']],
        "Answer": "\n".join(np.array(list(sent_set))[best_subset])
    })

test_data[:10]

In [None]:
result_file_name = "results-bruteforce-optimised"

run_eval(test_data, result_file_name)