In [12]:
#from transformer import Encoder
from torch import nn,optim
from torch.nn.functional import cross_entropy,softmax, relu
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate

import torch
import utils
import os
import pickle

In [13]:
import import_ipynb

In [14]:
Encoder = __import__('5_transformer').Encoder

In [15]:
class GPT(nn.Module):

    def __init__(self, model_dim, max_len, num_layer, num_head, n_vocab, lr, max_seg=3, drop_rate=0.2,padding_idx=0):
        super().__init__()
        self.padding_idx = padding_idx
        self.n_vocab = n_vocab
        self.max_len = max_len
        
        self.word_emb = nn.Embedding(n_vocab,model_dim)
        self.word_emb.weight.data.normal_(0,0.1)

        self.segment_emb = nn.Embedding(num_embeddings= max_seg, embedding_dim=model_dim)
        self.segment_emb.weight.data.normal_(0,0.1)
        self.position_emb = torch.empty(1,max_len,model_dim)
        nn.init.kaiming_normal_(self.position_emb,mode='fan_out', nonlinearity='relu')
        self.position_emb = nn.Parameter(self.position_emb)


        self.encoder = Encoder(n_head=num_head, emb_dim=model_dim, drop_rate=drop_rate, n_layer=num_layer)
        self.task_mlm = nn.Linear(in_features=model_dim, out_features=n_vocab)
        self.task_nsp = nn.Linear(in_features=model_dim*self.max_len, out_features=2)

        self.opt = optim.Adam(self.parameters(),lr)
    
    def forward(self,seqs, segs, training=False):
        embed = self.input_emb(seqs, segs)
        z = self.encoder(embed, training, mask = self.mask(seqs))   # [n, step, model_dim]
        mlm_logits = self.task_mlm(z)   # [n, step, n_vocab]
        nsp_logits = self.task_nsp(z.reshape(z.shape[0],-1))    # [n, n_cls]
        return mlm_logits, nsp_logits
    
    def step(self, seqs, segs, seqs_, nsp_labels):
        self.opt.zero_grad()
        mlm_logits, nsp_logits = self(seqs, segs, training=True)
        pred_loss = cross_entropy(mlm_logits.reshape(-1,self.n_vocab),seqs_.reshape(-1))
        nsp_loss = cross_entropy(nsp_logits,nsp_labels.reshape(-1))
        loss = pred_loss + 0.2 * nsp_loss
        loss.backward()
        self.opt.step()
        return loss.cpu().data.numpy(), mlm_logits
    
    def input_emb(self,seqs, segs):
        # device = next(self.parameters()).device
        # self.position_emb = self.position_emb.to(device)
        return self.word_emb(seqs) + self.segment_emb(segs) + self.position_emb
    
    def mask(self, seqs):
        device = next(self.parameters()).device
        batch_size, seq_len = seqs.shape
        mask = torch.triu(torch.ones((seq_len,seq_len), dtype=torch.long), diagonal=1).to(device)  # [seq_len ,seq_len]
        pad = torch.eq(seqs,self.padding_idx)   # [n, seq_len]
        mask = torch.where(pad[:,None,None,:],1,mask[None,None,:,:]).to(device)   # [n, 1, seq_len, seq_len]
        return mask>0   # [n, 1, seq_len, seq_len]
    
    @property
    def attentions(self):
        attentions = {
            "encoder": [l.mh.attention.cpu().data.numpy() for l in self.encoder.encoder_layers]
        }
        return attentions

In [17]:
def train():
    MODEL_DIM = 256
    N_LAYER = 4
    LEARNING_RATE = 1e-4
    dataset = utils.MRPCData("./MRPC",2000)
    print("num word: ",dataset.num_word)
    model = GPT(
        model_dim=MODEL_DIM, max_len=dataset.max_len-1, num_layer=N_LAYER, num_head=4, n_vocab=dataset.num_word,
        lr=LEARNING_RATE, max_seg=dataset.num_seg, drop_rate=0.2, padding_idx=dataset.pad_id
    )
    if torch.cuda.is_available():
        print("GPU train avaliable")
        device =torch.device("cuda")
        model = model.cuda()
    else:
        device = torch.device("cpu")
        model = model.cpu()
    
    loader = DataLoader(dataset,batch_size=32,shuffle=True)

    for epoch in range(100):
        for batch_idx, batch in enumerate(loader):
            seqs, segs,xlen,nsp_labels = batch
            seqs, segs,nsp_labels = seqs.type(torch.LongTensor).to(device), segs.type(torch.LongTensor).to(device),nsp_labels.to(device)
            # pred: [n, step, n_vocab]
            loss,pred = model.step(seqs=seqs[:,:-1], segs= segs[:,:-1], seqs_=seqs[:,1:], nsp_labels=nsp_labels)
            if batch_idx %100 == 0:
                pred = pred[0].cpu().data.numpy().argmax(axis = 1) # [step]
                print(
                    "Epoch: ",epoch,
                "|batch: ", batch_idx,
                "| loss: %.3f" % loss,
                "\n| tgt: ", " ".join([dataset.i2v[i] for i in seqs[0, 1:].cpu().data.numpy()[:xlen[0].sum()+1]]),
                "\n| prd: ", " ".join([dataset.i2v[i] for i in pred[:xlen[0].sum()+1]]),
                )
