- reference
    - https://github.com/HMJiangGatech/pytorch-pretrained-BERT/blob/master/examples/lm_finetuning/pregenerate_training_data.py
    - https://paul-hyun.github.io/bert-01/

In [67]:
import os, inspect
import sys
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0,parentdir) 
from transformer.transformer_utils import *

In [6]:
# vocab loading
import sentencepiece as spm

VOCAB_PATH = "/home/henry/Documents/wrapper/source"
vocab_file = f"{VOCAB_PATH}/kowiki.model"
vocab = spm.SentencePieceProcessor()
vocab.load(vocab_file)

True

In [36]:
from types import SimpleNamespace
config = dict({
    "n_enc_vocab": len(vocab),
    "n_enc_seq": 256,
    "n_seg_type": 2,
    "n_layer": 6,
    "d_hidn": 256,
    "i_pad": 0,
    "d_ff": 1024,
    "n_head": 4,
    "d_head": 64,
    "dropout": 0.1,
    "layer_norm_epsilon": 1e-12
})
config = SimpleNamespace(**config)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
config.device = device

# BERT

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [83]:
class EncoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        self.self_attn = MultiHeadAttention(self.config)
        self.layer_norm1 = nn.LayerNorm(self.config.d_hidn, eps=self.config.layer_norm_epsilon)
        self.pos_ffn = PoswiseFeedForwardNet(self.config)
        self.layer_norm2 = nn.LayerNorm(self.config.d_hidn, eps=self.config.layer_norm_epsilon)
        
    def forward(self, inputs, attn_mask):
        attn_outputs, attn_prob = self.self_attn(inputs, inputs, inputs, attn_mask)
        attn_outputs = self.layer_norm1(inputs + attn_outputs)
        
        ffn_outputs = self.pos_ffn(attn_outputs)
        ffn_outputs = self.layer_norm2(ffn_outputs + attn_outputs)
        
        return ffn_outputs, attn_prob

class Encoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        self.enc_emb = nn.Embedding(self.config.n_enc_vocab, self.config.d_hidn)
        self.pos_emb = nn.Embedding(self.config.n_enc_seq+1, self.config.d_hidn)
        self.seg_emb = nn.Embedding(self.config.n_seg_type, self.config.d_hidn)
        
        self.layers = nn.ModuleList([EncoderLayer(self.config) for _ in range(self.config.n_layer)])
        
    def forward(self, inputs, segments):
        positions = torch.arange(inputs.size(1), device=inputs.device, dtype=inputs.dtype)\
            .expand(inputs.size(0), inputs.size(1)).contiguous() + 1
        pos_mask = inputs.eq(self.config.i_pad)
        positions.masked_fill_(pos_mask, 0)
        
        outputs = self.enc_emb(inputs) + self.pos_emb(positions) + self.seg_emb(segments)
        attn_mask = get_attn_pad_mask(inputs, inputs, self.config.i_pad)
        
        attn_probs = []
        for layer in self.layers:
            outputs, attn_prob = layer(outputs, attn_mask)
            attn_probs.append(attn_prob)
        return outputs, attn_probs

In [85]:
class BERT(nn.Module):
    """
    outputs: [bs, len_seq, d_hidn] <- 잘 임베딩된 input seq
    outputs_cls = outputs[:, 0].contiguous(): [bs, d_hidn]
    - classification은 [cls] token의 임베딩만 사용

    """
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.encoder = Encoder(self.config)
        self.linear = nn.Linear(config.d_hidn, config.d_hidn)
        self.activation = torch.tanh
    
    def forward(self, inputs, segments):
        outputs, self_attn_probs = self.encoder(inputs, segments)
        outputs_cls = outputs[:, 0].contiguous()
        outputs_cls = self.linear(outputs_cls)
        outputs_cls = self.activation(outputs_cls)
        return outputs, outputs_cls, self_attn_probs
    
    def save(self, epoch, loss, path):
        torch.save({
            'epoch': epoch,
            'loss': loss,
            'state_dict': self.state_dict()
                   }, path)
        
    def load(self, path):
        save = torch.load(path)
        self.load_state_dict(save['state_dict'])
        return save['epoch'], save['loss']

In [14]:
class BERTpretrain(nn.Module):
    """
    self.feedforward_lm.weight
    - transformer encoder의 pretrained embedding layer weight 사용
    logits_cls: [bs, 2]
    - binary classification
    logits_lm: [bs, len_enc_seq, n_enc_vocab]
    """
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.bert  = BERT(self.config)
        # cls
        self.feedforward_cls = nn.Linear(self.config.d_hidn, 2, bias=False)
        # lm
        self.feedforward_lm = nn.Linear(self.config.d_hidn, self.config.n_enc_vocab, bias=False)
        self.feedforward_lm.weight = self.bert.encoder.enc_emb.weight
    
    def forward(self, inputs, segments):
        outputs, outputs_cls, attn_probs = self.bert(inputs, segments)
        logits_cls = self.feedforward_cls(outputs_cls)
        logits_lm = self.feedforward_lm(outputs)
        return logits_cls, logits_lm, attn_probs

