# Read Before Use:

The annotation used in this notebook matches those used in Table 1 of the paper.

Note in the current notebook we only illustrate some of the configurations but give you the instructions how to reproduce the other bits as well. For instance, the notebook is only showing how to use auslaw embedding to construct the RAG database for retrieval experiements. If you want to use OpenAI embedding, do the following three steps:
1. Import the library and set up the function:
```
import chromadb.utils.embedding_functions as embedding_functions
openai_ef = embedding_functions.OpenAIEmbeddingFunction(
                api_key="<YOUR API KEY>",
                model_name="text-embedding-3-large"
            )
```
2. Add `embedding_function=openai_ef` in every `get_collection` fuction you see, for example:
`client.get_collection('casetext', embedding_function=openai_ef)`

3. No need to pass any embeddings manually. Comment out the line containing embeddings, for example:
```
casetext_collection.upsert(
    documents = text,
    # embeddings = text_embeddings,
    metadatas = meta_data_text,
    ids = citation
)
rag_output_casetext = client.get_collection('casetext').query(
    # query_embeddings=test_data_rag_input_text_only_embeddings,
    n_results=5,
)
```

Also, the LLM used throughout this notebook is GPT4o. In order to load other LLMs, change `gpt4o_prompt` fuction to other LLM APIs.


In [None]:
!pip install chromadb
!pip install -U sentence-transformers

In [None]:
from openai import OpenAI
openai_client = OpenAI(
    api_key="<YOUR API KEY>",
)

def gpt4o_prompt(sys_content, content):
    got_result = False
    while not got_result:
        try:
            response = openai_client.chat.completions.create(
              model="gpt-4o",
              messages = [{"role": "system", "content": sys_content}, {"role": "user", "content": content}],
              temperature=0,
            )
            got_result = True
        except Exception:
            time.sleep(1)

    prompt_cost = (response.usage.prompt_tokens / 1000000) * 2.5
    completion_cost = (response.usage.completion_tokens / 1000000) * 10

    total_cost = prompt_cost + completion_cost
    # print(f"Total cost for gpt-4o: ${total_cost:.4f}\n")
    return response.choices[0].message.content.strip(), total_cost

def gpt4o_mini_prompt(sys_content, content):
    got_result = False
    while not got_result:
        try:
            response = openai_client.chat.completions.create(
              model="gpt-4o-mini",
              messages = [{"role": "system", "content": sys_content}, {"role": "user", "content": content}],
              temperature=0,
            )
            got_result = True
        except Exception:
            time.sleep(1)

    prompt_cost = (response.usage.prompt_tokens / 1000000) * 0.15
    completion_cost = (response.usage.completion_tokens / 1000000) * 0.6

    total_cost = prompt_cost + completion_cost
    # print(f"Total cost for gpt-4o-mini: ${total_cost:.4f}\n")
    return response.choices[0].message.content.strip(), total_cost

In [None]:
from sentence_transformers import SentenceTransformer
sentences = ["This is an example sentence", "Each sentence is converted"]

model = SentenceTransformer('adlumal/auslaw-embed-v1.0')
embeddings = model.encode(sentences)


In [None]:
import json
import pickle
import chromadb
from tqdm import tqdm
import numpy as np

In [None]:
with open('/content/citation_data_test_original.json', 'r') as f:
    test_set = json.load(f)

In [None]:
with open('/content/citation_data_combined_index_by_casename_test_only.json', 'r') as f:
    citation_data = json.load(f)

# **LLM-only Approach**

In [None]:
test_data_for_gpt_4o = []
for i in tqdm(range(len(test_set))):
    data = {}
    data['instruction'] = "The following description belongs to a case in the NSW Case Law. You will be given a brief text, and a brief description of a potential citation required. Your task is to predict the citation by listing up-to 5 potential citations, spearated by ';'."
    data['input'] = 'Text: ' + test_set[i]['citation_text'].replace(test_set[i]['cited_case_name'], '<CASENAME>') + '\nDescription: ' + test_set[i]['citation_reason'].replace(test_set[i]['cited_case_name'], '<CASENAME>')
    data['output'] = test_set[i]['cited_case_name']
    response, _ = gpt4o_prompt(data['instruction'], data['input'])
    data['predicted'] = response
    test_data_for_gpt_4o.append(data)

