# This torch Masking pipeline was created basrd on this [script](https://github.com/huggingface/transformers/blob/master/examples/flax/language-modeling/run_t5_mlm_flax.py)

In [1]:
import json
args_dict = {
  "n_gpu": 8,
  "model_type": 't5',
  "config_name":"../configs",
  "config_path":"../configs/t5-small.json",
  "model_name_or_path":"../configs",
  "tokenizer_name":None,  
  "dataset_name":"wikitext",
  "dataset_config_name":"wikitext-103-raw-v1", 
  "max_seq_length": 128 ,
  "output_dir": './test',
  "overwrite_output_dir": True,
  "per_device_train_batch_size": 10,
  "per_device_eval_batch_size": 10,
  "learning_rate": 0.005,
  "num_train_epochs": 1,
  "do_train": True,
  "do_eval":True,
  "logging_steps":500,
  "save_steps":10000,
  "eval_steps":500,
  "report_to":"wandb",
  "evauation_strategy":"steps",
#   "train_file_path":"Data/128train_data.pt",
#   "valid_file_path":"Data/128valid_data.pt",    
}
with open('args.json', 'w') as f:
    json.dump(args_dict, f)

In [2]:
%%time
from torch.utils.data import DataLoader
import logging
import os
import sys
import json
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional
import torch
import transformers
import numpy as np
from datasets import load_dataset
import datasets
from datasets import load_dataset
from tqdm import tqdm
import math
from itertools import chain
from transformers.integrations import TensorBoardCallback
from huggingface_hub import Repository
from transformers.models.t5.modeling_flax_t5 import shift_tokens_right

from transformers.file_utils import get_full_repo_name
from transformers import (
    CONFIG_MAPPING,
    MODEL_FOR_MASKED_LM_MAPPING,
    AutoTokenizer,
    BatchEncoding,
    HfArgumentParser,
    PreTrainedTokenizerBase,
    T5Config,
    TrainingArguments,
    is_tensorboard_available,
    set_seed,
    T5ForConditionalGeneration,
    T5Tokenizer,
    Trainer    
)

from sklearn.metrics import accuracy_score
from datasets import load_metric
from transformers.trainer_utils import get_last_checkpoint

logger = logging.getLogger(__name__)

@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
    """

    model_name_or_path: Optional[str] = field(
        default=None,
        metadata={
            "help": "The model checkpoint for weights initialization."
            "Don't set if you want to train a model from scratch."
        },
    )
    model_type: Optional[str] = field(
        default=None,
        metadata={"help": "If training from scratch, pass a model type from the list: " }, #+ ", ".join(MODEL_TYPES)},
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    config_path: Optional[str] = field(
        default=None, metadata={"help": "Full path if not the same as model_name"}
    )
    
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    dtype: Optional[str] = field(
        default="float32",
        metadata={
            "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
        },
    )


@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """

    dataset_name: Optional[str] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    dataset_config_name: Optional[str] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
    )
    train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
    validation_file: Optional[str] = field(
        default=None,
        metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
    )
    train_ref_file: Optional[str] = field(
        default=None,
        metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
    )
    validation_ref_file: Optional[str] = field(
        default=None,
        metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
    )
    validation_split_percentage: Optional[int] = field(
        default=5,
        metadata={
            "help": "The percentage of the train set used as validation set in case there's no validation split"
        },
    )
    
    train_file_path: Optional[str] = field(
        default= None,
        metadata={"help": "Path for cached train dataset"},
    )
    valid_file_path: Optional[str] = field(
        default= None,
        metadata={"help": "Path for cached valid dataset"},
    )
    
    max_seq_length: Optional[int] = field(
        default=None,
        metadata={
            "help": "The maximum total input sequence length after tokenization and masking. Sequences longer than this will be truncated. Default to the max input length of the model."
        },
    )
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={"help": "The number of processes to use for the preprocessing."},
    )
    mlm_probability: float = field(
        default=0.15, metadata={"help": "Ratio of tokens to mask for span masked language modeling loss"}
    )
    mean_noise_span_length: float = field(
        default=3.0,
        metadata={"help": "Mean span length of masked tokens"},
    )
    
    
    