---
# Order of Pretrain func call
- `make_pretrain_data` -> `create_pretrain_instances` -> `create_pretrain_mask`
---
# Masking
- https://github.com/codertimo/BERT-pytorch/blob/master/bert_pytorch/trainer/pretrain.py

In [15]:
def create_pretrain_mask(tokens, mask_cnt, vocab_list):
    """
    masking subwords(15% of entire subwords)
    - mask_cnt: len(subwords) * 0.15
    - [MASK]: 80% of masking candidate token
    - original token: 10% of masking candidate token
    - another token: 10% of masking candidate token
    """
    candidate_idx = []

    ## subwords in the same list augment a sementic word 
    ## eg. [[0], [1], [2], [4, 5]] -> token_idx 4 + 5 is semantic word
    # A list represent a sementic word
    for i, token in enumerate(tokens):
        if token == '[CLS]' or token == '[SEP]':
            continue
        if 0 < len(candidate_idx) and token.find(u'\u2581') < 0: #  LOWER ONE EIGHTH BLOCK
#        if 0 < len(candidate_idx) and token.find('_') < 0: #  test code
            candidate_idx[-1].append(i)
        else:
            candidate_idx.append([i])
    np.random.shuffle(candidate_idx)

    mask_lms = []
    for idx_set in candidate_idx:
        # check if len(mask_lms) exceeds threshold
        if len(mask_lms) >= mask_cnt:
            break
        if len(mask_lms) + len(idx_set) > mask_cnt:
            continue

        ## masking subwords with 15% probability
        ## mask_cnt is len(subwords) * 0.15 
        # iter subwords idx
        for sub_idx in idx_set:
            masked_token = None

            ### assign value to masked token: [MASK], original token, random token
            # 80% of masking candidate are replaced with '[MASK]' token
            if np.random.uniform() < 0.8:
                masked_token = '[MASK]'
            # remainng 20% of masking candidate
            else:
                # 10% of remaining preserve original token
                if np.random.uniform() < 0.5:
                    masked_token = tokens[sub_idx]
                # 10% of ones are replaced with rnadom token    
                else:
                    masked_token = np.random.choice(vocab_list)

                ### replace subword with masked_token value    
                mask_lms.append({'idx': sub_idx, 'label':tokens[sub_idx]})
                tokens[sub_idx] = masked_token
                
    mask_lms = sorted(mask_lms, key=lambda x: x['idx'])
    mask_idx = [mask_dict['idx'] for mask_dict in mask_lms]
    mask_label = [mask_dict['label'] for mask_dict in mask_lms]
#     print(candidate_idx)
#     print(mask_lms)
    print(mask_idx, mask_label)
    return tokens, mask_idx, mask_label

In [16]:
def truncate_token(tokenA, tokenB, max_seq):
    """
    truncate long sequence
    """
    while True:
        total_len = len(tokenA) + len(tokenB)
        print('max token {}\ntotal_len {} = {} + {}'.format(max_seq, total_len, len(tokenA), len(tokenB)))
        if total_len <= max_seq:
            break
        if len(tokenA) > len(tokenB):
            tokenA.pop()
        else:
            tokenB.pop()

### - example code

In [265]:
vocab_list = ['_I am', '_on', '_it', '_so wh', 'at',]
tokens = ['_I am', '_on', '_it', '[CLS]', '_so wh', 'at',]
mask_cnt = int(len(tokens) * 0.15)+1
create_pretrain_mask(tokens, mask_cnt, vocab_list)

(['_I am', '_on', '_it', '[CLS]', '_so wh', 'at'], [], [])

---
# pretrain dataset for each paragraph

