In [1]:
from datasets import load_dataset
import numpy as np

multi_lexsum = load_dataset("allenai/multi_lexsum", name="v20230518")
modified_dataset = multi_lexsum["test"].filter(lambda x: x["summary/short"] != None)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import tiktoken
import torch


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# To get the tokeniser corresponding to a specific model in the OpenAI API:
tokenizer = tiktoken.encoding_for_model("gpt-3.5")

In [3]:
modified_dataset["summary/short"][0]

"The plaintiffs filed a lawsuit on March 8, 2014, alleging that the City of Montgomery, Alabama, improperly imprisoned them for failing to pay traffic fines. They alleged that they did not have an ability to pay the fines due to their financial circumstances and that the city did not consider their ability to pay. On May 1, 2014, the District Court granted the plaintiffs motion for a preliminary injunction, preventing the city from collecting more money from traffic tickets of plaintiffs'. On October 31, 2014 the parties filed to dismiss the case pursuant to a settlement agreement, which included numerous changes to Municipal Court proceedings. The case is now closed."

In [4]:
# user_prompt = "Summarize concisely the following legal texts. Include as many relevant facts as possible. A fact is relevant if it mentions plaintiffs, counsel, type of action, filling date, name of the court, description of class, defendants, statuatory basis, rought remedy, judges, consolidated class, whether it is a class action, date of decree, citations, duration of decrees, last action in case."
system_prompt = "You are a legal expert. You must answer concisely and truthfully, including only information that is relevant to the conversation. Stay faithful to the original text and keep the exact wording as found in the text as closely as possible. Only include facts relevant to the text, without any filler words."

max_context_length = 18000
max_output_length = 130
length_system_prompt = len(tokenizer.encode(system_prompt))
print(length_system_prompt)

60


In [5]:
import json
import os

from env_utils import load_env_from_file
from groq import Groq

load_env_from_file(".")

client = Groq(
    api_key=os.environ.get("GROQ_API_KEY")
)

In [6]:
user_prompt_summary_basic = "Summarize the text:<INPUT>{INPUT}</INPUT>\nSummary:"
# user_prompt_summary_detailed = "Imagine you're a legal scholar tasked with distilling the essence of a complex case law into a concise and compelling summary. Your summary must encompass a rich tapestry of legal elements, including the identities of the plaintiffs and defendants, the brilliant minds behind the legal arguments (counsel), the decisive actions taken, the chronological tapestry of dates, the distinguished court presiding over the matter, the statutory framework underpinning the case, the sought-after remedy, the esteemed judges imparting wisdom, any case consolidations adding layers of complexity, the potential for class action ramifications, the pivotal date of decrees shaping the outcome, the temporal span of decree effectiveness, the authoritative citations grounding the legal analysis, and the climactic last action in the case. Your summary should weave these elements together seamlessly, ensuring a vivid and comprehensive portrayal of the legal landscape: <INPUT>{INPUT}</INPUT>\nSummary:"


In [11]:
from tqdm import tqdm
import pickle
import time
import logging
from tenacity import retry, stop_after_attempt, wait_random_exponential

def from_extracted(path, test_size, current_length, sentences=2):
    summs = []
    files = os.listdir(path)
    files = sorted(files, key = lambda x: int(x.split(".")[0]))
    for file in files:
        if int(file.split(".")[0]) < test_size:
            docs = json.load(open(path + file, "r"))

            doc = "".join(["".join(doc_sentences[:sentences]) for doc_sentences in docs])
            if len(tokenizer.encode(doc)) + current_length > max_context_length:
                doc = ""
                sent_limit = sentences-1
                while (doc == "" and len(doc) + current_length < max_context_length) and sent_limit > 0:
                    doc = "".join(["".join(doc_sentences[:sent_limit]) for doc_sentences in docs])
                    sent_limit -= 1

            summs.append(doc)

    return summs

@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(5))
def completion_with_retry(file_path, **kwargs):
    # file_path: model/prompt_type/extract_sum_type/idx
    model, prompt_type, extract_sum_type, idx = file_path.split("/")[1:]
    success = False
    try:
        chat_completion = client.chat.completions.create(**kwargs)
        json.dump([*message_history, {"role": "assistant", "content": chat_completion.choices[0].message.content}], open(f"{file_path}.json", "w"), indent=2)  
        success = True  
    except Exception as exc:
        logging.error(f"{type(exc).__name__} ({exc.args}) - {prompt_type} - {extract_sum_type} - {idx}")
        success = False  

    return success

