In [1]:
import os
import pandas as pd
import numpy as np
import glob
import itertools
import pickle
from tqdm import tqdm
import torch

from FlagEmbedding import BGEM3FlagModel

from transformers import AutoModelForSequenceClassification, AutoTokenizer
from peft import PeftModel, PeftConfig

from scipy.special import softmax

from src.util import load_samples
from src.data.data_collator import LegalDataCollatorWithPadding

  from .autonotebook import tqdm as notebook_tqdm


## Hyperparameters

In [21]:
TEST2025_EN_PATH = "data/COLIEE2025statute_data-English/train/R05_en.xml"

SELECTED_ID = "R05"

RAW_DATA_DIR = "data"
DATA_OUTPUT_DIR = "data/synthesys"
QUERY_PATH = os.path.join(RAW_DATA_DIR, "COLIEE2025statute_data-English/train")
ARTICLE_PATH = os.path.join(RAW_DATA_DIR, "full_en_civil_code_df_24.csv")

CHECKPOINT_DIR = "checkpoints"
STEP1_CHECKPOINT_DIR = f"{CHECKPOINT_DIR}/step1_bge_pre_retrieval"
STEP2_CHECKPOINT_DIR = f"{CHECKPOINT_DIR}/step2_rankllama_retrieval"
STEP3_CHECKPOINT_DIR = f"{CHECKPOINT_DIR}/step3_final_retrieval"

ACCEPTED_MODELS = [
    "e5_mistral_7b_instruct",
    "gemma_2_9b_it",
    "gemma_2_27b_it",
    "phi_3_medium_4k_instruct",
]


# TODO: fix bug
BUG_ARTICLE_POSTFIX = "(1)"  # In the R04's task 3 label, there are some ground truth labels having "(1)" postfix. We need to remove them.


INFERENCE_DIR = "checkpoints/inference"


# Step 1
BGE_TOP = 100
BGE_SEQUENCE_MAX_LENGTH = 1024
HISTOGRAM_N_POSITIVE_REPLICATES = 300


# Step 2
RANKLLAMA_MAX_LENGTH = 1024
RANKLLAMA_THRESHOLD = -3.5  # preserve about 50 candidates for each query
RANKLLAMA_TOP = 50


# Step 4
CUT_OFF_THRESHOLD = 0.3687529996711829
WEIGHTS = np.array([0.23716786, 0.21487627, 0.3068145 , 0.24114137])

In [3]:
WEIGHTS

array([0.23716786, 0.21487627, 0.3068145 , 0.24114137])

## Step 0: Create dataset

In [22]:
# 1. Load data
# Load the article data
en_article_df = pd.read_csv(ARTICLE_PATH)
en_article_df.rename(columns={"article": "article_id", "content": "article_content"}, inplace=True)

# Load the query data
query_files = glob.glob(f"{QUERY_PATH}/*.xml")

queries = []
for query_file in query_files:
    queries += load_samples(query_file)

en_query_df = pd.DataFrame(queries)
en_query_df = en_query_df.rename(columns={"index": "query_id",
                                          "content": "query_content",
                                          "result": "task3_label",
                                          "label": "task4_label"})

In [23]:
test_query_df = en_query_df[en_query_df["query_id"].str.startswith(SELECTED_ID)].copy(deep=True)

del en_query_df

if len(test_query_df) == 0:
    queries = load_samples(TEST2025_EN_PATH)

    test_query_df = pd.DataFrame(queries)
    test_query_df = test_query_df.rename(columns={"index": "query_id",
                                                    "content": "query_content",
                                                    "result": "task3_label",
                                                    "label": "task4_label"})


In [24]:
if "task3_label" in test_query_df.columns:
    test_query_df = test_query_df.drop(columns=["task3_label"])

## Step 1: BGE Pre-retrieval

### 1.1. BGE Embedding

In [6]:
model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True, device='cuda')

# article embedding
article_embeddings = model.encode(en_article_df["article_content"].tolist(),
                                  batch_size=32,
                                  max_length=BGE_SEQUENCE_MAX_LENGTH
                                  )['dense_vecs']


