In [1]:
from datasets import load_dataset
import numpy as np

multi_lexsum = load_dataset("allenai/multi_lexsum", name="v20230518")
modified_dataset = multi_lexsum["test"].filter(lambda x: x["summary/short"] != None)

In [2]:
import tiktoken
# import torch


# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# To get the tokeniser corresponding to a specific model in the OpenAI API:
tokenizer = tiktoken.encoding_for_model("gpt-3.5")

# embedding_model = SentenceTransformer("../models/multi-qa-mpnet-base-dot-v1", device = device).half()

In [3]:
modified_dataset["summary/short"][0]

"The plaintiffs filed a lawsuit on March 8, 2014, alleging that the City of Montgomery, Alabama, improperly imprisoned them for failing to pay traffic fines. They alleged that they did not have an ability to pay the fines due to their financial circumstances and that the city did not consider their ability to pay. On May 1, 2014, the District Court granted the plaintiffs motion for a preliminary injunction, preventing the city from collecting more money from traffic tickets of plaintiffs'. On October 31, 2014 the parties filed to dismiss the case pursuant to a settlement agreement, which included numerous changes to Municipal Court proceedings. The case is now closed."

In [4]:
# user_prompt = "Summarize concisely the following legal texts. Include as many relevant facts as possible. A fact is relevant if it mentions plaintiffs, counsel, type of action, filling date, name of the court, description of class, defendants, statuatory basis, rought remedy, judges, consolidated class, whether it is a class action, date of decree, citations, duration of decrees, last action in case."
system_prompt = "You are a legal expert. You must answer concisely and truthfully, including only information that is relevant to the conversation. Stay faithful to the original text and keep the exact wording as found in the text as closely as possible. Only include facts relevant to the text, without any filler words."

max_context_length = 16384
max_output_length = 130
length_system_prompt = len(tokenizer.encode(system_prompt))
print(length_system_prompt)

60


In [5]:
from openai import OpenAI
from env_utils import load_env_from_file
import json
import os

load_env_from_file(".")
client = OpenAI()

In [60]:
# user_prompt_cod = """TEXT:<INPUT>{INPUT}</INPUT>
# Generate increasingly concise, fact-dense summaries of the text above. Repeat the following 2 steps 3 times:
# Step 1. Identify 1-3 relevant facts from the INPUT which are missing from the previously generated summary.
# Step 2. Write a new, denser summary of identical length which covers all facts from the previous summary that includes the relevant facts you found.
# A relevant fact is:
# * plaintiffs
# * counsel
# * taken actions
# * dates
# * name of court
# * defendants
# * statutory basis
# * sought remedy
# * judges
# * case consolidation
# * class action
# * date of decrees
# * duration of decrees
# * citations
# * last action in the case
# Guidelines:
# * the first summary must be 130 words long, containing general information about the text
# * rewrite the first summary and add more relevant facts
# * make space by removing any filler words or phrases that are uninformative
# * the summaries should be highly dense and concise
# * relevant facts can be added anywhere in the summary
# * never remove relevant facts from a previous summary. If it is not possible to make more space, add fewever new facts
# """

In [65]:
from tqdm import tqdm
import pickle

def check_prompt_length(prompt):
    return len(tokenizer.encode(prompt)) + max_output_length + length_system_prompt > max_context_length

def from_extracted(path, test_size, current_length, sentences=2):
    summs = []
    files = os.listdir(path)
    files = sorted(files, key = lambda x: int(x.split(".")[0]))
    for file in files:
        if int(file.split(".")[0]) < test_size:
            docs = json.load(open(path + file, "r"))

            doc = "".join(["".join(doc_sentences[:sentences]) for doc_sentences in docs])
            if len(tokenizer.encode(doc)) + current_length > max_context_length:
                doc = ""
                sent_limit = sentences-1
                while (doc == "" and len(doc) + current_length < max_context_length) and sent_limit > 0:
                    doc = "".join(["".join(doc_sentences[:sent_limit]) for doc_sentences in docs])
                    sent_limit -= 1

            summs.append(doc)

    return summs

