In [None]:
import os
from sys import getsizeof
from time import sleep
import pickle
import json
import re
import inspect
from warnings import filterwarnings

import numpy as np
import nltk
import matplotlib.pyplot as plt
import torch
from transformers import (
	BartTokenizer, BartForConditionalGeneration,
	T5Tokenizer, T5ForConditionalGeneration,
	PegasusForConditionalGeneration, PegasusTokenizerFast,
	GPT2TokenizerFast
)
from sentence_transformers import SentenceTransformer
from dotenv import load_dotenv

from configs import *
from utils.helpers import *
from utils.encoders import *
from utils.pipelines import *
from utils.trainer_utils import *
from utils.evaluator_utils import *

def plot_histogram(data):
	bins = int(len(data) ** .5)
	plt.hist(data, bins=bins)
	plt.show()

filterwarnings("ignore")
device = get_device(GPU_USAGE_TOLERANCE)
load_dotenv()

In [None]:
model_name = "pegasus"

sent_dir = f"{MODELS_DIR}/sent-transformer"
model_dir = f"{MODELS_DIR}/{model_name.lower()}"

govreport_dir = f"{BASE_DIR}/GovReport/processed"
bigpatent_dir = f"{BASE_DIR}/BigPatent/processed"
govreport_files = os.listdir(govreport_dir)
bigpatent_files = os.listdir(bigpatent_dir)

len(govreport_files), len(bigpatent_files)

In [None]:
# Sentence transformer
# Automatically loads into gpu if available
sent_encoder = SentenceTransformer(sent_dir, device=device)

match model_name:

	case "bart":
		tokenizer = BartTokenizer.from_pretrained(model_dir)
		model = BartForConditionalGeneration.from_pretrained(model_dir)
		context_size = model.config.max_position_embeddings

	case "t5":
		tokenizer = T5Tokenizer.from_pretrained(model_dir)
		model = T5ForConditionalGeneration.from_pretrained(model_dir)
		context_size = model.config.n_positions

	case "pegasus":
		tokenizer = PegasusTokenizerFast.from_pretrained(model_dir)
		model = PegasusForConditionalGeneration.from_pretrained(model_dir)
		context_size = model.config.max_position_embeddings

	case "gpt":
		tokenizer = GPT2TokenizerFast.from_pretrained(model_dir)
		model = "gpt-3.5-turbo"
		context_size = 4096

context_size

In [None]:
preprocessor = TextProcessor(preprocessing=True)
postprocessor = None

## BigPatent

In [None]:
word_counts = []
for file in bigpatent_files:
	file_path = f"{bigpatent_dir}/{file}"
	with open(file_path) as fp:
		data = json.load(fp)
	for text in data["texts"]:
		word_counts.append(count_words(text))

plot_histogram(word_counts)

In [None]:
max(word_counts), np.mean(word_counts), len(word_counts)

In [None]:
sum([
	1
	for count in word_counts
	if count > 40_000
])

## Rough

In [None]:
texts, summaries = [], []
num_texts = 0
for file in govreport_files:
	file_path = f"{govreport_dir}/{file}"
	with open(file_path) as fp:
		data = json.load(fp)
	if MIN_WORDS < count_words(data["text"]) < MAX_WORDS:
		texts.append(data["text"])
		summaries.append(data["summary"])
		num_texts += 1
	if num_texts == MAX_TEXTS:
		break

num_texts

In [None]:
SEGMENT_MIN_WORDS = 20
text_segmenter = TextSegmenter(nltk.sent_tokenize, SEGMENT_MIN_WORDS)
keywords_preprocessor = TextProcessor(
	only_words_nums = True,
	remove_nums = True
)
stop_words = get_stop_words(extra_stop_words=EXTRA_STOP_WORDS)
len(stop_words)

In [None]:
encoders = [
	TruncateMiddle(
		tokenizer, context_size, 1, preprocessor
	),
	TruncateMiddle(
		tokenizer, context_size, HEAD_SIZE, preprocessor, True
	),
	UniformSampler(
		tokenizer, MIN_TOKEN_FRAC * context_size, context_size,
		text_segmenter, preprocessor, True, SEED
	),
	SegmentSampler(
		tokenizer, MIN_TOKEN_FRAC * context_size, context_size,
		text_segmenter, sent_encoder, preprocessor, THRESHOLD, PROB_BOOST, SEED
	),
	RemoveRedundancy(
		tokenizer, MIN_TOKEN_FRAC * context_size, context_size,
		text_segmenter, sent_encoder, preprocessor, THRESHOLD, SEED
	),
	KeywordScorer(
		tokenizer, context_size, text_segmenter, sent_encoder,
		preprocessor, NUM_KEYWORDS, keywords_preprocessor, stop_words
	)
]

pipelines = [
	SummarizationPipeline(
		model, enc, postprocessor, MIN_SUMMARY_TOKENS,
		context_size, device, TEMPERATURE, REPETITION_PENALTY, TOP_P
	) for enc in encoders
] if model_name != "gpt" else [
	OpenAIPipeline(
		model, enc, postprocessor, SYSTEM_PROMPT
	) for enc in encoders
]

In [None]:
processed_texts = preprocessor(texts)
threshold = .5
num_segments_found = []
for text in processed_texts:
	keywords = get_keywords(text, 20, stop_words, keywords_preprocessor)
	keywords = " ".join(keywords)
	keyword_emb = sent_encoder.encode(keywords)
	segments = text_segmenter(text)
	segment_embs = sent_encoder.encode(segments)
	scores = segment_embs @ keyword_emb
	num_segments = (scores > threshold).sum()
	num_segments_found.append(num_segments)

In [None]:
np.sort(num_segments_found)

In [None]:
with open(f"{BASE_DIR}/pegasus-govreport.pkl", "rb") as fp:
	results = pickle.load(fp)
scores = results["scores"]
sort1, sort2, sort3 = results["sort1"], results["sort2"], results["sort3"]
gen_summaries = results["gen_summaries"]
scores[0][sort1]

In [None]:
ind = 0
problem_text = results["texts"][sort1[ind]]
print(gen_summaries[sort1[ind]])

In [None]:
with open(f"{BASE_DIR}/bart-bigpatent-times.json") as fp:
	results = json.load(fp)
times = np.array(results["encoder_times"])[1:]
times

In [None]:
plt.bar([
	"Truncate\nMiddle", "Document\nSkimming",
	"Skimming w/\npost-sampling\nremoval",
	"Skimming\nw/ pre-\nsampling\nremoval", "Summarization\nw/ Keyword\nExtraction"
], times, color="green")

In [None]:
a = np.array([
	[1, 2],
	[3, 4],
	[5, 6]
])
b = np.array([1, 1])

a @ b