# **Retrieval-only Approach**

In [None]:
embeddings = {}

In [None]:
len(citation_data)

In [None]:
for i in tqdm(range(len(citation_data))):
    citation_reasons = ';'.join(citation_data[i]['citation_reasons'])
    embeddings[citation_reasons] = model.encode(citation_reasons)

    catchwords = citation_data[i]['catchwords']
    embeddings[catchwords] = model.encode(catchwords)

    casetext = citation_data[i]['text']
    embeddings[casetext] = model.encode(casetext)

In [None]:
for i in tqdm(range(len(test_set))):
    test_data_rag_input_text_only = test_set[i]['citation_text'].replace(test_set[i]['cited_case_name'], '<CASENAME>')
    embeddings[test_data_rag_input_text_only] = model.encode(test_data_rag_input_text_only)

    test_data_rag_input_text_reason = test_set[i]['citation_text'].replace(test_set[i]['cited_case_name'], '<CASENAME>') + ' ' + test_set[i]['citation_reason'].replace(test_set[i]['cited_case_name'], '<CASENAME>')
    embeddings[test_data_rag_input_text_reason] = model.encode(test_data_rag_input_text_reason)

In [None]:
with open("auslaw_embeddings.pkl", "wb") as f:
    pickle.dump(embeddings, f)

In [None]:
saved_embeddings = embeddings

In [None]:
# with open("/content/auslaw_embeddings.pkl", "rb") as f:
#     saved_embeddings = pickle.load(f)

In [None]:
client = chromadb.PersistentClient(path="db")

In [None]:
casetext = client.create_collection("casetext")
catchwords = client.create_collection("catchwords")
citation_reasons = client.create_collection("citation_reasons")

In [None]:
citation = []
text = []
catchwords = []
citation_reasons = []
meta_data_catchwords = []
meta_data_text = []
meta_data_citation_reasons = []

text_embeddings = []
catchwords_embeddings = []
citation_reasons_embeddings = []

for data in tqdm(citation_data):
    citation.append(data['citation'])
    catchwords.append(data['catchwords'])
    text.append(data['text'])

    citation_reasons_ = ';'.join(data['citation_reasons'])
    citation_reasons.append(citation_reasons_)

    text_embeddings.append(saved_embeddings[data['text']])
    catchwords_embeddings.append(saved_embeddings[data['catchwords']])
    citation_reasons_embeddings.append(saved_embeddings[citation_reasons_])

    meta_data_catchwords.append({data['citation']:data['catchwords']})
    meta_data_text.append({data['citation']:data['text']})
    meta_data_citation_reasons.append({data['citation']:citation_reasons_})

In [None]:
catchwords_collection = client.get_collection('catchwords')
catchwords_collection.upsert(
    documents = catchwords,
    embeddings = catchwords_embeddings,
    metadatas = meta_data_catchwords,
    ids = citation
)

In [None]:
citation_reasons_collection = client.get_collection('citation_reasons')
citation_reasons_collection.upsert(
    documents = citation_reasons,
    embeddings = citation_reasons_embeddings,
    metadatas = meta_data_citation_reasons,
    ids = citation
)

In [None]:
casetext_collection = client.get_collection('casetext')
casetext_collection.upsert(
    documents = text,
    embeddings = text_embeddings,
    metadatas = meta_data_text,
    ids = citation
)

In [None]:
test_data_rag_input_text_only_embeddings = []
test_data_rag_input_text_reason_embeddings = []
gold_citation = []

for i in tqdm(range(len(test_set))):
    gold_citation.append(test_set[i]['cited_case_name'])

    test_data_rag_input_text_only = test_set[i]['citation_text'].replace(test_set[i]['cited_case_name'], '<CASENAME>')
    test_data_rag_input_text_only_embeddings.append(saved_embeddings[test_data_rag_input_text_only])

    test_data_rag_input_text_reason = test_set[i]['citation_text'].replace(test_set[i]['cited_case_name'], '<CASENAME>') + ' ' + test_set[i]['citation_reason'].replace(test_set[i]['cited_case_name'], '<CASENAME>')
    test_data_rag_input_text_reason_embeddings.append(saved_embeddings[test_data_rag_input_text_reason])

