In [6]:
from datasets import load_dataset
import numpy as np

multi_lexsum = load_dataset("allenai/multi_lexsum", name="v20220616")
modified_dataset = multi_lexsum["test"].filter(lambda x: x["summary/short"] != None)

In [18]:
import tiktoken
import torch
from summarizer import Summarizer
from tqdm import tqdm

# from summarizer.sbert import SBertSummarizer

tokenizer = tiktoken.encoding_for_model("gpt-3.5")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_summ = Summarizer("distilbert-base-uncased", hidden_concat = True, hidden = [-1, -2], gpu_id = 0)
# model_summ = SBertSummarizer("paraphrase-MiniLM-L6-v2")

# To get the tokeniser corresponding to a specific model in the OpenAI API:

In [8]:
# user_prompt = """Write a concise summary of the following legal texts. Include as many relevant facts as possible. A fact is relevant if it mentions plaintiffs, counsel, type of action, filling date, name of the court, description of class, defendants, statuatory basis, rought remedy, judges, consolidated class, whether it is a class action, date of decree, citations, duration of decrees, last action in case."""
user_prompt = "Write a concise summary of the following legal texts."

max_context_length = 16384
max_output_length = 130
length_user_prompt = len(tokenizer.encode(user_prompt))
print(length_user_prompt)

10


In [19]:
from sklearn.metrics.pairwise import cosine_similarity
from lexrank.algorithms.summarizer import degree_centrality_scores

import warnings
warnings.filterwarnings('ignore') 

# def embed_sentences(sentences):
#     return embedding_model.encode(sentences, convert_to_numpy = True)

def prompt_from_sources(sources):
    # use embeddings from dense models instead of idf
    # segment with high granularity - each sentence within the paragraph
    token_size = 0
    summary_size = 5
    extracted_sentences = []
    for source in sources:
        summary = model_summ(source, num_sentences = summary_size)
        token_size += len(tokenizer.encode(summary))
        if token_size + length_user_prompt + max_output_length > max_context_length:
            break
        extracted_sentences.append(summary)
    

    user_input = "\n".join(extracted_sentences)
    prompt = user_prompt + "\n{" + user_input + "}\n\n"

    print(f"Number of summaries from documents: {len(extracted_sentences)} || Number of tokens: {len(tokenizer.encode(prompt))}") if token_size + length_user_prompt + max_output_length > max_context_length else 1
    # print(prompt)

    return prompt

def prompt_from_sources_loop(dockets, prompt_type, **prompt_args):
    for idx, docket in enumerate(tqdm(dockets)):
        p = prompt_type(docket, **prompt_args)
        if len(tokenizer.encode(p)) > 16384:
            print(f"Oh no {idx}")

def prompt_from_sources_fewshot(sources, example):
    # truncate each chunk so max output is < 4096
    # use embeddings from dense models instead of idf
    # segment with high granularity - each sentence within the paragraph

    token_size = 0
    summary_size = 10
    summ = []
    for source in sources:
        summary = model_summ(source, num_sentences = summary_size)
        token_size += len(tokenizer.encode(summary))
        summ.append(summary)
    
    user_input = "\n---\n".join(summ)
    prompt = example[0] + "\n" + example[1] + "\n\n" + "Here is your input:" + "\nINPUT:\n"+ user_input + "\n---\nSUMMARY: "

    # print(len(tokenizer.encode(prompt)))
    # print(prompt)

    return prompt


# print(prompt_from_sources_fewshot(modified_dataset["sources"][1], example = (prompt_from_sources(modified_dataset["sources"][0]), modified_dataset["summary/short"][0])))
prompt_from_sources_loop(modified_dataset["sources"], prompt_from_sources)

  6%|▌         | 34/616 [09:23<2:40:51, 16.58s/it]


KeyboardInterrupt: 