In [1]:
import tensorflow as tf
import numpy as np

import random

import pickle
import collections

def load_pkl(file_path) :
    
    with open(file_path, 'rb') as f:
        df = pickle.load(f)
        
    return df

In [27]:
def load_vocab(vocab_file):
    """Loads a vocabulary file into a dictionary."""
    vocab = collections.OrderedDict()
    index = 0
    with tf.io.gfile.GFile(vocab_file, "r") as reader:
        while True:
            token = reader.readline()
            token = token if isinstance(token, str) else token.decode('utf8')
            if not token:
                break
            token = token.strip()
            vocab[token] = index
            index += 1
    return vocab

In [224]:
def create_training_instances(input_files,
                              tokenizer,
                              max_seq_length,
                              dupe_factor,
                              short_seq_prob,
                              masked_lm_prob,
                              max_predictions_per_seq,
                              rng) :
    
    all_documents = [[]]
    vocab_size = tokenizer.vocab_size
    
    for input_file in input_files :
        
        inputs = load_pkl(input_file)
        keys = list(inputs.keys())
        
        for key in keys :
            tokens = [ tokenizer(x)['input_ids'] for x in inputs[key] if x ]

            all_documents.append(tokens)
        
    
    all_documents = [x for x in all_documents if x]
#     rng.shuffle(all_documents)
    
    instances = []
    
    for _ in range(dupe_factor):
        for document_index in range(len(all_documents)):
            instances.extend(
                create_instances_from_document(
                      all_documents, document_index, max_seq_length, short_seq_prob,
                      masked_lm_prob, max_predictions_per_seq, rng, vocab_size))

#     rng.shuffle(instances)
    return instances
    
        

In [225]:
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
    """Truncates a pair of sequences to a maximum sequence length."""
    while True:
        total_length = len(tokens_a) + len(tokens_b)
        if total_length <= max_num_tokens:
            break

        trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
        assert len(trunc_tokens) >= 1

        # We want to sometimes truncate from the front and sometimes from the
        # back to add more randomness and avoid biases.
        if rng.random() < 0.5:
            del trunc_tokens[0]
        else:
            trunc_tokens.pop()
            

In [242]:
def create_masked_lm_predictions(tokens, masked_lm_prob, max_predictions_per_seq, rng, vocab_size) :
    
    # 15% 이상으론 Masking하지 않겠다.
    # 15%는 반드시 Masking 하겠다.
    max_masked_tokens = min(max_predictions_per_seq, max(1, int(round(len(tokens) * masked_lm_prob))))
    
    num_tokens = len(tokens)
    
    masked_grams = []
    masked_tokens = [False] * num_tokens
    
    while sum(masked_tokens) < max_masked_tokens : # and sum(len(s) for s in ngrams.values())):
    
        # Choose a random n-gram of the given size.
        idx = random.choices(range(1, num_tokens))[0] # masked_idx
        
        if tokens[idx] in [0, 1, 2, 3, 4] : # [PAD], [SEQ], [CLS], [MASK], [UNK] 마스킹 제외
            continue
        
        # Check if any of the tokens in this gram have already been masked.
        if masked_tokens[idx]:
            continue

        # Found a usable n-gram!  Mark its tokens as masked and add it to return.
        masked_tokens[idx] = True
        masked_grams.append(idx)
        
    #  output_ngrams [token[idx1], token[idx2], token[idx3], ... token[idx_max_predictions_per_seq]]
    
    masked_lms = []
    output_tokens = list(tokens)
    
    for gram_idx in masked_grams :
        
        
        if rng.random() < 0.8 :
            replace_action = lambda idx: 4 # ["MASK"]
        else :
            if rng.random() < 0.5 :
                replace_action = lambda idx : idx
            else :
                replace_action = lambda idx : rng.choice(range(5, vocab_size)) # [PAD], [SEQ], [CLS], [MASK], [UNK] 제외
                
        output_tokens[gram_idx] = replace_action(gram_idx)
        masked_lms.append([gram_idx, tokens[gram_idx]])
        
    assert len(masked_lms) <= max_masked_tokens
    
    masked_lm_positions = []
    masked_lm_labels = []
    
    for p in masked_lms:
        masked_lm_positions.append(p[0])
        masked_lm_labels.append(p[1])
        
    return output_tokens, masked_lm_positions, masked_lm_labels
    

In [256]:

class TrainingInstance(object):
    """A single training instance (sentence pair)."""

    def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels,
               is_random_next):
        self.tokens = tokens
        self.segment_ids = segment_ids
        self.is_random_next = is_random_next
        self.masked_lm_positions = masked_lm_positions
        self.masked_lm_labels = masked_lm_labels

    def __str__(self):
        
        s = ""
        s += "tokens: %s\n" % (" ".join([str(x) for x in self.tokens]))
        s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids]))
        s += "is_random_next: %s\n" % self.is_random_next
        s += "masked_lm_positions: %s\n" % (" ".join(
            [str(x) for x in self.masked_lm_positions]))
        s += "masked_lm_labels: %s\n" % (" ".join(
            [str(x) for x in self.masked_lm_labels]))
        s += "\n"
        
        return s

    def __repr__(self):
        return self.__str__()


