- reference
    - https://github.com/HMJiangGatech/pytorch-pretrained-BERT/blob/master/examples/lm_finetuning/pregenerate_training_data.py
    - https://paul-hyun.github.io/bert-01/

In [2]:
# vocab loading
import sentencepiece as spm

VOCAB_PATH = "/home/henry/Documents/wrapper/source"
vocab_file = f"{VOCAB_PATH}/kowiki.model"
vocab = spm.SentencePieceProcessor()
vocab.load(vocab_file)

True

In [2]:
from types import SimpleNamespace
import torch.nn as nn
import torch.nn.functional as F
config = dict({
    "n_enc_vocab": len(vocab),
    "n_enc_seq": 256,
    "n_seg_type": 2,
    "n_layer": 6,
    "d_hidn": 256,
    "i_pad": 0,
    "d_ff": 1024,
    "n_head": 4,
    "d_head": 64,
    "dropout": 0.1,
    "layer_norm_epsilon": 1e-12
})
config = SimpleNamespace(**config)

In [11]:
print(config)

namespace(d_ff=1024, d_head=64, d_hidn=256, dropout=0.1, i_pad=0, layer_norm_epsilon=1e-12, n_enc_seq=256, n_enc_vocab=8007, n_head=4, n_layer=6, n_seg_type=2)


# BERT

In [12]:
class BERT(nn.Module):
    """
    outputs: [bs, len_seq, d_hidn] <- 잘 임베딩된 input seq
    ointput_cls = outputs[:, 0].contiguous(): [bs, d_hidn]
    - classification은 [cls] token의 임베딩만 사용

    """
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.encoder = Encoder(self.config)
        self.linear = nn.Linear(config.d_hidn, config.d_hidn)
        self.activation = torch.tanh
    
    def forward(slef, inputs, segment):
        outputs, self_attn_probs = self.encoder(inputs, segments)
        ointput_cls = outputs[:, 0].contiguous()
        outputs_cls = self.linear(outputs_cls)
        outputs_cls = self.activation(outputs_cls)
        return outputs, outputs_cls, self_attn_probs
    
    def save(self, epoch, loss, path):
        torch.save({
            'epoch': epoch,
            'loss': loss,
            'state_dict': self.state_dict()
                   }, path)
        
    def load(self, path):
        save = torch.load(path)
        self.load_state_dict(save['state_dict'])
        return save['epoch'], save['loss']

In [13]:
class BERTpretrain(nn.Module):
    """
    self.feedforward_lm.weight
    - transformer encoder의 pretrained embedding layer weight 사용
    logits_cls: [bs, 2]
    - binary classification
    logits_lm: [bs, len_enc_seq, n_enc_vocab]
    """
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.bert  = BERT(self.config)
        # cls
        self.feedforward_cls = nn.Linear(self.config.d_hidn, 2, bias=False)
        # lm
        self.feedforward_lm = nn.Linear(self.config.d_hidn, self.config.n_enc_vocab, bias=False)
        self.feedforward_lm.weight = self.bert.encoder.enc_emb.weight
    
    def forward(self, inputs, segments):
        outputs, outputs_cls, attn_probs = self.bert(inputs, segments)
        logits_cls = self.feedforward_cls(outputs)
        logits_lm = self.feedforward_lm(outputs)
        return logits_cls, logits_lm, attn_probs

# Masking
- https://github.com/codertimo/BERT-pytorch/blob/master/bert_pytorch/trainer/pretrain.py

In [325]:
def create_pretrain_mask(tokens, mask_cnt, vocab_list):
    """
    masking subwords(15% of entire subwords)
    - mask_cnt: len(subwords) * 0.15
    - [MASK]: 80% of masking candidate token
    - original token: 10% of masking candidate token
    - another token: 10% of masking candidate token
    """
    candidate_idx = []

    ## subwords in the same list augment a sementic word 
    ## eg. [[0], [1], [2], [4, 5]] -> token_idx 4 + 5 is semantic word
    # A list represent a sementic word
    for i, token in enumerate(tokens):
        if token == '[CLS]' or token == '[SEP]':
            continue
        if 0 < len(candidate_idx) and token.find(u'\u2581') < 0: #  LOWER ONE EIGHTH BLOCK
