# data

The `data` module handles preprocesing datasets.

In [None]:
#| default_exp data

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from transformers import AutoTokenizer
from datasets import Dataset
from datasets import load_dataset, concatenate_datasets

class Data:
    def __init__(self, inputs, outputs, model_id="google/flan-t5-base", verbose=1):
        """
        Preprocess a sequence-to-sequence dataset
        """
        self.ins = inputs
        self.outs = outputs
        self.verbose = 1
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.tokenized_inputs, self.tokenized_outputs = self.truncate()
        
    def truncate(self):
        """
        Truncate input and target texts.
        """
        ins, outs = Dataset.from_dict({'text': self.ins}), Dataset.from_dict({'text':self.outs})
        
        
        tokenized_inputs = ins.map(lambda x: self.tokenizer(x["text"], truncation=True), batched=True, remove_columns=["text"])
        max_source_length = max([len(x) for x in tokenized_inputs["input_ids"]])
        if self.verbose: print(f"Max source length: {max_source_length}")

        # The maximum total sequence length for target text after tokenization. 
        # Sequences longer than this will be truncated, sequences shorter will be padded."
        tokenized_targets = outs.map(lambda x: self.tokenizer(x["text"], truncation=True), batched=True, remove_columns=["text"])
        max_target_length = max([len(x) for x in tokenized_targets["input_ids"]])
        if self.verbose: print(f"Max target length: {max_target_length}")
        return tokenized_inputs, tokenized_targets
    

In [None]:
dataset_id = "samsum"
dataset = load_dataset(dataset_id)
inputs = [row['dialogue'] for row in concatenate_datasets([dataset["train"], dataset["test"]])]
outputs = [row['summary'] for row in concatenate_datasets([dataset["train"], dataset["test"]])]

Found cached dataset samsum (/root/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e)


  0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
ds = Data(inputs, outputs)

Map:   0%|          | 0/15551 [00:00<?, ? examples/s]

Max source length: 512


Map:   0%|          | 0/15551 [00:00<?, ? examples/s]

Max target length: 95


In [None]:
assert(max([len(x) for x in ds.tokenized_outputs["input_ids"]]) == 95)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()