In [None]:
!pip install -r requirements.txt

In [None]:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

In [None]:
!pip install -U bitsandbytes

In [1]:
!pip install FlagEmbedding
!pip install imblearn

Collecting FlagEmbedding
  Downloading FlagEmbedding-1.3.5.tar.gz (163 kB)
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Collecting sentence_transformers (from FlagEmbedding)
  Downloading sentence_transformers-5.1.2-py3-none-any.whl.metadata (16 kB)
Collecting ir-datasets (from FlagEmbedding)
  Downloading ir_datasets-0.5.11-py3-none-any.whl.metadata (12 kB)
Collecting sentencepiece (from FlagEmbedding)
  Downloading sentencepiece-0.2.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (10 kB)
Collecting protobuf (from FlagEmbedding)
  Downloading protobuf-6.33.0-cp39-abi3-manylinux2014_x86_64.whl.metadata (593 bytes)
Collecting beautifulsoup4>=4.4.1 (from ir-datasets->FlagEmbedding)
  Downloading beautifulsoup4-4.14.2-py3-none-any.whl.metadata (3.8 kB)
Collecting inscriptis>=2.2.0 (from ir-datasets->FlagEmbedding)
  Downloading inscriptis-2.6.0-py3

In [5]:
!pip install matplotlib

Collecting matplotlib
  Using cached matplotlib-3.10.7-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (11 kB)
Collecting contourpy>=1.0.1 (from matplotlib)
  Using cached contourpy-1.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.5 kB)
Collecting cycler>=0.10 (from matplotlib)
  Using cached cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)
Collecting fonttools>=4.22.0 (from matplotlib)
  Using cached fonttools-4.60.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (112 kB)
Collecting kiwisolver>=1.3.1 (from matplotlib)
  Using cached kiwisolver-1.4.9-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (6.3 kB)
Collecting pyparsing>=3 (from matplotlib)
  Using cached pyparsing-3.2.5-py3-none-any.whl.metadata (5.0 kB)
Using cached matplotlib-3.10.7-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (8.7 MB)
Using cached contourpy-1.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3

In [6]:
import os
import pandas as pd
import numpy as np
import glob
import itertools
import pickle
from tqdm import tqdm
import torch

from FlagEmbedding import BGEM3FlagModel

from transformers import AutoModelForSequenceClassification, AutoTokenizer
from peft import PeftModel, PeftConfig

from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.utils import shuffle
from sklearn.metrics import recall_score
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from imblearn.ensemble import BalancedBaggingClassifier
from sklearn.metrics import fbeta_score, classification_report, precision_recall_fscore_support

import matplotlib.pyplot as plt

from scipy.special import softmax

from src.util import load_samples
from src.data.data_collator import LegalDataCollatorWithPadding

## Hyperparameters

In [7]:
EVAL_ID = "R04"
TEST_ID = "R05"

RAW_DATA_DIR = "data"
DATA_OUTPUT_DIR = "data/synthesys"
QUERY_PATH = os.path.join(RAW_DATA_DIR, "COLIEE2025statute_data-English/train")
ARTICLE_PATH = os.path.join(RAW_DATA_DIR, "full_en_civil_code_df_24.csv")

CHECKPOINT_DIR = "checkpoints"
STEP1_CHECKPOINT_DIR = f"{CHECKPOINT_DIR}/step1_bge_pre_retrieval"
STEP2_CHECKPOINT_DIR = f"{CHECKPOINT_DIR}/step2_rankllama_retrieval"


# TODO: fix bug
BUG_ARTICLE_POSTFIX = "(1)"  # In the R04's task 3 label, there are some ground truth labels having "(1)" postfix. We need to remove them.


# Step 1
BGE_TOP = 100
HISTOGRAM_N_POSITIVE_REPLICATES = 300


# Step 2
# TODO: review H18 to get the better threshold
RANKLLAMA_THRESHOLD = -3.5  # preserve about 50 candidates for each query

## Step 0: Create dataset

In [8]:
# 1. Load data
# Load the article data
en_article_df = pd.read_csv(ARTICLE_PATH)
en_article_df.rename(columns={"article": "article_id", "content": "article_content"}, inplace=True)

# Load the query data
query_files = glob.glob(f"{QUERY_PATH}/*.xml")

queries = []
for query_file in query_files:
    queries += load_samples(query_file)