#        if 0 < len(candidate_idx) and token.find('_') < 0: #  test code
            candidate_idx[-1].append(i)
        else:
            candidate_idx.append([i])
    np.random.shuffle(candidate_idx)

    mask_lms = []
    for idx_set in candidate_idx:
        # check if len(mask_lms) exceeds threshold
        if len(mask_lms) >= mask_cnt:
            break
        if len(mask_lms) + len(idx_set) > mask_cnt:
            continue

        ## masking subwords with 15% probability
        ## mask_cnt is len(subwords) * 0.15 
        # iter subwords idx
        for sub_idx in idx_set:
            masked_token = None

            ### assign value to masked token: [MASK], original token, random token
            # 80% of masking candidate are replaced with '[MASK]' token
            if np.random.uniform() < 0.8:
                masked_token = '[MASK]'
            # remainng 20% of masking candidate
            else:
                # 10% of remaining preserve original token
                if np.random.uniform() < 0.5:
                    masked_token = tokens[sub_idx]
                # 10% of ones are replaced with rnadom token    
                else:
                    masked_token = np.random.choice(vocab_list)

                ### replace subword with masked_token value    
                mask_lms.append({'idx': sub_idx, 'label':tokens[sub_idx]})
                tokens[sub_idx] = masked_token
                
    mask_lms = sorted(mask_lms, key=lambda x: x['idx'])
    mask_idx = [mask_dict['idx'] for mask_dict in mask_lms]
    mask_label = [mask_dict['label'] for mask_dict in mask_lms]
#     print(candidate_idx)
#     print(mask_lms)
    return tokens, mask_idx, mask_label

In [None]:
def truncate_token(tokenA, tokenB, max_seq):
    """
    truncate long sequence
    """
    while True:
        total_len = len(tokenA) + len(tokenB)
        if total_len <= max_seq:
            break
        if len(tokenA) > len(tokenB):
            tokenA.pop()
        else:
            tokenB.pop()

- example code

In [324]:
vocab_list = ['_I am', '_on', '_it', '_so wh', 'at',]
tokens = ['_I am', '_on', '_it', '[CLS]', '_so wh', 'at',]
mask_cnt = int(len(tokens) * 0.15)+1
create_pretrain_mask(tokens, mask_cnt, vocab_list)

[[0], [2], [4, 5], [1]]
[{'idx': 1, 'label': '_on'}]


(['_I am', '_on', '_it', '[CLS]', '_so wh', 'at'], [1], ['_on'])

---
# load data and make pre train data

In [514]:
def make_pretrain_data(vocab, in_file, out_file, count, n_seq, mask_prob):
    """
    read text and return train data set format
    """
    vocab_list = []
    for id_ in range(vocab.get_piece_size()):
        if not vocab.is_unknown(id_):
            vocab_list.append(vocab.id_to_piece(id_))
    
    sent_cnt = 0
    with open(in_file, 'r') as in_f:
        for sent in in_f:
            sent_cnt += 1
    
    paragraph_ls = []
    with open(in_file, 'r') as in_f:
        paragraph = []
        for i, sent in enumerate(in_f):
            sent = sent.strip()
            
            ## blank means end of the paragraph
            if sent == '':
                # if not the beggining of the paragraph
                # it is the end of the paragraph
                if 0 < len(paragraph):
                    paragraph_ls.append(paragraph)
                    paragraph = [] # generate new paragraph list
                    # check if exceeding 100 thaousand paragraphs
                    if 1e+5 < len(paragraph_ls): 
                        break 
                        
            ## subwords in list is part of semantic token
            # eg. ['▁지','미','▁카','터']
            else:
                pieces = vocab.encode_as_pieces(sent)
                if 0 < len(pieces):
                    paragraph.append(pieces)
        if paragraph:
            paragraph_ls.append(paragraph)
    
    # masking def: create_pretrain_mask
    for index in range(count):
        output = out_file.format(index)
        if os.path.isfile(output):
            continue
        with open(output, 'w') as out_f:
            for i, paragraph in enumerate(paragraph_ls):
                masking_info = create_pretrain_instances(paragraph_ls, i, paragraph, n_seq, mask_prob, vocab_list)
                for elem in masking_info:
                    out_f.write(json.dumps(elem))
                    out_f.write('\n')    

- example code

In [21]:
# test data
in_PATH = '/home/henry/Documents/wrapper/source/kowiki.txt'
with open(in_PATH, 'r') as in_f:
        for i, sent in enumerate(in_f):
            # '': paragraph delimiter
            if i == 7:
                sentences.append('')
            if i == 13:
                break
            sentences.append(sent)