In [17]:
def create_pretrain_instances(paragraph_ls, paragraph_idx, paragraph, n_seq, mask_prob, vocab_list):
    """
    create NSP train set
    """
    # 3 special token: [CLS], [SEP] for sent A, [SEP] for sent B
    max_seq_len = n_seq - 2 - 1
    target_seq_len = max_seq_len # [CLS], segmentA, segmentA, ..., [SEP], segmentB, segmentB, ...

    instances = []
    temp_sentence = []
    temp_sent_seq_length = 0 # num of tokens

    max_num_tokens = 256
    target_seq_len = np.random.randint(2, max_num_tokens) # min len of tokens
    for i, sent in enumerate(paragraph):
        ## A. not the last sentence of the paragraph
        temp_sentence.append(sent)
        temp_sent_seq_length += len(sent)

        ## B. check if it is the last sentence of the paragraph
        ## or temp_sent_seq_length is longer than or equal to target_seq_len
        if i == len(paragraph) - 1 or temp_sent_seq_length >= target_seq_len:
            if temp_sentence:
                ## A. sentence A segment: from 0 to a_end
                a_end = 1
                if len(temp_sentence) != 1:
                    a_end = np.random.randint(1, len(temp_sentence))
                # append the sentences to tokenA 
                # from the front to the back
                tokenA = []
                for _, s in enumerate(temp_sentence[:a_end]):
                    tokenA.extend(s)

                ## B. sentence B segment
                tokenB = []
                # A. Actual next
                # is_next will be the label for NSP pretrain
                if len(temp_sentence) > 1 and np.random.uniform() >= 0.5:
                    is_next = True
                    for j in range(a_end, len(temp_sentence)):
                        tokenB.extend(temp_sentence[j])
                # B. random next
                else:
                    is_next = False
                    tokenB_len = target_seq_len - len(tokenA)
                    random_para_idx = para_idx
                    while para_idx == random_para_idx:
                        random_para_idx = np.random.randint(0, len(paragraph_ls))
                    random_para = paragraph[random_para_idx]

                    random_start = np.random.randint(0, len(random_para))
                    for j in range(random_start, len(random_para)):
                        tokenB.extend(random_para[j])

                truncate_token(tokenA, tokenB, max_seq)
                assert 0 < len(tokenA)
                assert 0 < len(tokenB)

                tokens = ["[CLS]"] + tokenA + ["[SEP]"] + tokenB + ["[SEP]"]
                segment = [0]*(len(tokenA)  + 2) + [1]*(len(tokenB) + 1)
                
                tokens, mask_idx, mask_label = \
                    create_pretrain_mask(tokens, int((len(tokens)-3)*mask_prob), vocab_list)
                instance = {
                    'tokens': tokens,
                    'segment': segment,
                    'is_next': is_next,
                    'mask_idx': mask_idx,
                    'mask_label': mask_label
                }

                instances.append(instance)

            # reset segment candidate
            temp_sentence = []
            temp_sent_seq_length = 0
    
    return instances

- example code

In [297]:
para_idx = 1
paragraph = paragraph_ls[para_idx]
print(len(paragraph), paragraph[1])

2 ['▁카', '터', '는', '▁1970', '년대', '▁후반', '▁당시', '▁대한민국', '▁등', '▁인', '권', '▁후', '진', '국의', '▁국민', '들의', '▁인', '권을', '▁지', '키', '기', '▁위해', '▁노력', '했으며', ',', '▁취임', '▁이후', '▁계속', '해서', '▁도', '덕', '정', '치를', '▁내', '세', '웠다', '.']


- i = 0
    - add 1st sequence to meet target sequence length
    - max_len_doc == 8

In [298]:
n_seq = 256
max_seq = n_seq - 3
target_seq_len = max_seq

instances = []
temp_sentences = []
temp_sent_seq_length = 0
print('target seq len:', target_seq_len)

target seq len: 253


In [299]:
# the case of 1st sentence in 2nd paragraph 
para_idx = 1
paragraph = paragraph_ls[para_idx] # 2nd paragraph

temp_sentences = []
i = 0 
temp_sentences.append(paragraph[i]) # 1st sentence
temp_sent_seq_length += len(paragraph[i])

print('current seq length:', temp_sent_seq_length)
print('# of sentences in the paragraph: {}'.format(len(paragraph)))
print('temp_sentences len: {}'.format(len(temp_sentences)), '\n', temp_sentences)

if i == len(paragraph) - 1 or temp_sent_seq_length >= target_seq_len:
    print('Is it the last sentence of the paragraph \nor current chunk longer than target_seq_len?', True)

current seq length: 50
# of sentences in the paragraph: 2
temp_sentences len: 1 
 [['▁그러나', '▁이것은', '▁공', '화', '당', '과', '▁미국의', '▁유대', '인', '▁단', '체의', '▁반', '발', '을', '▁일으', '켰', '다', '.', '▁1979', '년', '▁백', '악', '관', '에서', '▁양', '국', '▁간의', '▁평화', '조', '약', '으로', '▁이끌', '어졌다', '.', '▁또한', '▁소련', '과', '▁제', '2', '차', '▁전략', '▁무', '기', '▁제한', '▁협', '상에', '▁조', '인', '했다', '.']]


- i = 1
    - add one more sentece to segments A to meet target sequence length

In [300]:
i = 1
temp_sentences.append(paragraph[i]) # 1st sentence
temp_sent_seq_length += len(paragraph[i])

print('current seq length:', temp_sent_seq_length)
print('# of sentences in the paragraph: {}'.format(len(paragraph)))
print('temp_sentences len: {}'.format(len(temp_sentences)), '\n', temp_sentences)

if i == len(paragraph) - 1 or temp_sent_seq_length >= target_seq_len:
    print('Is it the last sentence of the paragraph \nor current chunk longer than target_seq_len?', True)

