In [None]:
import os
from sys import getsizeof
from time import sleep
import json
import re
import inspect
from warnings import filterwarnings

import numpy as np
import nltk
import matplotlib.pyplot as plt
import torch
from transformers import (
	BartTokenizer, BartForConditionalGeneration,
	T5Tokenizer, T5ForConditionalGeneration,
	PegasusForConditionalGeneration, PegasusTokenizerFast,
	GPT2TokenizerFast
)
from sentence_transformers import SentenceTransformer
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.decomposition import LatentDirichletAllocation
from dotenv import load_dotenv

from utils.helpers import *
from utils.encoders import *
from utils.pipelines import *
from utils.trainer_utils import *
from utils.evaluator_utils import *

In [None]:
inf = float("inf")
filterwarnings("ignore")
device = get_device()
# device = "cpu"
load_dotenv()

In [None]:
data_dir = "/Users/naman/Workspace/Data/Long-Document-Summarization"
data_dir = "/home/nchibbar/Data"

crs_files = os.listdir(crs_dir := f"{data_dir}/GovReport/crs")
gao_files = os.listdir(gao_dir := f"{data_dir}/GovReport/gao")

print(f"crs files: {len(crs_files)}, gao files: {len(gao_files)}")

out_dir = f"{data_dir}/GovReport/processed"

In [None]:
# Sentence transformer
# Automatically loads into gpu if available
sent_dir = f"{data_dir}/Models/Sent-Transformer"
sent_encoder = SentenceTransformer(sent_dir).to("cpu")

# BART
bart_dir = f"{data_dir}/Models/BART"
bart_tokenizer = BartTokenizer.from_pretrained(bart_dir)
bart_model = BartForConditionalGeneration.from_pretrained(bart_dir)
bart_context_size = bart_model.config.max_position_embeddings

# T5
t5_dir = f"{data_dir}/Models/T5"
t5_tokenizer = T5Tokenizer.from_pretrained(t5_dir)
t5_model = T5ForConditionalGeneration.from_pretrained(t5_dir)
t5_context_size = t5_model.config.n_positions

# Pegasus
pegasus_dir = f"{data_dir}/Models/PEGASUS"
pegasus_tokenizer = PegasusTokenizerFast.from_pretrained(pegasus_dir)
pegasus_model = PegasusForConditionalGeneration.from_pretrained(pegasus_dir)
pegasus_context_size = pegasus_model.config.max_position_embeddings

# GPT 3.5 turbo tokenizer
gpt_dir = f"{data_dir}/Models/GPT-3.5-turbo-tokenizer"
gpt_tokenizer = GPT2TokenizerFast.from_pretrained(gpt_dir)
gpt_model = "gpt-3.5-turbo"
gpt_context_size = 4096

bart_context_size, t5_context_size, pegasus_context_size

In [None]:
preprocessor = TextProcessor(preprocessing=True)
postprocessor = None

## GovReport

In [None]:
def combine_subsections(sections):
	text = ""
	for sec in sections:
		sec_text = "\n\n".join(sec["paragraphs"])
		if sec["section_title"]:
			sec_text = f"Section {sec["section_title"]}:\n\n{sec_text}"
		text = f"{text}\n\n{sec_text}" if text else sec_text
		if sec["subsections"]:
			sub_text = combine_subsections(sec["subsections"])
			text = f"{text}\n\n{sub_text}" if text else sub_text
	return text

In [None]:
num_crs_files = len(crs_files)
for i, file in enumerate(crs_files):
	full_path = os.path.join(crs_dir, file)
	with open(full_path) as fp:
		data = json.load(fp)
	clear_stdout()
	print(f"{num_crs_files - i} files left", end="")
	text = f"{data["title"]}\n\n"
	text += combine_subsections([data["reports"]])
	summary = " ".join(data["summary"])
	summary = preprocessor.process(summary)
	with open(f"{out_dir}/{file}", "w") as fp:
		json.dump({
			"text": text,
			"summary": summary
		}, fp)

In [None]:
for file in gao_files:
	file = os.path.join(gao_dir, file)
	with open(file) as fp:
		data = json.load(fp)
	text = combine_subsections(data["report"])
	text = preprocessor.process(text)
	print(data["highlight"])
	summary = "\n".join(data["highlight"])
	summary = preprocessor.process(summary)
	with open(f"{out_dir}/{file}", "w") as fp:
		json.dump({
			"text": text,
			"summary": summary
		}, fp)

## BigPatent

In [None]:
bigpatent_dir = f"{data_dir}/BigPatent/train/a"

bigpatent_files = os.listdir(bigpatent_dir)

word_counts = []
for file in bigpatent_files:
	with open(f"{bigpatent_dir}/{file}") as fp:
		for line in fp.readlines():
			data = json.loads(line)
			text = data["description"]
			word_counts.append(count_words(text))

bins = int(len(word_counts)**.5)
plt.hist(word_counts, bins=bins)
plt.show()

In [None]:
np.max(word_counts)

## Rough

In [None]:
min_words = 70_000
max_words = inf
max_texts = 10
texts, summaries = [], []
num_texts = 0
for file in crs_files:
	with open(f"{out_dir}/{file}") as fp:
		data = json.load(fp)
	if min_words < count_words(data["text"]) < max_words:
		texts.append(data["text"])
		summaries.append(data["summary"])
		num_texts += 1
	if num_texts == max_texts:
		break

len(texts)

