In [None]:
from chroma import Chroma
from embeddings import OpenAIEmbeddings
import os
import chromadb
import json
from tqdm import tqdm
os.environ['OPENAI_API_KEY'] = ''
import pickle
import pandas as pd

In [None]:
embd = OpenAIEmbeddings()
client = chromadb.PersistentClient(path="lawllm")
db = Chroma(client=client, embedding_function=embd, persist_directory='lawllm')

In [None]:
def get_text(data):
    if isinstance(data, str):
        return data
    if isinstance(data, list):
        return '   \n'.join([get_text(i) for i in data])
    if isinstance(data, dict):
        return get_text(data['text'])

In [None]:
test_cases = os.listdir('test_folder')
test_cases = [i for i in test_cases if i.endswith('.json')]
test_data = dict()
for i in range(len(test_cases)):
    with open('test_folder/' + test_cases[i], 'r') as file:
        try:
            d = json.load(file)
            test_data[test_cases[i]] = d
            test_cases[i] = d
        except:
            print('error in file: test_folder/' + test_cases[i])
for i in test_cases:
    assert 'casebody' in i
    assert 'data' in i['casebody']


In [None]:
train_cases = os.listdir('train_folder')
train_cases = [i for i in train_cases if i.endswith('.json')]
train_data = dict()
for i in range(len(train_cases)):
    with open('train_folder/' + train_cases[i], 'r') as file:
        try:
            d = json.load(file)
            train_data[train_cases[i]] = d
            train_cases[i] = d
        except:
            print('error in file: train_folder/' + train_cases[i])
for i in train_cases:
    assert 'casebody' in i
    assert 'data' in i['casebody']


In [None]:
data = []
metadata = []
for id, i in enumerate(train_cases):
    metadata.append(i)
    data.append("")
    for k, v in i['casebody']['data'].items():
        if isinstance(v, str):
            data[id] += f"{k}: {v}\n"
        elif isinstance(v, list):
            tmp = "   \n".join(map(get_text, v))
            data[id] += f"{k}: \n{tmp}\n"
        else:
            print(f"Unknown type: {type(v)}")
HASH = dict()
for i in range(len(data)):
    HASH[data[i]] = metadata[i]
pickle.dump(HASH, open('train_cases_content2metadata.pkl', 'wb'))
metadata = [{'name': i['name_abbreviation']} for i in metadata]
db.add_texts(data, metadata)

In [None]:
data = []
metadata = []
for id, i in enumerate(test_cases):
    metadata.append(i)
    data.append("")
    for k, v in i['casebody']['data'].items():
        if isinstance(v, str):
            data[id] += f"{k}: {v}\n"
        elif isinstance(v, list):
            tmp = "   \n".join(map(get_text, v))
            data[id] += f"{k}: \n{tmp}\n"
        else:
            print(f"Unknown type: {type(v)}")
HASH = dict()
for i in range(len(data)):
    HASH[data[i]] = metadata[i]
pickle.dump(HASH, open('test_cases_content2metadata.pkl', 'wb'))
metadata = [{'name': i['name_abbreviation']} for i in metadata]
db.add_texts(data, metadata)

In [None]:
def get_text(i):
    def _get_text(data):
        if isinstance(data, str):
            return data
        if isinstance(data, list):
            return '   \n'.join([_get_text(i) for i in data])
        if isinstance(data, dict):
            return _get_text(data['text'])
    ans = ""
    for k, v in i['casebody']['data'].items():
        if isinstance(v, str):
            ans += f"{k}: {v}\n"
        elif isinstance(v, list):
            tmp = "   \n".join(map(_get_text, v))
            ans += f"{k}: \n{tmp}\n"
        else:
            print(f"Unknown type: {type(v)}")
    return ans


In [None]:
df = pd.read_csv('capstone_1000.csv', encoding='latin1')

In [None]:
summary = df['1'].to_list()
content = df['2'].to_list()
id2summary = dict()

for i in range(len(summary)):
    if isinstance(summary[i], float):
        continue
    try:
        tmp = eval(summary[i])
        id2summary[tmp['id']] = dict()
        id2summary[tmp['id']]['data'] = tmp
        id2summary[tmp['id']]['content'] = content[i]
    except Exception as e:
        # print(e)
        # raise e
        pass


In [None]:
num = len(test_cases)
ans = []
i = 0
while len(ans) < num and i < num:
    text = get_text(test_cases[i])
    i += 1
    if HASH[text]['id'] not in id2summary:
        continue
    tmp = list(map(lambda x : x.page_content, db.similarity_search(text, k=110)))[1:]
    ans.append(dict(case_text=id2summary.get(HASH[text]['id'])['content'], case_data=HASH[text], choice=[]))
    for j in tmp:
        _data = HASH[j]
        id = _data['id']
        if id not in id2summary:
            continue
        ans[-1]['choice'].append(id2summary.get(id, dict(data=_data, content = get_text(_data))))

pickle.dump(ans, open('case_test.pkl', 'wb'))

In [None]:
prompt = f"""
You are a legal expert who specializes in comparing user-supplied legal cases to a list of candidate legal cases , which includes titles and content. Your main function is to identify and output the title of the most similar case from the list based on the description provided.
You should only output the case title and not any other information.
"""

In [None]:
# to text
num_choices = 10
for i in ans:
    text = prompt + "\nThis is the case:\n" + i['case_text'] + "\n\n"
    text += "Here are the choices:\n"
    for j in i['choice'][:num_choices]:
        text += j['data']['name'] + " : " + j['content'] + "\n\n"
        # print(j['content'])
    text += "Please choose the most similar one to the case above. You should only output the case title and not any other information."
    print(text)