current seq length: 87
# of sentences in the paragraph: 2
temp_sentences len: 2 
 [['▁그러나', '▁이것은', '▁공', '화', '당', '과', '▁미국의', '▁유대', '인', '▁단', '체의', '▁반', '발', '을', '▁일으', '켰', '다', '.', '▁1979', '년', '▁백', '악', '관', '에서', '▁양', '국', '▁간의', '▁평화', '조', '약', '으로', '▁이끌', '어졌다', '.', '▁또한', '▁소련', '과', '▁제', '2', '차', '▁전략', '▁무', '기', '▁제한', '▁협', '상에', '▁조', '인', '했다', '.'], ['▁카', '터', '는', '▁1970', '년대', '▁후반', '▁당시', '▁대한민국', '▁등', '▁인', '권', '▁후', '진', '국의', '▁국민', '들의', '▁인', '권을', '▁지', '키', '기', '▁위해', '▁노력', '했으며', ',', '▁취임', '▁이후', '▁계속', '해서', '▁도', '덕', '정', '치를', '▁내', '세', '웠다', '.']]
Is it the last sentence of the paragraph 
or current chunk longer than target_seq_len? True


- i = 2

In [303]:
# i = 2
# temp_sentences.append(paragraph[i]) # 1st sentence
# temp_sent_seq_length += len(paragraph[i])

# print('current seq length:', temp_sent_seq_length)
# print('# of sentences in the paragraph: {}'.format(len(paragraph)))
# print('temp_sentences len: {}'.format(len(temp_sentences)), '\n', temp_sentences)

# if i == len(paragraph) - 1 or temp_sent_seq_length >= target_seq_len:
#     print('Is it the last sentence of the paragraph \nor current chunk longer than target_seq_len?', True)

In [304]:
# i = 3
# temp_sentences.append(paragraph[i]) # 1st sentence
# temp_sent_seq_length += len(paragraph[i])

# print('current seq length:', temp_sent_seq_length)
# print('# of sentences in the paragraph: {}'.format(len(paragraph)))
# print('temp_sentences len: {}'.format(len(temp_sentences)), '\n', temp_sentences)

# if i == len(paragraph) - 1 or temp_sent_seq_length >= target_seq_len:
#     print('Is it the last sentence of the paragraph \nor current chunk longer than target_seq_len?', True)

- tokenA

In [305]:
if 0 < len(temp_sentences):
    a_end = 1
    if 1 < len(temp_sentences):
        a_end = np.random.randint(1, len(temp_sentences))
        
    tokenA = []
    for j in range(a_end):
        tokenA.extend(temp_sentences[j]) # convert sentence to segment
        
print('tokenA consists of {} sentences'.format(a_end), '\n', tokenA)

tokenA consists of 1 sentences 
 ['▁그러나', '▁이것은', '▁공', '화', '당', '과', '▁미국의', '▁유대', '인', '▁단', '체의', '▁반', '발', '을', '▁일으', '켰', '다', '.', '▁1979', '년', '▁백', '악', '관', '에서', '▁양', '국', '▁간의', '▁평화', '조', '약', '으로', '▁이끌', '어졌다', '.', '▁또한', '▁소련', '과', '▁제', '2', '차', '▁전략', '▁무', '기', '▁제한', '▁협', '상에', '▁조', '인', '했다', '.']


- tokenB random sentence with 1/2 probability

In [311]:
tokenB = []
if len(temp_sentences) == 1 or np.random.uniform() < 0.5:
    is_next = False
    tokenB_len = target_seq_len - len(tokenA) # minumum len of tokenB
    print('tokenB avaliable length: {}'.format(tokenB_len))
    
    # choose sentence from other paragraph
    random_para_idx = para_idx
    while para_idx == random_para_idx:
        random_para_idx = np.random.randint(0, len(paragraph_ls))
    # this is random paragraph    
    random_paragraph = paragraph_ls[random_para_idx]
    
    # add a series of sentences from random paragragraph
    random_start = np.random.randint(0, len(random_paragraph))
    for j in range(random_start, len(random_paragraph)):
        tokenB.extend(random_paragraph[j])
        if len(tokenB) > tokenB_len:
            break
print('tokenB len {} should be longer than {}, \nor truncate segments'.format(len(tokenB), tokenB_len))
print(tokenB[:10])

tokenB avaliable length: 203
tokenB len 198 should be longer than 203, 
or truncate segments
['▁196', '2', '년', '▁조지', '아', '▁주', '▁상', '원', '▁의원', '▁선거']


- tokenB the sentences right after tokenA

In [312]:
is_next = 1
for j in range(a_end, len(temp_sentences)):
    tokenB.extend(temp_sentences[j])
print('the end of the tokenA: \n{}'.format(tokenA[-10:]))
print('the beggining of the tokenB: \n{}'.format(tokenB[:10]))

the end of the tokenA: 
['▁전략', '▁무', '기', '▁제한', '▁협', '상에', '▁조', '인', '했다', '.']
the beggining of the tokenB: 
['▁196', '2', '년', '▁조지', '아', '▁주', '▁상', '원', '▁의원', '▁선거']


In [314]:
tokens = ["[CLS]"] + tokenA + ["[SEP]"] + tokenB + ["[SEP]"]
segment = [0]*(len(tokenA)  + 2) + [1]*(len(tokenB) + 1)
print(tokens)

