In [None]:
import os
import json
import math
import pickle
import random
from tqdm import tqdm

from collections import Counter, defaultdict

### Pay attention to the reproducibility !!

In [None]:
data_dir=f"/shared/data3/bowenj4/llm-graph-plugin/data/processed_data/legal"
downstream_dir=f"/shared/data3/bowenj4/llm-graph-plugin/data/raw_data/legal"

In [None]:
# read processed graph
#graph = pickle.load(open(os.path.join(data_dir, 'graph.pkl'),"rb"))
graph = json.load(open(os.path.join(data_dir, 'graph.json')))
print(graph.keys())

In [None]:
for opcid in list(graph['opinion_cluster_nodes'].keys())[:1000]:
    if isinstance(graph['opinion_cluster_nodes'][opcid]['features']['syllabus'], str):
        print(opcid)

In [None]:
graph['opinion_cluster_nodes']['opc-8599951']['neighbors'].keys()

In [None]:
graph['opinion_cluster_nodes']['opc-4601433']

In [None]:
graph['docket_nodes']['d-65862636']

In [None]:
graph['opinion_cluster_nodes']['opc-8599951']['neighbors']['docket']

In [None]:
graph['opinion_nodes']['op-7344185']

In [None]:
graph['opinion_cluster_nodes']['opc-6381448']

In [None]:
list(graph['opinion_nodes'].keys())[:10]

In [None]:
graph['opinion_nodes']['op-8044194']

In [None]:
len(graph['court_nodes'])

In [None]:
# double check the ids
id_set = set()

for k in graph:
    print(k)
    for idd in tqdm(graph[k]):
        assert idd not in id_set
        id_set.add(idd)

In [None]:
def random_int(start, end, num, target_list):
    res = set()
    while len(res) < num:
        tmp = random.randint(start, end)
        if tmp not in res:
            res.add(tmp)
    return [target_list[i] for i in res]

In [None]:
def text_process(text):
    return text

def ignore_text(text):
    if not isinstance(text, str) and math.isnan(text):
        return True
    if isinstance(text, str) and len(text) == 0:
        return True
    return False

In [None]:
# check duplication
opinion_text_dict = defaultdict(int)
for opinion_id in tqdm(graph['opinion_nodes'].keys()):
    opinion_text = graph['opinion_nodes'][opinion_id]['features']['plain_text']
    opinion_text_dict[opinion_text] += 1

opinion_cluster_text_dict= defaultdict(int)
for opinion_cluster_id in tqdm(graph['opinion_cluster_nodes'].keys()):
    opinion_cluster_text = graph['opinion_cluster_nodes'][opinion_cluster_id]['features']['syllabus']
    opinion_cluster_text_dict[opinion_cluster_text] += 1

# truncation function
from transformers import AutoTokenizer
truncate_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
def truncate(text, max_len=256):
    return truncate_tokenizer.decode(truncate_tokenizer.encode(text, truncation=True, max_length=max_len, add_special_tokens=False))

In [None]:
all_generated_data = {} # key: triple (question (str), answer (str)), value: generated data (List)
k = 10

### Design questions (one type of question in one cell)

### Easy Questions

In [None]:
## question (easy): what is the start date of court xxx?

random.seed(2023)

question = "what is the start date of court {court_name}?"
answer = "{start_date}"
generated_data = []

court_ids = list(graph['court_nodes'].keys())
random.shuffle(court_ids)

for court_id in court_ids:
    court_name = graph['court_nodes'][court_id]['features']['full_name']
    start_date = graph['court_nodes'][court_id]['features']['start_date']
    if isinstance(start_date, str):
        generated_data.append({"court_name":court_name, "start_date": start_date})

    if len(generated_data) == k:
        break

assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

In [None]:
## question (easy): what is the end date of court xxx?

random.seed(2024)

question = "what is the end date of court {court_name}?"
answer = "{end_date}"
generated_data = []

court_ids = list(graph['court_nodes'].keys())
random.shuffle(court_ids)

for court_id in court_ids:
    court_name = graph['court_nodes'][court_id]['features']['full_name']
    end_date = graph['court_nodes'][court_id]['features']['end_date']
    if isinstance(end_date, str):
        generated_data.append({"court_name":court_name, "end_date": end_date})

    if len(generated_data) == k:
        break

assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

In [None]:
## question (easy): what is the citation string of court xxx?

random.seed(2025)

question = "what is the citation string of court {court_name}?"
answer = "{citation_string}"
generated_data = []

court_ids = list(graph['court_nodes'].keys())
random.shuffle(court_ids)

for court_id in court_ids:
    court_name = graph['court_nodes'][court_id]['features']['full_name']
    citation_string = graph['court_nodes'][court_id]['features']['citation_string']

    if ignore_text(citation_string):
        continue

    generated_data.append({"court_name":court_name, "citation_string": citation_string})

    if len(generated_data) == k:
        break

assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

In [None]:
## question (easy): which court is handling the case listed under the PACER docket number xxx?

random.seed(2026)

question = "which court is handling the case listed under the PACER docket number {pacer_id}?"
answer = "{court_name}"
generated_data = []

docket_ids_list = list(graph['docket_nodes'].keys())
docket_ids = random_int(0, len(docket_ids_list)-1, 1000 * k, docket_ids_list)

for docket_id in docket_ids:
    pacer_id = graph['docket_nodes'][docket_id]['features']['pacer_case_id']
    assert len(graph['docket_nodes'][docket_id]['neighbors']['court']) == 1
    court_id = graph['docket_nodes'][docket_id]['neighbors']['court'][0]
    
    if ignore_text(pacer_id):
        continue

    court_name = graph['court_nodes'][court_id]['features']['full_name']
    generated_data.append({"pacer_id": pacer_id, "court_name": court_name})

    if len(generated_data) == k:
        break

assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

In [None]:
## question (easy): Who are the attorneys for the case corresponding to this opinion cluster?

random.seed(2027)

question = "Who are the attorneys for the case corresponding to this opinion cluster: {opinion_cluster_text}?"
answer = "{attorneys}"
generated_data = []

opinion_cluster_ids_list = list(graph['opinion_cluster_nodes'].keys())
# random.shuffle(opinion_ids)
opinion_cluster_ids = random_int(0, len(opinion_cluster_ids_list)-1, 100 * k, opinion_cluster_ids_list)

for opinion_cluster_id in opinion_cluster_ids:
    opinion_cluster_text = graph['opinion_cluster_nodes'][opinion_cluster_id]['features']['syllabus']
    if ignore_text(opinion_cluster_text) or opinion_cluster_text_dict[opinion_cluster_text] != 1:
        continue

    attorneys = graph['opinion_cluster_nodes'][opinion_cluster_id]['features']['attorneys']
    if ignore_text(attorneys):
        continue

    generated_data.append({"opinion_cluster_text":truncate(opinion_cluster_text), "attorneys": attorneys})

    if len(generated_data) == k:
        break

assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

### Medium questions

In [None]:
## question (medium): Which members of the judiciary are responsible for the group of rulings that includes the following opinion: {opinion_text}?

random.seed(2027)

question = "Which members of the judiciary are responsible for the group of rulings that includes the following opinion: {opinion_text}"
answer = "{judges}"
generated_data = []

opinion_ids_list = list(graph['opinion_nodes'].keys())
# random.shuffle(opinion_ids)
opinion_ids = random_int(0, len(opinion_ids_list)-1, 5 * k, opinion_ids_list)

for opinion_id in opinion_ids:
    opinion_text = graph['opinion_nodes'][opinion_id]['features']['plain_text']
    if ignore_text(opinion_text) or opinion_text_dict[opinion_text] != 1:
        continue

    assert len(graph['opinion_nodes'][opinion_id]['neighbors']['opinion_cluster']) == 1
    opinion_cluster_id = graph['opinion_nodes'][opinion_id]['neighbors']['opinion_cluster'][0]
    judges = graph['opinion_cluster_nodes'][opinion_cluster_id]['features']['judges']
    if ignore_text(judges):
        continue

    generated_data.append({"opinion_text":truncate(opinion_text), "judges": judges})

    if len(generated_data) == k:
        break

assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

In [None]:
## question (medium): What docket includes this opinion: xxxx? Please answer with the pacer case ID.

random.seed(2028)

