In [1]:
import os
import torch
import pandas as pd
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torch.optim import lr_scheduler

from sklearn import model_selection
from sklearn import metrics
import transformers
import tokenizers
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup
from tqdm.autonotebook import tqdm

In [2]:
MAX_LEN = 510
TRAIN_BATCH_SIZE = 8
EPOCHS = 6
PATH = "bert_config" #！！！需要改 config所在为止 我用的预训练模型注意下！！！
TRAINING_FILE = "train_data_frame.csv" #！！！需要改
dev_file = 'dev_deleted.csv'
TOKENIZER = transformers.BertTokenizer.from_pretrained("hfl/chinese-bert-wwm-ext")

In [3]:
#model_config = transformers.RobertaConfig.from_pretrained(PATH)
#model_config.output_hidden_states = True
#MODEL = transformers.AutoModel.from_pretrained("hfl/chinese-roberta-wwm-ext", config=model_config)

In [4]:
def find_ans(text, answer):
    for index in (i for i, e in enumerate(text) if e == answer[0]):
            if text[index:index+len(answer)] == answer:
                idx0 = index
                idx1 = index + len(answer)
                return idx0, idx1
    return find_start(text, answer[:-1])

def find_text(ans_start, ans_end, text, posslen):
    text_l = ans_start
    text_r = ans_end
    count = 0
    for i in text[ans_end:]:
        if i=="。":
            text_r += 1
            break
        else:
            text_r += 1
    text_l = text_r - posslen
    return text_l, text_r

In [5]:
def process_data(text, question, answer, tokenizer, max_len):
    len_ans = len(answer)
    posslen = max_len - len(question) - 3
    
    ans_start, ans_end = find_ans(text, answer)
    text_l, text_r = find_text(ans_start, ans_end, text, posslen)
    if len(text) > posslen:
        text = text[text_l:text_r]
    token_text = tokenizer.encode_plus(text,question)
    input_ids_orig = token_text["input_ids"]
        
    input_ids = input_ids_orig
    token_type_ids = [0] * (len(input_ids_orig))
    mask = [1] * len(token_type_ids)
    if ans_start>=510:
        ans_start=0
    if ans_end>=510:
        ans_end=0
    if ans_start<=0:
        ans_start = 0
    if ans_end<=0:
        ans_end=0
    ans_start += 1   #
    ans_end += 1       #
    
    padding_length = max_len - len(input_ids)
    if padding_length > 0:
        input_ids = input_ids + ([1] * padding_length)
        mask = mask + [0] * padding_length
        token_type_ids = token_type_ids + [0] * padding_length
        
    return {
        'ids': input_ids,
        'mask': mask,
        'token_type_ids': token_type_ids,
        'ans_start': ans_start,
        'ans_end': ans_end,
        'orig_text': text,
        'orig_ans': answer
    }

In [6]:
class QAdataset:
    def __init__(self, text, question, answer):
        self.text = text
        self.question = question
        self.answer = answer
        self.tokenizer = TOKENIZER
        self.max_len = MAX_LEN
        
    def __len__(self):
        return len(self.text)
    
    def __getitem__(self, item):
        data = process_data(
            self.text[item],
            self.question[item],
            self.answer[item],
            self.tokenizer,
            self.max_len
            )
        return {
            'ids': torch.tensor(data["ids"], dtype=torch.long),
            'mask': torch.tensor(data["mask"], dtype=torch.long),
            'token_type_ids': torch.tensor(data["token_type_ids"], dtype=torch.long),
            'ans_start': torch.tensor(data["ans_start"], dtype=torch.long),
            'ans_end': torch.tensor(data["ans_end"], dtype=torch.long),
            'orig_text': data["orig_text"],
            'orig_ans': data["orig_ans"]
        }

In [7]:
class QAmodel(transformers.BertForQuestionAnswering):
    def __init__(self, conf):
        super(QAmodel,self).__init__(conf)
        self.bertwwm = transformers.BertModel.from_pretrained("hfl/chinese-bert-wwm-ext", config=conf)
        self.drop_out = nn.Dropout(0.2)
        self.out = nn.Linear(768 * 2, 2)
        torch.nn.init.normal_(self.out.weight, std=0.02)
        
    def forward(self, input_ids, mask, token_type_ids):
        _, _, out = self.bertwwm(input_ids, token_type_ids=token_type_ids, attention_mask=mask)
        out = torch.cat((out[-1], out[-2]), dim=-1)
        out = self.drop_out(out)
        #out  = out.view(510,8, 768*2)
        logits = self.out(out)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)
        return start_logits, end_logits