def completion_with_retry(**kwargs):
    return client.chat.completions.create(**kwargs)

user_prompt_summary_basic = "Summarize the text:<INPUT>{INPUT}</INPUT>\nSummary:"
# user_prompt_summary_detailed = "Imagine you're a legal scholar tasked with distilling the essence of a complex case law into a concise and compelling summary. Your summary must encompass a rich tapestry of legal elements, including the identities of the plaintiffs and defendants, the brilliant minds behind the legal arguments (counsel), the decisive actions taken, the chronological tapestry of dates, the distinguished court presiding over the matter, the statutory framework underpinning the case, the sought-after remedy, the esteemed judges imparting wisdom, any case consolidations adding layers of complexity, the potential for class action ramifications, the pivotal date of decrees shaping the outcome, the temporal span of decree effectiveness, the authoritative citations grounding the legal analysis, and the climactic last action in the case. Your summary should weave these elements together seamlessly, ensuring a vivid and comprehensive portrayal of the legal landscape: <INPUT>{INPUT}</INPUT>\nSummary:"
user_prompt_length = len(tokenizer.encode(user_prompt_summary_basic))
test_size = 75
extract_types = np.asarray(["random_selection", "first5last5", "random_selection_bert", "first5last5_bert"])
model = "gpt-3.5-turbo-1106"
prompt_type = "basic"

if model not in os.listdir("answers"):
    os.mkdir(f"answers/{model}")
if prompt_type not in os.listdir(f"answers/{model}"):
    os.mkdir(f"answers/{model}/{prompt_type}")

for extract_sum_type in extract_types[1:]:
    responses = []
    basic_responses = []
    path = f"extracted_sums/extracted_sums_json_{extract_sum_type}/"

    extracted_summaries = from_extracted(path, test_size=test_size, sentences=2, current_length=user_prompt_length+length_system_prompt+max_output_length)

    if extract_sum_type not in os.listdir(f"answers/{model}/{prompt_type}/"):
        os.mkdir(f"answers/{model}/{prompt_type}/{extract_sum_type}/")

    for idx, summ in enumerate(tqdm(extracted_summaries[:test_size])):
        if f"{idx}.json" in os.listdir(f"answers/{model}/{prompt_type}"):
            continue
        basic_prompt = user_prompt_summary_basic.format(INPUT=summ)
        message_history = [
                {
                    "role": "system",
                    "content": system_prompt,
                },
                {
                    "role": "user",
                    "content": basic_prompt
                }
            ]
        chat_completion = completion_with_retry(
            messages=message_history,
            model=model,
            temperature=0,
            frequency_penalty=0,
            presence_penalty=0,
        )

        json.dump([*message_history, {"role": "assistant", "content": chat_completion.choices[0].message.content}], open(f"answers/{model}/{prompt_type}/{extract_sum_type}/{idx}.json", "w"), indent=2)    

100%|██████████| 75/75 [02:10<00:00,  1.74s/it]
100%|██████████| 75/75 [01:54<00:00,  1.52s/it]
100%|██████████| 75/75 [01:43<00:00,  1.38s/it]


In [None]:
# from tqdm import tqdm
# import pickle

# def from_extracted(path, test_size):
#     summs = []
#     files = os.listdir(path)
#     files = sorted(files, key = lambda x: int(x.split(".")[0]))
#     for file in files:
#         if int(file.split(".")[0]) < test_size:
#             docs = json.load(open(path + file, "r"))
#             doc = "".join(["".join(doc_sentences[:2]) for doc_sentences in docs])
#             summs.append(doc)

#     return summs