question = "What docket includes this opinion: {opinion_plain_text}? Please answer with the pacer case ID."
answer = "{pacer_id}"
generated_data = []

opinion_ids_list = list(graph['opinion_nodes'].keys())
opinion_ids = random_int(0, len(opinion_ids_list)-1, 10000, opinion_ids_list)

for opinion_id in opinion_ids:
    opinion_content = graph['opinion_nodes'][opinion_id]['features']['plain_text']
    assert len(graph['opinion_nodes'][opinion_id]['neighbors']['opinion_cluster']) == 1

    opinion_cluster_id = graph['opinion_nodes'][opinion_id]['neighbors']['opinion_cluster'][0]
    assert len(graph['opinion_cluster_nodes'][opinion_cluster_id]['neighbors']['docket']) == 1

    docket_id = graph['opinion_cluster_nodes'][opinion_cluster_id]['neighbors']['docket'][0]
    pacer_id = graph['docket_nodes'][docket_id]['features']['pacer_case_id']

    if ignore_text(opinion_content) or ignore_text(pacer_id) or opinion_text_dict[opinion_content] != 1:
        continue

    # raise ValueError('you may add a text processing function here')
    opinion_content = text_process(opinion_content)

    generated_data.append({"opinion_plain_text": truncate(opinion_content), "pacer_id": pacer_id})

    if len(generated_data) == k:
        break

assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

In [None]:
## question (medium): Which court is this opinion cluster syllabus published: {opinion_cluster_text}? Please answer the court full name.

random.seed(2029)

question = "Which court is this opinion cluster syllabus published: {opinion_cluster_text}?"
answer = "{court_name}"
generated_data = []

opinion_cluster_ids_list = list(graph['opinion_cluster_nodes'].keys())
#random.shuffle(opinion_cluster_ids)
opinion_cluster_ids = random_int(0, len(opinion_cluster_ids_list)-1, 1000, opinion_cluster_ids_list)


for opinion_cluster_id in opinion_cluster_ids:
    opinion_cluster_content = graph['opinion_cluster_nodes'][opinion_cluster_id]['features']['syllabus']

    assert len(graph['opinion_cluster_nodes'][opinion_cluster_id]['neighbors']['docket']) == 1
    docket_id = graph['opinion_cluster_nodes'][opinion_cluster_id]['neighbors']['docket'][0]
    
    assert len(graph['docket_nodes'][docket_id]['neighbors']['court']) == 1
    court_id = graph['docket_nodes'][docket_id]['neighbors']['court'][0]
    court_name = graph['court_nodes'][court_id]['features']['full_name']

    if ignore_text(opinion_cluster_content) or opinion_cluster_text_dict[opinion_cluster_content] != 1:
        continue

    # raise ValueError('you may add a text processing function here')
    opinion_cluster_content = text_process(opinion_cluster_content)

    generated_data.append({"opinion_cluster_text": truncate(opinion_cluster_content), "court_name": court_name})

    if len(generated_data) == k:
        break

assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

## easy questions

In [None]:
## question (medium): How many dockets have been processed in court xxx?

random.seed(2030)

question = "How many dockets have been processed in court {court_name}?"
answer = "{num}"

generated_data = []

court_ids_list = list(graph['court_nodes'].keys())
#random.shuffle(court_ids)
court_ids = random_int(0, len(court_ids_list)-1, 100 * k, court_ids_list)

for court_id in court_ids:
    court_name = graph['court_nodes'][court_id]['features']['full_name']

    num = len(graph['court_nodes'][court_id]['neighbors']['docket'])

    if num < 30 and num > 1:
        generated_data.append({"court_name": court_name, "num": num})
    #generated_data.append({"court_name": court_name, "num": num})

    if len(generated_data) == k:
        break

assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

## TODO: we can add more constrains later, e.g., in year xxx

In [None]:
## question (medium): How many opinions are citing this opinion xxx?

random.seed(2031)

question = "How many opinions are citing this opinion: {opinion_text}?"
answer = "{num}"

generated_data = []

opinion_ids_list = list(graph['opinion_nodes'].keys())
#random.shuffle(court_ids)
opinion_ids = random_int(0, len(opinion_ids_list)-1, 1000, opinion_ids_list)