def process_datasets(model_args, data_args, training_args, tokenizer, expanded_inputs_length, save_path):
    raw_datasets = None
    if data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        raw_datasets = load_dataset(
            data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir
        )
        if "validation" not in raw_datasets.keys():
            raw_datasets["validation"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[:{data_args.validation_split_percentage}%]",
                cache_dir=model_args.cache_dir,
            )
            raw_datasets["train"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[{data_args.validation_split_percentage}%:]",
                cache_dir=model_args.cache_dir,
            )
    else:
        data_files = {}
        if data_args.train_file is not None:
            data_files["train"] = data_args.train_file
        if data_args.validation_file is not None:
            data_files["validation"] = data_args.validation_file
        extension = data_args.train_file.split(".")[-1]
        if extension == "txt":
            extension = "text"
        raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)

        # If no validation data is there, validation_split_percentage will be used to divide the dataset.
        if "validation" not in raw_datasets.keys():
            raw_datasets["validation"] = load_dataset(
                extension,
                data_files=data_files,
                split=f"train[:{data_args.validation_split_percentage}%]",
                cache_dir=model_args.cache_dir,
            )
            raw_datasets["train"] = load_dataset(
                extension,
                data_files=data_files,
                split=f"train[{data_args.validation_split_percentage}%:]",
                cache_dir=model_args.cache_dir,
            )


    # Preprocessing the datasets.
    # First we tokenize all the texts.
    if training_args.do_train:
        column_names = raw_datasets["train"].column_names
    else:
        column_names = raw_datasets["validation"].column_names
    text_column_name = "text" if "text" in column_names else column_names[0]

    if data_args.max_seq_length is None:
        max_seq_length = tokenizer.model_max_length
        if max_seq_length > 512:
            logger.warning(
                f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
                "Picking 1024 instead. You can change that default value by passing --max_seq_length xxx."
            )
            max_seq_length = 512
    else:
        if data_args.max_seq_length > tokenizer.model_max_length:
            logger.warning(
                f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
                f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
            )
        max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
        

    def tokenize_function(examples):

        return tokenizer(examples[text_column_name], return_attention_mask=False)

    print("******************** Tokenize data set  ************************")
    with training_args.main_process_first(desc="dataset map tokenization"):
        tokenized_datasets = raw_datasets.map(
            tokenize_function,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not data_args.overwrite_cache,
            desc="Running tokenizer on every text in dataset",
        )
    # Main data processing function that will concatenate all texts from our dataset and generate chunks of
    # expanded_inputs_length.
    def group_texts(examples):
        # Concatenate all texts.
        concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
        if total_length >= expanded_inputs_length:
            total_length = (total_length // expanded_inputs_length) * expanded_inputs_length
        # Split by chunks of max_len.
        result = {
            k: [t[i : i + expanded_inputs_length] for i in range(0, total_length, expanded_inputs_length)]
            for k, t in concatenated_examples.items()
        }
        return result

        # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
        # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
        # might be slower to preprocess.
        #
        # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
        # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map

    with training_args.main_process_first(desc="grouping texts together"):
        tokenized_datasets = tokenized_datasets.map(
            group_texts,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            load_from_cache_file=not data_args.overwrite_cache,
            desc=f"Grouping texts in chunks of {max_seq_length}",
        )

    if training_args.do_train:
        if "train" not in tokenized_datasets:
            raise ValueError("--do_train requires a train dataset")
        train_dataset = tokenized_datasets["train"]
#         if data_args.max_train_samples is not None:
#             train_dataset = train_dataset.select(range(data_args.max_train_samples))

    if training_args.do_eval:
        if "validation" not in tokenized_datasets:
            raise ValueError("--do_eval requires a validation dataset")
        eval_dataset = tokenized_datasets["validation"]

    if "test" not in tokenized_datasets:
        raise ValueError("--do_predict requires a test dataset")
    test_dataset = tokenized_datasets["test"]

    # cach the dataset, so we can load it directly for training
    torch.save(train_dataset, save_path+str(max_seq_length)+'train_data.pt') 
    torch.save(eval_dataset, save_path+str(max_seq_length)+'valid_data.pt')
    torch.save(test_dataset, save_path+str(max_seq_length)+'test_data.pt')
    return train_dataset, eval_dataset, test_dataset
    
    
def create_tokenizer(model_args):
    config = None 
    config = T5Config.from_json_file(json_file=model_args.config_path)
   
    if model_args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
        )
        return tokenizer
    elif model_args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_path, config=config)

        return tokenizer
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported by this script."
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )

