In [1]:
!pip install transformers
!pip install tensorflow

Looking in indexes: http://mirrors.aliyun.com/pypi/simple
Collecting transformers
  Downloading http://mirrors.aliyun.com/pypi/packages/70/10/4f0924b0301042f226ed07c3273d7ae4577744033ef1309e770da8f7e03e/transformers-4.19.1-py3-none-any.whl (4.2 MB)
[K     |████████████████████████████████| 4.2 MB 8.9 MB/s eta 0:00:01
[?25hCollecting huggingface-hub<1.0,>=0.1.0
  Downloading http://mirrors.aliyun.com/pypi/packages/c1/f2/d6542f1e29b803442e058f7a1b52313bea37da46517b1e840ff2f166450c/huggingface_hub-0.6.0-py3-none-any.whl (84 kB)
[K     |████████████████████████████████| 84 kB 11.2 MB/s eta 0:00:01
[?25hCollecting filelock
  Downloading http://mirrors.aliyun.com/pypi/packages/e8/74/48523f5206b0930f7c6b312890c7ab285dba55cea3f0a303999c5425df08/filelock-3.7.0-py3-none-any.whl (10 kB)
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading http://mirrors.aliyun.com/pypi/packages/36/fa/e22ebbcaeecd9bd04efa30f7ec43ccf1501c97615c9af3bbf13a77ce0b81/tokenizers-0.12.1-cp38-cp38-manylinux_2_12_

In [1]:
import torch
torch.cuda.empty_cache()
torch.cuda.memory_summary(device=None, abbreviated=False)



In [2]:
import torch
import csv
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
from PIL import Image, ImageFile
import torchvision.transforms as transforms
from torchvision import models
import torch.utils.model_zoo as model_zoo
import torch.nn.functional as F
import torch.optim as optim
import os
import json
from transformers import BertTokenizer, BertModel
import operator

In [3]:
device = torch.device('cuda')
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
pre_trained_model = BertModel.from_pretrained("bert-base-uncased")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
class ContextQuType:
    def __init__(self, tokenizer, length, ids, context, question):
        self.tokenizer = tokenizer
        self.codeLength = length
        self.ids = ids
        self.context = context
        self.question = question
        self.after_encode = 0
        
    def convert(self):
        if self.after_encode==0:
            after_encode = tokenizer(self.question,self.context,max_length=self.codeLength,padding="max_length")
        return torch.stack([torch.tensor(after_encode['input_ids'][:512]),torch.tensor(after_encode['token_type_ids'][:512]),torch.tensor(after_encode['attention_mask'][:512])])
            
class LabelType:
    def __init__(self, tokenizer, length, CQu, has_ans, text, start):
        self.tokenizer = tokenizer
        self.codeLength = length
        self.has_ans = has_ans
        self.text = text.lower()
        self.start = start
        self.CQu = CQu
    
    def convert(self):
        # make groundtruth
        S = [0.0] * self.codeLength
        E = [0.0] * self.codeLength
        background = tokenizer.encode(self.CQu.question,self.CQu.context,max_length=self.codeLength,padding="max_length")
        if self.has_ans:
            after_encode = tokenizer.encode(self.text)[1:-1]
            ans_length = len(after_encode)
            before_answer = tokenizer.encode(self.CQu.question,self.CQu.context[:self.start])
            start_id = len(before_answer) - 1
            end_id = start_id + ans_length - 1
            if tokenizer.decode(background[start_id:end_id+1])!=self.text:
                    lower = max(start_id-3,0)
                    upper = min(end_id+4,len(background))
                    ans_found = False
                    for i in range(lower,upper):
                        for j in range(i+1,upper):
                            candidate = tokenizer.decode(background[i:j])
                            #if len(candidate)>0:
                            #    if candidate[0]=='$' and self.text[0]!='$':
                            #        candidate=candidate[1:]
                            #    if candidate[0]!='$' and self.text[0]=='$':
                            #        candidate='$'+candidate
                            if candidate==self.text:
                                start_id = i
                                end_id = j-1
                                ans_found = True
                                break
                            if ans_found:
                                break
            if start_id >= self.codeLength:
                start_id = 0
                end_id = 0
            end_id = min(end_id,self.codeLength - 1)
            S[start_id] = 1.0
            E[end_id] = 1.0
        else:
            S[0] = 1.0
            E[0] = 1.0
        return torch.stack([torch.tensor(S),torch.tensor(E)])

In [5]:
def load_data(load_path,tokenizer,length):
    with open(load_path) as json_data:
        dev = json.load(json_data)['data']
    CQu = []
    Label = []
    for data in dev:
        for paragraphs in data['paragraphs']:
            context = paragraphs['context']
            for qas in paragraphs['qas']:
                question = qas['question']
                ids = qas['id']
                ctx = ContextQuType(tokenizer,length,ids,context,question)
                CQu.append(ctx)
                if qas['is_impossible']:
                    Label.append(LabelType(tokenizer,length,ctx,False,'',0))
                else:
                    answer = qas['answers'][0]
                    Label.append(LabelType(tokenizer,length,ctx,True,answer['text'],answer['answer_start']))
    return CQu,Label

In [7]:
class OutputPair:
    def __init__(self,logits_start,logits_end):
        self.logits_start = logits_start
        self.logits_end = logits_end

class CrossEntropy(nn.Module):
    def __init__(self):
        super(CrossEntropy, self).__init__()
    def forward(self,outputs,labels):
        log_soft = F.log_softmax(outputs.logits_start,dim=1)
        loss_start = F.cross_entropy(log_soft, labels[:,0])
        loss_end   = F.cross_entropy(F.log_softmax(outputs.logits_end,dim=1), labels[:,1])
        return loss_start + loss_end / 2.0

class BERTQuA(nn.Module):
    def __init__(self, bert, tokenizer,code_length,device,epsilon = 1.0):
        super(BERTQuA, self).__init__()
        self.tokenizer = tokenizer
        self.device = device
        self.encoder = bert.to(device)
        self.code_length = code_length
        self.output_start = nn.Linear(768,1).to(device)
        self.output_end = nn.Linear(768,1).to(device)
        self.epsilon = 1.0

    def forward(self, inputs):
        tokens_X, segments_X, masks = inputs[:,0],inputs[:,1],inputs[:,2]
        encoded_X = self.encoder(tokens_X, token_type_ids=segments_X, attention_mask=masks).last_hidden_state
        return OutputPair(self.output_start(encoded_X),self.output_end(encoded_X))
        
    def prediction(self,question,answer_text):
        ctx = ContextQuType(self.tokenizer,self.code_length,'',answer_text,question)
        inputs = torch.stack([ctx.convert()]).to(device)
        token_ids = tokenizer(question,answer_text,max_length=self.code_length,padding="max_length")['input_ids']
        outputs = self.forward(inputs)
        score_len = outputs.logits_start.shape[1]
        start_scores = outputs.logits_start.reshape(score_len)
        end_scores = outputs.logits_end.reshape(score_len)
        no_ans_scores = start_scores[0] + end_scores[0]
        answer = ''
        flag = False
        tokens = tokenizer.convert_ids_to_tokens(token_ids)
        if (torch.max(start_scores) + torch.max(end_scores) <= no_ans_scores +  self.epsilon):
            answer = ''
        else:
            answer_start = torch.argmax(start_scores)
            answer_end = torch.argmax(end_scores)
            answer = self.tokenizer.decode(token_ids[answer_start:answer_end+1])
        return answer
    
    def predictionLabel(self,label):
        #ctx = ContextQuType(self.tokenizer,self.code_length,'',answer_text,question)
        #inputs = torch.stack([ctx.convert()]).to(device)
        #token_ids = tokenizer(question,answer_text,max_length=self.code_length,padding="max_length")['input_ids']
        #outputs = self.forward(inputs)
        outputs = label.convert()
        #print(outputs)
        token_ids = tokenizer(label.CQu.question,label.CQu.context,max_length=self.code_length,padding="max_length")['input_ids']
        start_scores = outputs[0,:]
        end_scores = outputs[1,:]
        #print(start_scores,end_scores)
        no_ans_scores = start_scores[0] + end_scores[0]
        answer = ''
        tokens = tokenizer.convert_ids_to_tokens(token_ids)
        #print(outputs.logits_start.reshape(-1))
        #print(outputs.logits_end.reshape(-1))
        if (torch.max(start_scores) + torch.max(end_scores) <= no_ans_scores +  self.epsilon):
            answer = ''
        else:
            answer_start = torch.argmax(start_scores)
            answer_end = torch.argmax(end_scores)
            answer = self.tokenizer.decode(token_ids[answer_start:answer_end+1])
        return answer
    def load_data(self,file_path):
        CQu, La = load_data(file_path,self.tokenizer,self.code_length)
        self.inputs = torch.stack([ctx.convert() for ctx in CQu])
        self.labels = torch.stack([ctx.convert() for ctx in La])
    
    def train(self, epochs=3, batch_size = 48, lr = 3e-5, T=10):
        optimizer = optim.Adam(self.parameters(), lr=lr)
        criterion = CrossEntropy().to(device)
        self.train_loader = DataLoader(TensorDataset(self.inputs,self.labels),batch_size = batch_size, shuffle = True, num_workers = 2)
        epoch_loss = []
        for epoch in range(epochs):      
            running_loss = 0.0
            epoch_running_loss = 0.0
            batch_count = 0
            for batchidx, (x, label) in enumerate(self.train_loader):
                x, label = x.to(device), label.to(device)
                output = self.forward(x)
                output.logits_start=output.logits_start.resize(len(output.logits_start),self.code_length)
                output.logits_end=output.logits_end.resize(len(output.logits_end),self.code_length)
                loss = criterion(output, label)
                # backprop
                optimizer.zero_grad()  #梯度清0
                loss.backward()   #梯度反传
                optimizer.step()   #保留梯度
                
                running_loss += loss.item()
                epoch_running_loss += loss.item()
                batch_count += 1
                if batchidx % T == T-1:
                    print(epoch,' batchidx: ', batchidx, ' loss: ', running_loss/T)
                    running_loss = 0.0
            epoch_loss.append(epoch_running_loss/batch_count)
            print(epoch, 'loss:', epoch_running_loss/batch_count)
        return epoch_loss
    def printLabel(self,file_path):
        CQu, La = load_data(file_path,self.tokenizer,self.code_length)
        count = 0
        precise = 0
        for label in La:
            count += 1
            predict = self.predictionLabel(label)
            if predict==label.text:
                precise += 1
            else:
                print(predict,label.text,label.has_ans)
        print(precise/count)

In [8]:
device = torch.device('cuda')
net = BERTQuA(pre_trained_model,tokenizer,512,device)
net.load_data('train-v2.0-Tag1.0.json')

In [9]:
epoch_loss = net.train(batch_size=16,epochs=3)
f = open("loss.txt", "w")
f.write(str(epoch_loss))



0  batchidx:  9  loss:  8.375053596496581
0  batchidx:  19  loss:  6.674959278106689
0  batchidx:  29  loss:  5.4409748077392575
0  batchidx:  39  loss:  5.341647148132324
0  batchidx:  49  loss:  4.824629902839661
0  batchidx:  59  loss:  4.046962070465088
0  batchidx:  69  loss:  4.170640397071838
0  batchidx:  79  loss:  3.8895325660705566
0  batchidx:  89  loss:  3.5661971807479858
0  batchidx:  99  loss:  3.508423352241516
0  batchidx:  109  loss:  3.3395523548126222
0  batchidx:  119  loss:  3.0631514310836794
0  batchidx:  129  loss:  3.2824429750442503
0  batchidx:  139  loss:  3.0704012870788575
0  batchidx:  149  loss:  2.967134141921997
0  batchidx:  159  loss:  3.3613591313362123
0  batchidx:  169  loss:  3.0499977588653566
0  batchidx:  179  loss:  2.4612540125846865
0  batchidx:  189  loss:  2.8389235615730284
0  batchidx:  199  loss:  2.5803094625473024
0  batchidx:  209  loss:  2.9079527258872986
0  batchidx:  219  loss:  2.6807554841041563
0  batchidx:  229  loss:  2.3

60

In [10]:
question = """What kind of device can access the Twilight Princess manga?"""
context = """A Japan-exclusive manga series based on Twilight Princess, penned and illustrated by Akira Himekawa, was first released on February 8, 2016. The series is available solely via publisher Shogakukan's MangaOne mobile application. While the manga adaptation began almost ten years after the initial release of the game on which it is based, it launched only a month before the release of the high-definition remake."""
net.prediction(question,context)

'mangaone mobile application'

In [11]:
torch.save({'state_dict': net.state_dict(), 'epoch': 3},'bertCluster0-'+str(3) + '.pkl')