# query embedding
query_embeddings = model.encode(test_query_df["query_content"].tolist(),
                                batch_size=32,
                                max_length=BGE_SEQUENCE_MAX_LENGTH
                                )['dense_vecs']

Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 29739.81it/s]
pre tokenize: 100%|██████████| 24/24 [00:00<00:00, 322.82it/s]
You're using a XLMRobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Inference Embeddings: 100%|██████████| 24/24 [00:00<00:00, 33.19it/s]
pre tokenize: 100%|██████████| 3/3 [00:00<00:00, 284.96it/s]
Inference Embeddings: 100%|██████████| 3/3 [00:00<00:00, 62.49it/s]


In [7]:
article_embedding_dict = dict(zip(en_article_df["article_id"].tolist(), article_embeddings))
query_embedding_dict = dict(zip(test_query_df["query_id"].tolist(), query_embeddings))

### 1.2. Retrieval with Histogram-based Gradient Boosting
Data

In [8]:
def make_pairs(query_id, labels):
    return list(itertools.product([query_id], labels))


def distance_function(query_emb, article_emb):
    return query_emb - article_emb


def get_distance(query_id, article_id, query_embedding_dict, article_embedding_dict):
    query_emb = query_embedding_dict[query_id]
    article_emb = article_embedding_dict[article_id]

    return distance_function(query_emb, article_emb)


query_article_pairs = test_query_df.apply(lambda x: make_pairs(x["query_id"], en_article_df["article_id"].values), axis=1)
query_article_pairs = sum(query_article_pairs, [])

X_test = list(map(lambda x: get_distance(*x, query_embedding_dict, article_embedding_dict), query_article_pairs))
X_test = np.array(X_test)

In [9]:
import joblib

Infer

In [10]:
def get_top_preds(group):
    group = group.sort_values("step1_score", ascending=False)

    # cut_off_score = group.iloc[BGE_TOP]["step1_score"]
    # group["keep"] = group["step1_score"] > cut_off_score - 1e-5

    group["keep"] = False
    group.iloc[:BGE_TOP, group.columns.get_loc("keep")] = True

    return group


model = joblib.load(open(f"{STEP1_CHECKPOINT_DIR}/histogram_classifier.pkl", "rb"))
y_pred = model.predict_proba(X_test)

test_df_step1 = pd.DataFrame(query_article_pairs, columns=["query_id", "article_id"])
test_df_step1["step1_score"] = y_pred[:, 1]

test_df_step1 = test_df_step1.groupby("query_id")[test_df_step1.columns.tolist()]\
                             .apply(get_top_preds)\
                             .reset_index(drop=True)

test_df_step1 = test_df_step1[test_df_step1["keep"] == True]

## Step 2: RankLlama for 2nd-stage retrieval

In [11]:
def get_model(peft_model_name):
    config = PeftConfig.from_pretrained(peft_model_name)
    base_model = AutoModelForSequenceClassification.from_pretrained(config.base_model_name_or_path,
                                                                    num_labels=1,
                                                                    torch_dtype=torch.bfloat16,
                                                                    device_map="auto")
    model = PeftModel.from_pretrained(base_model, peft_model_name)
    model = model.merge_and_unload()
    model.eval()
    return model


def make_prompt(query, title, passage):
    return f'query: {query}<s>document: {title} {passage}'


def get_scores(model, tokenizer, df, batch_size, max_len, data_collator):
    scores = []

    for i in tqdm(range(0, len(df), batch_size)):
        batch = df[i:i+batch_size]

        text = batch.apply(lambda x: make_prompt(x['query_content'], x['article_id'], x['article_content']), axis=1)
        text = list(text)

        inputs = tokenizer(text, max_length=max_len, truncation=True)
        inputs = [dict(zip(inputs.keys(), values)) for values in zip(*inputs.values())]  # convert to list of dicts

        inputs = data_collator(inputs)
        inputs = inputs.to(model.device)

        with torch.no_grad():
            outputs = model(**inputs)
        logits = outputs.logits
        scores.extend(logits.tolist())

    return np.array(scores).squeeze().tolist()