In [None]:
def cal_acc(preds, gold_labels):
    acc_at_1 = 0
    acc_at_5 = 0
    for i in range(len(gold_labels)):
        if gold_labels[i] in preds[i]:
            acc_at_5 += 1
        if gold_labels[i] in preds[i][0]:
            acc_at_1 += 1
    print("Acc@1: ", acc_at_1/len(gold_labels))
    print("Acc@5: ", acc_at_5/len(gold_labels))

def cal_acc_1(preds, gold_labels):
    acc_at_1 = 0
    for i in range(len(gold_labels)):
        if gold_labels[i] in preds[i]:
            acc_at_1 += 1
    print("Acc@1: ", acc_at_1/len(gold_labels))

***Catchwords***

In [None]:
rag_output_catchwords = client.get_collection('catchwords').query(
    query_embeddings=test_data_rag_input_text_only_embeddings,
    n_results=5,
)

In [None]:
with open("auslaw_rag_output_catchwords_top5_citation.json", "w") as file:
    json.dump(rag_output_catchwords['ids'], file)

In [None]:
cal_acc(rag_output_catchwords['ids'], gold_citation)

***RoC Aggregations***

In [None]:
rag_output_citation_reasons = client.get_collection('citation_reasons').query(
    query_embeddings=test_data_rag_input_text_only_embeddings,
    n_results=5,
)

In [None]:
cal_acc(rag_output_citation_reasons['ids'], gold_citation)

***Full Cases***

In [None]:
rag_output_casetext = client.get_collection('casetext').query(
    query_embeddings=test_data_rag_input_text_only_embeddings,
    n_results=5,
)

In [None]:
with open("auslaw_rag_output_casetext_top5_citation.json", "w") as file:
    json.dump(rag_output_casetext['ids'], file)

In [None]:
cal_acc(rag_output_casetext['ids'], gold_citation)

# **(Hybrid Approach) Query Expansion**

In [None]:
test_data_gpt_4o_generated_reason = []
with open('/content/test_data_gpt_4o_generated_reason.jsonl', 'r') as f:
    for line in f.readlines():
        line = json.loads(line.strip())
        test_data_gpt_4o_generated_reason.append(line)

In [None]:
test_data_saul_54b_generated_reason = []
with open('/content/saul_54b_test_citation_reason_pred.jsonl', 'r') as f:
    for line in f.readlines():
        line = json.loads(line.strip())
        test_data_saul_54b_generated_reason.append(line)

In [None]:
test_data_rag_input_text_generated_reason = []
for data in test_data_gpt_4o_generated_reason[0]:
    test_data_rag_input_text_generated_reason.append(data['input'].split('Text: ')[1] + ' ' + data['predicted'].split('Citation Reason: ')[1])

In [None]:
test_data_rag_input_text_saul_54b_generated_reason = []
for data in test_data_saul_54b_generated_reason:
    try:
        test_data_rag_input_text_saul_54b_generated_reason.append(data['input'].split('Text: ')[1] + ' ' + data['predicted'].split('Citation Reason: ')[1])
    except:
        test_data_rag_input_text_saul_54b_generated_reason.append(data['input'].split('Text: ')[1] + ' ' + data['predicted'])

In [None]:
test_data_rag_input_text_generated_reason_embeddings = []

for i in tqdm(range(len(test_data_rag_input_text_generated_reason))):
    test_data_rag_input_text_generated_reason_embeddings.append(model.encode(test_data_rag_input_text_generated_reason[i]))

In [None]:
test_data_rag_input_text_saul_54b_generated_reason_embeddings = []

for i in tqdm(range(len(test_data_rag_input_text_saul_54b_generated_reason))):
    test_data_rag_input_text_saul_54b_generated_reason_embeddings.append(model.encode(test_data_rag_input_text_saul_54b_generated_reason[i]))

***Catchwords (RoC generated by Saul 54B)***

In [None]:
rag_output_catchwords2 = client.get_collection('catchwords').query(
    query_embeddings=test_data_rag_input_text_saul_54b_generated_reason_embeddings,
    n_results=5,
)

In [None]:
cal_acc(rag_output_catchwords2['ids'], gold_citation)

***RoC Aggregations (RoC generated by Saul 54B)***

In [None]:
rag_output_citation_reasons2 = client.get_collection('citation_reasons').query(
    query_embeddings=test_data_rag_input_text_saul_54b_generated_reason_embeddings,
    n_results=5,
)