@dataclass
class DataCollatorForT5MLM:
    """
    Data collator used for T5 span-masked language modeling.
    It is made sure that after masking the inputs are of length `data_args.max_seq_length` and targets are also of fixed length.
    For more information on how T5 span-masked language modeling works, one can take a look
    at the `official paper <https://arxiv.org/pdf/1910.10683.pdf>`__
    or the `official code for preprocessing <https://github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/data/preprocessors.py>`__ .
    Args:
        tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
            The tokenizer used for encoding the data.
        noise_density (:obj:`float`):
            The probability with which to (randomly) mask tokens in the input.
        mean_noise_span_length (:obj:`float`):
            The average span length of the masked tokens.
        input_length (:obj:`int`):
            The expected input length after masking.
        target_length (:obj:`int`):
            The expected target length after masking.
        pad_token_id: (:obj:`int`):
            The pad token id of the model
        decoder_start_token_id: (:obj:`int):
            The decoder start token id of the model
    """

    tokenizer: PreTrainedTokenizerBase
    noise_density: float
    mean_noise_span_length: float
    input_length: int
    target_length: int
    pad_token_id: int

    def __call__(self, examples: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]:
        '''
        1. random_spans_noise_mask
        2. random span noise 
        3. create_sentinel_ids
        4. filter_input_ids
        if max length = 512 then len(example['input_ids']) = 568
        '''
        
        batch = BatchEncoding( # batch.keys() = input_ids
            {k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()}
        )
        print('-----------------------------------------------------------------   samples before masking using special tokens -------------------------------------------------------------------')
        for i , x in enumerate(batch["input_ids"]):
            print(self.tokenizer.decode(x))
        print("**************************************************************************************************************************************")   
        input_ids = batch["input_ids"]
        batch_size, expandend_input_length = input_ids.shape
        mask_indices = np.asarray([self.random_spans_noise_mask(expandend_input_length) for i in range(batch_size)])
        labels_mask = ~mask_indices
        input_ids_sentinel = self.create_sentinel_ids(mask_indices.astype(np.int8))
        labels_sentinel = self.create_sentinel_ids(labels_mask.astype(np.int8))
        inputs = self.filter_input_ids(input_ids, input_ids_sentinel)
        batch["input_ids"] = torch.tensor(inputs)
        labels = self.filter_input_ids(input_ids, labels_sentinel)
        batch["labels"] = torch.tensor(labels)
        if batch["input_ids"].shape[-1] != self.input_length:
            raise ValueError(
                f"`input_ids` are incorrectly preprocessed. `input_ids` length is {batch['input_ids'].shape[-1]}, but should be {self.target_length}."
            )
        if batch["labels"].shape[-1] != self.target_length:
            raise ValueError(
                f"`labels` are incorrectly preprocessed. `labels` length is {batch['labels'].shape[-1]}, but should be {self.target_length}."
            )
        # to check that tokens are correctly proprocessed, one can run `self.tokenizer.batch_decode(input_ids)` and `self.tokenizer.batch_decode(labels)` here...
        print('---------------------------   input_ids  ------------------------------------')
        for i , x in enumerate(batch["input_ids"]):
            print(self.tokenizer.decode(x))
        print("---------------------------------------------------------------------------")
        print('---------------------------------- labels ---------------------------------')
        for i,x in enumerate(batch["labels"]):
            print(self.tokenizer.decode(x))
#         print("self.tokenizer.batch_decode(labels):  ",  self.tokenizer.decode(labels))   
        # to check that tokens are correctly proprocessed, one can run `self.tokenizer.batch_decode(input_ids)` and `self.tokenizer.batch_decode(labels)` here...