def get_top_preds(group):
    group = group.sort_values("step2_score", ascending=False)

    cut_off_score = group.iloc[RANKLLAMA_TOP]["step2_score"]
    group["keep"] = group["step2_score"] > cut_off_score - 1e-5

    return group

Infer

In [12]:
# Load the tokenizer and model
model = get_model('castorini/rankllama-v1-7b-lora-passage')
tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf')

tokenizer.pad_token = "<unk>"
model.config.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)

Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.25s/it]
Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at meta-llama/Llama-2-7b-hf and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [13]:
data_collator = LegalDataCollatorWithPadding(tokenizer)

test_df_step2 = test_df_step1.copy(deep=True)
test_df_step2 = test_df_step2.merge(test_query_df[["query_id", "query_content"]], how="left")
test_df_step2 = test_df_step2.merge(en_article_df[["article_id", "article_content"]], how="left")

test_step2_scores = get_scores(model, tokenizer, test_df_step2, 16, RANKLLAMA_MAX_LENGTH, data_collator)

  0%|          | 0/519 [00:00<?, ?it/s]

100%|██████████| 519/519 [07:41<00:00,  1.13it/s]


In [5]:
test_df_step2

NameError: name 'test_df_step2' is not defined

In [6]:
test_step2_scores

NameError: name 'test_step2_scores' is not defined

In [16]:
test_df_step2["step2_score"] = test_step2_scores
test_df_step2 = test_df_step2.groupby("query_id")[test_df_step2.columns.tolist()]\
                             .apply(get_top_preds)\
                             .reset_index(drop=True)

test_df_step2 = test_df_step2[test_df_step2["keep"] == True]

In [17]:
test_df_step2.to_json(f"checkpoints/inference/{SELECTED_ID}_step2.jsonl", lines=True, orient="records")

## Step 3: LLM Inference

In [None]:
# refer to run.sh to get the predicted logits

In [25]:
test_df_step2 = pd.read_json(f"checkpoints/inference/{SELECTED_ID}_step2.jsonl", lines=True, orient="records")

In [28]:
def load_logits(file_paths):
    logits = []
    for file_path in file_paths:
        preds = np.load(file_path)
        preds = softmax(preds, axis=1)

        logits.append(preds[:, 1])

    return np.array(logits).T


all_logit_path = [INFERENCE_DIR + f"/{SELECTED_ID}eval/" + model + f"/{SELECTED_ID}_step2_logits.npy" for model in ACCEPTED_MODELS]
all_logit_path.sort()

all_logits = load_logits(all_logit_path)

## Step 4: Ensemble

In [29]:
# TODO: review the top filter
def top_filter(group_df):
    group_df = group_df.sort_values(by=["step3_score"],
                                          ascending=False,
                                          ignore_index=True)
    return group_df[:2]


def fill_none_predicted(row, step3_top2_df):
    if type(row["article_id"]) == list:
        return row
    row["article_id"] = step3_top2_df[step3_top2_df["query_id"] == row["query_id"]]["article_id"].values[0]

    return row

In [30]:
preds = (np.dot(all_logits, WEIGHTS) > CUT_OFF_THRESHOLD).astype(int)

test_df_step3 = test_df_step2.copy(deep=True)
test_df_step3["keep"] = preds & (test_df_step3["step2_score"] > RANKLLAMA_THRESHOLD)
test_df_step3 = test_df_step3[test_df_step3["keep"] == True]

In [31]:
test_df_step3