In [None]:
cal_acc(rag_output_citation_reasons2['ids'], gold_citation)

***Full Cases (RoC generated by Saul 54B)***

In [None]:
rag_output_casetext2 = client.get_collection('casetext').query(
    query_embeddings=test_data_rag_input_text_saul_54b_generated_reason_embeddings,
    n_results=5,
)

In [None]:
cal_acc(rag_output_casetext2['ids'], gold_citation)

In [None]:
citation_pred_test_llama = []
with open('/content/citation_pred_test_10e.jsonl', 'r') as f:
    for line in f.readlines():
        line = json.loads(line.strip())
        citation_pred_test_llama.append(line)


In [None]:
citation_pred_test_saul = []
with open('/content/citation_pred_test_saul_7b.jsonl', 'r') as f:
    for line in f.readlines():
        line = json.loads(line.strip())
        citation_pred_test_saul.append(line)


In [None]:
citation_pred_test_llama[0]

In [None]:
llama_pred_citation = []
llama_pred_citation_reason = []
for i in range(len(citation_pred_test_llama)):
    llama_pred_citation.append(citation_pred_test_llama[i]['predicted'].split('<')[1].split('>')[0].strip())
    llama_pred_citation_reason.append(citation_pred_test_llama[i]['predicted'].split('<')[0].strip().replace(citation_pred_test_llama[i]['predicted'].split('<')[1].split('>')[0].strip(), ''))

In [None]:
saul_pred_citation = []
saul_pred_citation_reason = []
for i in range(len(citation_pred_test_saul)):
    saul_pred_citation.append(citation_pred_test_saul[i]['predicted'].split('<')[1].split('>')[0].strip())
    saul_pred_citation_reason.append(citation_pred_test_saul[i]['predicted'].split('<')[0].strip().replace(citation_pred_test_saul[i]['predicted'].split('<')[1].split('>')[0].strip(), ''))

In [None]:
test_data_rag_input_text_llama_generated_reason_embeddings = []

for i in tqdm(range(len(llama_pred_citation_reason))):
    test_data_rag_input_text_llama_generated_reason_embeddings.append(model.encode(citation_pred_test_llama[i]['input']+ ' ' + llama_pred_citation_reason[i]))

In [None]:
test_data_rag_input_text_saul_generated_reason_embeddings = []

for i in tqdm(range(len(saul_pred_citation_reason))):
    test_data_rag_input_text_saul_generated_reason_embeddings.append(model.encode(citation_pred_test_saul[i]['input']+ ' ' + saul_pred_citation_reason[i]))

***Catchwords (RoC generated by our SFT Saul-7B)***

In [None]:
rag_output_catchwords4 = client.get_collection('catchwords').query(
    query_embeddings=test_data_rag_input_text_saul_generated_reason_embeddings,
    n_results=5,
)

In [None]:
cal_acc(rag_output_catchwords4['ids'], gold_citation)

***RoC Aggregations (RoC generated by our SFT Saul-7B)***

In [None]:
rag_output_citation_reasons4 = client.get_collection('citation_reasons').query(
    query_embeddings=test_data_rag_input_text_saul_generated_reason_embeddings,
    n_results=5,
)

In [None]:
cal_acc(rag_output_citation_reasons4['ids'], gold_citation)

***Full Cases (RoC generated by our SFT Saul-7B)***

In [None]:
rag_output_casetext4 = client.get_collection('casetext').query(
    query_embeddings=test_data_rag_input_text_saul_generated_reason_embeddings,
    n_results=5,
)

In [None]:
cal_acc(rag_output_casetext4['ids'], gold_citation)

# **(Hybrid Approach) Voting Ensemble**

In [None]:
cal_acc_1(saul_pred_citation, gold_citation)

In [None]:
rag_output_citation_reasons4['ids'][0]

In [None]:
agg_pred_citation = []
num_valid_model_pred = 0
num_agg_rag_pred = 0
for i in range(len(saul_pred_citation)):
    if saul_pred_citation[i] in rag_output_citation_reasons4['ids'][i]:
        num_valid_model_pred += 1
        agg_pred_citation.append(saul_pred_citation[i])
    else:
        num_agg_rag_pred += 1
        agg_pred_citation.append(rag_output_citation_reasons4['ids'][i][0])

