In [None]:
import torch
import random
import sys
import json
import functools
import argparse
from transformers import T5Tokenizer

"""
This code is heavily based on the TensorFlow preprocessing code from the T5 paper, available here:
https://github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/data/preprocessors.py

Adapted for use with huggingface (torch) by Aaron Mueller.
"""

def to_dict(text, tokenizer, include_eos=True):
    target = tokenizer.encode(text) if include_eos else tokenizer.encode(text)[:-1]
    return {'inputs': "",
            'targets': torch.tensor(target)}


def load_data(in_file, tokenizer):
    """Expects input of the format
    `{"translation": {"src": utterance, "tgt": label}}`.
    Returns list of dictionaries of the following format:
    {"inputs": "", "targets": tensor([encoded_text])}."""
    utterances, intents = [], []
    punc = (".", ";", "!", "?", ",")
    with open(in_file, 'r') as datastrings:
        for datastring in datastrings:
            data = json.loads(datastring)
            utterance = data["translation"]["src"].strip()
            intent = data["translation"]["tgt"].strip()
            if not utterance.endswith(punc):
                utterance += "."
            if not intent.endswith(punc):
                intent += "."

            utterances.append(to_dict(utterance, tokenizer, include_eos=False))
            intents.append(to_dict(intent, tokenizer, include_eos=True))
    return (utterances, intents)


def write_data(dataset, out_name, tokenizer):
    with open(out_name, "w") as out_file:
        for data in dataset:
            data = {"inputs": tokenizer.decode(data["inputs"]),
                    "targets": tokenizer.decode(data["targets"])}
            json_obj = json.dumps(data)
            out_file.write(json_obj + "\n")


def span_corruption(utterances, intents,
                    sequence_length,
                    mean_noise_span_length=3.0,
                    0.15=0.15,
                    seq_pack=False,
                    label_semantics="multiple choice",
                    label_0.15=0.5):
    """Preprocessing for T5 denoising objective. Returns preprocessed
    tokenized and encoded data.
    Args:
        dataset -- list of tensors (N, ?) where N is number of examples.
                   tensor length depends on length of tokenized example.
        sequence_length -- Maximum sequence length (default: 512)
        seq_pack -- pack inputs into sequences of length approximately `sequence_length`.
        label_semantics -- Whether and how to mask the utterance and intent. Can take the following values:
                               None: only use utterances. Intents will not appear in the data.
                               'concat': append intents to utterances, noise as if it were one full sequence.
                               'full label': simply mask the entire label and none of the utterance.
                               'separate': mask tokens in utterance with `0.15` probability, and mask
                                           tokens in label with `label_0.15` probability.
                               'label permute': try all possible ways of masking the tokens in the intent. Treat
                                                each permutation as a new training example.
                               'multiple choice': treat as a multiple choice problem. Give correct intent and
                                                  a set of [2, 29] random intents with the utterance in the source
                                                  sequence. Transduce to intent.
    """
    if label_semantics is not None and label_semantics not in ("full label", "label permute", "separate",
                                                               "multiple choice"):
            raise ValueError("Unrecognized label masking strategy. Must be one of "
                             "{'full label', 'label permute', 'separate'}.")

    input_length, targets_length = random_spans_helper(inputs_length=512)

    if sequence_length['targets'] < targets_length:
        # raise Exception("Exception not working?")
        raise ValueError(f'Expected targets length for span corruption ({targets_length}) is '
                         f'greater than configured targets length '
                         f"({sequence_length['targets']})")

    tokenizer = T5Tokenizer.from_pretrained('t5-base')
    proc_utterance_label_together = False
    if label_semantics is None:
        proc_utterance_label_together = True
        ds = utterances
    elif label_semantics == "concat":
        proc_utterance_label_together = True
        ds = []
        for utterance, intent in zip(utterances, intents):
            ds.append({'inputs': "",
                       'targets': torch.cat((utterance["targets"], intent["targets"]))})
    if proc_utterance_label_together:
        ds = select_random_chunk(ds)    # deal with inputs longer than 512 tokens
        if seq_pack:                    # pack sequences into training examples of ~512 tokens
            ds = random_concat(ds)
        ds = denoise(
            ds,
            tokenizer=tokenizer,
            inputs_fn=noise_span_to_unique_sentinel,
            targets_fn=nonnoise_span_to_unique_sentinel,
            0.15=0.15,
            noise_mask_fn=functools.partial(
                random_spans_noise_mask,
                mean_noise_span_length=mean_noise_span_length
            )
        )
        return ds

    if label_semantics == "full label":  # mask full label, not utterance
        ds = []
        for utterance, intent in zip(utterances, intents):
            sentinel_id = tokenizer.convert_tokens_to_ids("<extra_id_0>")
            input = torch.cat((utterance["targets"], torch.tensor([sentinel_id])))
            target = torch.cat((torch.tensor([sentinel_id]), intent["targets"]))
            if input.shape[0] > 512:    # if seq too long, truncate
                input = input[:512]
            data = {'inputs': input,
                    'targets': target}
            ds.append(data)
        return ds
    if label_semantics == "multiple choice":
        ds = []
        for utterance, intent in zip(utterances, intents):
            num_choices = int(random.uniform(2, 19))    # between 3 and 30 with correct intent
            intents_list = random.sample(intents, num_choices)
            intents_list.append(intent)
            random.shuffle(intents_list)
            # concatenate intent list (without eos tokens)
            intents_choices = torch.cat([intent_item["targets"][:-1] for intent_item in intents_list])
            int_prefix = torch.tensor(tokenizer.encode("intents: ")[:-1])   # [:-1] gets rid of eos token
            utt_prefix = torch.tensor(tokenizer.encode("utterance: ")[:-1])
            eos_id = tokenizer.convert_tokens_to_ids("</s>")
            source_tok = torch.cat((int_prefix, intents_choices, utt_prefix, utterance["targets"],
                                    torch.tensor([eos_id])))
            if source_tok.shape[0] > 512:   # if seq too long, only give the correct intent.
                source_tok = torch.cat((int_prefix, intent["targets"][:-1], utt_prefix, utterance["targets"],
                                        torch.tensor([eos_id])))
            if source_tok.shape[0] > 512:   # if seq still too long, truncate
                source_tok = source_tok[:512]
            data = {'inputs': source_tok, 'targets': intent["targets"]}
            ds.append(data)
        return ds