# ## CoT summarization
# test_size = 616
# extract_types = np.asarray(["random_selection", "first5last5", "random_selection_bert", "first5last5_bert"])
# prompt_type = "1shot_cot_summarization"
# for extract_sum_type in extract_types[[0,2]]:
#     responses = []
#     basic_responses = []
#     path = f"extracted_sums/extracted_sums_json_{extract_sum_type}/"

#     user_prompt_summary_basic = "Summarize the text below in 130 words. Let's think about it carefully, considering the importance of each fact in the final summary."\
#           + "\n\nSOURCE:{{\n{SOURCE}\n}}\n\nSUMMARY:{{\n{SUMMARY}\n}}\n\nSOURCE:{{\n{SOURCE_Q}\n}}\n\nSUMMARY:"
#     user_prompt_revision = "Let's have another look through the summary and source text. Include more important facts, namely: plaintiffs, counsel, taken actions, dates, name of court, defendants, statutory basis, sought remedy, judges, case consolidation, class action, date of decrees, duration of decrees, citations, last action in the case, but within 130 words."
#     extracted_summaries = from_extracted(path, test_size=test_size)

#     for summ in tqdm(extracted_summaries[1:]):
#         basic_prompt = user_prompt_summary_basic.format(SOURCE=extracted_summaries[0], SUMMARY=modified_dataset["summary/short"][0], SOURCE_Q=summ)
#         # original summary
#         completion = client.chat.completions.create(
#             model = "gpt-3.5-turbo-16k",
#             messages=[
#                 {"role": "system", "content": system_prompt},
#                 {"role": "user", "content": basic_prompt}
#             ],
#             frequency_penalty=0,
#             presence_penalty=0,
#             top_p=0.2,
#             max_tokens=250,
#             stop=["SOURCE"]
#         )

#         basic_summary = completion.choices[0].message.content

#         # # get elements from text
#         completion = client.chat.completions.create(
#             model = "gpt-3.5-turbo-1161",
#             messages=[
#                 {"role": "system", "content": system_prompt},
#                 {"role": "user", "content": basic_prompt},
#                 {"role": "assistant", "content": basic_summary},
#                 {"role": "user", "content": user_prompt_revision}
#             ],
#             frequency_penalty=0,
#             presence_penalty=0,
#             top_p=0.2,
#             max_tokens=250,
#             stop=["SOURCE"]
#         )

#         responses.append(completion.choices[0].message.content)
#         basic_responses.append(basic_summary)

#     pickle.dump(responses, open(f"{test_size}_predicted_text_{prompt_type}_{extract_sum_type}.pickle", "wb"))
#     pickle.dump(basic_responses, open(f"{test_size}_predicted_basic_text_{prompt_type}_{extract_sum_type}.pickle", "wb"))

In [None]:
import evaluate

# ['led: rouge1: 45.89', 'led: rouge2: 23.00', 'led: rougeL: 31.17', 'led: rougeLsum: 32.01']
# ['primera: rouge1: 42.87', 'primera: rouge2: 20.79', 'primera: rougeL: 29.31', 'primera: rougeLsum: 29.79']

rouge_scoring = evaluate.load("rouge")
print(rouge_scoring.compute(predictions=responses, references=modified_dataset[:test_size]["summary/short"], use_stemmer = True))

In [None]:
# simple-random: {'rouge1': 0.3669443053395881, 'rouge2': 0.12299475909651572, 'rougeL': 0.2148706258231231, 'rougeLsum': 0.22091825483890215}
# simple-5/5: {'rouge1': 0.3600777579472476, 'rouge2': 0.11723337466920167, 'rougeL': 0.21617088924165473, 'rougeLsum': 0.22142449133012626}
# simple-randombert: {'rouge1': 0.3542897797369077, 'rouge2': 0.10311088709231235, 'rougeL': 0.2069714451531819, 'rougeLsum': 0.20992633649991677}
# simple-5/5bert: {'rouge1': 0.35138148214931236, 'rouge2': 0.104968022054537, 'rougeL': 0.20538624503075653, 'rougeLsum': 0.20893048554674404}