In [8]:
def loss_fn(start_logits, end_logits, start_positions, end_positions):
    loss_fct = nn.CrossEntropyLoss()
    start_loss = loss_fct(start_logits, start_positions)
    end_loss = loss_fct(end_logits, end_positions)
    total_loss = (start_loss + end_loss)
    return total_loss

In [9]:
def train_fn(data_loader, model, optimizer, device, scheduler=None):
    model.train()
    
    for bi, d in enumerate(data_loader):
        ids = d["ids"]
        token_type_ids = d["token_type_ids"]
        mask = d["mask"]
        ans_start = d["ans_start"]
        ans_end = d["ans_end"]
        orig_ans = d["orig_ans"]
        orig_text = d["orig_text"]
        
        ids = ids.to(device, dtype=torch.long)
        token_type_ids = token_type_ids.to(device, dtype=torch.long)
        mask = mask.to(device, dtype=torch.long)
        ans_start = ans_start.to(device, dtype=torch.long)
       # ans_start = ans_start.squeeze(-1)  #  修改了
        ans_end = ans_end.to(device, dtype=torch.long)
        #ans_end = ans_end.squeeze(-1) #  修改了
        
        model.zero_grad()
        outputs_start, outputs_end = model(ids, mask, token_type_ids)
        #print(outputs_start.shape)
        #outputs_s = outputs_start.view(8,510)
        #outputs_e= outputs_end.view(8,510)
        #print(outputs_s.shape)
        loss = loss_fn(outputs_start, outputs_end, ans_start, ans_end)
        print(loss)
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        outputs_start = torch.softmax(outputs_start, dim=1).cpu().detach().numpy()
        outputs_end = torch.softmax(outputs_end, dim=1).cpu().detach().numpy()
        
df_train = pd.read_csv(TRAINING_FILE)
train_dataset = QAdataset(
    text=df_train.context.values,
    question=df_train.question.values,
    answer=df_train.answer.values
)

train_data_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=8,
)


device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
print(device)
device = torch.device('cuda:1') #!!!需要改成"cuda"
model_config = transformers.BertConfig.from_pretrained(PATH)
model_config.output_hidden_states = True
model = QAmodel(conf=model_config)
model.to(device)

cuda:1