Unnamed: 0,query_id,article_id,step1_score,keep,query_content,article_content,step2_score
0,R05-01-A,537,0.538559,True,The validity of a third party beneficiary cont...,Article 537 (1) If one of the parties promise...,4.437500
55,R05-02-A,713,0.020527,True,Mental capacity means the capacity to apprecia...,Article 713 A person who has inflicted damage...,3.484375
56,R05-02-A,712,0.001173,True,Mental capacity means the capacity to apprecia...,Article 712 If a minor has inflicted damage o...,2.250000
60,R05-02-A,3-2,0.001198,True,Mental capacity means the capacity to apprecia...,Article 3-2 If the person making a juridical ...,-0.941406
108,R05-02-E,526,0.999860,True,If an offeror of a contract comes to be in a c...,"Article 526 If an offeror dies, comes to be i...",8.875000
...,...,...,...,...,...,...,...
5373,R05-36-E,705,0.999874,True,A person that has paid money or delivered anyt...,Article 705 A person that has paid money or d...,7.000000
5424,R05-36-I,466-5,0.938789,True,A special agreement to restrict assignment mad...,Article 466-5 (1) Notwithstanding the provisi...,7.312500
5425,R05-36-I,466-4,0.968195,True,A special agreement to restrict assignment mad...,Article 466-4 (1) The provisions of Article 4...,6.281250
5427,R05-36-I,466,0.630611,True,A special agreement to restrict assignment mad...,Article 466 (1) A claim may be assigned; prov...,6.000000


In [32]:
step3_score_df = test_df_step2.copy(deep=True)
step3_score_df["step3_score"] = np.dot(all_logits, WEIGHTS)

step3_top2_df = step3_score_df.drop_duplicates(subset=["query_id", "article_id"])\
    .groupby("query_id")[step3_score_df.columns]\
    .apply(top_filter)\
    .reset_index(drop=True)
step3_top2_df = step3_top2_df.groupby("query_id")["article_id"].apply(list).reset_index()


submission_df = test_df_step3.copy(deep=True)
submission_df = submission_df.groupby("query_id")["article_id"].apply(list).reset_index()
submission_df = submission_df.merge(test_query_df, on="query_id", how="right")


# In some cases, we can't find any predicted articles. We need to fill them with the top 2 articles from step 3
submission_df = submission_df.apply(lambda x: fill_none_predicted(x, step3_top2_df), axis=1)

In [33]:
submission_df.to_json(f"{SELECTED_ID}_submission.jsonl", lines=True, orient="records")

In [34]:
submission_df

Unnamed: 0,query_id,article_id,query_content,task4_label
0,R05-01-A,[537],The validity of a third party beneficiary cont...,Y
1,R05-02-A,"[713, 712, 3-2]",Mental capacity means the capacity to apprecia...,N
2,R05-02-I,[3-2],If a party to a contract did not have mental c...,Y
3,R05-02-U,"[121-2, 3-2]",If a party to a contract did not have mental c...,Y
4,R05-02-E,"[526, 97]",If an offeror of a contract comes to be in a c...,Y
...,...,...,...,...
104,R05-36-A,[107],In the case where an agent performs an act tha...,N
105,R05-36-I,"[466-5, 466-4, 466]",A special agreement to restrict assignment mad...,Y
106,R05-36-U,[505],A special agreement to prohibit a set-off made...,Y
107,R05-36-E,[705],A person that has paid money or delivered anyt...,N


In [37]:
import json
import numpy as np
import pandas as pd

# -----------------------
# Config: choose max k to evaluate precision@k / recall@k
# If None, will use the maximal prediction length in your dataframe.
MAX_K = None   # or set e.g. 5 or 10

# -----------------------
# Helpers
def precision_at_k_single(preds, gold_set, k):
    """Standard precision@k: |relevant in top-k| / k.
       If len(preds) < k, missing slots are treated as non-relevant (denominator still k).
    """
    if k <= 0:
        return 0.0
    topk = preds[:k]
    hits = sum(1 for p in topk if p in gold_set)
    return hits / k

def recall_at_k_single(preds, gold_set, k):
    """Recall@k: |relevant in top-k| / |gold_set|. If gold_set empty -> 0.0."""
    if len(gold_set) == 0:
        return 0.0
    topk = preds[:k]
    hits = sum(1 for p in topk if p in gold_set)
    return hits / len(gold_set)

