In [1]:
batch_size=6
max_pred=5
max_len=30
num_hidden=768
d_k=d_v=64
num_heads=12
d_ff=4*num_hidden
num_layers=6

In [None]:
class Preprocess:
    def __init__(self,corpus) -> None:
        self.corpus=corpus
    
    def _drop_specialchar(self) -> None:
        import re
        self.sentences = re.sub("[.,!?\\-]", '', self.corpus.lower()).split('\n')
    
    def _get_worddict(self):
        wordset=list(set(' '.join(self.sentences).split()))
        self.worddict={'[PAD]': 0, '[CLS]': 1, '[SEP]': 2, '[MASK]': 3}
        for index,word in enumerate(wordset):
            self.worddict[word]=index+4
        self.vocab_size=len(self.worddict)
        del wordset
    
    def _get_sentencetoken(self):
        self.sentencetoken=[]
        for sentence in self.sentences:
            self.sentencetoken.append([self.worddict[word] for word in sentence.split()])
    
    def build_dataset(self):
        self._drop_specialchar()
        self._get_worddict()
        self._get_sentencetoken()
        self.dataset=[]
        positive=negative=0
        assert not batch_size%2,'batch_size应该是偶数'
        import random
        while positive!=negative:
            sentence_a,sentence_b=random.randrange(len(self.sentences)),random.randrange(len(self.sentences))
            token_a,token_b=self.sentencetoken[sentence_a],self.sentencetoken[sentence_b]
            input_ids=[self.worddict['[CLS]']]+token_a+[self.worddict['[SEP]']]+token_b+[self.worddict['SEP']]
            segment_ids=[0]*(1+len(token_a)+1)+[1]*(len(token_b)+1)
            n_pred=min(max_pred,max(1,len(input_ids)*0.15))
            position_canbemasked=[i for i,token in enumerate(input_ids) if token != self.worddict['[CLS'] and token != self.worddict['SEP']]
            random.shuffle(position_canbemasked)
            target_token,masked_position=[],[]
            for pos in position_canbemasked[:n_pred]:
                masked_position.append(pos)
                target_token.append(input_ids[pos])
                chance=random.random()
                if chance<0.8:
                    input_ids[pos]=self.worddict['[MASK]']
                elif chance>0.9:
                    index=random.randrange(4,self.vocab_size)
                    input_ids[pos]=index
            n_pad=max_len-len(input_ids)
            input_ids.extend([0]*n_pad)
            segment_ids.extend([0]*n_pad)
            if max_pred>n_pred:
                n_pad=max_pred-n_pred
                target_token.extend([0]*n_pad)
                masked_position.extend([0]*n_pad)
            if sentence_a+1==sentence_b and positive<negative:
                self.dataset.append([input_ids,segment_ids,target_token,masked_position,True])
                positive+=1
            elif sentence_a+1!=sentence_b and positive>negative:
                self.dataset.append([input_ids,segment_ids,target_token,masked_position,False])
                negative+=1
        return self.dataset


In [None]:
corpus=""

In [None]:
rawdata=Preprocess(corpus)
dataset=rawdata.build_dataset()
vocab_size=rawdata.vocab_size

In [None]:
import torch
from torch import nn

In [None]:
class Embedding(nn.Module):
    def __init__(self) -> None:
        super(Embedding,self).__init__()
        self.tok_embed=nn.Embedding(vocab_size,num_hidden)
        self.pos_embed=nn.Embedding(max_len,num_hidden)
        self.seg_embed=nn.Embedding(2,num_hidden)
        self.norm=nn.LayerNorm(num_hidden)

    def forward(self,x,seg):
        seq_len=x.size(1)
        pos=torch.arange(seq_len,dtype=torch.long)
        pos=pos.unsqueeze(0).expand_as(x)
        embedding=self.tok_embed(x)+self.pos_embed(pos)+self.seg_embed(seg)
        return self.norm(embedding)


In [None]:
import numpy as np