In [23]:
len(sentences), sentences[:4]

(24,
 ['지미 카터',
  '',
  '제임스 얼 "지미" 카터 주니어(, 1924년 10월 1일 ~ )는 민주당 출신 미국 39번째 대통령 (1977년 ~ 1981년)이다.',
  '지미 카터는 조지아주 섬터 카운티 플레인스 마을에서 태어났다. 조지아 공과대학교를 졸업하였다. 그 후 해군에 들어가 전함·원자력·잠수함의 승무원으로 일하였다. 1953년 미국 해군 대위로 예편하였고 이후 땅콩·면화 등을 가꿔 많은 돈을 벌었다. 그의 별명이 "땅콩 농부" (Peanut Farmer)로 알려졌다.'])

In [24]:
paragraph_ls = []
paragraph = []
for i, sent in enumerate(sentences):
    sent = sent.strip()
    ## blank means end of paragraph
    if sent == '':
        # if not the beggining of sentence
        # it is the end of the paragraph
        # generate new doc list
        if 0 < len(paragraph):
            paragraph_ls.append(paragraph)
            paragraph = []
            # check if 100thaousand paragraph
            if 1e+5 < len(paragraph_ls):
                break
                
    ## subwords in list is part of semantic token
    # eg. ['▁지','미','▁카','터']
    else:
        pieces = vocab.encode_as_pieces(sent)
        if 0 < len(pieces):
            paragraph.append(pieces)
if paragraph:
    paragraph_ls.append(paragraph)

In [25]:
print('num of paragraph:', len(paragraph_ls))

num of paragraph: 3


In [425]:
print('1st paragraph - 1st sentence')
print(paragraph_ls[0][0])

1st paragraph - 1st sentence
['▁지', '미', '▁카', '터']


In [426]:
print('1st paragraph - 2nd sentence')
print(paragraph_ls[1][0])

1st paragraph - 2nd sentence
['▁제임스', '▁얼', '▁"', '지', '미', '"', '▁카', '터', '▁주', '니어', '(,', '▁192', '4', '년', '▁10', '월', '▁1', '일', '▁~', '▁)', '는', '▁민주', '당', '▁출신', '▁미국', '▁3', '9', '번째', '▁대통령', '▁(19', '7', '7', '년', '▁~', '▁1981', '년', ')', '이다', '.']


In [427]:
print('2nd paragraph - 1st sentence')
print(paragraph_ls[1][1])

2nd paragraph - 1st sentence
['▁지', '미', '▁카', '터', '는', '▁조지', '아', '주', '▁섬', '터', '▁카', '운', '티', '▁플', '레', '인', '스', '▁마을', '에서', '▁태어났다', '.', '▁조지', '아', '▁공', '과', '대학교', '를', '▁졸업', '하였다', '.', '▁그', '▁후', '▁해', '군에', '▁들어가', '▁전', '함', '·', '원', '자', '력', '·', '잠', '수', '함', '의', '▁승', '무', '원으로', '▁일', '하였다', '.', '▁195', '3', '년', '▁미국', '▁해군', '▁대', '위로', '▁예', '편', '하였고', '▁이후', '▁땅', '콩', '·', '면', '화', '▁등을', '▁가', '꿔', '▁많은', '▁돈', '을', '▁벌', '었다', '.', '▁그의', '▁별', '명이', '▁"', '땅', '콩', '▁농', '부', '"', '▁(', 'P', 'e', 'an', 'ut', '▁F', 'ar', 'm', 'er', ')', '로', '▁알려', '졌다', '.']


In [541]:
in_PATH = '/home/henry/Documents/wrapper/source/kowiki.txt'
out_PATH = '/home/henry/Documents/wrapper/source/out_kowiki.txt'

count = 1
n_seq = 256
mask_prob = 0.15

make_pretrain_data(vocab, in_PATH, out_PATH, count, n_seq, mask_prob)

---
# pretrain dataset for each paragraph