def average_precision(preds, gold_set):
    """AP: sum_{i:pred_i in gold} (precision@i) / |gold_set| ; returns 0 if gold_set empty."""
    if len(gold_set) == 0:
        return 0.0
    hits = 0
    sum_precisions = 0.0
    for i, p in enumerate(preds, start=1):
        if p in gold_set:
            hits += 1
            sum_precisions += hits / i
    if hits == 0:
        return 0.0
    return sum_precisions / len(gold_set)

def f_beta(prec, rec, beta=2.0):
    if prec == 0 and rec == 0:
        return 0.0
    b2 = beta * beta
    return (1 + b2) * (prec * rec) / (b2 * prec + rec)

# -----------------------
# Load gold
with open(f"./kg/data_parsed/{SELECTED_ID}_data.json", "r", encoding="utf-8") as f:
    gold = json.load(f)

# normalize gold ids to strings
for qid, info in gold.items():
    gold[qid]["retrieved_list"] = [str(x).strip() for x in info.get("retrieved_list", [])]

# -----------------------
# Prepare preds df (uses in-memory step3_top2_df)
preds_df = submission_df.copy()

# Ensure preds are lists of strings
def ensure_list_of_str(x):
    if isinstance(x, list):
        return [str(v) for v in x]
    if isinstance(x, str):
        s = x.strip()
        if s.startswith("[") and s.endswith("]"):
            # crude parse: split on commas
            items = [it.strip().strip("'\"") for it in s[1:-1].split(",") if it.strip() != ""]
            return [str(it) for it in items]
        return [s]
    # fallback
    return [str(x)]

preds_df["preds"] = preds_df["article_id"].apply(ensure_list_of_str)

# Determine K for precision@k / recall@k
if MAX_K is None:
    max_pred_len = preds_df["preds"].apply(len).max()
    # but also consider gold length if you prefer; here we pick max prediction length
    K = int(max(1, max_pred_len))
else:
    K = int(MAX_K)

# -----------------------
# Evaluate per query
rows = []
# matrices for precision@k & recall@k
prec_at_k_matrix = []  # list of lists per query
rec_at_k_matrix = []

for _, r in preds_df.iterrows():
    qid = r["query_id"]
    preds = r["preds"]
    gold_list = gold.get(qid, {}).get("retrieved_list", [])
    gold_set = set(gold_list)

    # overall precision using full predicted list (len(preds) as denom)
    full_prec = (sum(1 for p in preds if p in gold_set) / len(preds)) if len(preds) > 0 else 0.0
    full_rec = (sum(1 for p in preds if p in gold_set) / len(gold_set)) if len(gold_set) > 0 else 0.0

    # F2 using full_prec, full_rec
    f2 = f_beta(full_prec, full_rec, beta=2.0)

    # AP
    ap = average_precision(preds, gold_set)

    # precision@k and recall@k for k=1..K
    prec_k = [precision_at_k_single(preds, gold_set, k) for k in range(1, K+1)]
    rec_k = [recall_at_k_single(preds, gold_set, k) for k in range(1, K+1)]

    prec_at_k_matrix.append(prec_k)
    rec_at_k_matrix.append(rec_k)

    rows.append({
        "query_id": qid,
        "n_pred": len(preds),
        "n_gold": len(gold_list),
        "precision_full": full_prec,
        "recall_full": full_rec,
        "f2_full": f2,
        "AP": ap
    })

df_metrics = pd.DataFrame(rows)

# -----------------------
# Aggregate summary
mean_precision_full = df_metrics["precision_full"].mean()
mean_recall_full = df_metrics["recall_full"].mean()
mean_f2_full = df_metrics["f2_full"].mean()

map_all = df_metrics["AP"].mean()
map_relevant = df_metrics.loc[df_metrics["n_gold"]>0, "AP"].mean()

# Precision@k and Recall@k averaged across queries (treating missing preds as non-relevant)
prec_at_k_arr = np.mean(np.array(prec_at_k_matrix), axis=0) if len(prec_at_k_matrix)>0 else np.zeros(K)
rec_at_k_arr = np.mean(np.array(rec_at_k_matrix), axis=0) if len(rec_at_k_matrix)>0 else np.zeros(K)