In [None]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self) -> None:
        super(ScaledDotProductAttention,self).__init__()

    def forward(self,Q,K,V,attn_mask):
        scores=torch.matmul(Q,K.transpose(-1,-2))/np.sqrt(d_k)
        scores.masked_fill_(attn_mask,-1e9)
        attn=nn.Softmax(dim=-1)(scores)
        context=torch.matmul(attn,V)
        return context

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self) -> None:
        super(MultiHeadAttention,self).__init__()
        self.W_Q=nn.Linear(num_hidden,d_k*num_heads)
        self.W_K=nn.Linear(num_hidden,d_k*num_heads)
        self.W_V=nn.Linear(num_hidden,d_v*num_heads)
    
    def forward(self,Q,K,V,attn_mask):
        residual,batch_size=Q,Q.size(0)
        q_s=self.W_Q(Q).view(batch_size,-1,num_heads,d_k).transpose(1,2)
        k_s=self.W_K(K).view(batch_size,-1,num_heads,d_k).transpose(1,2)
        v_s=self.W_V(V).view(batch_size,-1,num_heads,d_v).transpose(1,2)

        attn_mask=attn_mask.unsqueeze(1).repeat(1,num_heads,1,1)
        context=ScaledDotProductAttention()(q_s,k_s,v_s,attn_mask)
        context.transpose(1,2).contiguous().view(batch_size,-1,num_heads*d_v)
        output=nn.Linear(num_heads*d_v,num_hidden)(context)
        return nn.LayerNorm(num_hidden)(output+residual)

In [None]:
def gelu(x):
    import math
    return x*0.5*(1.0+torch.erf(x/math.sqrt(2.0)))

In [None]:
class PoswiseFeedForwardNet(nn.Module):
    def __init__(self) -> None:
        super(PoswiseFeedForwardNet,self).__init__()
        self.fc1=nn.Linear(num_hidden,d_ff)
        self.fc2=nn.Linear(d_ff,num_hidden)
    
    def forward(self,x):
        return self.fc2(gelu(self.fc1(x)))

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self) -> None:
        super(EncoderLayer,self).__init__()
        self.enc_self_attn=MultiHeadAttention()
        self.pos_ffn=PoswiseFeedForwardNet()
    
    def forward(self,enc_inputs,enc_self_attn_mask):
        enc_outputs=self.enc_self_attn(enc_inputs,enc_inputs,enc_inputs,enc_self_attn_mask)
        enc_outputs=self.pos_ffn(enc_outputs)
        return enc_outputs

In [None]:
def get_attn_pad_mask(seq_q,seq_k):
    batch_size,seq_len=seq_q.size()
    pad_attn_mask=seq_len.data.eq(0).unsqueeze(1)
    return pad_attn_mask.expand(batch_size,seq_len,seq_len)

In [None]:
class BERT(nn.Module):
    def __init__(self) -> None:
        super(BERT,self).__init__()
        self.embedding=Embedding()
        self.layers=nn.ModuleList([EncoderLayer() for _ in range(num_layers)])
        self.fc=nn.Sequential(
            nn.Linear(num_hidden,num_hidden),
            nn.Dropout(0.5),
            nn.Tanh()
        )
        self.classifier=nn.Linear(num_hidden,2)
        self.linear=nn.Linear(num_hidden,num_hidden)
        self.activ2=gelu
        embed_weight=self.embedding.tok_embed.weight
        self.fc2=nn.Linear(num_hidden,vocab_size,bias=False)
        self.fc2.weight=embed_weight

    def forward(self,input_ids,segment_ids,masked_pos):
        output=self.embedding(input_ids,segment_ids,masked_pos)
        enc_self_attn_mask=get_attn_pad_mask(input_ids,input_ids)
        for layer in self.layers:
            output=layer(output,enc_self_attn_mask)
        h_pooled=self.fc(output[:,0])
        logits_clsf=self.classifier(h_pooled)
        masked_pos=masked_pos[:,:,None].expand(-1,-1,num_hidden)
        h_masked=torch.gather(output,1,masked_pos)
        h_masked=self.activ2(self.linear(h_masked))
        logits_lm=self.fc2(h_masked)
        return logits_lm,logits_clsf

In [None]:
from torch import optim

In [None]:
model=BERT()
criterion=nn.CrossEntropyLoss()
optimizer=optim.Adadelta(model.parameters(),lr=0.001)

In [None]:
for epoch in range(100):
    for input_ids,segment_ids,masked_tokens,masked_pos,isNext in dataset:
        logits_lm,logits_clsf=model(input_ids,segment_ids,masked_pos)
        loss_lm=criterion(logits_lm.view(-1,vocab_size),masked_tokens.view(-1))
        loss_lm=(loss_lm.float()).mean()
        loss_clsf=criterion(logits_clsf,isNext)
        loss=loss_lm+loss_clsf
        if (epoch+1)%10==0:
            print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()