#     os.makedirs("./visual/models/gpt",exist_ok=True)
#     torch.save(model.state_dict(),"./visual/models/gpt/model.pth")
#     export_attention(model,device,dataset)

In [18]:
def export_attention(model,device,data,name="gpt"):
    model.load_state_dict(torch.load("./visual/models/gpt/model.pth",map_location=device))
    seqs, segs,xlen,nsp_labels = data[:32]
    seqs, segs,xlen,nsp_labels = torch.from_numpy(seqs),torch.from_numpy(segs),torch.from_numpy(xlen),torch.from_numpy(nsp_labels)
    seqs, segs,nsp_labels = seqs.type(torch.LongTensor).to(device), segs.type(torch.LongTensor).to(device),nsp_labels.to(device)
    model(seqs[:,:-1],segs[:,:-1],False)
    seqs = seqs.cpu().data.numpy()
    data = {"src": [[data.i2v[i] for i in seqs[j]] for j in range(len(seqs))], "attentions": model.attentions}
    path = "./visual/tmp/%s_attention_matrix.pkl" % name
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "wb") as f:
        pickle.dump(data, f)

In [19]:
if __name__ == "__main__":
    train()

num word:  12880
GPU train avaliable
Epoch:  0 |batch:  0 | loss: 10.043 
| tgt:  myanmar 's pro-democracy leader aung san suu kyi will return home late friday but will remain in detention after recovering from surgery at a yangon hospital , her personal physician said . <SEP> myanmar 's pro-democracy leader aung san suu kyi will be kept under house arrest following her release from a hospital where she underwent surgery , her personal physician said friday . 
| prd:  leighton structure jennette foolish tightness hit influential member wants asphyxiating pingeon dialogue district aboriginal particular security ash twin gillian <NUM>--available allegiant foolish communist awareness barnett losyukov losyukov radical addresses accounted multimillionaire significantly realized floors airline would-be accountable bitter pioneer nashville myanmar country village dallas-based village bodyguards movie gbi interfere sia biological gbi desperate keywords mclean preposterous toured dirk celebrati

Epoch:  16 |batch:  0 | loss: 3.095 
| tgt:  kelly killed himself after being exposed as the source for a bbc report which claimed the government had embellished evidence of iraq 's banned weapons to justify the war . <SEP> he killed himself after being exposed as the source for a bbc report which claimed the government exaggerated the case for war against iraq . 
| prd:  the a <NUM> and the and in the united of the statement of said is the united . been of . the . . . . the the united . <SEP> the said in with the the in the new to the statement of to the the first said the united . the . the .
Epoch:  17 |batch:  0 | loss: 3.281 
| tgt:  <quote> his advice to the house republicans is to pass it , to send it to him , so he can sign it . <quote> <SEP> the president 's <quote> advice to the house republicans is to pass it , to send it to him , so he can sign it , <quote> said white house spokesman ari fleischer . 
| prd:  the the wife , the company , , to the the 's <quote> the the 's th

Epoch:  31 |batch:  0 | loss: 2.282 
| tgt:  overall , <NUM> percent of women washed up , compared with <NUM> percent of men . <SEP> but in san francisco , <NUM> percent of men and only <NUM> percent of women washed their hands . 
| prd:  the , <NUM> , of the has the to a with the percent of the . <SEP> the were the diego , <NUM> percent of the were <NUM> <NUM> percent of the . <NUM> lowest .
Epoch:  32 |batch:  0 | loss: 2.296 
| tgt:  the woman was exposed to the sars virus while in the hospital but was not a health care worker , said dr. colin d â cunha , ontario â s commissioner of public health . <SEP> the woman was exposed to the sars virus while in the hospital but was not a health-care worker , said dr colin d 'cunha , ontario 's commissioner of public health . 
| prd:  the dow was not to the u.s. of to the the case , he not a statement department and , but . melinda . 'archiac s , which 's s . . the . . <SEP> the us 's a to the time has to the the first and a not a stateme

