In [2]:
import os
import warnings
warnings.filterwarnings("ignore")
import torch
from torch import Tensor
import torch.nn as nn
import pandas as pd
import numpy as np
import random

In [3]:
if torch.cuda.is_available():
    print("cuda is available")
    print(f"{torch.cuda.device_count()} GPU detected")
else:
    print("cuda not available")

print(f"torch version as {torch.__version__}")

cuda is available
1 GPU detected
torch version as 1.13.1


In [2]:
data_dir = os.path.join('../../data', 'bert/wikitext-2', 'wiki.train.tokens')

In [12]:
def _read_wiki(data_dir):
    with open(data_dir, 'r') as f:
        lines = f.readlines() # list of strings. each string is several sentences(joined by ' . '), that is a paragraph
    paragraphs = [line.strip().lower().split(' . ') for line in lines if len(line.split(' . ')) > 2] # list of sentence lists
    random.shuffle(paragraphs)
    return paragraphs

In [18]:
paragraphs = _read_wiki(data_dir)
len(paragraphs), len(paragraphs[0])

(14222, 2)

In [24]:
def _get_next_sentence(sentence, next_sentence, paragraphs):
    # paragraphs: num_paragraphs 个 paragraph, 每个paragraph是list of sentences(or list of tokens_lst)
    if random.random() < 0.5:
        is_next = True
    else:
        next_sentence = random.choice(random.choice(paragraphs)) # 随机选一个paragraph, 再从该段随机选一句
        is_next = False
    return sentence, next_sentence, is_next # 输出1个sentence, next sentence, 是否连接flag

In [25]:
def get_tokens_segments(tokens_a, tokens_b=None):
    # input token_a: list of tokens, [tk1, tk2,...,tkN]
    # input token_b if not none: list of tokens, [tk1, tk2,...,tkN]
    tokens = ['<cls>'] + tokens_a + ['<sep>']
    segments = [0] * len(tokens)
    if tokens_b is not None:
        tokens += tokens_b + ['<sep>']
        segments += [1]*(len(tokens_b) + 1)
    # 拼接两个tokens list和<cls><sep>, 并输出它们的segments
    return tokens, segments

In [26]:
def _get_nsp_data_from_paragraph(paragraph, paragraphs, vocab, max_len):
    # paragraph is a list of lists: each list is a list of tokens(a sentence)
    # paragraphs is a list of paragraph
    nsp_data_from_paragraph = []
    for i in range(len(paragraph)-1): # 当前paragraph
        tokens_a, tokens_b, is_next = _get_next_sentence(paragraph[i], paragraph[i+1], paragraphs)
        if len(tokens_a) + len(tokens_b) + 3 > max_len: # 如果当前前后两句token数量+<cls>+2个<sep>数量超过max_len, 放弃当前两句
            continue
        tokens, segments = get_tokens_segments(tokens_a, tokens_b) # 拼接后的token list, 和对应的segments
        nsp_data_from_paragraph.append( [tokens, segments, is_next] )
    return nsp_data_from_paragraph # 输出一个list, 元素是nsp单样本: [拼接后token list, segment list, 是否相邻flag]


In [27]:
# 对一个token list, 输入可mask的token positions(<cls>和<sep>不可被替换), 以及该token list要mask的token数量
# 输出替换后的token list, 替换的<mask>在list中的positions, 以及真实label token list
def _replace_mlm_tokens(tokens, candidate_mask_positions, num_mlm_masks, vocab):
    mlm_input_tokens = [token for token in tokens] # 将输入token list拷贝下来
    mask_positions_labels = [] # 记录(mask_position, mask_label) pair
    random.shuffle(candidate_mask_positions) # 打乱
    for mask_position in candidate_mask_positions:
        if len(mask_positions_labels) >= num_mlm_masks: # 当已经作了足够次数mask操作, 退出循环
            break
        mask_token = None
        if random.random() < 0.8: # 80%的概率, 用<mask>去mask
            mask_token = '<mask>'
        else:
            if random.random() < 0.5: # 10%的概率, 用随机token去mask
                mask_token = random.choice(vocab.idx_to_tokens)
            else: # 10%的概率, 用自身去mask
                mask_token = tokens[mask_position]
        mlm_input_tokens[mask_position] = mask_token # mask操作
        mask_positions_labels.append( (mask_position, tokens[mask_position]) ) # 记录被mask的token位置, 以及真实token
    return mlm_input_tokens, mask_positions_labels # 输出token list, list of (position_idx, true_label)