for opinion_id in opinion_ids:
    opinion_text = graph['opinion_nodes'][opinion_id]['features']['plain_text']

    if ignore_text(opinion_text) or opinion_text_dict[opinion_text] != 1:
        continue

    num = len(graph['opinion_nodes'][opinion_id]['neighbors']['cited_by'])

    if num < 30 and num > 1:
        opinion_text = text_process(opinion_text)
        generated_data.append({"opinion_text": truncate(opinion_text), "num": num})

    if len(generated_data) == k:
        break

assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

## TODO: we can add more constrains later, e.g., in year xxx

## medium questions

In [None]:
## question (medium): How many times has the xxx case been judged in different courts?

random.seed(2032)

question = "How many times has the case {case_name} been judged in different courts?"
answer = "{num}"

generated_data = []

opinion_cluster_ids_list = list(graph['opinion_cluster_nodes'].keys())
opinion_cluster_ids = random_int(0, len(opinion_cluster_ids_list)-1, 5 * k, opinion_cluster_ids_list)
seen_name = set()

case_court_set = defaultdict(set)
for op in tqdm(opinion_cluster_ids_list):
    case_name = graph['opinion_cluster_nodes'][op]['features']['case_name']
    docket_id = graph['opinion_cluster_nodes'][op]['neighbors']['docket']
    assert len(docket_id) == 1
    docket_id = docket_id[0]
    court_id = graph['docket_nodes'][docket_id]['neighbors']['court']
    assert len(court_id) == 1
    court_id = court_id[0]
    case_court_set[case_name].add(graph['court_nodes'][court_id]['features']['full_name'])

for opinion_cluster_id in opinion_cluster_ids:
    case_name = graph['opinion_cluster_nodes'][opinion_cluster_id]['features']['case_name']

    num = len(case_court_set[case_name])
    if num > 10 or num < 2 or case_name in seen_name:
        continue

    generated_data.append({"case_name": case_name, "num": num})
    seen_name.add(case_name)

    if len(generated_data) == k:
        break

assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

## TODO: we can add more constrains later, e.g., in year xxx

In [None]:
## question (medium): How many opinions are contained in the opinion clusters about {case_name}?

random.seed(2033)

question = "How many opinions are contained in the opinion clusters about {case_name}?"
answer = "{num}"

generated_data = []

opinion_cluster_ids_list = list(graph['opinion_cluster_nodes'].keys())
opinion_cluster_ids = random_int(0, len(opinion_cluster_ids_list)-1, 5 * k, opinion_cluster_ids_list)
seen_name = set()

case_opinion_num = defaultdict(int)
for op in tqdm(opinion_cluster_ids_list):
    case_name = graph['opinion_cluster_nodes'][op]['features']['case_name']
    opinion_ids = graph['opinion_cluster_nodes'][op]['neighbors']['opinion']
    case_opinion_num[case_name] += len(opinion_ids)

for opinion_cluster_id in opinion_cluster_ids:
    case_name = graph['opinion_cluster_nodes'][opinion_cluster_id]['features']['case_name']

    num = case_opinion_num[case_name]
    if num > 10 or num < 2 or case_name in seen_name:
        continue

    generated_data.append({"case_name": case_name, "num": num})
    seen_name.add(case_name)

    if len(generated_data) == k:
        break

assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

## TODO: we can add more constrains later, e.g., in year xxx

In [None]:
## question (medium): How many opinions are contained in the opinion cluster with syllabus {opinion_cluster_text}?

random.seed(2034)

question = "How many opinions are contained in the opinion cluster with syllabus: {opinion_cluster_text}?"
answer = "{num}"

generated_data = []

opinion_cluster_ids_list = list(graph['opinion_cluster_nodes'].keys())
opinion_cluster_ids = random_int(0, len(opinion_cluster_ids_list)-1, 100000, opinion_cluster_ids_list)
seen_name = set()

cluster_text_num = defaultdict(int)
cluster_opinion_num = defaultdict(int)
for op in tqdm(opinion_cluster_ids_list):
    cluster_text = graph['opinion_cluster_nodes'][op]['features']['syllabus']
    cluster_text_num[cluster_text] += 1

    opinion_ids = graph['opinion_cluster_nodes'][op]['neighbors']['opinion']
    cluster_opinion_num[cluster_text] += len(opinion_ids)