In [None]:
num_valid_model_pred

In [None]:
num_agg_rag_pred

In [None]:
cal_acc_1(agg_pred_citation, gold_citation)

In [None]:
num_agg_rag_pred

# **(Hybrid Approach) RAG + GPT-4o Ranker**

In [None]:
catchwords_rank_sys_prompt = """
The following description belongs to a case in the NSW Case Law, but with a missing citation showing <CASENAME>. You will be given a brief text, 5 potential citations and their corresponding catchwords. Your task is to rank the 5 potential citations according to what is most likely to be the correct citation in the text. Show your ranking result in a list, separated by '\n'.
"""

catchwords_prompt = """
Text:
TEXT

Potential Citations:

CITATION1
Catchwords: CATCHWORDS1

CITATION2
Catchwords: CATCHWORDS2

CITATION3
Catchwords: CATCHWORDS3

CITATION4
Catchwords: CATCHWORDS4

CITATION5
Catchwords: CATCHWORDS5

"""

In [None]:
roc_rank_sys_prompt = """
The following description belongs to a case in the NSW Case Law, but with a missing citation showing <CASENAME>. You will be given a brief text, 5 potential citations and their corresponding citation reasons. Your task is to rank the 5 potential citations according to what is most likely to be the correct citation in the text. Show your ranking result in a list, separated by '\n'.
"""

roc_prompt = """
Text:
TEXT

Potential Citations:

CITATION1
Citation Reasons: CITATIONREASON1

CITATION2
Citation Reasons: CITATIONREASON2

CITATION3
Citation Reasons: CITATIONREASON3

CITATION4
Citation Reasons: CITATIONREASON4

CITATION5
Citation Reasons: CITATIONREASON5

"""

In [None]:
casetext_rank_sys_prompt = """
The following description belongs to a case in the NSW Case Law, but with a missing citation showing <CASENAME>. You will be given a brief text, 5 potential citations and their corresponding case text. Your task is to rank the 5 potential citations according to what is most likely to be the correct citation in the text. Show your ranking result in a list, separated by '\n'.
"""

casetext_prompt = """
Text:
INPUTTEXT

Potential Citations:

CITATION1
Case Text: CASETEXT1

CITATION2
Case Text: CASETEXT2

CITATION3
Case Text: CASETEXT3

CITATION4
Case Text: CASETEXT4

CITATION5
Case Text: CASETEXT5

"""

**Catchwords**

In [None]:
ranked_citations_catchwords = []
catchwords_retrieved_text_only_top_5_citations = rag_output_catchwords['ids']

for i in tqdm(range(len(test_set))):
    text = test_set[i]['citation_text'].replace(test_set[i]['cited_case_name'], '<CASENAME>')
    citation1 = catchwords_retrieved_text_only_top_5_citations[i][0]
    citation2 = catchwords_retrieved_text_only_top_5_citations[i][1]
    citation3 = catchwords_retrieved_text_only_top_5_citations[i][2]
    citation4 = catchwords_retrieved_text_only_top_5_citations[i][3]
    citation5 = catchwords_retrieved_text_only_top_5_citations[i][4]
    for dict in citation_data:
        if dict['citation'] == citation1:
            catchwords1 = dict['catchwords']
        if dict['citation'] == citation2:
            catchwords2 = dict['catchwords']
        if dict['citation'] == citation3:
            catchwords3 = dict['catchwords']
        if dict['citation'] == citation4:
            catchwords4 = dict['catchwords']
        if dict['citation'] == citation5:
            catchwords5 = dict['catchwords']
    catchwords_prompt_input = catchwords_prompt.replace('TEXT', text).replace('CITATION1', citation1).replace('CITATION2', citation2).replace('CITATION3', citation3).replace('CITATION4', citation4).replace('CITATION5', citation5).replace('CATCHWORDS1', catchwords1).replace('CATCHWORDS2', catchwords2).replace('CATCHWORDS3', catchwords3).replace('CATCHWORDS4', catchwords4).replace('CATCHWORDS5', catchwords5)
    response, cost = gpt4o_prompt(catchwords_rank_sys_prompt, catchwords_prompt_input)
    ranked_citations_catchwords.append(response)