#         batch["decoder_input_ids"] = torch.tensor(shift_tokens_right(
#             labels, self.pad_token_id, self.decoder_start_token_id
#         ))


        return batch

    def create_sentinel_ids(self, mask_indices):
        """
        Sentinel ids creation given the indices that should be masked.
        The start indices of each mask are replaced by the sentinel ids in increasing
        order. Consecutive mask indices to be deleted are replaced with `-1`.
        """
        start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices
        start_indices[:, 0] = mask_indices[:, 0]
        sentinel_ids = np.where(start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices)
        sentinel_ids = np.where(sentinel_ids != 0, (len(self.tokenizer) - sentinel_ids), 0)
        sentinel_ids -= mask_indices - start_indices

        return sentinel_ids

    def filter_input_ids(self, input_ids, sentinel_ids):
        """
        Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting.
        This will reduce the sequence length from `expanded_inputs_length` to `input_length`.
        """
        batch_size = input_ids.shape[0]
        input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids)
        input_ids = input_ids_full[input_ids_full > 0].reshape((batch_size, -1))
        input_ids = np.concatenate(
            [input_ids, np.full((batch_size, 1), self.tokenizer.eos_token_id, dtype=np.int32)], axis=-1
        )
        return input_ids

    def random_spans_noise_mask(self, length):

        """This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
        Noise mask consisting of random spans of noise tokens.
        The number of noise tokens and the number of noise spans and non-noise spans
        are determined deterministically as follows:
        num_noise_tokens = round(length * noise_density)
        num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length)
        Spans alternate between non-noise and noise, beginning with non-noise.
        Subject to the above restrictions, all masks are equally likely.
        Args:
            length: an int32 scalar (length of the incoming token sequence)
            noise_density: a float - approximate density of output mask
            mean_noise_span_length: a number
        Returns:
            a boolean tensor with shape [length]
        """
        orig_length = length
        num_noise_tokens = int(np.round(length * self.noise_density))
        # avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens.
        num_noise_tokens = min(max(num_noise_tokens, 1), length - 1)
        num_noise_spans = int(np.round(num_noise_tokens / self.mean_noise_span_length))

        # avoid degeneracy by ensuring positive number of noise spans
        num_noise_spans = max(num_noise_spans, 1)
        num_nonnoise_tokens = length - num_noise_tokens

        # pick the lengths of the noise spans and the non-noise spans
        def _random_segmentation(num_items, num_segments):
            """Partition a sequence of items randomly into non-empty segments.
            Args:
                num_items: an integer scalar > 0
                num_segments: an integer scalar in [1, num_items]
            Returns:
                a Tensor with shape [num_segments] containing positive integers that add
                up to num_items
            """
            mask_indices = np.arange(num_items - 1) < (num_segments - 1)
            np.random.shuffle(mask_indices)
            first_in_segment = np.pad(mask_indices, [[1, 0]])
            segment_id = np.cumsum(first_in_segment)
            # count length of sub segments assuming that list is sorted
            _, segment_length = np.unique(segment_id, return_counts=True)
            return segment_length

        noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
        nonnoise_span_lengths = _random_segmentation(num_nonnoise_tokens, num_noise_spans)

        interleaved_span_lengths = np.reshape(
            np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), [num_noise_spans * 2]
        )
        span_starts = np.cumsum(interleaved_span_lengths)[:-1]
        span_start_indicator = np.zeros((length,), dtype=np.int8)
        span_start_indicator[span_starts] = True
        span_num = np.cumsum(span_start_indicator)
        is_noise = np.equal(span_num % 2, 1)
        return is_noise[:orig_length]
          
def compute_input_and_target_lengths(inputs_length, noise_density, mean_noise_span_length):
    """This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2466>`__ .
    Training parameters to avoid padding with random_spans_noise_mask.
    When training a model with random_spans_noise_mask, we would like to set the other
    training hyperparmeters in a way that avoids padding.
    This function helps us compute these hyperparameters.
    We assume that each noise span in the input is replaced by extra_tokens_per_span_inputs sentinel tokens,
    and each non-noise span in the targets is replaced by extra_tokens_per_span_targets sentinel tokens.
    This function tells us the required number of tokens in the raw example (for split_tokens())
    as well as the length of the encoded targets. Note that this function assumes
    the inputs and targets will have EOS appended and includes that in the reported length.
    Args:
        inputs_length: an integer - desired length of the tokenized inputs sequence
        noise_density: a float
        mean_noise_span_length: a float
    Returns:
        tokens_length: length of original text in tokens
        targets_length: an integer - length in tokens of encoded targets sequence
    """

    def _tokens_length_to_inputs_length_targets_length(tokens_length):
        num_noise_tokens = int(round(tokens_length * noise_density))
        num_nonnoise_tokens = tokens_length - num_noise_tokens
        num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length))
        # inputs contain all nonnoise tokens, sentinels for all noise spans
        # and one EOS token.
        _input_length = num_nonnoise_tokens + num_noise_spans + 1
        _output_length = num_noise_tokens + num_noise_spans + 1
        return _input_length, _output_length

    tokens_length = inputs_length

    while _tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] <= inputs_length:
        tokens_length += 1

    inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length(tokens_length)

    # minor hack to get the targets length to be equal to inputs length
    # which is more likely to have been set to a nice round number.
    if noise_density == 0.5 and targets_length > inputs_length:
        tokens_length -= 1
        targets_length -= 1
    return tokens_length, targets_length

                