['[CLS]', '▁그러나', '▁이것은', '▁공', '화', '당', '과', '▁미국의', '▁유대', '인', '▁단', '체의', '▁반', '발', '을', '▁일으', '켰', '다', '.', '▁1979', '년', '▁백', '악', '관', '에서', '▁양', '국', '▁간의', '▁평화', '조', '약', '으로', '▁이끌', '어졌다', '.', '▁또한', '▁소련', '과', '▁제', '2', '차', '▁전략', '▁무', '기', '▁제한', '▁협', '상에', '▁조', '인', '했다', '.', '[SEP]', '▁196', '2', '년', '▁조지', '아', '▁주', '▁상', '원', '▁의원', '▁선거', '에서', '▁낙', '선', '하나', '▁그', '▁선거', '가', '▁부정', '선거', '▁', '였', '음을', '▁입', '증', '하게', '▁되어', '▁당선', '되고', ',', '▁196', '6', '년', '▁조지', '아', '▁주', '▁지', '사', '▁선거', '에', '▁낙', '선', '하지만', '▁1970', '년', '▁조지', '아', '▁주', '▁지', '사를', '▁역임', '했다', '.', '▁대통령', '이', '▁되', '기', '▁전', '▁조지', '아', '주', '▁상', '원의', '원을', '▁두', '번', '▁연', '임', '했으며', ',', '▁1971', '년부터', '▁1975', '년까지', '▁조지', '아', '▁지', '사로', '▁근무', '했다', '.', '▁조지', '아', '▁주', '지', '사로', '▁지', '내', '면서', ',', '▁미국', '에', '▁사는', '▁흑', '인', '▁등', '용', '법을', '▁내', '세', '웠다', '.', '▁1976', '년', '▁대통령', '▁선거', '에', '▁민주', '당', '▁후보', '로', '▁출', '마', '하여', '▁도'

In [318]:
tokens, mask_idx, mask_label = \
    create_pretrain_mask(tokens, int((len(tokens)-3)*mask_prob), ['test_mask_voc1', 'test_mask_voc2'])

In [330]:
for label, idx in zip(mask_label, mask_idx):
    print(label, idx)

▁미국의 7
인 9
test_mask_voc1 37
▁제한 44
▁협 45
상에 46
▁조 47
인 48
▁196 52
▁의원 60
▁선거 61
▁ 71
▁입 74
아 85
년 95
이 105
▁지 127
▁근무 129
내 138
▁내 149
. 152
▁1976 153
년 154
test_mask_voc2 159
로 161
워 172
지 186
을 188
대로 196
는 203
▁이스라엘 207
test_mask_voc2 215
다 222
트 223
▁대통령 224
동 235
▁1970 253
test_mask_voc2 272
, 274
해서 278
test_mask_voc2 285


- truncate long segmentsA, B

In [84]:
truncate_token(tokenA, tokenB, max_seq)

max token 253
total_len 253 = 127 + 126


In [104]:
assert 0 < len(tokenA)
assert 0 < len(tokenB)
tokens = ['[CLS]'] + tokenA + ['[SEP]'] + tokenB + ['[SEP]']
segment = [0]*(len(tokenA)  + 2) + [1]*(len(tokenB) + 1)

In [111]:
print(segment)

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


---
# make pre train data

In [18]:
import os

def make_pretrain_data(vocab, in_file, out_file, count, n_seq, mask_prob):
    """
    read text and return train data set format
    """
    vocab_list = []
    for id_ in range(vocab.get_piece_size()):
        if not vocab.is_unknown(id_):
            vocab_list.append(vocab.id_to_piece(id_))
    
    paragraph_ls = []
    with open(in_file, 'r') as in_f:
        paragraph = []
        for i, sent in enumerate(in_f):
            sent = sent.strip()
            
            ## blank means end of the paragraph
            if sent == '':
                # if not the beggining of the paragraph
                # it is the end of the paragraph
                if 0 < len(paragraph):
                    paragraph_ls.append(paragraph)
                    paragraph = [] # generate new paragraph list
                    # check if exceeding 100 thaousand paragraphs
                    if 1e+5 < len(paragraph_ls): 
                        break 
                        
            ## subwords in list is part of semantic token
            # eg. ['▁지','미','▁카','터']
            else:
                pieces = vocab.encode_as_pieces(sent)
                if 0 < len(pieces):
                    paragraph.append(pieces)
        if paragraph:
            paragraph_ls.append(paragraph)
    # masking def: create_pretrain_mask
    for index in range(count):
        output = out_file.format(index)
#         if os.path.isfile(output):
#             continue
        with open(output, 'w') as out_f:
            for i, paragraph in enumerate(paragraph_ls):
                masking_info = create_pretrain_instances(paragraph_ls, i, paragraph, n_seq, mask_prob, vocab_list)
                for elem in masking_info:
                    out_f.write(json.dumps(elem))
                    out_f.write('\n')    

- example code

In [270]:
# test data
PATH = '/home/henry/Documents/wrapper/source/
#in_PATH = PATH + 'kowiki.txt'
in_PATH = PATH + 'kowiki_sample.txt' # sample text
sentences = []
with open(in_PATH, 'r') as in_f:
    for i, sent in enumerate(in_f):
        # '': paragraph delimiter
        if i == 7:
            sentences.append('')
        if i == 13:
            break
        sentences.append(sent)

In [133]:
len(sentences), sentences[:4]

(10,
 ['지미 카터\n',
  '제임스 얼 "지미" 카터 주니어(, 1924년 10월 1일 ~ )는 민주당 출신 미국 39번째 대통령 (1977년 ~ 1981년)이다.\n',
  '지미 카터는 조지아주 섬터 카운티 플레인스 마을에서 태어났다. 조지아 공과대학교를 졸업하였다. 그 후 해군에 들어가 전함·원자력·잠수함의 승무원으로 일하였다. 1953년 미국 해군 대위로 예편하였고 이후 땅콩·면화 등을 가꿔 많은 돈을 벌었다. 그의 별명이 "땅콩 농부" (Peanut Farmer)로 알려졌다.\n',
  '1962년 조지아 주 상원 의원 선거에서 낙선하나 그 선거가 부정선거 였음을 입증하게 되어 당선되고, 1966년 조지아 주 지사 선거에 낙선하지만 1970년 조지아 주 지사를 역임했다. 대통령이 되기 전 조지아주 상원의원을 두번 연임했으며, 1971년부터 1975년까지 조지아 지사로 근무했다. 조지아 주지사로 지내면서, 미국에 사는 흑인 등용법을 내세웠다.\n'])

In [271]:
paragraph_ls = []
paragraph = []
for i, sent in enumerate(sentences):
    sent = sent.strip()
    ## blank means end of paragraph
    if sent == '':
        # if not the beggining of sentence
        # it is the end of the paragraph
        # generate new doc list
        if 0 < len(paragraph):
            paragraph_ls.append(paragraph)
            paragraph = []
            # check if 100thaousand paragraph
            if 1e+5 < len(paragraph_ls):
                break
                
    ## subwords in list is part of semantic token
    # eg. ['▁지','미','▁카','터']
    else:
        pieces = vocab.encode_as_pieces(sent)
        if 0 < len(pieces):
            paragraph.append(pieces)
if paragraph:
    paragraph_ls.append(paragraph)

In [272]:
print('num of paragraph:', len(paragraph_ls))

num of paragraph: 2


In [273]:
print('1st paragraph - 1st sentence')
print(paragraph_ls[0][0])

1st paragraph - 1st sentence
['▁지', '미', '▁카', '터']


In [274]:
print('1st paragraph - 2nd sentence')
print(paragraph_ls[1][0])

1st paragraph - 2nd sentence
['▁그러나', '▁이것은', '▁공', '화', '당', '과', '▁미국의', '▁유대', '인', '▁단', '체의', '▁반', '발', '을', '▁일으', '켰', '다', '.', '▁1979', '년', '▁백', '악', '관', '에서', '▁양', '국', '▁간의', '▁평화', '조', '약', '으로', '▁이끌', '어졌다', '.', '▁또한', '▁소련', '과', '▁제', '2', '차', '▁전략', '▁무', '기', '▁제한', '▁협', '상에', '▁조', '인', '했다', '.']


In [275]:
print('2nd paragraph - 1st sentence')
print(paragraph_ls[1][1])

2nd paragraph - 1st sentence
['▁카', '터', '는', '▁1970', '년대', '▁후반', '▁당시', '▁대한민국', '▁등', '▁인', '권', '▁후', '진', '국의', '▁국민', '들의', '▁인', '권을', '▁지', '키', '기', '▁위해', '▁노력', '했으며', ',', '▁취임', '▁이후', '▁계속', '해서', '▁도', '덕', '정', '치를', '▁내', '세', '웠다', '.']


In [349]:
in_PATH = '/home/henry/Documents/wrapper/source/kowiki_sample.txt' 
out_PATH = '/home/henry/Documents/wrapper/source/out_kowiki_sample' + '_{}.json'

count = 1
n_seq = 200
mask_prob = 0.15

make_pretrain_data(vocab, in_PATH, out_PATH, count, n_seq, mask_prob)

max token 253
total_len 87 = 50 + 37
[6, 15, 19, 26, 27, 30, 35, 43, 44, 47, 68, 69, 77] ['과', '▁일으', '▁1979', '국', '▁간의', '약', '▁또한', '기', '▁제한', '▁조', '▁인', '권을', '▁취임']


In [350]:
# result
pretrain_output = []
with open('/home/henry/Documents/wrapper/source/out_kowiki_sample_0.json', 'r') as f:
    for line in f:
        temp = json.loads(line)
        pretrain_output.append(temp)

In [351]:
len(pretrain_output), pretrain_output[0].keys()

(1, dict_keys(['tokens', 'segment', 'is_next', 'mask_idx', 'mask_label']))

In [352]:
for k, v in pretrain_output[0].items():
    print(k)
    print(v)
    print()

tokens
['[CLS]', '▁그러나', '▁이것은', '▁공', '화', '당', '과', '▁미국의', '▁유대', '인', '▁단', '체의', '▁반', '발', '을', '▁일으', '켰', '다', '.', '▁사랑', '년', '▁백', '악', '관', '에서', '▁양', '국', '▁간의', '▁평화', '조', '약', '으로', '▁이끌', '어졌다', '.', '덫', '▁소련', '과', '▁제', '2', '차', '▁전략', '▁무', 'ム', '▁완전히', '▁협', '상에', '攻', '인', '했다', '.', '[SEP]', '▁카', '터', '는', '▁1970', '년대', '▁후반', '▁당시', '▁대한민국', '▁등', '▁인', '권', '▁후', '진', '국의', '▁국민', '들의', '▁인', '권을', '▁지', '키', '기', '▁위해', '▁노력', '했으며', ',', '▁반면', '▁이후', '▁계속', '해서', '▁도', '덕', '정', '치를', '▁내', '세', '웠다', '.', '[SEP]']

segment
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]