In [None]:
ranked_citations_catchwords_preds = [pred.split('\n') for pred in ranked_citations_catchwords]
cal_acc(ranked_citations_catchwords_preds, gold_citation)

**RoC Aggregations**

In [None]:
ranked_citations_citation_reasons = []
citation_reasons_retrieved_text_only_top_5_citations = rag_output_citation_reasons['ids']

for i in tqdm(range(len(test_set))):
    text = test_set[i]['citation_text'].replace(test_set[i]['cited_case_name'], '<CASENAME>')
    citation1 = citation_reasons_retrieved_text_only_top_5_citations[i][0]
    citation2 = citation_reasons_retrieved_text_only_top_5_citations[i][1]
    citation3 = citation_reasons_retrieved_text_only_top_5_citations[i][2]
    citation4 = citation_reasons_retrieved_text_only_top_5_citations[i][3]
    citation5 = citation_reasons_retrieved_text_only_top_5_citations[i][4]
    for dict in citation_data:
        if dict['citation'] == citation1:
            citation_reasons1 = '; '.join(dict['citation_reasons'])
        if dict['citation'] == citation2:
            citation_reasons2 = '; '.join(dict['citation_reasons'])
        if dict['citation'] == citation3:
            citation_reasons3 = '; '.join(dict['citation_reasons'])
        if dict['citation'] == citation4:
            citation_reasons4 = '; '.join(dict['citation_reasons'])
        if dict['citation'] == citation5:
            citation_reasons5 = '; '.join(dict['citation_reasons'])
    citation_reasons_prompt_input = roc_prompt.replace('TEXT', text).replace('CITATION1', citation1).replace('CITATION2', citation2).replace('CITATION3', citation3).replace('CITATION4', citation4).replace('CITATION5', citation5).replace('CITATIONREASON1', citation_reasons1).replace('CITATIONREASON2', citation_reasons2).replace('CITATIONREASON3', citation_reasons3).replace('CITATIONREASON4', citation_reasons4).replace('CITATIONREASON5', citation_reasons5)
    response, cost = gpt4o_prompt(roc_rank_sys_prompt, citation_reasons_prompt_input)
    ranked_citations_citation_reasons.append(response)

In [None]:
ranked_citations_citation_reasons_preds = [pred.split('\n') for pred in ranked_citations_citation_reasons]
cal_acc(ranked_citations_citation_reasons_preds, gold_citation)

**Full Cases**

In [None]:
ranked_citations_casetext = []
casetext_retrieved_text_only_top_5_citations = rag_output_casetext['ids']

for i in tqdm(range(len(test_set))):
    text = test_set[i]['citation_text'].replace(test_set[i]['cited_case_name'], '<CASENAME>')
    citation1 = casetext_retrieved_text_only_top_5_citations[i][0]
    citation2 = casetext_retrieved_text_only_top_5_citations[i][1]
    citation3 = casetext_retrieved_text_only_top_5_citations[i][2]
    citation4 = casetext_retrieved_text_only_top_5_citations[i][3]
    citation5 = casetext_retrieved_text_only_top_5_citations[i][4]
    for dict in citation_data:
        if dict['citation'] == citation1:
            casetext1 = dict['text']
        if dict['citation'] == citation2:
            casetext2 = dict['text']
        if dict['citation'] == citation3:
            casetext3 = dict['text']
        if dict['citation'] == citation4:
            casetext4 = dict['text']
        if dict['citation'] == citation5:
            casetext5 = dict['text']
    casetext_prompt_input = casetext_prompt.replace('INPUTTEXT', text).replace('CITATION1', citation1).replace('CITATION2', citation2).replace('CITATION3', citation3).replace('CITATION4', citation4).replace('CITATION5', citation5).replace('CASETEXT1', casetext1[:20000]).replace('CASETEXT2', casetext2[:20000]).replace('CASETEXT3', casetext3[:20000]).replace('CASETEXT4', casetext4[:20000]).replace('CASETEXT5', casetext5[:20000])
    response, cost = gpt4o_mini_prompt(casetext_rank_sys_prompt, casetext_prompt_input)
    ranked_citations_casetext.append(response)


In [None]:
ranked_citations_casetext_preds = [pred.split('\n') for pred in ranked_citations_casetext]
cal_acc(ranked_citations_casetext_preds, gold_citation)