In [None]:
def create_pretrain_instances(paragraph_ls, paragraph_idx, paragraph, n_seq, mask_prob, vocab_list):
    """
    create NSP train set
    """
    # 3 special token: [CLS], [SEP] for sent A, [SEP] for sent B
    max_seq_len = n_seq - 2 - 1
    target_seq_len = max_seq_len # [CLS], segA, segA, ..., [SEP], segB, ...

    instances = []
    current_chunk = []
    current_length = 0
    for i, sent in enumerate(paragraph):
        ## A. not the last sentence of the paragraph
        current_chunk.append(sent)
        current_length += len(sent)

        ## B. check if the last sentence of the paragraph
        ## or current_length is longer than or equal to target_seq_len
        if i == len(paragraph) - 1 or current_length >= target_seq_len:

            if current_chunk:
                ## A. sentence A segment: from 0 to a_end
                a_end = 1 # 1st sentence
                if len(current_chunk) != 1:
                    a_end = np.random.randint(1, len(current_chunk))
                # append sentence of current_chunk 
                # from the front to the back
                tokenA = []
                for _, s in enumerate(current_chunk[:a_end]):
                    tokenA.extend(s)

                ## B. sentence B segment
                tokenB = []
                # A. Actual next
                if len(current_chunk) > 1 and np.random.uniform() > 0.5:
                    is_next = True
                    for j in range(a_end, len(current_chunk)):
                        tokenB.extend(current_chunk[j])

                # B. random next
                else:
                    is_next = False
                    tokenB_len = target_seq_len - len(tokenA)
                    random_para_idx = para_idx
                    while para_idx == random_para_idx:
                        random_para_idx = np.random.randint(0, len(paragraph_ls))
                    random_para = paragraph[random_para_idx]

                    random_start = np.random.randint(0, len(random_para))
                    for j in range(random_start, len(random_doc)):
                        tokenB.extend(random_para[j])


                truncate_token(tokenA, tokenB, max_seq_len)
                assert 0 < len(tokenA)
                assert 0 < len(tokenB)

                token = ['[CLS]'] + tokenA + ['[SEP]'] + tokenB +['[SEP]']
                segment = [0]*(len(tokenA)+2) + [1]*(len(tokenB)+1)

                create_pretrain_mask(tokens, int((len(tokens)-3)*mask_prob), vocab_list)
                instance = {
                    "tokens": tokens,
                    "segment": segment,
                    "is_next": is_next,
                    "mask_idx": mask_idx,
                    "mask_label": mask_label
                }
                instances.append(instance)
            current_chunk = []
            current_length = 0


- example code

In [62]:
para_idx = 1
paragraph = paragraph_ls[para_idx]
print(len(paragraph), paragraph[1])

15 ['▁지', '미', '▁카', '터', '는', '▁조지', '아', '주', '▁섬', '터', '▁카', '운', '티', '▁플', '레', '인', '스', '▁마을', '에서', '▁태어났다', '.', '▁조지', '아', '▁공', '과', '대학교', '를', '▁졸업', '하였다', '.', '▁그', '▁후', '▁해', '군에', '▁들어가', '▁전', '함', '·', '원', '자', '력', '·', '잠', '수', '함', '의', '▁승', '무', '원으로', '▁일', '하였다', '.', '▁195', '3', '년', '▁미국', '▁해군', '▁대', '위로', '▁예', '편', '하였고', '▁이후', '▁땅', '콩', '·', '면', '화', '▁등을', '▁가', '꿔', '▁많은', '▁돈', '을', '▁벌', '었다', '.', '▁그의', '▁별', '명이', '▁"', '땅', '콩', '▁농', '부', '"', '▁(', 'P', 'e', 'an', 'ut', '▁F', 'ar', 'm', 'er', ')', '로', '▁알려', '졌다', '.']


In [58]:
current_chunk = []
current_length = 0
max_num_tokens = 256
target_seq_len = np.random.randint(2, max_num_tokens)
print(target_seq_len)

for i, sent in enumerate(paragraph):
    ## A. not the last sentence of the paragraph
    current_chunk.append(sent)
    current_length += len(sent)

    ## B. check if the last sentence of the paragraph
    ## or current_length is longer than or equal to target_seq_len
    if i == len(paragraph) - 1 or current_length >= target_seq_len:

        if current_chunk:
            ## A. sentence A segment: from 0 to a_end
            a_end = 1 # 1st sentence
            if len(current_chunk) != 1:
                a_end = np.random.randint(1, len(current_chunk))
            # append sentence of current_chunk 
            # from the front to the back
            tokenA = []
            for _, s in enumerate(current_chunk[:a_end]):
                tokenA.extend(s)

            ## B. sentence B segment
            tokenB = []
            # A. Actual next
            if len(current_chunk) > 1 and np.random.uniform() > 0.5:
                is_next = True
                for j in range(a_end, len(current_chunk)):
                    tokenB.extend(current_chunk[j])
            # B. random next
            else:
                is_next = False
                tokenB_len = target_seq_len - len(tokenA)
                random_para_idx = para_idx
                while para_idx == random_para_idx:
                    random_para_idx = np.random.randint(0, len(paragraph_ls))
                random_para = paragraph[random_para_idx]

                random_start = np.random.randint(0, len(random_para))
                for j in range(random_start, len(random_para)):
                    tokenB.extend(random_para[j])
            assert 0 < len(tokenA)
            assert 0 < len(tokenB)
            tokens = ["[CLS]"] + tokenA + ["[SEP]"] + tokenB + ["[SEP]"]
            print(tokens, '\n')


