In [None]:
from datasets import load_dataset
import numpy as np

multi_lexsum = load_dataset("allenai/multi_lexsum", name="v20230518")
modified_dataset = multi_lexsum["test"].filter(lambda x: x["summary/short"] != None)

In [None]:
import tiktoken
import torch

import spacy

from tqdm import tqdm

# To get the tokeniser corresponding to a specific model in the OpenAI API:
tokenizer = tiktoken.encoding_for_model("gpt-3.5")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
nlp = spacy.load("en_core_web_sm")
nlp.max_length = 3_000_000


In [3]:
from summarizer import Summarizer

# from summarizer.sbert import SBertSummarizer
# model_summ = SBertSummarizer("paraphrase-MiniLM-L6-v2")

model_summ = Summarizer("distilbert-base-uncased", hidden_concat = True, hidden = [-1, -2], gpu_id = 0)

In [4]:
import warnings

warnings.filterwarnings("ignore")

def get_extractive_summary(doc, limit_sentences = 5):
    return model_summ(doc, use_first = False, return_as_list = True, num_sentences = limit_sentences) 

In [5]:
from collections import defaultdict

np.random.seed(42)
# build in layers, from smallest number of doc types to largest
def select_docs(docket, docket_metadata, limit_docket_docs = 10):
    doc_types = defaultdict(list)
    for doc, doc_type in zip(docket, docket_metadata["doc_type"]):
        doc_types[doc_type].append(doc)


    # first pass, add at least 1 doc type 
    limit_counter = 0
    docs = []
    aux = doc_types.copy()
    for doc_type, documents in doc_types.items():
        if len(documents) == 1:
            docs.append(documents[0])
            aux.pop(doc_type)
        else:
            random_idx = np.random.randint(0, len(documents))
            docs.append(documents[random_idx])
            documents.pop(random_idx)
        limit_counter += 1

    while limit_counter < limit_docket_docs and aux.keys():
        # compute softmax prob weights to perform weighted sampling, by choosing least present documents preponderentely
        prob_weights = [1/len(documents) for documents in aux.values()]
        prob_weights = np.asarray(prob_weights)/np.sum(prob_weights)
        random_key = np.random.choice(list(aux.keys()), p = prob_weights)

        # uniform sampling across documents from the chosen key
        random_idx = np.random.randint(0, len(aux[random_key]))
        docs.append(aux[random_key][random_idx])

        # remove doc so it won't be present in future sampling
        aux[random_key].pop(random_idx)

        # if there are no more documents to this key, remove the key entirely
        if len(aux[random_key]) == 0:
            aux.pop(random_key)

        limit_counter += 1

    return docs

In [6]:
import json
import os

limit_docket = len(modified_dataset)
limit_docket_docs = 10
iterator = zip(modified_dataset["sources"][:limit_docket], modified_dataset["sources_metadata"][:limit_docket])
path = f"extracted_sums/extracted_sums_json_{'random_selection'}_bert"
for docket_id, (docket, docket_metadata) in enumerate(tqdm(iterator, total = limit_docket)):
    if f"{docket_id}.json" in os.listdir(path):
        continue
    
    documents = select_docs(docket, docket_metadata, limit_docket_docs = limit_docket_docs)

    summaries = []
    for doc in documents:
        summary_aux = get_extractive_summary(doc, limit_sentences = 5)
        summaries.append(summary_aux)
    json.dump(summaries, open(f"extracted_sums/extracted_sums_json_{'random_selection'}_bert/{docket_id}.json", "w"), indent = 2)

100%|██████████| 616/616 [36:16<00:00,  3.53s/it]