In [257]:
def create_instances_from_document(all_documents
                                    , document_index
                                    , max_seq_length
                                    , short_seq_prob
                                    , masked_lm_prob
                                    , max_predictions_per_seq
                                    , rng
                                    , vocab_size) :
    
    current_document = all_documents[document_index]
    tokens = []
    # Sequence 길이 제한
    max_num_tokens = max_seq_length - 3 
    target_seq_length = max_num_tokens
    
    if rng.random() < short_seq_prob:
        target_seq_length = rng.randint(2, max_num_tokens)
    
    instances = []
    current_chunk = []
    current_length = 0
    i = 0
    
    while i < len(current_document) :

        current_statement = current_document[i][1:-1] # [CLS], [SEP] 제거
        
        current_chunk.append(current_statement)
        current_length += len(current_statement)
        
        if i == len(current_document) - 1 or current_length >= target_seq_length:
            
            if current_chunk:
            # `a_end` is how many segments from `current_chunk` go into the `A`
            # (first) sentence.
                a_end = 1
                if len(current_chunk) >= 2:
                    a_end = rng.randint(1, len(current_chunk) - 1)

                tokens_a = []
                for j in range(a_end):
                    tokens_a.extend(current_chunk[j])

        
                is_random_next = False
                tokens_b = []
                
                if (len(current_chunk) == 1) or (rng.random() > 0.5) :

                    is_random_next = True 
                    target_b_length = target_seq_length - len(tokens_a) 
                    
                    r_document_ind = document_index
                    while r_document_ind == document_index :
                        r_document_ind = np.random.randint(0, len(all_documents))

                    r_document = all_documents[r_document_ind]
                    random_start = rng.randint(0, len(r_document) - 1)
                    
                    for j in range(random_start, len(r_document)):
                        tokens_b.extend(r_document[j][1:-1])
                        if len(tokens_b) >= target_b_length:
                            break

                    r_statement_ind = np.random.randint(0, len(r_document))
                    next_statement = r_document[r_statement_ind][1:-1]
                    
                    num_unused_segments = len(current_chunk) - a_end
                    i -= num_unused_segments
                    
                else :

                    is_random_next = False #isNext
                    for j in range(a_end, len(current_chunk)):
                        tokens_b.extend(current_chunk[j])

                truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)

                assert len(tokens_a) >= 1
                assert len(tokens_b) >= 1

                tokens = []
                segment_ids = []
                tokens.append(2)
                segment_ids.append(0)
                
                for token in tokens_a:
                    tokens.append(token)
                    segment_ids.append(0)

                tokens.append(3)
                segment_ids.append(0)

                for token in tokens_b:
                    tokens.append(token)
                    segment_ids.append(1)
                tokens.append(3)
                segment_ids.append(1)
                
                (tokens, masked_lm_positions
                 , masked_lm_labels) = create_masked_lm_predictions(
                    tokens, masked_lm_prob, max_predictions_per_seq, rng
                                    , vocab_size)
                
                instance = TrainingInstance(
                    tokens = tokens,
                    segment_ids = segment_ids,
                    is_random_next = is_random_next,
                    masked_lm_positions = masked_lm_positions,
                    masked_lm_labels = masked_lm_labels)
                
                instances.append(instance)
                
            current_chunk = []
            current_length = 0
        i += 1

        # current_statement
        # next_statement
        # is_random_next
        
    return instances

In [271]:
def create_int_feature(values):
    feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
    return feature


def create_float_feature(values):
    feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
    return feature

In [320]:
def write_instance_to_example_files(instances, tokenizer, max_seq_length,
                                    max_predictions_per_seq, output_files,
                                    gzip_compress, use_v2_feature_names):
    """Creates TF example files from `TrainingInstance`s."""
    writers = []
    for output_file in output_files:
        writers.append(
            tf.io.TFRecordWriter(
                output_file, options="GZIP" if gzip_compress else ""))

    writer_index = 0

    total_written = 0
    for (inst_index, instance) in enumerate(instances):
        input_ids = instance.tokens
#         input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
        input_mask = [1] * len(input_ids)
        segment_ids = list(instance.segment_ids)
        assert len(input_ids) <= max_seq_length

        while len(input_ids) < max_seq_length:
            input_ids.append(0)
            input_mask.append(0)
            segment_ids.append(0)

        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length

        masked_lm_positions = list(instance.masked_lm_positions)
        masked_lm_ids = instance.masked_lm_labels
        masked_lm_weights = [1.0] * len(masked_lm_ids)

        while len(masked_lm_positions) < max_predictions_per_seq:
            masked_lm_positions.append(0)
            masked_lm_ids.append(0)
            masked_lm_weights.append(0.0)

        next_sentence_label = 1 if instance.is_random_next else 0

        features = collections.OrderedDict()
        if use_v2_feature_names:
            features["input_word_ids"] = create_int_feature(input_ids)
            features["input_type_ids"] = create_int_feature(segment_ids)
        else:
            features["input_ids"] = create_int_feature(input_ids)
            features["segment_ids"] = create_int_feature(segment_ids)

        features["input_mask"] = create_int_feature(input_mask)
        features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
        features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
        features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
        features["next_sentence_labels"] = create_int_feature([next_sentence_label])

        tf_example = tf.train.Example(features=tf.train.Features(feature=features))

        writers[writer_index].write(tf_example.SerializeToString())
        writer_index = (writer_index + 1) % len(writers)

        total_written += 1

    for writer in writers:
        writer.close()

## TEST