"""========== DRIVER CODE =========="""
if __name__ == "__main__":
    
    labelsemantics = "concat"

    # SET RANDOM SEED FOR REPLICABLE BEHAVIOR
    torch.manual_seed(1248)
    random.seed(1248)

    dataset = "dataset/json/polyai-bank/polyai-bank_train.json"

    sequence_length = {'inputs': 512, 'targets': 512}
    tokenizer = T5Tokenizer.from_pretrained('t5-base')
    utterances, intents = load_data(dataset, tokenizer)
    dataset = span_corruption(utterances, intents, sequence_length, seq_pack=False, label_semantics=labelsemantics)
    write_data(dataset, "random_noise.json", tokenizer)

In [71]:
from dataclasses import dataclass
from transformers import T5Tokenizer, PreTrainedTokenizerBase

@dataclass
class RandomNoise:
    """Randomly corrupts a span of tokens in the input."""

    tokenizer: PreTrainedTokenizerBase

    def span_corruption(self, utterances, mean_noise_span_length=3.0, noise_density=0.15):

        dataset = [self.get_random_segment(data, max_length=512) 
                   for data in list(filter(lambda x: x["targets"].shape[0] > 0, utterances))]
        
        return self.denoise(dataset, noise_density=noise_density, noise_mask_fn=functools.partial(
                self.random_spans_noise_mask,
                mean_noise_span_length=mean_noise_span_length
            )
        )

    def get_random_segment(self, data, max_length):
        """Extract a chunk from the data, given a maximum length."""
        tokens = data[ "targets" ]
        if tokens.shape[0] < max_length:
            return {"targets": tokens}
        start = torch.randint(0, tokens.shape[0] - max_length + 1, (1,)).item()
        return {"targets": tokens[start: start + max_length]}


    def random_spans_noise_mask(self, length, noise_density=0.15, mean_noise_span_length=3.0):
        """Calculate which spans to mask given input length.
        Returns a vector of Booleans of length `length`, where `True`
        corresponds to masking and `False` corresponds to keeping a token.
        """
        orig_length = length
        length = torch.tensor(length, dtype=torch.int32)
        # avoid degenerate length values
        length = torch.maximum(length, torch.tensor(2, dtype=torch.int32))
        # helper functions for concise type conversion
        def to_int(x):
            return x.type(torch.int32)
        def to_float(x):
            return x.type(torch.float32)
        # calculate number of noised and non-noised tokens
        num_noise_tokens = to_int(torch.round(to_float(length) * noise_density))
        num_noise_tokens = torch.minimum(
            torch.maximum(num_noise_tokens, torch.tensor(1, dtype=torch.int32)), length-1)
        num_noise_spans = to_int(
            torch.round(to_float(num_noise_tokens) / mean_noise_span_length))
        num_noise_spans = torch.maximum(num_noise_spans, torch.tensor(1, dtype=torch.int32))
        num_nonnoise_tokens = length - num_noise_tokens
        # pick lengths of noise spans and non-noise spans
        def _random_segmentation(num_items, num_segments):
            """Partition items randomly into non-empty segments."""
            first_in_segment = torch.nn.functional.pad(
                self.shuffle(to_int(torch.arange(num_items - 1) < num_segments - 1)),
                [1, 0])
            segment_id = torch.cumsum(first_in_segment, 0)
            segment_length = self.segment_sum(torch.ones_like(segment_id), segment_id)
            return segment_length

        noise_span_lengths = _random_segmentation(
            num_noise_tokens, num_noise_spans)
        nonnoise_span_lengths = _random_segmentation(
            num_nonnoise_tokens, num_noise_spans)
        interleaved_span_lengths = torch.reshape(
            torch.stack([nonnoise_span_lengths, noise_span_lengths], axis=1),
                        [num_noise_spans * 2])
        span_starts = torch.cumsum(interleaved_span_lengths, 0)[:-1]
        span_start_indicator = self.segment_sum(
            torch.ones_like(span_starts), span_starts, length)
        span_num = torch.cumsum(span_start_indicator, 0)
        is_noise = torch.eq(span_num % 2, torch.tensor(1, dtype=torch.int64))
        return is_noise[:orig_length]


    def denoise(self, dataset, noise_density=0.15, noise_mask_fn=None):
        vocab_size = self.tokenizer.vocab_size
        def map_fn(features):
            tokens = features['targets']
            noise_mask = noise_mask_fn(tokens.shape[0], noise_density)
            inputs = self.noise_span_to_unique_sentinel(tokens, noise_mask, vocab_size)
            return {
                'inputs': inputs,
                'targets': self.nonnoise_span_to_unique_sentinel(tokens, noise_mask, vocab_size)
            }
        return [map_fn(data) for data in dataset]


    def noise_span_to_unique_sentinel(self, tokens, noise_mask, vocab_size):
        prev_token_is_noise = torch.nn.functional.pad(
            noise_mask[:-1], [1, 0])

        first_noise_tokens = torch.logical_and(
            noise_mask, torch.logical_not(prev_token_is_noise))
        subsequent_noise_tokens = torch.logical_and(
            noise_mask, prev_token_is_noise)

        sentinel = vocab_size - torch.cumsum(first_noise_tokens.int(), 0)

        tokens = torch.where(first_noise_tokens, sentinel, tokens)
        return torch.masked_select(tokens, torch.logical_not(subsequent_noise_tokens))


    def nonnoise_span_to_unique_sentinel(self, tokens, noise_mask, vocab_size):
        return self.noise_span_to_unique_sentinel(
            tokens, torch.logical_not(noise_mask), vocab_size)


    """============= UTILITY FUNCTIONS ==============="""
    def shuffle(self, value):
        """Randomly shuffle a tensor."""
        return value[torch.randperm(value.numel())].reshape(value.shape)

    def segment_sum(self, data, segment_ids, num_segments=None):
        """Compute the sum along segments of a tensor."""
        if num_segments is None:
            num_segments = segment_ids.unique().numel()
        shape = [num_segments] + list(data.shape[1:])
        return torch.zeros(*shape, dtype=data.dtype).scatter_add_(0, segment_ids, data)


random_denoiser = RandomNoise( tokenizer=tokenizer )
ds = random_denoiser.span_corruption(
    utterances=utterances
)
write_data(ds, "random_noise.json", tokenizer)