for opinion_cluster_id in opinion_cluster_ids:
    cluster_text = graph['opinion_cluster_nodes'][opinion_cluster_id]['features']['syllabus']
    if cluster_text_num[cluster_text] > 1:
        continue

    num = cluster_opinion_num[cluster_text]
    if num > 10 or num < 2 or cluster_text in seen_name:
        continue

    generated_data.append({"opinion_cluster_text": truncate(cluster_text), "num": num})
    seen_name.add(cluster_text)

    if len(generated_data) == k:
        break

assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

## TODO: we can add more constrains later, e.g., in year xxx

In [None]:
## question (medium): How many opinions are contained in the opinion cluster with opinion xxx?

random.seed(2035)

question = "How many opinions are contained in the opinion cluster with opinion {opinion_text}?"
answer = "{num}"

generated_data = []

opinion_ids_list = list(graph['opinion_nodes'].keys())
opinion_ids = random_int(0, len(opinion_ids_list)-1, 100000, opinion_ids_list)
seen_name = set()

opinion_text_num = defaultdict(int)
for op in tqdm(opinion_ids_list):
    opinion_text = graph['opinion_nodes'][op]['features']['plain_text']
    opinion_text_num[opinion_text] += 1

for opinion_id in opinion_ids:
    opinion_text = graph['opinion_nodes'][opinion_id]['features']['plain_text']
    if opinion_text_num[opinion_text] > 1:
        continue

    opinion_cluster_id = graph['opinion_nodes'][opinion_id]['neighbors']['opinion_cluster']
    assert len(opinion_cluster_id) == 1
    opinion_cluster_id = opinion_cluster_id[0]

    num = len(graph['opinion_cluster_nodes'][opinion_cluster_id]['neighbors']['opinion'])
    if num > 10 or num < 2 or opinion_text in seen_name:
        continue

    generated_data.append({"opinion_text": truncate(opinion_text), "num": num})
    seen_name.add(opinion_text)

    if len(generated_data) == k:
        break

assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

## TODO: we can add more constrains later, e.g., in year xxx

In [None]:
## question (medium): Which court is this opinion ({opinion text}) published?

random.seed(2036)

question = "Which court is this opinion ({opinion_text}) published?"
answer = "{court_name}"

generated_data = []

opinion_ids_list = list(graph['opinion_nodes'].keys())
opinion_ids = random_int(0, len(opinion_ids_list)-1, 100000, opinion_ids_list)
seen_name = set()

opinion_text_num = defaultdict(int)
for op in tqdm(opinion_ids_list):
    opinion_text = graph['opinion_nodes'][op]['features']['plain_text']
    opinion_text_num[opinion_text] += 1

for opinion_id in opinion_ids:
    opinion_text = graph['opinion_nodes'][opinion_id]['features']['plain_text']
    if opinion_text_num[opinion_text] > 1:
        continue

    opinion_cluster_id = graph['opinion_nodes'][opinion_id]['neighbors']['opinion_cluster']
    assert len(opinion_cluster_id) == 1
    opinion_cluster_id = opinion_cluster_id[0]

    docket_id = graph['opinion_cluster_nodes'][opinion_cluster_id]['neighbors']['docket']
    assert len(docket_id) == 1
    docket_id = docket_id[0]

    court_id = graph['docket_nodes'][docket_id]['neighbors']['court']
    assert len(court_id) == 1
    court_id = court_id[0]
    court_name = graph['court_nodes'][court_id]['features']['full_name']

    if opinion_text in seen_name:
        continue

    generated_data.append({"opinion_text": truncate(opinion_text), "court_name": court_name})
    seen_name.add(opinion_text)

    if len(generated_data) == k:
        break

assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

## TODO: we can add more constrains later, e.g., in year xxx

In [None]:
## question (medium): What is the preferred cited court of judges in court {court_name}?

random.seed(2037)

question = "What is the preferred court to cite of judges in court {source_court_name}?"
answer = "{target_court_name}"

generated_data = []

opinion_ids_list = list(graph['opinion_nodes'].keys())
#opinion_ids = random_int(0, len(opinion_ids_list)-1, 10000, opinion_ids_list)
court_ids_list = list(graph['court_nodes'].keys())
court_ids = random_int(0, len(court_ids_list)-1, 1000, court_ids_list)
seen_name = set()