def main():
    # 1: Parse argumen
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath('args.json'))
        
    tokenizer = create_tokenizer(model_args)


    # 2: Otput directory is exist and empty
    if (
        os.path.exists(training_args.output_dir)
        and os.listdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
    ):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty."
            "Use --overwrite_output_dir to overcome."
        )

    #3:  Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        level="NOTSET",
        datefmt="[%X]",
    )
    # Log on each process the small summary:
    logger = logging.getLogger(__name__)
    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()
    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    # Set the verbosity to info of the Transformers logger (on main process only):
    logger.info(f"Training/evaluation parameters {training_args}")
    logger.info(f"Model arguments {model_args}")
    logger.info(f"Data arguments {data_args}")
    
    #4:  Set seed before initializing model.
    set_seed(training_args.seed)

    #5:  Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    #6:  create  tokenizer.
    tokenizer = create_tokenizer(model_args)    
    if data_args.max_seq_length is None:
        max_seq_length = tokenizer.model_max_length
        if max_seq_length > 512:
            logger.warning(
                f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
                "Picking 1024 instead. You can change that default value by passing --max_seq_length xxx."
            )
            max_seq_length = 512
    else:
        if data_args.max_seq_length > tokenizer.model_max_length:
            logger.warning(
                f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
                f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
            )
        max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)

    # 7: T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token.
    # To ensure that the input length is `max_seq_length`, we need to increase the maximum length
    # according to `mlm_probability` and `mean_noise_span_length`. We can also define the label length accordingly.
    expanded_inputs_length, targets_length = compute_input_and_target_lengths(
        inputs_length=max_seq_length,
        noise_density=data_args.mlm_probability,
        mean_noise_span_length=data_args.mean_noise_span_length,
    )

    #8:  Get datasets if they already exists or download and prepare them for using
    train_dataset, eval_dataset = None, None
    print('data_args.train_file_path:  -------------------------------------------------->', data_args.train_file_path)
    if data_args.train_file_path != None and data_args.valid_file_path != None:
        print('loading data')
        train_dataset  = torch.load(data_args.train_file_path)
        eval_dataset = torch.load(data_args.valid_file_path)
        print('loading done')    
    elif data_args.dataset_name is not None or data_args.train_file is not None: 
        train_dataset, eval_dataset, test_dataset = process_datasets(model_args, 
                                                       data_args, 
                                                       training_args, 
                                                       tokenizer, 
                                                       expanded_inputs_length, 
                                                       "../Data/wiki_103/")        
    else:
        raise ValueError(
            "No available datasets. You need to load a cashed dataset or process a dataset."
        )  
    
    # 9: create model     
#     model = create_model(model_args, tokenizer) 

    # 10: Data collator
    # This one will take care of randomly masking the tokens.
    data_collator = DataCollatorForT5MLM(
        tokenizer=tokenizer,
        noise_density=data_args.mlm_probability,
        mean_noise_span_length=data_args.mean_noise_span_length,
        input_length=max_seq_length,
        target_length=targets_length,
        pad_token_id=0  # model.config.pad_token_id,
    )
    train_dataloader = DataLoader(
                train_dataset,
                batch_size=1,
                collate_fn=data_collator,
                num_workers=0,
                pin_memory=True,
            )
    for step, batch in enumerate(tqdm(train_dataloader, desc="Training...", position=1)):
        if step == 1:
            break   