QAmodel(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(21128, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    

In [11]:
num_train_steps = int(len(df_train) / TRAIN_BATCH_SIZE * EPOCHS)
param_optimizer = list(model.named_parameters())
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
#no_decay = ["weight_hh_l0", 'bias_ih_l0', 'weight_ih_l0','bias_hh_l0']
#optimizer_parameters = [
#    {'params': [p for n, p in param_optimizer if  any(nd in n for nd in no_decay)], 'weight_decay': 0.001}
    #{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],'weight_decay': 0.0},
    #{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],'weight_decay': 0.0},
#]
optimizer_parameters = [
    {'params': [p for n, p in param_optimizer if  any(nd in n for nd in no_decay)], 'weight_decay': 0.001}#if not any(nd in n for nd in no_decay)
    #{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}    #if any(nd in n for nd in no_decay)
]

optimizer = AdamW(optimizer_parameters, lr=2e-5)
scheduler = get_linear_schedule_with_warmup(
    optimizer, 
    num_warmup_steps=0, 
    num_training_steps=num_train_steps
)

In [12]:
for epoch in range(8):
    train_fn(train_data_loader, model, optimizer, device, scheduler=scheduler)
    print("over epoch-------------------------", epoch)
#torch.save(model, 'QAdev.pt')   
torch.save(model.state_dict(), 'QAdev_params.pt')

tensor(12.2647, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.9421, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.4458, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.0831, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.4576, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.3247, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.7587, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.7790, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.8540, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.4349, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.5673, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.3290, device='cuda:1', grad_fn=<AddBackward0>)
tensor(13.1908, device='cuda:1', grad_fn=<AddBackward0>)
tensor(13.2021, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.8396, device='cuda:1', grad_fn=<AddBackward0>)
tensor(13.2587, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.6899, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.0951, device='cuda:1'

tensor(11.7770, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.5005, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.6823, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.2360, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.3823, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.6702, device='cuda:1', grad_fn=<AddBackward0>)
tensor(13.1822, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.5954, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.8671, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.7557, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.1722, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.2958, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.0368, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.4389, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.2602, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.3154, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.9223, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.1137, device='cuda:1'

tensor(11.8679, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.3348, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.2121, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.8333, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.1400, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.8456, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.3710, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.0177, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.7868, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.9948, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.0826, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.2347, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.2715, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.3638, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.4472, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.0081, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.1204, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.9010, device='cuda:1'

tensor(12.2780, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.3381, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.7563, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.4210, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.6006, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.2518, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.1842, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.2307, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.3678, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.3955, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.7450, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.3751, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.8419, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.7930, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.0755, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.7327, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.4313, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.6786, device='cuda:1'

tensor(11.2453, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.7571, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.8847, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.6130, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.3497, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.5143, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.3650, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.6656, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.5140, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.4003, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.1181, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.6797, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.8847, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.6890, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.5757, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.1875, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.8445, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.0936, device='cuda:1'

tensor(10.9713, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.9114, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.5048, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.9018, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.4011, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.1524, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.9335, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.1437, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.1927, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.6945, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.7140, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.4702, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.7533, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.1513, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.4479, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.3391, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.5362, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.6208, device='cuda:1'

tensor(11.3971, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.0930, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.5520, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.9025, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.1348, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.5268, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.7723, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.4212, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.9844, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.1729, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.2577, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.3909, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.1105, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.5989, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.3023, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.1861, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.0922, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.3083, device='cuda:1'

tensor(11.0545, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.5654, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.8460, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.8134, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.4773, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.9553, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.8508, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.6713, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.8514, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.4543, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.6550, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.9205, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.0091, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.6408, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.4704, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.4712, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.9191, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.3054, device='cuda:1'

tensor(11.3214, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.8151, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.7664, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.9603, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.4527, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.1300, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.0939, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.9547, device='cuda:1', grad_fn=<AddBackward0>)
tensor(12.0396, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.2668, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.7154, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.3409, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.4901, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.2586, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.1483, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.7724, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.7779, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.7900, device='cuda:1'

tensor(10.3364, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.8635, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.8070, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.6926, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.7700, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.7666, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.3195, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.1475, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.4828, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.7188, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.1674, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.4328, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.8349, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.5895, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.1228, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.9215, device='cuda:1', grad_fn=<AddBackward0>)
tensor(10.1564, device='cuda:1', grad_fn=<AddBackward0>)
tensor(11.3053, device='cuda:1'

RuntimeError: cuda runtime error (710) : device-side assert triggered at /pytorch/aten/src/THC/THCReduceAll.cuh:327

In [159]:
class QAdatasetvalid:
    def __init__(self, text,question):
        self.text = text
        self.question = question
        self.tokenizer = TOKENIZER
        self.max_len = MAX_LEN
        
    def __len__(self):
        return len(self.text)
    
    def __getitem__(self, item):
        data = process_datavalid(
            self.text[item],
            self.question[item],
            self.tokenizer,
            self.max_len
            )
        
        return{
            'ids': torch.tensor(data["ids"], dtype=torch.long),
            'mask': torch.tensor(data["mask"], dtype=torch.long),
            'token_type_ids': torch.tensor(data["token_type_ids"], dtype=torch.long),
            'orig_text': data["orig_text"],
        }

In [18]:
VALID_FILE = 'test_data_with_stp.csv'
df_valid = pd.read_csv(VALID_FILE)
valid_dataset = QAdatasetvalid(
    text=df_valid.context.values,#
    question=df_valid.id.values,
)

valid_data_loader = torch.utils.data.DataLoader(
    valid_dataset,
    batch_size=TRAIN_BATCH_SIZE,
)

In [None]:
answer_list = []

In [None]:
model.eval()
with torch.no_grad():
    tk0 = tqdm(valid_data_loader, total=len(valid_data_loader))
    for bi, d in enumerate(tk0):
        ids = d["ids"]
        token_type_ids = d["token_type_ids"]
        mask = d["mask"]
        orig_text = d["orig_text"]

        ids = ids.to(device, dtype=torch.long)
        token_type_ids = token_type_ids.to(device, dtype=torch.long)
        mask = mask.to(device, dtype=torch.long)

        outputs_start, outputs_end = model(
            ids,
            mask,
            token_type_ids
        )

        outputs_start = torch.softmax(outputs_start, dim=1).cpu().detach().numpy()
        outputs_end = torch.softmax(outputs_end, dim=1).cpu().detach().numpy()
        idx_start=np.argmax(outputs_start)
        idx_end=np.argmax(outputs_end)
        out = orig_text[0][int(idx_start)-1:int(idx_end)-1]
        answer_list.append(out)

In [None]:
answer_dict = {}
for i in range(df_valid.shape[0]):
    key = df_valid.loc[i]['question']
    value = answer_list[i]
    answer_dict[key] = value 

In [None]:
i=0
for k,v in answer_dict.items():
    print(k,v)
    i=i+1n
    if i >50:
        break

In [None]:
import json
outfile = 'answer_json_27.json'
with open(outfile,'w') as f:
    json.dump(answer_dict,f, ensure_ascii=False) # 解决中文编码问题