In [139]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import d2l.torch as d2l
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import random
import os
import math

In [140]:
def transpose_qkv(X, num_heads):
    batch_size, length = X.shape[0], X.shape[1]
    X = X.reshape(batch_size, length, num_heads, -1)

    X = torch.permute(X, (0, 2, 1, 3))
    X = X.reshape((-1, X.shape[2], X.shpe[3]))
    return X

def transpose_out(X, num_heads):
    X = X.reshape((-1, num_heads, X.shape[1], X.shape[2]))
    X = torch.permute(X, (0, 2, 1, 3))
    X = X.reshape(X.shape[0], X.shape[1], -1)
    return X

In [141]:
def sequence_mask(X, valid_len, value=0):
    maxlen = X.size(1)
    mask = torch.arange((maxlen), dtype=torch.float32,
                        device=X.device)[None, :] < valid_len[:, None]
    X[~mask] = value
    return X

In [142]:
def masked_softmax(X, valid_lens):
    if valid_lens is None:
        return F.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return F.softmax(X.reshape(shape), dim=-1)


class MultiheadAttention(nn.Module):
    def __init__(self, num_hiddens, qkv_dim, num_heads, dropout, 
                 bias=False, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.num_heads = num_heads
        self.dropout = nn.Dropout(dropout)
        self.W_q = nn.Linear(num_hiddens, qkv_dim, bias=bias)
        self.W_k = nn.Linear(num_hiddens, qkv_dim, bias=bias)
        self.W_v = nn.Linear(num_hiddens, qkv_dim, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        q = transpose_qkv(self.W_q(queries), self.num_heads)
        k = transpose_qkv(self.W_k(keys), self.num_heads)
        v = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(
                valid_lens, repeats=self.num_heads, dim=0
            )

        d = q.shape[-1]
        scores = torch.bmm(q, torch.transpose(k, 1, 2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        out = torch.bmm(self.dropout(self.attention_weights), v)

        out = transpose_out(out, self.num_heads)
        return self.W_o(out)

In [143]:
x = torch.randn(2, 4)
torch.transpose(x, 1, 0).shape

torch.Size([4, 2])

In [144]:
class Addnorm(nn.Module):
    def __init__(self, normalized_shape, dropout, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.layernorm = nn.LayerNorm(normalized_shape)
        self.dropout = nn.Dropout(dropout)

    def forward(self, X, Y):
        return self.layernorm(X + self.dropout(Y))

In [145]:
class PositionwiseFFN(nn.Module):
    def __init__(self, num_inputs, num_hiddens, num_outputs, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.lin1 = nn.Linear(num_inputs, num_hiddens)
        self.relu = nn.ReLU()
        self.lin2 = nn.Linear(num_hiddens, num_outputs)
    
    def forward(self, X):
        return self.lin2(self.relu(self.lin1(X)))

class Block(nn.Module):
    '''Multihead Attention + Addnorm + FFN'''
    def __init__(self, num_hiddens, qkv_dim, num_heads, norm_shape, ffn_num_inputs,
                 ffn_num_hiddens, dropout, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.attention = MultiheadAttention(embedding_dim=num_hiddens, qkv_dim=qkv_dim, num_heads=num_heads)
        self.addnorm1 = Addnorm(norm_shape, dropout)
        self.ffn = PositionwiseFFN(ffn_num_inputs, ffn_num_hiddens, num_hiddens)
        self.addnorm2 = Addnorm(norm_shape, dropout)

    def forward(self, X, valid_lens):
        '''embedding: batch_size, length, embedding_dim'''
        X = self.addnorm1(X, self.attention(X,X,X, valid_lens))
        X = self.addnorm2(self.ffn(X), X)
        return X


In [146]:
class BertEncoder(nn.Module):
    '''block'''
    def __init__(self, vocab_size, num_hiddens, num_blocks, max_len=1000,
                 key_size=768, query_size=768, value_size=768,
                  *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.token_embedding = nn.Embedding(vocab_size, num_hiddens)
        self.segment_embedding = nn.Embedding(2, num_hiddens)

        self.blks = nn.Sequential()
        for i in range(num_blocks):
            self.blks.add_module(f'block {i}', Block(
                key_size, query_size, value_size, num_hiddens))
            
        self.pos_embedding = nn.Parameter(torch.randn(1, max_len, num_hiddens))
    
    def forward(self, tokens, segments, valid_lens):
        '''embedding: (batch_size, length_of_sentence, embedding_dim)'''
        X = self.token_embedding(tokens) + self.segment_embedding(segments)
        X = X + self.pos_embedding.data[:, :X.shape[1], :]
        for blk in self.blks:
            X = blk(X, valid_lens)
        return X

In [147]:
class MaskLM(nn.Module):
    def __init__(self, vocab_size, num_hiddens, num_inputs=768, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.mlp = nn.Sequential(
            nn.Linear(num_inputs, num_hiddens), nn.ReLU(),
            nn.LayerNorm(num_hiddens), nn.Linear(num_hiddens, vocab_size)
        )
    
    def forward(self, X, predict_pos):
        num_pred_positions = predict_pos.shape[1]
        batch_idx = torch.arange(X.shape[0])
        batch_idx = torch.repeat_interleave(batch_idx, num_pred_positions)
    
        pred = X[batch_idx, predict_pos.reshape(-1)]
        return self.mlp(pred)

class NSP(nn.Module):
    def __init__(self, hid_in_features, num_hiddens, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.mlp = nn.Sequential(nn.Linear(hid_in_features, num_hiddens),
                                    nn.Tanh(), nn.Linear(num_hiddens, 2))

    def forward(self, X):
        # 直接预测单词(不用softmax？)
        return self.mlp(X)


class  BertModel(nn.Module):
    '''embedding + BertEncoder + MaskLM + NSP'''
    def __init__(self, embedding_dim, max_len, vocab, num_blocks, num_heads):
        super(BertModel, self).__init__()
        self.encoder = BertEncoder(num_blocks=num_blocks, 
                                   embedding_dim=embedding_dim, 
                                   qkv_dim=embedding_dim, 
                                   num_heads=num_heads)
        self.masklm = MaskLM(embedding_dim, vocab)
        self.nsp = NSP(embedding_dim)

    def forward(self, tokens, segments, valid_lens=None, predict_pos=None):
        encoded_X = self.encoder(tokens, segments, valid_lens)
        if predict_pos is not None:
            masked_Y_hat = self.masklm(encoded_X, predict_pos)
        else:
            masked_Y_hat = None

        nsp_Y_hat = self.nsp(encoded_X[:, 0, :])

        return encoded_X, masked_Y_hat, nsp_Y_hat

In [148]:
#---------------------- test ------------------#
# MaskLM
# x = torch.arange(32).reshape(2, 4, 4).float()
# pos = torch.arange(4).reshape(2, 2)
# vocab = [i for i in range(10)]
# masklm = MaskLM(4, vocab)
# masklm(x, pos).shape

# BertModel
values = torch.randint(0, 3, (2, 4))
segments = torch.cat((torch.zeros(2, 2), torch.ones(2, 2)), dim=1).int()
predict_pos = torch.tensor([
    [1, 2],
    [2, 3]
])
vocab = [i for i in range(10)]
bert = BertModel(embedding_dim=4, max_len=4, vocab=vocab, num_blocks=2, num_heads=2)
masked_tokens, nsp = bert(values, segments, predict_pos)
masked_tokens.shape, nsp.shape


TypeError: __init__() missing 2 required positional arguments: 'vocab_size' and 'num_hiddens'

# 数据加载

In [149]:
def count_corpus(tokens):
    '''Count token frequencies'''
    if len(tokens) == 0 or isinstance(tokens[0], list):
        tokens = [token for line in tokens 
                  for token in line]
    return Counter(tokens)

In [150]:
class Vocab():
    '''Vocabulary for text'''
    def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):
        if tokens is None:
            tokens = []
        if reserved_tokens is None:
            reserved_tokens = []
        
        counter = count_corpus(tokens)
        self._token_freqs = sorted(counter.items(), key=lambda x: x[1],
                                   reverse=True)

        self.idx_to_token = ['<unk>'] + reserved_tokens
        self.token_to_idx = {token : i for 
                             i, token in enumerate(self.idx_to_token)}
        
        for token, freq in self._token_freqs:
            if freq < min_freq:
                break
            if token not in self.token_to_idx:
                self.idx_to_token.append(token)
                self.token_to_idx[token] = len(self.idx_to_token) - 1
        
    def __getitem__(self, tokens):
        if not isinstance(tokens, (list, tuple)):
            return self.token_to_idx.get(tokens, self.unk)
        return [self.__getitem__(token) for token in tokens]
    
    def __len__(self):
        return len(self.id_to_token)
    
    def to_tokens(self, indices):
        if not isinstance(indices, (list, tuple)):
            return self.idx_to_token[indices]
        return [self.to_tokens(index) for index in indices]
    
    @property
    def unk(self):
        return 0
        


In [151]:
def get_tokens_and_segments(sentence1, sentence2=None):
    sentence = ['<cls>'] + sentence1 + ['<seg>']
    segments = [0] * len(sentence1)

    if sentence2 is not None:
        sentence += sentence2 + ['<seg>']
        segments += [1] * (len(sentence2) + 1)
    return sentence, segments

def _get_nsp_data(paragraphs, max_len):
    examples = []
    is_next = True
    for paragraph in paragraphs:
        for i in range(len(paragraph) - 1):
            sentence1 = paragraph[i]
            if random.random() < 0.5:
                sentence2 = paragraph[i + 1]
                is_next = True
            else:
                sentence2 = random.choice(random.choice(paragraphs))
                is_next = False
            
            if len(sentence1) + len(sentence2) + 3 > max_len:
                continue

            sentence, segments = get_tokens_and_segments(sentence1, sentence2)
            examples.append((sentence, segments, is_next))

    return examples

In [152]:
def _get_mlm_data_from_tokens(tokens, vocab):
    candidate_pos = []
    for i, token in enumerate(tokens):
        if token not in ('<cls>', '<sep>'):
            candidate_pos.append(i)
    
    predict = []
    random.shuffle(candidate_pos)
    mask_token = None
    for i in range(max(1, round(0.15 * (len(tokens) - 3)))):
        if random.random() < 0.8:
            mask_token = '<mask>'
        else:
            if random.random() < 0.5:
                mask_token = tokens[candidate_pos[i]]
            else:
                mask_token = random.choice(vocab.idx_to_token)
        predict.append((candidate_pos[i], tokens[candidate_pos[i]]))
        tokens[candidate_pos[i]] = mask_token
    
    predict = sorted(predict, key=lambda x: x[0])

    predict_pos = [v[0] for v in predict]
    predict_label = [v[1] for v in predict]

    return vocab[tokens], predict_pos, vocab[predict_label]

In [153]:
def _padding_inputs(examples, max_len, vocab):
    max_masked_len = round(max_len * 0.15)
    all_tokens, all_segments, all_weights, all_pred_pos, all_pred_labels, all_is_next = [], [], [], [], [], []
    valid_lens = []
    for tokens, predict_pos, predict_label, segments, is_next in examples:
        all_tokens.append(torch.tensor(
            tokens + [vocab['<pad>']] * (max_len - len(tokens)), dtype=torch.long))
        all_segments.append(torch.tensor(
            segments + [0] * (max_len - len(segments)), dtype=torch.long
        ))
        valid_lens.append(torch.tensor(len(tokens), dtype=torch.float))

        all_pred_pos.append(torch.tensor(
            predict_pos + [0] * (max_masked_len - len(predict_pos)), dtype=torch.long
        ))
        all_pred_labels.append(torch.tensor(
            predict_label + [0] * (max_masked_len - len(predict_label)), dtype=torch.long
        ))
        all_weights.append(torch.tensor(
            [1.0] *  len(predict_label) + [0.0] * (max_masked_len - len(predict_label)), dtype=torch.float
        ))
        all_is_next.append(torch.tensor(
            is_next, dtype=torch.long
        ))
    
    return (all_tokens, all_segments, valid_lens, all_pred_pos, all_pred_labels,
            all_weights, all_is_next)
        
    
    

In [154]:
class _WikiDataset(Dataset):
    def __init__(self, data_path, max_len):
        with open(data_path, 'r') as f:
            paragraphs = f.readlines()
        paragraphs = [paragraph.strip().lower().split('.')
                    for paragraph in paragraphs if len(paragraph.split('.') >= 2)]
        random.shuffle(paragraphs)

        paragraphs = [[sentence.split() for sentence in paragraph] 
                      for paragraph in paragraphs]  # 三维列表
        sentences = [sentence for paragraph in paragraphs 
                     for sentence in paragraph] # 转为二维列表
        self.vocab = Vocab(sentences, min_freq=5, reserved_tokens=[
            '<pad>', '<mask>', '<cls>', '<sep>'])

        # 获取下一句子预测任务的数据
        examples = _get_nsp_data(paragraphs=paragraphs, max_len=max_len)

        examples = [_get_mlm_data_from_tokens(tokens, self.vocab) 
                    + (segments, is_next)
                    for tokens, segments, is_next in examples]
        
        # padding bert
        (self.all_tokens, self.all_segments, self.valid_lens, self.all_pred_pos, self.all_pred_labels,
            self.all_weights, self.all_is_next) = _padding_inputs(examples, max_len, self.vocab)
    
    def __getitem__(self, idx):
        return (self.all_tokens[idx], self.all_segments[idx], 
                self.valid_lens[idx], self.all_pred_pos[idx], 
                self.all_pred_labels[idx], self.all_weights[idx],
                self.all_is_next[idx])
    
    def __len__(self):
        return len(self.all_is_next)



In [155]:
def load_data_wiki(data_path, batch_size, max_len):
    """加载wikiText-2数据"""
    train_set = _WikiDataset(data_path=data_path, max_len=max_len)
    train_iter = DataLoader(train_set, batch_size, shuffle=True )
    return train_iter, train_set.vocab


In [156]:
def _get_batch_loss(net, loss, vocab_size, tokens_X,
                         segments_X, valid_lens_x,
                         pred_positions_X, mlm_weights_X,
                         mlm_Y, nsp_y):
    # 前向传播
    _, mlm_Y_hat, nsp_Y_hat = net(tokens_X, segments_X,
                                  valid_lens_x.reshape(-1),
                                  pred_positions_X)
    # 计算遮蔽语言模型损失
    mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) *\
    mlm_weights_X.reshape(-1, 1)
    mlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)
    # 计算下一句子预测任务的损失
    nsp_l = loss(nsp_Y_hat, nsp_y).sum()
    l = mlm_l + nsp_l
    return mlm_l, nsp_l, l

SyntaxError: invalid syntax (2709165333.py, line 1)

In [None]:
def train_bert(train_iter, net, loss, devices, max_step):
    net = net.to(devices[0])
    trainer = torch.optim.Adam(net.parameters(), lr=0.01)
    step, timer = 0, d2l.Timer()
    num_steps_reached = False
    animator = d2l.Animator(xlabel='step', ylabel='loss', xlim=[1, max_step],
                            legend=['mlm', 'nsp'])
    metric = d2l.Accumulator(4)

    while(step < max_step and not num_steps_reached):
        for tokens, segments, valid_lens, pred_pos, \
            weights, pred_labels, is_next in train_iter:
            tokens = tokens.to(devices[0])
            segments = segments.to(devices[0])
            valid_lens = valid_lens.to(devices[0])
            pred_pos = pred_pos.to(devices[0])
            weights = weights.to(devices[0])
            pred_labels = pred_labels.to(devices[0])
            is_next = is_next.to(devices[0])

            trainer.zero_grad()
            timer.start()

            mlm_l, nsp_l, l = _get_batch_loss(
                net, loss, 
            )

            l.backward()
            trainer.step()

            metric.add(mlm_l, nsp_l, tokens.shape[0], 1)
            timer.stop()
            animator.add(step + 1,
                         (metric[0] / metric[3], metric[1] / metric[3]))
            
            step += 1
            if step == max_step:
                num_steps_reached = True
                break
    
    print(f'MLM loss {metric[0] / metric[3]: .3f}, NSP loss {metric[1] / metric[3]: .3f}')
    print(f'{metric[2] / timer.sum():.1f} sentence pairs/sec on {str(devices)}')


            

In [None]:
def try_all_gpus():
    '''Return all available GPUs, or [cpu(), ], if no GPU exists'''
    devices = [torch.device(f"cuda:{i}")
               for i in range(torch.cuda.device_count())]
    return devices if devices else [torch.device('cpu')]

In [None]:
batch_size, max_len = 512, 64
data_path = os.path.join("../data/wikitext-2", "wiki.train.tokens")
train_iter, vocab = load_data_wiki(data_path, batch_size, max_len)

net = BertModel(embedding_dim=4, max_len=4, vocab=vocab, num_blocks=2, num_heads=2)
devices = try_all_gpus()
loss = nn.CrossEntropyLoss(reduction='none')

train_bert(net, )