test_size = 616
extract_types = np.asarray(["random_selection", "first5last5", "random_selection_bert", "first5last5_bert"])
model = "mixtral-8x7b-32768"
# model = "gemma-7b-it"
user_prompt_length = len(tokenizer.encode(user_prompt_summary_basic))
prompt_type = "basic"
logger = logging.getLogger(model)
logging.basicConfig(filename=f"log_{model}.log", encoding="utf-8", level=logging.WARNING)
if model not in os.listdir("answers"):
    os.mkdir(f"answers/{model}")

if prompt_type not in os.listdir(f"answers/{model}"):
    os.mkdir(f"answers/{model}/{prompt_type}")

for extract_sum_type in extract_types:
    responses = []
    basic_responses = []
    path = f"extracted_sums/extracted_sums_json_{extract_sum_type}/"
    if extract_sum_type not in os.listdir(f"answers/{model}/{prompt_type}/"):
        os.mkdir(f"answers/{model}/{prompt_type}/{extract_sum_type}/")

    extracted_summaries = from_extracted(path, test_size=test_size, sentences=2, current_length=max_output_length+user_prompt_length+length_system_prompt)
    errors = 0
    iterator = tqdm(extracted_summaries[:test_size], desc=f"Errors={errors}")
    for idx, summ in enumerate(iterator):
        if f"{idx}.json" in os.listdir(f"answers/{model}/{prompt_type}/{extract_sum_type}/"):
            continue
        basic_prompt = user_prompt_summary_basic.format(INPUT=summ)
        message_history = [
                {
                    "role": "system",
                    "content": system_prompt,
                },
                {
                    "role": "user",
                    "content": basic_prompt
                }
            ]
        chat_completion_success = completion_with_retry(
            file_path=f"answers/{model}/{prompt_type}/{extract_sum_type}/{idx}",
            messages=message_history,
            model=model,
            temperature=0
        )

        if not chat_completion_success:
            errors += 1
            iterator.set_description(f"Errors={errors}")

Errors=0: 100%|██████████| 616/616 [00:00<00:00, 3623.49it/s]
Errors=0: 100%|██████████| 616/616 [00:00<00:00, 3583.27it/s]
Errors=0: 100%|██████████| 616/616 [00:00<00:00, 3640.24it/s]
Errors=0: 100%|██████████| 616/616 [00:00<00:00, 3616.84it/s]


In [None]:
import evaluate

# ['led: rouge1: 45.89', 'led: rouge2: 23.00', 'led: rougeL: 31.17', 'led: rougeLsum: 32.01']
# ['primera: rouge1: 42.87', 'primera: rouge2: 20.79', 'primera: rougeL: 29.31', 'primera: rougeLsum: 29.79']

rouge_scoring = evaluate.load("rouge")
print(rouge_scoring.compute(predictions=responses, references=modified_dataset[:test_size]["summary/short"], use_stemmer = True))

{'rouge1': 0.38965949863751514, 'rouge2': 0.12448661981111782, 'rougeL': 0.21450728498797272, 'rougeLsum': 0.21450562157362763}


In [None]:
# simple-random: {'rouge1': 0.3669443053395881, 'rouge2': 0.12299475909651572, 'rougeL': 0.2148706258231231, 'rougeLsum': 0.22091825483890215}
# simple-5/5: {'rouge1': 0.3600777579472476, 'rouge2': 0.11723337466920167, 'rougeL': 0.21617088924165473, 'rougeLsum': 0.22142449133012626}
# simple-randombert: {'rouge1': 0.3542897797369077, 'rouge2': 0.10311088709231235, 'rougeL': 0.2069714451531819, 'rougeLsum': 0.20992633649991677}
# simple-5/5bert: {'rouge1': 0.35138148214931236, 'rouge2': 0.104968022054537, 'rougeL': 0.20538624503075653, 'rougeLsum': 0.20893048554674404}