In [None]:
segment_min_words = 20
text_segmenter = TextSegmenter(nltk.sent_tokenize, segment_min_words)

In [None]:
min_tokens_frac = .5
min_summary_tokens = 400
head_size = .5
threshold = .7
boost = .03
seed = 69
system_prompt = "You will be given some segments of a very long document. Your task is to summarize the entire document as a whole by extracting key information and ideas from the segments. Generate a detailed, concise, and coherent summary in 500 words. Do not refer to the document in the summary in any way."

sent_encoder.to(device)

bart_encoders = [
	TruncateMiddle(
		bart_tokenizer, bart_context_size, head_size, preprocessor, True
	),
	UniformSampler(
		bart_tokenizer, min_tokens_frac * bart_context_size, bart_context_size,
		text_segmenter, preprocessor, True, seed
	),
	SegmentSampler(
		bart_tokenizer, min_tokens_frac * bart_context_size, bart_context_size,
		text_segmenter, sent_encoder, preprocessor, True, threshold, boost, seed
	),
	RemoveRedundancy(
		bart_tokenizer, min_tokens_frac * bart_context_size, bart_context_size,
		text_segmenter, sent_encoder, preprocessor, True, threshold, seed
	)
]
t5_encoders = [
	TruncateMiddle(
		t5_tokenizer, t5_context_size, head_size, preprocessor, True
	),
	UniformSampler(
		t5_tokenizer, min_tokens_frac * bart_context_size, t5_context_size,
		text_segmenter, preprocessor, True, seed
	),
	SegmentSampler(
		t5_tokenizer, min_tokens_frac * bart_context_size, t5_context_size,
		text_segmenter, sent_encoder, preprocessor, True, threshold, boost, seed
	),
	RemoveRedundancy(
		t5_tokenizer, min_tokens_frac * bart_context_size, t5_context_size,
		text_segmenter, sent_encoder, preprocessor, True, threshold, seed
	)
]
gpt_encoders = [
	TruncateMiddle(
		gpt_tokenizer, gpt_context_size, head_size, preprocessor, True
	),
	UniformSampler(
		gpt_tokenizer, min_tokens_frac * gpt_context_size, gpt_context_size,
		text_segmenter, preprocessor, True, seed
	),
	SegmentSampler(
		gpt_tokenizer, min_tokens_frac * gpt_context_size, gpt_context_size,
		text_segmenter, sent_encoder, preprocessor, True, threshold, boost, seed
	),
	RemoveRedundancy(
		gpt_tokenizer, min_tokens_frac * gpt_context_size, gpt_context_size,
		text_segmenter, sent_encoder, preprocessor, True, threshold, seed
	)
]
bart_pipelines = [
	SummarizationPipeline(
		bart_model, enc, postprocessor, min_summary_tokens,
		bart_context_size, device
	) for enc in bart_encoders
]
t5_pipelines = [
	SummarizationPipeline(
		t5_model, enc, postprocessor, min_summary_tokens,
		t5_context_size, device
	) for enc in t5_encoders
]
gpt_pipelines = [
	OpenAIPipeline(
		gpt_model, enc, system_prompt=system_prompt
	) for enc in gpt_encoders
]
pipelines = bart_pipelines + t5_pipelines

In [None]:
encodings1 = bart_encoders[1](texts, return_batch=False)
encodings2 = bart_encoders[2](texts, return_batch=False)

token_lengths1 = [len(enc) for enc in encodings1]
token_lengths2 = [len(enc) for enc in encodings2]

avg_tokens1 = np.mean(token_lengths1)
avg_tokens2 = np.mean(token_lengths2)

avg_tokens1, avg_tokens2

In [None]:
for i, text in enumerate(texts):
	print(f"Processing text {i + 1}")
	bart_encoders[3](text, return_batch=False)

In [None]:
text = preprocessor(texts[3])
text = text_segmenter(text)

[count_words(seg) for seg in text]

In [None]:
text = texts[1]
text = preprocessor(text)
count_words(text)

In [None]:
my_stop_words = [
	"also", "however", "therefore", "thus", "hence", "moreover",
	"must", "may", "might", "could", "would", "shall", "need",
	"needs", "given", "since", "though",
]
for word in my_stop_words:
	if word in nltk.corpus.stopwords.words("english"):
		print(word)

stop_words = nltk.corpus.stopwords.words("english") + my_stop_words
stop_words += [
	word.capitalize()
	for word in stop_words
	if not word.istitle()
]

len(stop_words)

In [None]:
keywords_preprocessor = TextProcessor(
	only_words_nums = True,
	remove_nums = True
)

text_keywords = get_keywords(
	text,
	stop_words = stop_words,
	preprocessor = keywords_preprocessor
)

text_keywords

In [None]:
keywords_vec = sent_encoder.encode(" ".join(text_keywords))

segment_similarities = []
segments = text_segmenter(text)
for segment in segments:
	segment_vec = sent_encoder.encode(segment)
	segment_similarities.append(
		keywords_vec @ segment_vec
	)

segment_similarities

In [None]:
probs = torch.nn.functional.softmax(
	torch.tensor(segment_similarities) * 10
).numpy()

probs

In [None]:
plt.plot(probs)
plt.show()

In [None]:
selected = np.random.choice(segments, size=2, p=probs, replace=False)
selected

In [None]:
[
	segments[i]
	for i, sim in enumerate(segment_similarities)
	if sim < 0
]

In [None]:
re.sub(r"(\b|\+)[\d+-]+\b", "", "+1234-5678 +90")