en_query_df = pd.DataFrame(queries)
en_query_df = en_query_df.rename(columns={"index": "query_id",
                                          "content": "query_content",
                                          "result": "task3_label",
                                          "label": "task4_label"})

In [11]:
en_query_df[:2]

Unnamed: 0,task3_label,query_content,query_id,task4_label,task3_false_articles
0,"[15, 11]",The family court may decide to commence an ass...,R02-1-A,N,"[223, 218, 117, 237, 126, 275, 226, 491, 659, ..."
1,[15],The issuance of a decision for commencement of...,R02-1-I,N,"[223, 218, 117, 237, 126, 275, 226, 491, 659, ..."


In [9]:
# TODO: fix bug
en_query_df["task3_label"] = en_query_df["task3_label"].apply(lambda articles: [article.replace(BUG_ARTICLE_POSTFIX, "") for article in articles])

In [10]:
# 2. Get the out of label articles
def get_out_of_label(true_labels, articles):
    out_of_label = list(set(articles) - set(true_labels))
    return out_of_label


en_query_df["task3_false_articles"] = en_query_df.apply(lambda x: get_out_of_label(x["task3_label"], en_article_df["article_id"].values), axis=1)

## Step 1: BGE Pre-retrieval

### 1.1. BGE Embedding

In [12]:
model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True, device='cuda')

Fetching 30 files: 100%|██████████| 30/30 [00:55<00:00,  1.86s/it]


In [13]:
# article embedding
article_embeddings = model.encode(en_article_df["article_content"].tolist(),
                                  batch_size=32,
                                  max_length=1024
                                  )['dense_vecs']


# query embedding
query_embeddings = model.encode(en_query_df["query_content"].tolist(),
                                batch_size=32,
                                max_length=1024
                                )['dense_vecs']

pre tokenize: 100%|██████████| 24/24 [00:00<00:00, 177.28it/s]
You're using a XLMRobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Inference Embeddings: 100%|██████████| 24/24 [00:00<00:00, 30.01it/s]
pre tokenize: 100%|██████████| 38/38 [00:00<00:00, 213.39it/s]
Inference Embeddings: 100%|██████████| 38/38 [00:00<00:00, 66.46it/s]


In [16]:
### SAVE #####
import joblib

joblib.dump(article_embeddings, f"./{STEP1_CHECKPOINT_DIR}/article_embeddings.pkl")
joblib.dump(query_embeddings, f"./{STEP1_CHECKPOINT_DIR}/query_embeddings.pkl")


['./checkpoints/step1_bge_pre_retrieval/query_embeddings.pkl']

In [None]:
###LOAD####
article_embeddings = joblib.load("article_embeddings.pkl")
query_embeddings = joblib.load("query_embeddings.pkl")

In [17]:
article_embedding_dict = dict(zip(en_article_df["article_id"].tolist(), article_embeddings))
query_embedding_dict = dict(zip(en_query_df["query_id"].tolist(), query_embeddings))

### 1.2. Make X, y for training Histogram-based Gradient Boosting

In [20]:
train_query_df = en_query_df[
    (~en_query_df["query_id"].str.startswith(EVAL_ID)) & ~en_query_df["query_id"].str.startswith(TEST_ID)
]
eval_query_df = en_query_df[en_query_df["query_id"].str.startswith(EVAL_ID)]
test_query_df = en_query_df[en_query_df["query_id"].str.startswith(TEST_ID)]

In [22]:
train_query_df

Unnamed: 0,task3_label,query_content,query_id,task4_label,task3_false_articles
0,"[15, 11]",The family court may decide to commence an ass...,R02-1-A,N,"[223, 218, 117, 237, 126, 275, 226, 491, 659, ..."
1,[15],The issuance of a decision for commencement of...,R02-1-I,N,"[223, 218, 117, 237, 126, 275, 226, 491, 659, ..."
2,[18],If the grounds of commencement of assistance c...,R02-1-U,N,"[223, 218, 117, 237, 126, 275, 226, 491, 659, ..."
3,[17],If the assistant does not consent to an act fo...,R02-1-E,Y,"[223, 218, 117, 237, 126, 275, 226, 491, 659, ..."
4,[35],A foreign juridical person permitted possesses...,R02-2-E,N,"[223, 218, 117, 237, 126, 275, 226, 491, 659, ..."
...,...,...,...,...,...
1201,[297],A holder of a right of retention may collect f...,H27-36-A,N,"[223, 218, 117, 237, 126, 275, 226, 491, 659, ..."
1202,[405],In cases where the seller's obligation to deli...,H27-36-I,Y,"[223, 218, 117, 237, 126, 275, 226, 491, 659, ..."
1203,"[462, 459, 459-2]","If a person, who has become a guarantor withou...",H27-36-U,N,"[223, 218, 117, 237, 126, 275, 226, 491, 659, ..."
1204,[545],In cases where a seller canceled the contract ...,H27-36-E,Y,"[223, 218, 117, 237, 126, 275, 226, 491, 659, ..."


