In [None]:
import torch
import transformers
import huggingface_hub
import datasets

In [None]:
print("Is CUDA available for PyTorch:", torch.cuda.is_available())
print("CUDA device count:", torch.cuda.device_count())

In [None]:
huggingface_hub.notebook_login()

In [None]:
wiki_wtp: datasets.DatasetDict = datasets.load_dataset("YawKar/wikitext_with_entitled_paragraphs")

In [None]:
wiki_wtp

In [None]:
def cook_gpt2_tokenizer():
    tokenizer = transformers.GPT2Tokenizer.from_pretrained("openai-community/gpt2")
    tokenizer.pad_token = tokenizer.eos_token
    return tokenizer

In [None]:
def cook_summarizer() -> transformers.Pipeline:
    return transformers.pipeline("summarization", model="facebook/bart-large-cnn", device=0, truncation=True)

In [None]:
def preprocess_wiki_wtp(
    wiki_wtp: datasets.DatasetDict,
    tokenizer: transformers.tokenization_utils.PreTrainedTokenizerBase,
    summarizer: transformers.Pipeline,
    summarizer_max_tokens: int,
    max_tokens_length: int,
) -> datasets.DatasetDict:
    if max_tokens_length <= 0:
        raise Exception(f"max_tokens_length isn't positive: {max_tokens_length}")

    def batch_concat_with_summarization(batch: dict[str, list[str]]) -> dict:
        processed = {
            "input_ids": [],
            "attention_mask": [],
        }
        for heading, paragraph in zip(batch["heading"], batch["paragraph"]):
            concat = heading + paragraph
            tokenized = tokenizer(concat)
            if len(tokenized["input_ids"]) > max_tokens_length:
                cumulative_tokens = 0
                batch_to_summarize: list[str] = []
                summarized: list[str] = []
                for sentence, tokens in [
                    (sentence, len(summarizer.tokenizer(sentence + ".")))
                    for sentence in paragraph.split(".")
                ]:
                    if cumulative_tokens + tokens > summarizer_max_tokens:
                        # summarize batch
                        summarized.append(
                            summarizer(
                                ".".join(batch_to_summarize),
                                min_length=0,
                                max_length=cumulative_tokens - 1,
                                truncation=True,
                            )[0]["summary_text"]
                        )
                        cumulative_tokens = 0
                        batch_to_summarize.clear()
                    cumulative_tokens += tokens
                    batch_to_summarize.append(sentence)

                # summarize the last batch
                summarized.append(
                    summarizer(
                        ".".join(batch_to_summarize),
                        min_length=0,
                        max_length=cumulative_tokens - 1,
                    )[0]["summary_text"]
                )
                cumulative_tokens = 0
                batch_to_summarize.clear()

                concat: str = heading + ".".join(summarized)
                tokenized = tokenizer(concat)
            processed["input_ids"].append(tokenized["input_ids"][:max_tokens_length])
            processed["attention_mask"].append(
                tokenized["attention_mask"][:max_tokens_length]
            )
        return processed

    return wiki_wtp.map(batch_concat_with_summarization, batched=True)

In [None]:
preprocessed_wiki_wtp = preprocess_wiki_wtp(wiki_wtp, cook_gpt2_tokenizer(), cook_summarizer(), 1024, 1022)

In [None]:
preprocessed_wiki_wtp.push_to_hub("summarized_and_tokenized_by_gpt2_wiki")