is_next
True

mask_idx
[6, 15, 19, 26, 27, 30, 35, 43, 44, 47, 68, 69, 77]

mask_label
['과', '▁일으', '▁1979', '국', '▁간의', '약', '▁또한', '기', '▁제한', '▁조', '▁인', '권을', '▁

---
# Pretrain DataSet

In [None]:
class PretrainDataset(Dataset):
    """
    eg. instance
    {tokens:
        ['[CLS]', '▁지', ', '대학교', '를', '▁졸업', '하였다', '.', '▁그', '▁후', ...],
    segment:
        [0, 0, 0, 0, 0, 0, ..., 1, 1, 1],
    is_next: True,
    mask_idx: 
        [16, 21, ..., 41],
    mask_label:
        ['▁192', '▁1', '일', '▁~', '는', ..., '▁조지', '법을']}
    """
    def __init__(self, vocab, infile):
        self.vocab = vocab
        self.labels_cls = []
        self.label_lm_ls = []
        self.sentence_ls = []
        self.segments = []
        
        with open(infile, 'r') as f:
            for i, line in enumerate(f):
                instance = json.loads(line)
                self.labels_cls.append(instance['is_next'])
                sentence = [vocab.piece_to_id(p) for p in instance['tokens']]
                
                self.sentence_ls.append(sentence)
                self.segments.append(instance['segment'])
                
                mask_idx = np.array(instance['mask_idx'], dtype=np.int)
                mask_label = np.array([vocab.piece_to_id(p) for p in instance['mask_label']], dtype=np.int)
                label_lm = np.full(len(sentence), dtype=np.int, fill_value=-1)
                label_lm[mask_idx] = mask_label
                self.label_lm_ls.append(label_lm)
    
    def __len__(self):
        assert len(self.labels_cls) == len(self.label_lm_ls)
        assert len(self.labels_cls) == len(self.sentence_ls)
        assert len(self.labels_cls) == len(self.segments)
        return len(self.labels_cls)
    
    def __getitem__(self, idx):
        return (torch.tensor(self.labels_cls[idx]),
                torch.tensor(self.label_lm_ls[idx]),
                torch.tensor(self.sentence_ls[idx]),
                torch.tensor(self.segments[idx]),)

### - example code

In [370]:
vocab = vocab
labels_cls = []; labels_lm = []; sentences = []; segments = [];

with open(infile, 'r') as f:
    for i, line in enumerate(f):
        instance = json.loads(line)
        labels_cls.append(instance['is_next'])
        
        sentence = [vocab.piece_to_id(p) for p in instance['tokens']]
        sentences.append(sentence)
        
        segments.append(instance['segment'])
        mask_idx = np.array(instance['mask_idx'], dtype=np.int)
        mask_label = np.array([vocab.piece_to_id(p) for p in instance['mask_label']], dtype=np.int)
        
        label_lm = np.full(len(sentence), dtype=np.int, fill_value=-1)
        label_lm[mask_idx] = mask_label
        labels_lm.append(label_lm)

In [382]:
print('this is mask label: {}'.format([vocab.piece_to_id(p) for p in instance['mask_label']]),\
      '\nthis is idx of mask label in seq: {}'.format(instance['mask_idx']),\
     '\nthis is sentence with voc_idx: {}'.format([vocab.piece_to_id(p) for p in instance["tokens"]]))

this is mask label: [3635, 1226, 2962, 3634, 2712, 3818, 274, 3605, 1983, 53, 44, 904, 2612] 
this is idx of mask label in seq: [6, 15, 19, 26, 27, 30, 35, 43, 44, 47, 68, 69, 77] 
this is sentence with voc_idx: [5, 322, 1470, 41, 3676, 3718, 3635, 668, 2652, 3619, 167, 1321, 142, 3710, 3598, 1226, 4162, 3589, 3590, 1421, 3616, 456, 3918, 3698, 10, 224, 3634, 2712, 2791, 3667, 3818, 9, 1436, 2518, 3590, 6864, 1271, 3635, 30, 3610, 3741, 2918, 108, 6771, 2056, 623, 1790, 7079, 3619, 31, 3590, 4, 210, 3705, 3593, 1908, 592, 1808, 312, 408, 50, 44, 3821, 82, 3704, 134, 967, 247, 44, 904, 18, 3784, 3605, 233, 3366, 528, 3595, 2897, 165, 776, 869, 74, 4078, 3633, 1232, 115, 3682, 1844, 3590, 4]


In [383]:
infile = '/home/henry/Documents/wrapper/source/out_kowiki_sample_0.json'
dataset = PretrainDataset(vocab, infile)
sample_dataset = next(iter(dataset))
sample_dataset

(tensor(True),
 tensor([  -1,   -1,   -1,   -1,   -1,   -1, 3635,   -1,   -1,   -1,   -1,   -1,
           -1,   -1,   -1, 1226,   -1,   -1,   -1, 2962,   -1,   -1,   -1,   -1,
           -1,   -1, 3634, 2712,   -1,   -1, 3818,   -1,   -1,   -1,   -1,  274,
           -1,   -1,   -1,   -1,   -1,   -1,   -1, 3605, 1983,   -1,   -1,   53,
           -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
           -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   44,  904,   -1,   -1,
           -1,   -1,   -1,   -1,   -1, 2612,   -1,   -1,   -1,   -1,   -1,   -1,
           -1,   -1,   -1,   -1,   -1,   -1]),
 tensor([   5,  322, 1470,   41, 3676, 3718, 3635,  668, 2652, 3619,  167, 1321,
          142, 3710, 3598, 1226, 4162, 3589, 3590, 1421, 3616,  456, 3918, 3698,
           10,  224, 3634, 2712, 2791, 3667, 3818,    9, 1436, 2518, 3590, 6864,
         1271, 3635,   30, 3610, 3741, 2918,  108, 6771, 2056,  623, 1790, 7079,
         3619,   31, 3590,    4,  210, 3705, 35

---
# Pretrain

In [53]:
def pretrain_collate_fn(inputs):
    """
    padding batch
    """
    labels_cls, labels_lm, inputs, segments = list(zip(*inputs))
    labels_lm = torch.nn.utils.rnn.pad_sequence(labels_lm, batch_first=True, padding_value=-1)
    inputs = torch.nn.utils.rnn.pad_sequence(inputs, batch_first=True, padding_value=0)
    segments = torch.nn.utils.rnn.pad_sequence(segments, batch_first=True, padding_value=0)
    
    batch = [
        torch.stack(labels_cls, dim=0),
        labels_lm,
        inputs,
        segments,
    ]
    return batch

In [54]:
batch_size = 128
dataset = PretrainDataset(vocab, PATH+'kowiki_bert_0.json')
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size\
                                           , shuffle=True, collate_fn=pretrain_collate_fn)

In [49]:
def train_epoch(config, epoch, model, criterion_lm, criterion_cls, optimizer, train_loader):
    loss_ls = []
    model.train()
    print('model train')
    for i, value in enumerate(train_loader):
        labels_cls, labels_lm, inputs, segments = map(lambda x: x.to(config.device), value)
        
        optimizer.zero_grad()
        outputs = model(inputs, segments)
        logits_cls, logits_lm = outputs[0], outputs[1]
        
        loss_cls = criterion_cls(logits_cls, labels_cls)
        loss_lm = criterion_lm(logits_lm.view(-1, logits_lm.size(2)), labels_lm.view(-1))
        loss = loss_cls + loss_lm
        
        loss_val = loss_lm.item()
        loss_ls.append(loss_val)
        
        loss.backward()
        optimizer.step()
    
    return np.mean(loss_ls)

In [45]:
config.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(config)

learning_rate = 5e-5
n_epoch = 10

namespace(d_ff=1024, d_head=64, d_hidn=256, device=device(type='cpu'), dropout=0.1, i_pad=0, layer_norm_epsilon=1e-12, n_enc_seq=256, n_enc_vocab=8007, n_head=4, n_layer=6, n_seg_type=2)


In [None]:
model = BERTpretrain(config)
save_pretrain = PATH + 'bert_pretrain_weights.pkl'
best_epoch, best_loss = 0, 0
if os.path.isfile(save_pretrain):
    best_epoch, best_loss = model.bert.load(save_pretrain)
    best_epoch += 1

model.to(config.device)

criterion_lm = torch.nn.CrossEntropyLoss(ignore_index=-1, reduction='mean')
criterion_cls = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

loss_ls = []
offset = best_epoch
for step in range(n_epoch):
    epoch = step + offset
#     if 0 < step:
#         del train_loader
#         dataset = PretrainDataset(vocab, PATH + 'kowiki_bert_0.json')
#         train_loader = DataLoader(dataset, batch_size=batch_size, \
#                                   suffle=True, collate_fn=pretrain_collate_fn)
    loss = train_epoch(config, epoch, model, criterion_lm, criterion_cls,\
                      optimizer, train_loader)
    loss_ls.append(loss)
    model.bert.save(epoch, loss, save_pretrain)

model train