In [21]:
def make_pairs(query_id, labels):
    return list(itertools.product([query_id], labels))


train_positive_pairs = train_query_df.apply(lambda x: make_pairs(x["query_id"], x["task3_label"]), axis=1)
eval_positive_pairs = eval_query_df.apply(lambda x: make_pairs(x["query_id"], x["task3_label"]), axis=1)
test_positive_pairs = test_query_df.apply(lambda x: make_pairs(x["query_id"], x["task3_label"]), axis=1)

train_negative_pairs = train_query_df.apply(lambda x: make_pairs(x["query_id"], x["task3_false_articles"]), axis=1)
eval_negative_pairs = eval_query_df.apply(lambda x: make_pairs(x["query_id"], x["task3_false_articles"]), axis=1)
test_negative_pairs = test_query_df.apply(lambda x: make_pairs(x["query_id"], x["task3_false_articles"]), axis=1)

train_positive_pairs = sum(train_positive_pairs, [])
eval_positive_pairs = sum(eval_positive_pairs, [])
test_positive_pairs = sum(test_positive_pairs, [])

train_negative_pairs = sum(train_negative_pairs, [])
eval_negative_pairs = sum(eval_negative_pairs, [])
test_negative_pairs = sum(test_negative_pairs, [])

In [23]:
train_positive_pairs

