In [1]:
import json
import os
import numpy as np
import re
import nltk
import torch
from transformers import BartTokenizer, BartForConditionalGeneration
from sentence_transformers import SentenceTransformer
import inspect

from pipelines import TruncateMiddle, UniformSampler
from utils import *

device = get_device()
device

'mps'

In [2]:
data_dir = "/Users/naman/Workspace/Data/UCCS-REU"

crs_files = os.listdir(crs_dir := f"{data_dir}/GovReport/crs")
gao_files = os.listdir(gao_dir := f"{data_dir}/GovReport/gao")

print(f"crs files: {len(crs_files)}, gao files: {len(gao_files)}")

crs_out = f"{data_dir}/GovReport/crs-processed"
gao_out = f"{data_dir}/GovReport/gao-processed"

preprocessor = TextPreprocessor()

crs files: 7238, gao files: 12228


In [3]:
tokenizer_dir = f"{data_dir}/Models/BART/tokenizer"
model_dir = f"{data_dir}/Models/BART/model"
checkpoint = "facebook/bart-large-cnn"

tokenizer = BartTokenizer.from_pretrained(tokenizer_dir)
model = BartForConditionalGeneration.from_pretrained(model_dir)

context_size, _ = max_lengths(model)
max_output_tokens = 500

context_size

1024

In [4]:
special_tokens = tokenizer.special_tokens_map.values()
postprocessor = TextPostprocessor(special_tokens)
special_tokens

dict_values(['<s>', '</s>', '<unk>', '</s>', '<pad>', '<s>', '<mask>'])

In [5]:
file = f"{crs_out}/{crs_files[0]}"

with open(file) as fp:
	data = json.load(fp)
count_words(data["text"]), count_words(data["summary"])

(8357, 479)

## gov-report

In [28]:
for file in crs_files:
	with open(f"{crs_dir}/{file}") as fp:
		data = json.load(fp)
	text = combine_subsections([data["reports"]])
	text = preprocessor.preprocess(text)
	summary = "\n".join(data["summary"])
	summary = preprocessor.preprocess(summary)
	with open(f"{crs_out}/{file}", "w") as fp:
		json.dump({
			"text": text,
			"summary": summary
		}, fp)

In [34]:
for file in gao_files:
	with open(f"{gao_dir}/{file}") as fp:
		data = json.load(fp)
	text = combine_subsections(data["report"])
	text = preprocessor.preprocess(text)
	print(data["highlight"])
	summary = "\n".join(data["highlight"])
	summary = preprocessor.preprocess(summary)
	with open(f"{gao_out}/{file}", "w") as fp:
		json.dump({
			"text": text,
			"summary": summary
		}, fp)

[]


## Scratch

In [None]:
sent_checkpoint = "sentence-transformers/all-MiniLM-L6-v2"
sent_save_dir = f"{data_dir}/Models/Sent-Transformer"

sent_model = SentenceTransformer(sent_save_dir)
sent_model

In [9]:
out = sent_model.encode([
	"hey bruh",
	"whats up huh?",
	data["text"]
])
out.shape

array([[-0.06765417, -0.03888808,  0.05449073, ...,  0.01187451,
         0.04905998,  0.01324423],
       [-0.11608477, -0.07559214,  0.05973152, ..., -0.01747406,
        -0.00257643,  0.05474273],
       [ 0.10994177, -0.04852243, -0.07865883, ...,  0.00304659,
         0.0611313 , -0.02873247]], dtype=float32)

In [6]:
summarizer = TruncateMiddle(
	preprocessor, postprocessor, tokenizer, model, context_size,
	max_output_tokens, .4
)

summarizer(data["text"])

['Glass-Steagall Act separated commercial banking from investment banking. Separation of commercial and investment banking can help insulate insured depositories from volatility in securities markets. The separation, by itself, does not address how investment banks are regulated within securities markets or how nonbanks can use securities activities to fund consumer and commercial debt.']

In [7]:
summarizer = UniformSampler(
	preprocessor, postprocessor, nltk.sent_tokenize, tokenizer, model,
	context_size, max_output_tokens
)

summarizer(data["text"])

['The law was passed to prevent a repeat of the financial crisis of the 1930s and 1940s. The law was designed to prevent banks from taking on too much of a risk. It was also intended to prevent the creation of a financial system that was too risky for banks to take on.']