In [1]:
import torch
import os
import evaluate
rouge_score = evaluate.load("rouge")

gpu_id = 2
device = f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu"
print(device)

cuda:2


In [3]:
from datasets import load_dataset
multi_lexsum = load_dataset("allenai/multi_lexsum", name="v20220616")
modified_dataset = multi_lexsum["test"].filter(lambda x: x["summary/short"] != None)
# modified_dataset = modified_dataset.map(lambda x: {"sources": [a.encode('utf-8').decode('utf-8').strip().split(' ') for a in x["sources"]]})

In [7]:
import re

def split_into_paras(sources):
    all_docs = []
    for source in sources:
        sents = []
        for doc in source:
            text = re.split("\n", doc)
            text = [sentence for sentence in text if sentence != ""]
            sents.append(text)
        all_docs.extend(sents)
    return all_docs

In [2]:
from lexrank import STOPWORDS, LexRank
import os
import dill

if "lexrank.pickle" not in os.listdir():
    training_docs = split_into_paras(modified_dataset["sources"][:100])
    lex_rank = LexRank(training_docs, stopwords = STOPWORDS["en"], show_progress = True)
    dill.dump(lex_rank, open("lexrank.pickle", "wb"))
else:
    lex_rank = dill.load(open("lexrank.pickle", "rb"))

In [6]:
# t = "test \x9b"
# print(t.encode("utf-8").decode())

In [8]:
# law_prompt = f'''\nYou are given a number of summaries from legal documents. Create one summary that encompasses all of them:\n'''
# law_prompt = f'''Task: Create one summary from the following chunks of legal text. It must be aroung 130 words long.\n'''
law_prompt = f'''You are a legal expert, knowledgeable in all legal cases, their structure, and what they contain. You are tasked with creating one summary from multiple legal texts. Your summary must contain around 130 words.'''

# our_system_prompt = "\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n" # Please do NOT change this
our_system_prompt = ""
# system_prompt = f"<s>[INST] <<SYS>>{our_system_prompt}{law_prompt}<</SYS>>\n\n [/INST]"
system_prompt = f"""<s>[INST] <<SYS>>{our_system_prompt}
{law_prompt}
<</SYS>>
"""
system_prompt_size = len(tokenizer.encode(system_prompt))
print(system_prompt)
print(system_prompt_size)

<s>[INST] <<SYS>>
You are a legal expert, knowledgeable in all legal cases, their structure, and what they contain. You are tasked with creating one summary from multiple legal texts. Your summary must contain around 130 words.
<</SYS>>

62


In [4]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("../models/law-chat", local_files_only = True, device_map=device, torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained("../models/law-chat", local_files_only = True, use_fast=False, device_map=device)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

  return self.fget.__get__(instance, owner)()


In [5]:
# # # NOTE:
# # # If you want to apply your own system prompt, please integrate it into the instruction part following our system prompt like this:
# # your_system_prompt = "Please, answer this question faithfully."

# inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).input_ids.to(model.device)
# outputs = model.generate(input_ids=inputs, max_new_tokens = 4096)[0]

# answer_start = int(inputs.shape[-1])
# pred = tokenizer.decode(outputs[answer_start:], skip_special_tokens=True)

# print(f'### User Input:\n{user_input}\n\n### Assistant Output:\n{pred}')
# print(len(pred.split(" ")))

'''
<s> <<SYS>> [INST] {system prompt} <</SYS>> {user_input} [/INST]
'''

In [16]:
def prompt_from_sources(sources):
    # truncate each chunk so max output is < 4096
    # use embeddings from dense models instead of idf
    # segment with high granularity - each sentence within the paragraph
    token_size = 0
    prompt = ""
    summary_too_long = False
    summary_size = 5
    while not summary_too_long:
        summ = []
        for test_doc in split_into_paras([sources]):
            summary = lex_rank.get_summary(test_doc, threshold = 0.3, summary_size=summary_size)
            token_size += len(tokenizer.encode(summary))
            if token_size + 250 + system_prompt_size > 4096: # too much information, the model would not be able to generate the output
                summary_too_long = True
                break
            summ.append(summary)

        if summary_too_long:
            summary_size -= 1
            print("Too long")
            if summary_size < 1:
                break 
            summary_too_long = False
        else:
            break


    user_input = "\n\nTask: Summarize the following legal texsts into one summary. Include as much relevant information as possible, and carefully think about the texts:\n\n" + "\n\n".join(["\n".join(summary) for summary in summ]) + "\n\n{ANSWER} [/INST]"
    prompt = system_prompt + user_input

    # print(len(tokenizer.encode(prompt)))
    # print(prompt)

    return prompt

print(prompt_from_sources(modified_dataset["sources"][100]))

<s>[INST] <<SYS>>
You are a legal expert, knowledgeable in all legal cases, their structure, and what they contain. You are tasked with creating one summary from multiple legal texts. Your summary must contain around 130 words.
<</SYS>>


Task: Summarize the following legal texsts into one summary. Include as much relevant information as possible, and carefully think about the texts:

118. The aforesaid pattern or practice of age discrimination by Best Buy was willful. 119. The Plaintiffs are each 40 years of age or older, and are within the class of persons protected against age discrimination by the ADEA, 29 U.S.c. § 621, et seq., and the MHRA, Minn. Stat. § 363A.02, et seq. 120. The Plaintiffs are among the former employees of Best Buy who have been adversely affected by the aforesaid pattern or practice of age discrimination. Verne A. Hall 121. Hall was employed with Best Buy as a Software Engineer. 122. Hall was well qualified for his Software Engineer position and performed his d

In [19]:
from tqdm import tqdm
import numpy as np

predicted_summaries = []
slice_idx = slice(100,102)
for sources in tqdm(modified_dataset["sources"][slice_idx]):
    prompt = prompt_from_sources(sources)
    inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=4095).input_ids.to(model.device)
    outputs = model.generate(input_ids=inputs, max_new_tokens = np.min([250, 4096 - len(inputs[0])]))[0] # 

    answer_start = int(inputs.shape[-1])
    pred = tokenizer.decode(outputs[answer_start:], skip_special_tokens=True)

    # print(f'### User Input:\n{user_input}\n\n### Assistant Output:\n{pred}')
    print(len(pred.split(" ")))

    predicted_summaries.append(pred)
    torch.cuda.empty_cache()

 50%|█████     | 1/2 [00:21<00:21, 21.45s/it]

139


100%|██████████| 2/2 [00:24<00:00, 12.11s/it]

17





In [20]:
print(predicted_summaries[0])

118. The aforesaid pattern or practice of age discrimination by Best Buy was willful. 119. The Plaintiffs are each 40 years of age or older, and are within the class of persons protected against age discrimination by the ADEA, 29 U.S.c. § 621, et seq., and the MHRA, Minn. Stat. § 363A.02, et seq. 120. The Plaintiffs are among the former employees of Best Buy who have been adversely affected by the aforesaid pattern or practice of age discrimination. Verne A Hall 121. Hall was employed with Best Buy as a Software Engineer. 122. Hall was well qualified for his Software Engineer position and performed his duties in a proper, satisfactory and competent manner. 123. Best Buy terminated Hall's employment on or about October 14, 2003, when Hall, who was born on July 23, 1943, was 60 years old.


In [18]:
print(rouge_score.compute(predictions = predicted_summaries, references = modified_dataset["summary/short"][slice_idx]))

IndexError: list index out of range