Epoch:  48 |batch:  0 | loss: 1.386 
| tgt:  baer said he had concluded that lawyers for the two victims <quote> have shown , albeit barely ... that iraq provided material support to bin laden and al-qaeda <quote> . <SEP> judge harold baer concluded wednesday that lawyers for the two victims <quote> have shown , albeit barely ... that iraq provided material support to bin laden and al-qaida . <quote> 
| prd:  the said he would no that he for the the years <quote> he been , <quote> barely ... that they 's material support to be laden and exchange <quote> . <SEP> <quote> 's baer that that that the for the two years <quote> this been , <quote> barely ... the he 's material support to him laden and secular . <quote>
Epoch:  49 |batch:  0 | loss: 1.308 
| tgt:  <quote> we 've put a lot of effort and energy into improving our patching progress , probably later than we should have . <SEP> <quote> we have put a lot of energy into patching , later than we should have , <quote> he said . 
| prd:

Epoch:  65 |batch:  0 | loss: 0.793 
| tgt:  general motors corp. posted a record <NUM> percent improvement in <NUM> . <SEP> general motors corp. posted the best-ever improvement in <NUM> at <NUM> percent . 
| prd:  the wesley corp. posted a $ <NUM> percent of in <NUM> . <SEP> after motors corp. posted the best-ever improvement in <NUM> percent <NUM> percent .
Epoch:  66 |batch:  0 | loss: 0.792 
| tgt:  elena slough , considered to be the nation 's oldest person and the third oldest person in the world , died early sunday morning . <SEP> elena slough , considered to be the oldest person in the us and the third oldest person in the world , has died . 
| prd:  the slough , considered to be the program 's retail person and the third oldest person in the company , would of in , . <SEP> elena slough , could to be the number person in the nation and the world oldest person in the world , has be .
Epoch:  67 |batch:  0 | loss: 0.711 
| tgt:  iraq 's economy was ravaged during years of u.n. s

Epoch:  81 |batch:  0 | loss: 0.388 
| tgt:  he said that with the u.s.-backed peace plan , or road map , â in a coma , â the attack could easily widen conflict through the region . <SEP> mr jouejati said that with the us-backed road map <quote> in a coma <quote> the attack could easily widen through the region . 
| prd:  <quote> said that he the u.s.-backed peace plan , or road map , â in a coma , â the attack could easily widen conflict through the region . <SEP> he jouejati said that with the us-backed road map <quote> in a coma <quote> the attack could easily widen through the region .
Epoch:  82 |batch:  0 | loss: 0.390 
| tgt:  <quote> i love the catholic church with all my heart , mind , soul and strength , <quote> said troy , who spoke quickly but in a steady voice . <SEP> <quote> i love the catholic church with all my heart , mind , soul and strength , <quote> he said . 
| prd:  the i think the catholic church with some my heart , i , soul and strength , <quote> he tro

Epoch:  96 |batch:  0 | loss: 0.249 
| tgt:  the technology-laced nasdaq composite index .ixic was off <NUM> points , or <NUM> percent , at <NUM> . <SEP> the broader standard & poor 's <NUM> index edged up <NUM> points , or <NUM> percent , at <NUM> . 
| prd:  the company nasdaq composite index < was off <NUM> points , or <NUM> percent , to <NUM> . <SEP> the broader standard & poor 's <NUM> index edged up <NUM> points , or <NUM> percent , at <NUM> .
Epoch:  97 |batch:  0 | loss: 0.233 
| tgt:  negotiators talked with the boy for about an hour and a half , bragdon said . <SEP> negotiators talked with the boy for more than an hour , and swat officers surrounded the classroom , bragdon said . 
| prd:  the talked with the boy for about an hour and a half , bragdon said . <SEP> negotiators talked with the boy for more than an hour , and swat officers surrounded the classroom , bragdon said .
Epoch:  98 |batch:  0 | loss: 0.217 
| tgt:  in a statement later , he said it appeared his side had 