# -----------------------
# Print summary
print("=== Aggregate retrieval metrics ===")
print(f"Queries evaluated : {len(df_metrics)}")
print(f"Mean Precision (full predicted lists) : {mean_precision_full:.4f}")
print(f"Mean Recall (full predicted lists)    : {mean_recall_full:.4f}")
print(f"Mean F2 (β=2)                         : {mean_f2_full:.4f}")
print(f"MAP (all queries; AP=0 for empty)    : {map_all:.4f}")
print(f"MAP (only queries with >=1 gold)     : {map_relevant:.4f}")
print()
print("Precision@k (k=1..{}):".format(K))
for k, val in enumerate(prec_at_k_arr, start=1):
    print(f"  P@{k}: {val:.4f}")
print()
print("Recall@k (k=1..{}):".format(K))
for k, val in enumerate(rec_at_k_arr, start=1):
    print(f"  R@{k}: {val:.4f}")

# -----------------------
# Expose results for further inspection
# df_metrics contains per-query metrics
# prec_at_k_arr and rec_at_k_arr contain averaged P@k / R@k across queries
# You can examine per-query AP distribution:
df_metrics_sorted = df_metrics.sort_values("AP", ascending=False).reset_index(drop=True)

# show top/bottom problematic queries
print("\nTop 5 queries by AP:")
display(df_metrics_sorted.head(5))
print("\nBottom 5 queries by AP (including zero AP):")
display(df_metrics_sorted.tail(5))


=== Aggregate retrieval metrics ===
Queries evaluated : 109
Mean Precision (full predicted lists) : 0.7385
Mean Recall (full predicted lists)    : 0.8532
Mean F2 (β=2)                         : 0.8114
MAP (all queries; AP=0 for empty)    : 0.8203
MAP (only queries with >=1 gold)     : 0.8203

Precision@k (k=1..4):
  P@1: 0.8073
  P@2: 0.4817
  P@3: 0.3333
  P@4: 0.2523

Recall@k (k=1..4):
  R@1: 0.7339
  R@2: 0.8303
  R@3: 0.8486
  R@4: 0.8532

Top 5 queries by AP:


Unnamed: 0,query_id,n_pred,n_gold,precision_full,recall_full,f2_full,AP
0,R05-01-A,1,1,1.0,1.0,1.0,1.0
1,R05-02-I,1,1,1.0,1.0,1.0,1.0
2,R05-02-U,2,2,1.0,1.0,1.0,1.0
3,R05-02-E,2,1,0.5,1.0,0.833333,1.0
4,R05-03-A,1,1,1.0,1.0,1.0,1.0



Bottom 5 queries by AP (including zero AP):


Unnamed: 0,query_id,n_pred,n_gold,precision_full,recall_full,f2_full,AP
104,R05-20-E,2,1,0.0,0.0,0.0,0.0
105,R05-19-O,1,1,0.0,0.0,0.0,0.0
106,R05-27-A,2,1,0.0,0.0,0.0,0.0
107,R05-26-I,1,1,0.0,0.0,0.0,0.0
108,R05-26-E,1,2,0.0,0.0,0.0,0.0


In [64]:
df_metrics.head(100)

Unnamed: 0,query_id,precision,recall,f2,AP,n_gold
0,R06-01-A,1.0,1.0,1.000000,1.00,2
1,R06-01-E,0.5,1.0,0.833333,1.00,1
2,R06-01-I,1.0,1.0,1.000000,1.00,1
3,R06-01-O,0.0,0.0,0.000000,0.00,1
4,R06-03-A,1.0,1.0,1.000000,1.00,1
...,...,...,...,...,...,...
78,R06-31-U,1.0,1.0,1.000000,1.00,1
79,R06-37-A,1.0,1.0,1.000000,1.00,1
80,R06-37-E,0.5,1.0,0.833333,0.50,1
81,R06-37-I,1.0,1.0,1.000000,1.00,1
