In [1]:
import os
import random
import data_tokenize
import torch

In [2]:
#@save
def _read_xhj():
    convs = data_tokenize.get_convs()
    paragraphs = data_tokenize.tokenize(convs, token='char')
    random.shuffle(paragraphs)
    return paragraphs
xhj_data = _read_xhj()
xhj_data[:5]

[[['那', '我', '咬', '你', '一', '口'], ['我', '真', '的', '好', '饿']],
 [['你', '这', '都', '什', '么'], ['还', '是', '我', '不', '够', '温', '柔']],
 [['操', '他'],
  ['菊', '花', '松', '肯', '定', '被', '不', '少', '人', '爆', '过', '没', '快', '感']],
 [['你', '就', '知', '道', '吃'],
  ['嘿',
   '嘿',
   '嘿',
   '有',
   '个',
   '吃',
   '货',
   '主',
   '人',
   '当',
   '然',
   '会',
   '有',
   '能',
   '吃',
   '的',
   '我',
   '喽',
   '~',
   '~',
   '~',
   '~',
   '~']],
 [['你', '变', '成', '烤', '鸡'], ['把', '我', '变', '蜡', '笔', '小', '新']]]

In [3]:
def get_tokens_and_segments(tokens_a, tokens_b=None):  #@save
    tokens = ['<cls>'] + tokens_a + ['<sep>']
    segments = [0] * (len(tokens_a) + 2)
    if tokens_b is not None:
        tokens += tokens_b + ['<sep>']
        segments += [1] * (len(tokens_b) + 1)
    return tokens, segments

In [4]:
def _get_next_sentence(sentence, next_sentence, paragraphs): #@save
    if random.random() < 0.5:
        is_next = True
    else:
        next_sentence = random.choice(random.choice(paragraphs))
        is_next = False
    return sentence, next_sentence, is_next

In [5]:
def _get_train_data_from_paragraph(paragraph, paragraphs, vocab, max_len): #@save
    train_data_from_paragraph = []
    for i in range(len(paragraph) - 1):
        tokens_a, tokens_b, is_next = _get_next_sentence(paragraph[i][0], paragraph[i][1], paragraphs)
        if len(tokens_a) + len(tokens_b) + 3 > max_len:
            continue
        tokens, segments = get_tokens_and_segments(tokens_a, tokens_b)
        train_data_from_paragraph.append((tokens, segments, is_next))
    return train_data_from_paragraph

In [6]:
def _replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds, vocab): #@save
    mlm_input_tokens = [token for token in tokens]
    pred_positions_and_labels = []
    random.shuffle(candidate_pred_positions)
    for mlm_pred_position in candidate_pred_positions:
        if len(pred_positions_and_labels) >= num_mlm_preds:
            break
        masked_token = None
        if random.random() < 0.8:
            masked_token = '<mask>'
        else:
            if random.random() < 0.5:
                masked_token = tokens[mlm_pred_position]
            else:
                masked_token = random.randint(0, len(vocab) - 1)
        mlm_input_tokens[mlm_pred_position] = masked_token
        pred_positions_and_labels.append((mlm_pred_position, tokens[mlm_pred_position]))
    return mlm_input_tokens, pred_positions_and_labels

In [7]:
def _get_mlm_data_from_tokens(tokens, vocab): #@save
    candidate_pred_positions = []
    for i, token in enumerate(tokens):
        if token in ['<cls>', '<sep>']:
            continue
        candidate_pred_positions.append(i)
    num_mlm_preds = max(1, round(len(tokens) * 0.15))
    mlm_input_tokens, pred_positions_and_labels = _replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds, vocab)
    pred_positions_and_labels = sorted(pred_positions_and_labels, key=lambda x: x[0])
    pred_positions  = [v[0] for v in pred_positions_and_labels]
    mlm_pred_labels  = [v[1] for v in pred_positions_and_labels]
    return vocab[mlm_input_tokens], pred_positions, vocab[mlm_pred_labels]

In [8]:
def _pad_bert_inputs(examples, max_len, vocab):
    max_num_mlm_preds = round(max_len * 0.15)
    all_tokens_ids, all_segments, valid_lens = [], [], []
    all_pred_positions, all_mlm_weights, all_mlm_labels = [], [], []
    xhj_labels = []
    for (token_ids, pred_positions, mlm_pred_label_ids, segments, is_next) in examples:
        all_tokens_ids.append(torch.tensor(token_ids + [vocab['<pad>']] * (max_len - len(token_ids)), dtype=torch.long))
        all_segments.append(torch.tensor(segments + [0] * (max_len - len(segments)), dtype=torch.long))
        valid_lens.append(torch.tensor(len(token_ids), dtype=torch.float32))
        all_pred_positions.append(torch.tensor(pred_positions + [0] * (max_num_mlm_preds - len(pred_positions)), dtype=torch.long))
        all_mlm_weights.append(torch.tensor([1.0] * len(mlm_pred_label_ids) + [0.0] * (max_num_mlm_preds - len(pred_positions)), dtype=torch.float32))
        all_mlm_labels.append(torch.tensor(mlm_pred_label_ids + [0] * (max_num_mlm_preds - len(mlm_pred_label_ids)), dtype=torch.long))
        xhj_labels.append(torch.tensor(is_next, dtype=torch.long))
    return (all_tokens_ids, all_segments, valid_lens, all_pred_positions, all_mlm_weights, all_mlm_labels, xhj_labels)

In [9]:
#@save
class _xhjDataset(torch.utils.data.Dataset):
    def __init__(self, paragraphs, max_len):
        self.vocab = data_tokenize.Vocab(paragraphs, min_freq=5, reserved_tokens=['<pad>', '<mask>', '<cls>', '<sep>'])
        examples = _get_train_data_from_paragraph(paragraphs, paragraphs, self.vocab, max_len)
        # for paragraph in paragraphs:
        #     examples.extend(
        #         _get_train_data_from_paragraph(paragraph, paragraphs, self.vocab, max_len)
        #     )

        examples = [
            (_get_mlm_data_from_tokens(tokens, self.vocab) + (segments, is_next)) for tokens, segments, is_next in examples
        ]
        (self.all_token_ids, self.all_segments, self.valid_lens, self.all_pred_positions, self.all_mlm_weights, self.all_mlm_labels, self.xhj_labels) = _pad_bert_inputs(examples, max_len, self.vocab)

    def __getitem__(self, idx):
        return (self.all_token_ids[idx], self.all_segments[idx], self.valid_lens[idx], self.all_pred_positions[idx], self.all_mlm_weights[idx], self.all_mlm_labels[idx], self.xhj_labels[idx])
    
    def __len__(self):
        return len(self.all_token_ids)

In [1]:
#@save
def load_data_xhj(batch_size, max_len):
    paragraphs = _read_xhj()
    train_set = _xhjDataset(paragraphs, max_len)
    train_iter = torch.utils.data.DataLoader(train_set, batch_size, shuffle=True, num_workers=0)
    return train_iter, train_set.vocab

In [11]:
batch_size, max_len = 512, 64
train_iter, vocab = load_data_xhj(batch_size, max_len)

In [12]:
for (tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weight_X, mlm_Y, xhj_y) in train_iter:
    print(tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weight_X, mlm_Y, xhj_y)
    break