main()

 This a JAX bug; please report an issue at https://github.com/google/jax/issues
  _warn(f"cloud_tpu_init failed: {repr(exc)}\n This a JAX bug; please report "
I0220 10:30:34.624964 3248029 tpu_initializer_helper.cc:94] libtpu.so already in use by another process. Not attempting to load libtpu.so in this process.
file ../tokenizer/config.json not found
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
[10:30:35] - INFO - __main__ -   Training/evaluation parameters TrainingArguments(
_n_gpu=8,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_pin_memory=True,
ddp_find_unused_parameters=None,
debug=[],
deepspeed=None,
disable_tqdm=False,
do_eval=True,
do_predict=False,
do_train=True,
eval_accumulation_steps=None,
eval_steps=500,
evaluation_strategy=IntervalStrategy.NO,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O

data_args.train_file_path:  --------------------------------------------------> None


[10:30:35] - DEBUG - urllib3.connectionpool -   https://s3.amazonaws.com:443 "HEAD /datasets.huggingface.co/datasets/datasets/wikitext/wikitext.py HTTP/1.1" 200 0
[10:30:35] - DEBUG - urllib3.connectionpool -   Starting new HTTPS connection (1): raw.githubusercontent.com:443
[10:30:36] - DEBUG - urllib3.connectionpool -   https://raw.githubusercontent.com:443 "HEAD /huggingface/datasets/1.15.1/datasets/wikitext/wikitext.py HTTP/1.1" 200 0
[10:30:36] - DEBUG - urllib3.connectionpool -   Starting new HTTPS connection (1): raw.githubusercontent.com:443
[10:30:36] - DEBUG - urllib3.connectionpool -   https://raw.githubusercontent.com:443 "HEAD /huggingface/datasets/1.15.1/datasets/wikitext/dataset_infos.json HTTP/1.1" 200 0
[10:30:36] - INFO - datasets.info -   Loading Dataset Infos from /home/arij/.cache/huggingface/modules/datasets_modules/datasets/wikitext/aa5e094000ec7afeb74c3be92c88313cd6f132d564c7effd961c10fd47c76f20
[10:30:36] - INFO - datasets.builder -   Overwrite dataset info fro

  0%|          | 0/3 [00:00<?, ?it/s]



******************** Tokenize data set  ************************


Grouping texts in chunks of 128:   0%|          | 0/5 [00:00<?, ?ba/s]

[10:30:36] - INFO - datasets.arrow_dataset -   Caching processed dataset at /home/arij/.cache/huggingface/datasets/wikitext/wikitext-103-raw-v1/1.0.0/aa5e094000ec7afeb74c3be92c88313cd6f132d564c7effd961c10fd47c76f20/cache-94a903d644427d72.arrow


Grouping texts in chunks of 128:   0%|          | 0/1802 [00:00<?, ?ba/s]

[10:30:36] - INFO - datasets.arrow_dataset -   Caching processed dataset at /home/arij/.cache/huggingface/datasets/wikitext/wikitext-103-raw-v1/1.0.0/aa5e094000ec7afeb74c3be92c88313cd6f132d564c7effd961c10fd47c76f20/cache-32859008fe3fe8bd.arrow


Grouping texts in chunks of 128:   0%|          | 0/4 [00:00<?, ?ba/s]

[10:31:39] - INFO - datasets.arrow_dataset -   Caching processed dataset at /home/arij/.cache/huggingface/datasets/wikitext/wikitext-103-raw-v1/1.0.0/aa5e094000ec7afeb74c3be92c88313cd6f132d564c7effd961c10fd47c76f20/cache-b971abb1884cd9c7.arrow

Training...:   0%|                                                                                                                                                                                                                                     | 0/1105878 [00:00<?, ?it/s][A

-----------------------------------------------------------------   samples before masking using special tokens -------------------------------------------------------------------
</s> = valkyria chronicles iii = </s></s> senjō no valkyria 3 : unrecorded chronicles ( japanese : 戦場のヴァルキュリア3, lit. valkyria of the battlefield 3 ), commonly referred to as valkyria chronicles iii outside japan, is a tactical role @-@ playing video game developed by sega and media.vision for the playstation portable. released in january 2011 in japan, it is the third game in the valkyria series. employing the same fusion of tactical
**************************************************************************************************************************************
---------------------------   input_ids  ------------------------------------
</s> = valkyria chronicles iii = </s></s> senjō no valkyria 3 <extra_id_99> unrecorded chronicles ( japanese : 戦場のヴァ<extra_id_98>リア3, lit. valkyria of the battlefield 3 


Training...:   0%|                                                                                                                                                                                                                        | 1/1105878 [00:04<1424:50:41,  4.64s/it][A

-----------------------------------------------------------------   samples before masking using special tokens -------------------------------------------------------------------
and real @-@ time gameplay as its predecessors, the story runs parallel to the first game and follows the " nameless ", a penal military unit serving the nation of gallia during the second europan war who perform secret black operations and are pitted against the imperial unit " calamaty raven ". </s> the game began development in 2010, carrying over a large portion of the work done on valkyria chronicles ii. while it retained the standard features of the series, it also underwent multiple adjustments, such as making the game more forgiving for series newcomers 
**************************************************************************************************************************************
---------------------------   input_ids  ------------------------------------
and real @-@ time gameplay as its pred




In [3]:
import re
line = '''</s> = valkyria chronicles iii = </s></s> senjō no valkyria 3 : unrecorded chronicles ( japanese : 戦場のヴァルキュリア3, lit. valkyria of the battlefield 3 ), commonly referred to as valkyria chronicles iii outside japan, is a tactical role @-@ playing video game developed by sega and media.vision for the playstation portable. released in january 2011 in japan, it is the third game in the valkyria series. employing the same fusion of tactical and real @-@ time gameplay as its predecessors, the story runs parallel to the first game and follows the " nameless ", a penal military unit serving the nation of gallia during the second europan war who perform secret black operations and are pitted against the imperial unit " calamaty raven ". </s> the game began development in 2010, carrying over a large portion of the work done on valkyria chronicles ii. while it retained the standard features of the series, it also underwent multiple adjustments, such as making the game more forgiving for series newcomers. character designer raita honjou and composer hitoshi sakimoto both returned from previous entries, along with valkyria chronicles ii director takeshi ozawa. a large team of writers handled the script. the game's opening theme was sung by may 'n. </s> it met with positive sales in japan, and was praised by both japanese and western critics. after release, it received downloadable content, along with an expanded edition in november of that year. it was also adapted into manga and an original video animation series. due to low sales of valkyria chronicles ii, valkyria chronicles iii was not localized, but a fan translation compatible with the game's expanded edition was released in 2014. media.vision would return to the franchise with the development of valkyria : azure revolution for the playstation 4. </s></s> = = gameplay = = </s></s> as with previous valkyira chronicles games, valkyria chronicles iii is a tactical role @-@ playing game where players take control of a military unit and take part in missions against enemy forces. stories are told through comic book '''
line.replace("</s>","")

' = valkyria chronicles iii =  senjō no valkyria 3 : unrecorded chronicles ( japanese : 戦場のヴァルキュリア3, lit. valkyria of the battlefield 3 ), commonly referred to as valkyria chronicles iii outside japan, is a tactical role @-@ playing video game developed by sega and media.vision for the playstation portable. released in january 2011 in japan, it is the third game in the valkyria series. employing the same fusion of tactical and real @-@ time gameplay as its predecessors, the story runs parallel to the first game and follows the " nameless ", a penal military unit serving the nation of gallia during the second europan war who perform secret black operations and are pitted against the imperial unit " calamaty raven ".  the game began development in 2010, carrying over a large portion of the work done on valkyria chronicles ii. while it retained the standard features of the series, it also underwent multiple adjustments, such as making the game more forgiving for series newcomers. characte

In [4]:
print(line.replace("</s>","\n"))


 = valkyria chronicles iii = 

 senjō no valkyria 3 : unrecorded chronicles ( japanese : 戦場のヴァルキュリア3, lit. valkyria of the battlefield 3 ), commonly referred to as valkyria chronicles iii outside japan, is a tactical role @-@ playing video game developed by sega and media.vision for the playstation portable. released in january 2011 in japan, it is the third game in the valkyria series. employing the same fusion of tactical and real @-@ time gameplay as its predecessors, the story runs parallel to the first game and follows the " nameless ", a penal military unit serving the nation of gallia during the second europan war who perform secret black operations and are pitted against the imperial unit " calamaty raven ". 
 the game began development in 2010, carrying over a large portion of the work done on valkyria chronicles ii. while it retained the standard features of the series, it also underwent multiple adjustments, such as making the game more forgiving for series newcomers. chara