[('R02-1-A', '15'),
 ('R02-1-A', '11'),
 ('R02-1-I', '15'),
 ('R02-1-U', '18'),
 ('R02-1-E', '17'),
 ('R02-2-E', '35'),
 ('R02-3-A', '95'),
 ('R02-3-I', '120'),
 ('R02-3-U', '121-2'),
 ('R02-3-O', '95'),
 ('R02-4-A', '117'),
 ('R02-4-I', '117'),
 ('R02-4-U', '117'),
 ('R02-5-A', '166'),
 ('R02-5-I', '126'),
 ('R02-5-U', '724'),
 ('R02-5-E', '169'),
 ('R02-5-O', '168'),
 ('R02-8-E', '192'),
 ('R02-8-O', '192'),
 ('R02-8-O', '193'),
 ('R02-9-A', '200'),
 ('R02-9-I', '200'),
 ('R02-9-U', '200'),
 ('R02-9-E', '192'),
 ('R02-9-O', '201'),
 ('R02-10-I', '268'),
 ('R02-10-U', '265'),
 ('R02-10-E', '269-2'),
 ('R02-11-E', '338'),
 ('R02-12-A', '396'),
 ('R02-13-A', '398-11'),
 ('R02-13-A', '376'),
 ('R02-13-E', '398-6'),
 ('R02-13-E', '398-4'),
 ('R02-15-I', '412-2'),
 ('R02-15-U', '422-2'),
 ('R02-15-E', '536'),
 ('R02-15-O', '567'),
 ('R02-16-A', '424-7'),
 ('R02-16-I', '424-5'),
 ('R02-16-I', '424'),
 ('R02-16-U', '424'),
 ('R02-16-E', '425'),
 ('R02-16-O', '425-4'),
 ('R02-16-O', '425-2'),

In [24]:
train_negative_pairs

[('R02-1-A', '223'),
 ('R02-1-A', '218'),
 ('R02-1-A', '117'),
 ('R02-1-A', '237'),
 ('R02-1-A', '126'),
 ('R02-1-A', '275'),
 ('R02-1-A', '226'),
 ('R02-1-A', '491'),
 ('R02-1-A', '659'),
 ('R02-1-A', '564'),
 ('R02-1-A', '183'),
 ('R02-1-A', '182'),
 ('R02-1-A', '623'),
 ('R02-1-A', '505'),
 ('R02-1-A', '703'),
 ('R02-1-A', '530'),
 ('R02-1-A', '568'),
 ('R02-1-A', '374'),
 ('R02-1-A', '688'),
 ('R02-1-A', '169'),
 ('R02-1-A', '230'),
 ('R02-1-A', '418'),
 ('R02-1-A', '699'),
 ('R02-1-A', '199'),
 ('R02-1-A', '31'),
 ('R02-1-A', '221'),
 ('R02-1-A', '268'),
 ('R02-1-A', '36'),
 ('R02-1-A', '500'),
 ('R02-1-A', '324'),
 ('R02-1-A', '94'),
 ('R02-1-A', '296'),
 ('R02-1-A', '681'),
 ('R02-1-A', '301'),
 ('R02-1-A', '250'),
 ('R02-1-A', '417'),
 ('R02-1-A', '249'),
 ('R02-1-A', '37'),
 ('R02-1-A', '326'),
 ('R02-1-A', '465-8'),
 ('R02-1-A', '466-3'),
 ('R02-1-A', '474'),
 ('R02-1-A', '557'),
 ('R02-1-A', '697'),
 ('R02-1-A', '657'),
 ('R02-1-A', '325'),
 ('R02-1-A', '520-15'),
 ('R02-1-A

In [25]:
def distance_function(query_emb, article_emb):
    return query_emb - article_emb


def get_distance(query_id, article_id, query_embedding_dict, article_embedding_dict):
    query_emb = query_embedding_dict[query_id]
    article_emb = article_embedding_dict[article_id]

    return distance_function(query_emb, article_emb)


def make_X_y(positive_pairs, negative_pairs, query_embedding_dict, article_embedding_dict, n_positive_replicate=0, do_shuffle=False):
    X_pos = list(map(lambda x: get_distance(*x, query_embedding_dict, article_embedding_dict), positive_pairs))
    X_neg = list(map(lambda x: get_distance(*x, query_embedding_dict, article_embedding_dict), negative_pairs))

    X_pos = np.array(X_pos)
    if n_positive_replicate > 0:
        X_pos = np.repeat(X_pos, n_positive_replicate, axis=0)

    X_neg = np.array(X_neg)

    X = np.concatenate([X_pos, X_neg])
    y = np.concatenate([np.ones(X_pos.shape[0]), np.zeros(X_neg.shape[0])])

    return shuffle(X, y) if do_shuffle else (X, y)


X_train, y_train = make_X_y(train_positive_pairs, train_negative_pairs,
                            query_embedding_dict, article_embedding_dict,
                            n_positive_replicate=HISTOGRAM_N_POSITIVE_REPLICATES,
                            do_shuffle=True)

X_eval, y_eval = make_X_y(eval_positive_pairs, eval_negative_pairs,
                          query_embedding_dict, article_embedding_dict)

X_test, y_test = make_X_y(test_positive_pairs, test_negative_pairs,
                          query_embedding_dict, article_embedding_dict)

# For backup
X_train_no_positive_replicate, y_train_no_positive_replicate = make_X_y(train_positive_pairs, train_negative_pairs,
                                                                          query_embedding_dict, article_embedding_dict)

### 1.3. Train Histogram-based Gradient Boosting

In [27]:
model = HistGradientBoostingClassifier(
    max_bins=76,
    max_iter=200,
    warm_start=True,
    learning_rate=0.13525541463963714,
    l2_regularization=0.07809942471674647,
    max_leaf_nodes=21,
    max_depth=18,
    verbose=1,
    random_state=0
)

model.fit(X_train, y_train)
joblib.dump(model,f"./{STEP1_CHECKPOINT_DIR}/histogram_classifier.pkl")

Binning 8.444 GB of training data: 30.631 s
Binning 0.938 GB of validation data: 16.016 s
Fitting gradient boosted rounds:
Fit 200 trees in 715.693 s, (4200 total leaves)
Time spent computing histograms: 304.896s
Time spent finding best splits:  165.031s
Time spent applying splits:      168.199s
Time spent predicting:           3.281s


FileNotFoundError: [Errno 2] No such file or directory: './{STEP1_CHECKPOINT_DIR}/histogram_classifier.pkl'

In [28]:
joblib.dump(model,f"./{STEP1_CHECKPOINT_DIR}/histogram_classifier.pkl")

['./checkpoints/step1_bge_pre_retrieval/histogram_classifier.pkl']

In [None]:
### LOAD MODEL #####
model = joblib.load(f"./{STEP1_CHECKPOINT_DIR}/histogram_classifier.pkl")

### 1.4 Test Histogram-based Gradient Boosting

In [29]:
def get_top_preds(group):
    group = group.sort_values("step1_score", ascending=False)

    cut_off_score = group.iloc[BGE_TOP]["step1_score"]
    group["keep"] = group["step1_score"] > cut_off_score

    return group

On train dataset (with no positive samples replication)

In [30]:
y_pred = model.predict_proba(X_train_no_positive_replicate)

train_df_step1 = pd.DataFrame(train_positive_pairs + train_negative_pairs, columns=["query_id", "article_id"])
train_df_step1["step1_score"] = y_pred[:, 1]
train_df_step1["label"] = y_train_no_positive_replicate

train_df_step1 = train_df_step1.groupby("query_id")[train_df_step1.columns.tolist()]\
                               .apply(get_top_preds)\
                               .reset_index(drop=True)

In [31]:
train_df_step1

Unnamed: 0,query_id,article_id,step1_score,label,keep
0,H18-1-1,572,0.989175,1.0,True
1,H18-1-1,559,0.208653,0.0,True
2,H18-1-1,520,0.202565,0.0,True
3,H18-1-1,87,0.118228,0.0,True
4,H18-1-1,707,0.107317,0.0,True
...,...,...,...,...,...
764923,R1-24-U,167,0.000010,0.0,False
764924,R1-24-U,326,0.000010,0.0,False
764925,R1-24-U,465-9,0.000010,0.0,False
764926,R1-24-U,398-16,0.000008,0.0,False


In [32]:
recall_score(train_df_step1["label"], train_df_step1["keep"])

1.0

On eval dataset

In [33]:
y_pred = model.predict_proba(X_eval)

eval_df_step1 = pd.DataFrame(eval_positive_pairs + eval_negative_pairs, columns=["query_id", "article_id"])
eval_df_step1["step1_score"] = y_pred[:, 1]
eval_df_step1["label"] = y_eval

eval_df_step1 = eval_df_step1.groupby("query_id")[eval_df_step1.columns.tolist()]\
                             .apply(get_top_preds)\
                             .reset_index(drop=True)

In [34]:
recall_score(eval_df_step1["label"], eval_df_step1["keep"])

0.9769230769230769

On test dataset

In [35]:
y_pred = model.predict_proba(X_test)

test_df_step1 = pd.DataFrame(test_positive_pairs + test_negative_pairs, columns=["query_id", "article_id"])
test_df_step1["step1_score"] = y_pred[:, 1]
test_df_step1["label"] = y_test

test_df_step1 = test_df_step1.groupby("query_id")[test_df_step1.columns.tolist()]\
                             .apply(get_top_preds)\
                             .reset_index(drop=True)

In [36]:
recall_score(test_df_step1["label"], test_df_step1["keep"])

0.9307692307692308

In [37]:
joblib.dump(model,f"./{STEP1_CHECKPOINT_DIR}/test_df_step1.pkl")

['./checkpoints/step1_bge_pre_retrieval/test_df_step1.pkl']

In [38]:
del model

## Step 2: RankLlama for 2nd stage retrieval

In [40]:
def get_model(peft_model_name):
    config = PeftConfig.from_pretrained(peft_model_name)
    base_model = AutoModelForSequenceClassification.from_pretrained(config.base_model_name_or_path,
                                                                    num_labels=1,
                                                                    torch_dtype=torch.bfloat16,
                                                                    device_map="auto")
    model = PeftModel.from_pretrained(base_model, peft_model_name)
    model = model.merge_and_unload()
    model.eval()
    return model


def make_prompt(query, title, passage):
    return f'query: {query}<s>document: {title} {passage}'


def get_scores(model, tokenizer, df, batch_size, max_len, data_collator):
    scores = []

    for i in tqdm(range(0, len(df), batch_size)):
        batch = df[i:i+batch_size]

        text = batch.apply(lambda x: make_prompt(x['query_content'], x['article_id'], x['article_content']), axis=1)
        text = list(text)

        inputs = tokenizer(text, max_length=max_len, truncation=True)
        inputs = [dict(zip(inputs.keys(), values)) for values in zip(*inputs.values())]  # convert to list of dicts

        inputs = data_collator(inputs)
        inputs = inputs.to(model.device)

        with torch.no_grad():
            outputs = model(**inputs)
        logits = outputs.logits.squeeze()
        scores.extend(logits.tolist())

    return scores

In [41]:
# Load the tokenizer and model
model = get_model('castorini/rankllama-v1-7b-lora-passage')

tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf')
tokenizer.pad_token = "<unk>"
model.config.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)

Downloading shards: 100%|██████████| 2/2 [09:02<00:00, 271.09s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:07<00:00,  3.88s/it]
Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at meta-llama/Llama-2-7b-hf and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [42]:
# Save model
model.save_pretrained(f"./{STEP2_CHECKPOINT_DIR}/rankllama_model")

# Save tokenizer
tokenizer.save_pretrained(f"./{STEP2_CHECKPOINT_DIR}/rankllama_tokenizer")

('./checkpoints/step2_rankllama_retrieval/rankllama_tokenizer/tokenizer_config.json',
 './checkpoints/step2_rankllama_retrieval/rankllama_tokenizer/special_tokens_map.json',
 './checkpoints/step2_rankllama_retrieval/rankllama_tokenizer/tokenizer.model',
 './checkpoints/step2_rankllama_retrieval/rankllama_tokenizer/added_tokens.json',
 './checkpoints/step2_rankllama_retrieval/rankllama_tokenizer/tokenizer.json')

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM  # or your model class

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("rankllama_tokenizer")
tokenizer.pad_token = "<unk>"

# Load model
model = AutoModelForCausalLM.from_pretrained("rankllama_model")
model.config.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)


### 2.1. Get RankLlama scores

In [43]:
batch_size = 16
max_len = 1024

data_collator = LegalDataCollatorWithPadding(tokenizer)

For train dataset

In [44]:
temp_df = train_df_step1[train_df_step1["keep"]].copy(deep=True)

temp_df = temp_df.merge(en_query_df[["query_id", "query_content"]], how="left")
temp_df = temp_df.merge(en_article_df[["article_id", "article_content"]], how="left")

train_step2_scores = get_scores(model, tokenizer, temp_df, batch_size, max_len, data_collator)

joblib.dump(train_step2_scores, f"./{STEP2_CHECKPOINT_DIR}/train_step2_scores.pkl")

100%|██████████| 6225/6225 [1:34:55<00:00,  1.09it/s]  


['./checkpoints/step2_rankllama_retrieval/train_step2_scores.pkl']

In [None]:
#### LOAD ########
train_step2_scores = joblib.load("train_step2_scores.pkl")

In [None]:
train_df_step2 = train_df_step1.copy(deep=True)

train_df_step2["step2_score"] = -5.0
train_df_step2.loc[train_df_step2["keep"], "step2_score"] = train_step2_scores

train_df_step2["keep"] &= train_df_step2["step2_score"] > RANKLLAMA_THRESHOLD

NameError: name 'train_df_step1' is not defined

In [46]:
train_df_step2.to_csv(os.path.join(STEP2_CHECKPOINT_DIR, f"train_df_threshold.{RANKLLAMA_THRESHOLD}.csv"), index=False)

In [47]:
recall_score(train_df_step2["label"], train_df_step2["keep"])

0.9339622641509434

For eval dataset

In [48]:
temp_df = eval_df_step1[eval_df_step1["keep"]].copy(deep=True)

temp_df = temp_df.merge(en_query_df[["query_id", "query_content"]], how="left")
temp_df = temp_df.merge(en_article_df[["article_id", "article_content"]], how="left")

eval_step2_scores = get_scores(model, tokenizer, temp_df, batch_size, max_len, data_collator)

joblib.dump(eval_step2_scores, f"./{STEP2_CHECKPOINT_DIR}/eval_step2_scores.pkl")

100%|██████████| 632/632 [09:38<00:00,  1.09it/s]


['./checkpoints/step2_rankllama_retrieval/eval_step2_scores.pkl']

In [None]:
### LOAD ########
eval_step2_scores = joblib.load("eval_step2_scores.pkl")

In [49]:
eval_df_step2 = eval_df_step1.copy(deep=True)

eval_df_step2["step2_score"] = -5.0
eval_df_step2.loc[eval_df_step2["keep"], "step2_score"] = eval_step2_scores

eval_df_step2["keep"] &= eval_df_step2["step2_score"] > RANKLLAMA_THRESHOLD

In [50]:
eval_df_step2.to_csv(os.path.join(STEP2_CHECKPOINT_DIR, f"eval_df_threshold.{RANKLLAMA_THRESHOLD}.csv"), index=False)

In [51]:
recall_score(eval_df_step2["label"], eval_df_step2["keep"])

0.9692307692307692

For test dataset

In [52]:
temp_df = test_df_step1[test_df_step1["keep"]].copy(deep=True)

temp_df = temp_df.merge(en_query_df[["query_id", "query_content"]], how="left")
temp_df = temp_df.merge(en_article_df[["article_id", "article_content"]], how="left")

test_step2_scores = get_scores(model, tokenizer, temp_df, batch_size, max_len, data_collator)

joblib.dump(test_step2_scores, f"./{STEP2_CHECKPOINT_DIR}/test_step2_scores.pkl")


100%|██████████| 682/682 [10:24<00:00,  1.09it/s]


['./checkpoints/step2_rankllama_retrieval/test_step2_scores.pkl']

In [None]:
#### LOAD ########
test_step2_scores = joblib.load("test_step2_scores.pkl")

In [53]:
test_df_step2 = test_df_step1.copy(deep=True)

test_df_step2["step2_score"] = -5.0
test_df_step2.loc[test_df_step2["keep"], "step2_score"] = test_step2_scores

test_df_step2["keep"] &= test_df_step2["step2_score"] > RANKLLAMA_THRESHOLD

In [54]:
test_df_step2.to_csv(os.path.join(STEP2_CHECKPOINT_DIR, f"test_df_threshold.{RANKLLAMA_THRESHOLD}.csv"), index=False)

In [55]:
recall_score(test_df_step2["label"], test_df_step2["keep"])

0.9307692307692308

### 2.2. Visualize RankLlama scores, exaplaination for choosing RANKLLAMA_THRESHOLD

In [None]:
query_id_to_visualize = "H18"


df = train_df_step2[train_df_step2["query_id"].str.startswith(query_id_to_visualize)]

data_dict = {
    'should be DROPPED': df[df["label"] == 0]["step2_score"],
    'should be KEPT': df[df["label"] == 1]["step2_score"],
}

plt.boxplot(data_dict.values(), tick_labels=data_dict.keys())
plt.show()

## Step 3: Model Finetuning

Please refer to **run_step3.sh** for model finetuning & inference.

Don't forget to set your Huggingface and Wandb tokens in the script.
If you don't want to use Wandb, you can change the `report_to` flag from `wandb` to `none` in the training configuration files.

In [57]:
temp_df

Unnamed: 0,query_id,article_id,step1_score,label,keep,query_content,article_content
0,R05-01-A,537,0.490752,1.0,True,The validity of a third party beneficiary cont...,Article 537 (1) If one of the parties promise...
1,R05-01-A,282,0.273778,0.0,True,The validity of a third party beneficiary cont...,Article 282 (1) One of the co-owners of land ...
2,R05-01-A,538,0.192004,0.0,True,The validity of a third party beneficiary cont...,Article 538 (1) After rights of the third par...
3,R05-01-A,93,0.184400,0.0,True,The validity of a third party beneficiary cont...,Article 93 (1) The validity of a manifestatio...
4,R05-01-A,96,0.149929,0.0,True,The validity of a third party beneficiary cont...,Article 96 (1) A manifestation of intention b...
...,...,...,...,...,...,...,...
10895,R05-36-U,540,0.006542,0.0,True,A special agreement to prohibit a set-off made...,Article 540 (1) If one of the parties has the...
10896,R05-36-U,501,0.006525,0.0,True,A special agreement to prohibit a set-off made...,Article 501 (1) A person that is subrogated t...
10897,R05-36-U,349,0.006524,0.0,True,A special agreement to prohibit a set-off made...,"Article 349 The pledgor may not, either by th..."
10898,R05-36-U,457,0.006401,0.0,True,A special agreement to prohibit a set-off made...,Article 457 (1) The postponement of expiry of...


In [None]:
train_step2_scores[0]

5.15625

: 