def get_court_name(opinion_id):
    opinion_cluster_id = graph['opinion_nodes'][opinion_id]['neighbors']['opinion_cluster']
    assert len(opinion_cluster_id) == 1
    opinion_cluster_id = opinion_cluster_id[0]

    docket_id = graph['opinion_cluster_nodes'][opinion_cluster_id]['neighbors']['docket']
    assert len(docket_id) == 1
    docket_id = docket_id[0]

    court_id = graph['docket_nodes'][docket_id]['neighbors']['court']
    assert len(court_id) == 1
    court_id = court_id[0]
    court_name = graph['court_nodes'][court_id]['features']['full_name']
    return court_name

#opinion_text_num = defaultdict(int)
#court2court = defaultdict(defaultdict(int))
court2court = dict()
for op in tqdm(opinion_ids_list):
    #opinion_text = graph['opinion_nodes'][op]['features']['plain_text']
    #opinion_text_num[opinion_text] += 1

    # source
    source_court_name = get_court_name(op)
    for ref in graph['opinion_nodes'][op]['neighbors']['reference']:
        target_court_name = get_court_name(ref)
        if source_court_name not in court2court:
            court2court[source_court_name] = defaultdict(int)

        court2court[source_court_name][target_court_name] += 1

for court_id in court_ids:
    court_name = graph['court_nodes'][court_id]['features']['full_name']

    if court_name in seen_name or court_name not in court2court:
        continue

    candidate = [(tgt_court, court2court[court_name][tgt_court]) for tgt_court in court2court[court_name]]
    candidate.sort(key= lambda x: -x[1])

    if candidate[0][1] == candidate[1][1]:
        continue

    generated_data.append({"source_court_name": court_name, "target_court_name": candidate[0][0]})
    seen_name.add(court_name)

    if len(generated_data) == k:
        break

assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

## TODO: we can add more constrains later, e.g., in year xxx

### Inductive Reasoning (hard)

In [None]:
# download data

from datasets import load_dataset

lp_cls_dataset = load_dataset("nguha/legalbench", 'citation_prediction_classification', split="test")
lp_open_dataset = load_dataset("nguha/legalbench", 'citation_prediction_open', split="test")

case_name_set = set()
docket_ids_list = list(graph['docket_nodes'].keys())

for docket_id in tqdm(docket_ids_list):
    case_name = graph['docket_nodes'][docket_id]['features']['case_name']
    case_name_set.add(case_name)

In [None]:
# citation classification
## question (hard): Is the given sentence supported by the given case? Sentence: {text}, case: {case_name}.

random.seed(2038)

question = "Is the given sentence supported by the given case? Sentence: {text}, case: {case_name}."
answer = "{answer}"

generated_data = []

exist_data = []
for row in lp_cls_dataset:
    if row['citation'] in case_name_set:
        exist_data.append((row['text'], row['citation'], row['answer']))
random.shuffle(exist_data)
print(len(exist_data))

for d in exist_data[:k]:
    generated_data.append({"text": d[0], "case_name": d[1], "answer": d[2]})

assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

In [None]:
# citation open
## question (hard): Find a case which can support this sentence: {text}.

random.seed(2039)

question = "Find a case which can support this sentence: {text}."
answer = "{case_name}"

generated_data = []

exist_data = []
for row in lp_open_dataset:
    if row['answer'] in case_name_set:
        exist_data.append((row['text'], row['answer']))
random.shuffle(exist_data)
print(len(exist_data))

for d in exist_data[:k]:
    generated_data.append({"text": d[0], "case_name": d[1]})

assert len(generated_data) == k
all_generated_data[(question, answer)] = generated_data

In [None]:
## save
pickle.dump(all_generated_data, open(os.path.join(data_dir, 'preprocess_samples.pkl'), 'wb'))

print('Saving file of #questions, ', len(all_generated_data))

In [None]:
## nodes: opinion, opinion_cluster, court, docket
## opinion features: plain_text
## opinion_cluster features: syllabus, judges, case_name
## court: full_name, start_date, end_date, citation_string
## docket: pacer_case_id, case_name_full