In [28]:
# 对一个token list, 计算可mask的token positions, 计算该token list要mask的token数量
# 然后作mask操作
def _get_mlm_data_from_tokens(tokens, vocab):
    candidate_mask_positions = []
    for i, token in enumerate(tokens):
        if token not in ('<cls>', '<sep>'):
            candidate_mask_positions.append(i)
    num_mlm_masks = max(1, round(len(tokens)*0.15))
    mlm_input_tokens, mask_positions_labels = _replace_mlm_tokens(tokens, candidate_mask_positions, num_mlm_masks, vocab)
    mask_positions_labels = sorted(mask_positions_labels, key=lambda x: x[0]) # 将mask_positions_labels按照positions从小到大排列
    mask_positions = [v[0] for v in mask_positions_labels]
    mask_labels = [v[1] for v in mask_positions_labels]
    return vocab[mlm_input_tokens], mask_positions, vocab[mask_labels] # 输出masked token_idx list, mask position list, mask true token_idx list

In [None]:
# 两个任务的输出组合
# [two_sentence_token_idx_list, mask_position_list, mask_label_token_idx_list], (two_sentence_segment_list, is_next_flag)
# 为了batch化处理, 将 two_sentence_token_idx_list/mask_position_list/mask_label_token_idx_list/two_sentence_segment_list 作pad到统一长度
# 根据mask_position_list/mask_label_token_idx_list, 同步生成一个mlm_weight_list, 对于pad元素权重设0

In [None]:
def _pad_bert_inputs(examples, max_len, vocab):
    max_num_mlm_masks = round(max_len*0.15) # token list的最大长度乘以0.15
    all_tokens_idx, all_segments, valid_lens = [], [], []
    all_mask_positions, all_mlm_weights, all_mlm_labels_idx = [], [], []
    nsp_labels = []
    for tokens_idx, mask_positions, mlm_labels_idx, segments, if_next in examples:
        # pad tokens_idx
        pad_tokens_idx = tokens_idx + ['<pad>']*(max_len - len(tokens_idx))
        all_tokens_idx.append( torch.tensor(vocab[pad_tokens_idx], dtype=torch.int64) )
        # pad segments
        pad_segments = segments + [0]*(max_len - len(segments))
        all_segments.append( torch.tensor(pad_segments, dtype=torch.int64) )
        # valid_lens
        valid_lens.append( torch.tensor(len(tokens_idx), dtype=torch.float32) )
        # pad mask_positions
        pad_mask_positions = mask_positions + [0]*(max_num_mlm_masks - len(mask_positions))
        all_mask_positions.append( torch.tensor(pad_mask_positions, dtype=torch.int64) )
        # pad mlm_labels_idx
        pad_mlm_labels_idx = mlm_labels_idx + [0]*(max_num_mlm_masks - len(mlm_labels_idx))
        all_mlm_labels_idx.append( torch.tensor(pad_mlm_labels_idx, dtype=torch.int64) )
        # weights for mlm_labels_idx: 0 for pad
        mlm_labels_weight = [1]*len(mlm_labels_idx) + [0]*(max_num_mlm_masks - len(mlm_labels_idx))
        all_mlm_weights.append( torch.tensor(mlm_labels_weight, dtype=torch.float32) )
        nsp_labels.append( torch.tensor(if_next, dtype=torch.int64) )
    return all_tokens_idx, all_segments, valid_lens, all_mask_positions, all_mlm_weights, all_mlm_labels_idx, nsp_labels

In [2]:
x = [[1,2,3] ,[4,5,6]]
tuple(tensor[0] for tensor in x)

<generator object <genexpr> at 0x7fa0e5faa810>

In [8]:
x = ['a', 'b', 'c', 'a']
x.remove('a')

In [9]:
x

['b', 'c', 'a']