152
['[CLS]', '▁제임스', '▁얼', '▁"', '지', '미', '"', '▁카', '터', '▁주', '니어', '(,', '▁192', '4', '년', '▁10', '월', '▁1', '일', '▁~', '▁)', '는', '▁민주', '당', '▁출신', '▁미국', '▁3', '9', '번째', '▁대통령', '▁(19', '7', '7', '년', '▁~', '▁1981', '년', ')', '이다', '.', '[SEP]', '▁지', '미', '▁카', '터', '는', '▁조지', '아', '주', '▁섬', '터', '▁카', '운', '티', '▁플', '레', '인', '스', '▁마을', '에서', '▁태어났다', '.', '▁조지', '아', '▁공', '과', '대학교', '를', '▁졸업', '하였다', '.', '▁그', '▁후', '▁해', '군에', '▁들어가', '▁전', '함', '·', '원', '자', '력', '·', '잠', '수', '함', '의', '▁승', '무', '원으로', '▁일', '하였다', '.', '▁195', '3', '년', '▁미국', '▁해군', '▁대', '위로', '▁예', '편', '하였고', '▁이후', '▁땅', '콩', '·', '면', '화', '▁등을', '▁가', '꿔', '▁많은', '▁돈', '을', '▁벌', '었다', '.', '▁그의', '▁별', '명이', '▁"', '땅', '콩', '▁농', '부', '"', '▁(', 'P', 'e', 'an', 'ut', '▁F', 'ar', 'm', 'er', ')', '로', '▁알려', '졌다', '.', '▁지', '미', '▁카', '터', '▁제임스', '▁얼', '▁"', '지', '미', '"', '▁카', '터', '▁주', '니어', '(,', '▁192', '4', '년', '▁10', '월', '▁1', '일', '▁~', '▁)', '는', '▁민주', '당', '▁출신', '▁미국', 

In [68]:
print('original paragraph\n')
for sent in paragraph:
    print(sent)

original paragraph

['▁제임스', '▁얼', '▁"', '지', '미', '"', '▁카', '터', '▁주', '니어', '(,', '▁192', '4', '년', '▁10', '월', '▁1', '일', '▁~', '▁)', '는', '▁민주', '당', '▁출신', '▁미국', '▁3', '9', '번째', '▁대통령', '▁(19', '7', '7', '년', '▁~', '▁1981', '년', ')', '이다', '.']
['▁지', '미', '▁카', '터', '는', '▁조지', '아', '주', '▁섬', '터', '▁카', '운', '티', '▁플', '레', '인', '스', '▁마을', '에서', '▁태어났다', '.', '▁조지', '아', '▁공', '과', '대학교', '를', '▁졸업', '하였다', '.', '▁그', '▁후', '▁해', '군에', '▁들어가', '▁전', '함', '·', '원', '자', '력', '·', '잠', '수', '함', '의', '▁승', '무', '원으로', '▁일', '하였다', '.', '▁195', '3', '년', '▁미국', '▁해군', '▁대', '위로', '▁예', '편', '하였고', '▁이후', '▁땅', '콩', '·', '면', '화', '▁등을', '▁가', '꿔', '▁많은', '▁돈', '을', '▁벌', '었다', '.', '▁그의', '▁별', '명이', '▁"', '땅', '콩', '▁농', '부', '"', '▁(', 'P', 'e', 'an', 'ut', '▁F', 'ar', 'm', 'er', ')', '로', '▁알려', '졌다', '.']
['▁지', '미', '▁카', '터']
['▁제임스', '▁얼', '▁"', '지', '미', '"', '▁카', '터', '▁주', '니어', '(,', '▁192', '4', '년', '▁10', '월', '▁1', '일', '▁~', '▁)', '는', '▁민주', '당', '▁출신', '▁미국',