In [8]:
!pip install torch

Defaulting to user installation because normal site-packages is not writeable


In [1]:
import torch

In [2]:
torch.cuda.is_available()

True

In [3]:
device=torch.device("cuda" if torch.cuda.is_available else "cpu")

In [4]:
device

device(type='cuda')

In [1]:
import logging
import os
import argparse
import random
import json
import numpy as np
import torch
import glob
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
# from utiles import read_race_examples,convert_examples_to_features,accuracy,select_field
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.modeling import BertForMultipleChoice
from pytorch_pretrained_bert.optimization import BertAdam
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from tqdm import tqdm
from transformers import BertModel, BertTokenizer



In [2]:
class RaceSequence(object):
    def __init__(self,article,question,option1,option2,option3,option4,label = None):
        self.article = article
        self.question=question
        self.options=[option1,option2,option3,option4]
        self.label = label
        

class InputSequence(object):
    def __init__(self,sequence_list,label):
        self.choices_sequence = [
            {
                'input_sequence': input_sequence,
                'segment_id': segment_id,
                'mask_id': mask_id
            }
            for input_sequence,segment_id,mask_id in sequence_list
        ]
        self.label = label


In [3]:
def to_input(input_path):
    cloze_input_list=[]
    direct_input_list=[]
    for filename in os.listdir(input_path):
        file_path=input_path+"/"+filename
        # print(file_path)
        with open(file_path,'r',encoding='utf-8') as f:
            json_text=json.loads(f.read())
            correct_answers=json_text['answers']
            options_list=json_text['options']
            questions_list=json_text['questions']
            article=json_text['article']

            for i in range(len(questions_list)):
                correct_answer=correct_answers[i]
                label=ord(correct_answer)-ord('A')
                options=options_list[i]
                question=questions_list[i]
                if '_' in question:
                    case=RaceSequence(article,question,options[0],options[1],options[2],options[3],label)
                    cloze_input_list.append(case)
                else:
                    case=RaceSequence(article,question,options[0],options[1],options[2],options[3],label)
                    direct_input_list.append(case)
    return cloze_input_list,direct_input_list

In [4]:
def input_to_sequence_question(input_list,max_length,tokenizer,is_training):

    choice_list = []
    for case in input_list:
#         print(case.article)
        article_token = tokenizer.tokenize(case.article)
        question_token = tokenizer.tokenize(case.question)
        
        sequence_list = []
        for option in case.options:
            
            question_answer_token = question_token + tokenizer.tokenize(option)
            
            while len(article_token)+len(question_answer_token)>max_length-3:
                if  len(article_token)>len(question_answer_token):
                    article_token.pop()
                else:
                    question_answer_token.pop()
            
            sequence_token = ["[CLS]"] + article_token + ["[SEP]"] + question_answer_token + ["[SEP]"]
            segment_id = [0] * (len(article_token) + 2) + [1] * (len(question_answer_token) + 1)
            input_sequence = tokenizer.convert_tokens_to_ids(sequence_token)
            mask_id = [1] * len(input_sequence)

            # Zero-pad up to the sequence length.
            for i in range(max_length-len(input_sequence)):
                input_sequence+=[0]
                mask_id+=[0]
                segment_id+=[0]

            sequence_list.append((input_sequence,segment_id,mask_id))

        label = case.label

        choice_list.append(InputSequence(sequence_list,label))

    return choice_list






def input_to_sequence_cloze(input_list,max_length,tokenizer,is_training):

    choice_list = []
    for case in input_list:
#         print(case.article)
        article_token = tokenizer.tokenize(case.article)
#         question_token = tokenizer.tokenize(case.question)
        
        sequence_list = []
        for option in case.options:
            question_answer=case.question.replace("_",option)
            
            question_answer_token = tokenizer.tokenize(question_answer)
            
            while len(article_token)+len(question_answer_token)>max_length-3:
                if  len(article_token)>len(question_answer_token):
                    article_token.pop()
                else:
                    question_answer_token.pop()
            
            sequence_token = ["[CLS]"] + article_token + ["[SEP]"] + question_answer_token + ["[SEP]"]
            segment_id = [0] * (len(article_token) + 2) + [1] * (len(question_answer_token) + 1)
            input_sequence = tokenizer.convert_tokens_to_ids(sequence_token)
            mask_id = [1] * len(input_sequence)

            # Zero-pad up to the sequence length.
            for i in range(max_length-len(input_sequence)):
                input_sequence+=[0]
                mask_id+=[0]
                segment_id+=[0]

            sequence_list.append((input_sequence,segment_id,mask_id))

        label = case.label

        choice_list.append(InputSequence(sequence_list,label))

    return choice_list


In [5]:
train_middle_path='data/train/middle'
train_high_path='data/train/high'

max_seq_length=360
is_training=True
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")


train_input_cloze,train_input_direct=to_input(train_middle_path)
train_input_high_cloze,train_input_high_direct=to_input(train_high_path)
train_input_cloze+=train_input_high_cloze
train_input_direct+=train_input_high_direct
print("train_input completed!")
input_sequences_question=input_to_sequence_question(train_input_direct, max_seq_length,tokenizer,is_training)
input_sequences_cloze=input_to_sequence_cloze(train_input_cloze, max_seq_length,tokenizer,is_training)

train_input completed!


In [6]:
train_input=train_input_cloze+train_input_direct

In [7]:
input_sequences=input_sequences_question+input_sequences_cloze

In [8]:
eval_middle_path='data/dev/middle'
max_seq_length=360
is_training=False
# tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

eval_input_middle_cloze,eval_input_middle_direct=to_input(eval_middle_path)

input_sequences_eval_middle_cloze=input_to_sequence_cloze(eval_input_middle_cloze, max_seq_length,tokenizer,is_training)
input_sequences_eval_middle_direct=input_to_sequence_question(eval_input_middle_direct, max_seq_length,tokenizer,is_training)
input_sequences_eval_middle=input_sequences_eval_middle_cloze+input_sequences_eval_middle_direct

In [9]:
def warmup_linear(x, warmup=0.002):
    if x < warmup:
        return x/warmup
    return 1.0 - x


def accuracy(out, labels):
    outputs = np.argmax(out, axis=1)
    return np.sum(outputs == labels)

In [10]:
def to_all_tensor(input_sequences,key):
    all_input_sequences=[]
    for input_sequence in input_sequences:
        lst=[]
        for case in input_sequence.choices_sequence:
            lst.append(case[key])
        all_input_sequences.append(lst)
    return torch.tensor(all_input_sequences,dtype=torch.long)

In [11]:
train_input_sequence=to_all_tensor(input_sequences,'input_sequence')
train_segment_id=to_all_tensor(input_sequences,'segment_id')
train_mask_id=to_all_tensor(input_sequences,'mask_id')
train_label=torch.tensor([case.label for case in input_sequences],dtype=torch.long)

In [12]:
eval_input_sequence_middle=to_all_tensor(input_sequences_eval_middle,'input_sequence')
eval_segment_id_middle=to_all_tensor(input_sequences_eval_middle,'segment_id')
eval_mask_id_middle=to_all_tensor(input_sequences_eval_middle,'mask_id')
eval_label_middle=torch.tensor([case.label for case in input_sequences_eval_middle],dtype=torch.long)

In [13]:
eval_input_sequence_middle.shape

torch.Size([1436, 4, 360])

In [20]:
train_input_sequence.shape

torch.Size([87866, 4, 360])

In [21]:
device=torch.device("cuda" if torch.cuda.is_available else "cpu")

model = BertForMultipleChoice.from_pretrained("MiniLM-L6-H384-distilled-from-BERT-Base",
        cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(-1)),
        num_choices=4).to(device)

In [22]:
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule


learning_rate=2.5e-5
warmup_proportion=0.1
batch_size=4
gradient_accumulation_steps=5
# train_batch_size=int(batch_size/gradient_accumulation_steps)
train_batch_size=batch_size
eval_batch_size=4

num_train_steps=len(train_input)

n_gpu=torch.cuda.device_count()
epoch_num=5


device=torch.device("cuda" if torch.cuda.is_available else "cpu")

model = BertForMultipleChoice.from_pretrained("MiniLM-L6-H384-distilled-from-BERT-Base",
        cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(-1)),
        num_choices=4).to(device)

param_optimizer=list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

optimizer = BertAdam(optimizer_grouped_parameters,
                    lr=learning_rate,
                    warmup=warmup_proportion,
                    t_total=num_train_steps)

# optimizer = BertAdam(model.parameters(),lr=learning_rate,schedule='warmup_linear',warmup=warmup_proportion,t_total=num_train_steps)

train_dataset=TensorDataset(train_input_sequence,train_segment_id,train_mask_id,train_label)
train_sampler=RandomSampler(train_dataset)
train_data=DataLoader(train_dataset,sampler=train_sampler, batch_size=train_batch_size)

eval_dataset=TensorDataset(eval_input_sequence_middle,eval_segment_id_middle,eval_mask_id_middle,eval_label_middle)
eval_sampler=RandomSampler(eval_dataset)
eval_data=DataLoader(eval_dataset,sampler=eval_sampler, batch_size=eval_batch_size)

In [23]:
num_train_steps

87866

In [25]:
def val_race(eval_dataloader,device,model,global_step):
    print("begin evaluation!!!")
    eval_loss, eval_accuracy = 0, 0
    eval_steps, eval_examples = 0, 0
    for step, batch in enumerate(eval_dataloader):
        batch = tuple(t.to(device) for t in batch)
        input_ids, segment_ids, input_mask, label_ids = batch

        with torch.no_grad():
            tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids)
            logits = model(input_ids, segment_ids, input_mask)

        logits = logits.detach().cpu().numpy()
        label_ids = label_ids.to('cpu').numpy()
        tmp_eval_accuracy = accuracy(logits, label_ids)

        eval_loss += tmp_eval_loss.mean().item()
        eval_accuracy += tmp_eval_accuracy

        eval_examples += input_ids.size(0)
        eval_steps += 1

    eval_loss = eval_loss / eval_steps
    eval_accuracy = eval_accuracy / eval_examples
    print("eval_loss:",eval_loss)
    print("eval_accuracy:",eval_accuracy)

    result = {'dev_eval_loss': eval_loss,
              'dev_eval_accuracy': eval_accuracy,
              'global_step': global_step}

    output_eval_file = os.path.join('output', "eval_results_nowarmpup.txt")
    with open(output_eval_file, "a+") as writer:
        for key in sorted(result.keys()):
            writer.write("%s = %s\n" % (key, str(result[key])))

In [26]:
def train_race(train_dataloader,device,n_gpu,model,optimizer,global_step,t_total,train_loss,train_examples,train_steps):
    for step, data in enumerate(tqdm(train_dataloader)):
        data = tuple(t.to(device) for t in data)
        input_ids, segment_ids,input_mask, label_ids = data
#         print(segment_ids)
        loss = model(input_ids, segment_ids, input_mask, label_ids)
        
        train_loss += loss.item()
        train_examples += input_ids.size(0)
        train_steps += 1

        loss.backward()

        
         # modify learning rate with special warm up BERT uses
        if (step + 1) % gradient_accumulation_steps == 0:
#             tmp_lr= learning_rate * warmup_linear(global_step / t_total, warmup_proportion)
# we do not do linear warmup to see the effect
            tmp_lr=learning_rate
#             print(warmup_proportion)
            for param_group in optimizer.param_groups:
                param_group['lr'] = tmp_lr
            optimizer.step()
            optimizer.zero_grad()
            global_step += 1
#             print("Training loss: {}, global step: {}".format(tr_loss / nb_tr_steps, global_step))

        if (global_step+1) % 100 == 0:
            print("Training loss: {}, global step: {}".format(train_loss / train_steps, global_step))
            
            

In [27]:
n_gpu

1

In [27]:
global_step=0
model.train()

for epoch in range(epoch_num):
    train_loss=0
    training_case_num=0
    training_step_num=0
    
  
  
    print("Training Epoch: {}/{}".format(epoch+1, int(epoch_num)))
    train_race(train_data,device,n_gpu,model,optimizer,global_step,num_train_steps,train_loss,training_case_num,training_step_num)
#     if (global_step+1) % 100 == 0:
    model.eval()
    val_race(eval_data, device, model, global_step)

  

  0%|          | 0/21967 [00:00<?, ?it/s]

Training Epoch: 1/5


	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  ../torch/csrc/utils/python_arg_parser.cpp:1055.)
  next_m.mul_(beta1).add_(1 - beta1, grad)
  2%|▏         | 496/21967 [00:45<32:45, 10.93it/s] 

Training loss: 1.386563254847671, global step: 99
Training loss: 1.3865545988082886, global step: 99
Training loss: 1.3865566644630203, global step: 99


  2%|▏         | 500/21967 [00:45<33:00, 10.84it/s]

Training loss: 1.386559635281084, global step: 99
Training loss: 1.3865585812108072, global step: 99


  5%|▍         | 996/21967 [01:30<31:58, 10.93it/s]

Training loss: 1.3863436030383085, global step: 199
Training loss: 1.386341900590912, global step: 199
Training loss: 1.3863388496033526, global step: 199


  5%|▍         | 1000/21967 [01:31<32:12, 10.85it/s]

Training loss: 1.3863340883790132, global step: 199
Training loss: 1.3863283385027636, global step: 199


  7%|▋         | 1496/21967 [02:15<31:17, 10.90it/s]

Training loss: 1.3862844731098034, global step: 299
Training loss: 1.3862851820686921, global step: 299
Training loss: 1.3862829367638272, global step: 299


  7%|▋         | 1500/21967 [02:16<31:31, 10.82it/s]

Training loss: 1.3862816328518541, global step: 299
Training loss: 1.3862828591253218, global step: 299


  9%|▉         | 1996/21967 [03:01<30:32, 10.90it/s]

Training loss: 1.3861727622517368, global step: 399
Training loss: 1.386173696341161, global step: 399
Training loss: 1.3861772644561592, global step: 399


  9%|▉         | 2000/21967 [03:01<30:46, 10.81it/s]

Training loss: 1.3861760621791606, global step: 399
Training loss: 1.3861823221514855, global step: 399


 11%|█▏        | 2496/21967 [03:46<29:44, 10.91it/s]

Training loss: 1.3859999123938336, global step: 499
Training loss: 1.3859988307246032, global step: 499
Training loss: 1.3859958623378907, global step: 499


 11%|█▏        | 2500/21967 [03:47<30:00, 10.81it/s]

Training loss: 1.3859948597973686, global step: 499
Training loss: 1.3859906458482594, global step: 499


 14%|█▎        | 2996/21967 [04:32<29:00, 10.90it/s]

Training loss: 1.3857760022996064, global step: 599
Training loss: 1.3857721037794655, global step: 599
Training loss: 1.3857703825454533, global step: 599


 14%|█▎        | 3000/21967 [04:32<29:13, 10.82it/s]

Training loss: 1.385765892454113, global step: 599
Training loss: 1.3857669388703324, global step: 599


 16%|█▌        | 3496/21967 [05:17<28:10, 10.93it/s]

Training loss: 1.3854515717606006, global step: 699
Training loss: 1.3854555835099187, global step: 699
Training loss: 1.3854490921843008, global step: 699


 16%|█▌        | 3500/21967 [05:17<28:23, 10.84it/s]

Training loss: 1.3854528421535024, global step: 699
Training loss: 1.385449013596911, global step: 699


 18%|█▊        | 3996/21967 [06:02<27:26, 10.91it/s]

Training loss: 1.3844379722549858, global step: 799
Training loss: 1.384439382467184, global step: 799
Training loss: 1.3844274735194253, global step: 799


 18%|█▊        | 4000/21967 [06:02<27:39, 10.83it/s]

Training loss: 1.3844178161244205, global step: 799
Training loss: 1.384429454505369, global step: 799


 20%|██        | 4496/21967 [06:47<26:40, 10.92it/s]

Training loss: 1.3826602386287905, global step: 899
Training loss: 1.3826344215254767, global step: 899
Training loss: 1.3826557327064377, global step: 899


 20%|██        | 4500/21967 [06:48<26:52, 10.83it/s]

Training loss: 1.382654721814296, global step: 899
Training loss: 1.3826459075747979, global step: 899


 23%|██▎       | 4996/21967 [07:33<25:52, 10.93it/s]

Training loss: 1.3796774795463493, global step: 999
Training loss: 1.3796839406005472, global step: 999
Training loss: 1.3796720088470547, global step: 999


 23%|██▎       | 5000/21967 [07:33<26:06, 10.83it/s]

Training loss: 1.379693322083434, global step: 999
Training loss: 1.3796565123237736, global step: 999


 25%|██▌       | 5496/21967 [08:18<25:10, 10.91it/s]

Training loss: 1.3768653676551077, global step: 1099
Training loss: 1.3768348979724374, global step: 1099
Training loss: 1.3768634824117834, global step: 1099


 25%|██▌       | 5500/21967 [08:18<25:20, 10.83it/s]

Training loss: 1.3769013416164353, global step: 1099
Training loss: 1.3768688510777538, global step: 1099


 27%|██▋       | 5996/21967 [09:03<24:21, 10.93it/s]

Training loss: 1.373748036739327, global step: 1199
Training loss: 1.373767292901148, global step: 1199
Training loss: 1.373799377806051, global step: 1199


 27%|██▋       | 6000/21967 [09:03<24:32, 10.84it/s]

Training loss: 1.3738043067971242, global step: 1199
Training loss: 1.373782568562923, global step: 1199


 30%|██▉       | 6496/21967 [09:48<23:35, 10.93it/s]

Training loss: 1.3704659338819696, global step: 1299
Training loss: 1.3704655299265984, global step: 1299
Training loss: 1.3704378295352098, global step: 1299


 30%|██▉       | 6500/21967 [09:49<23:46, 10.84it/s]

Training loss: 1.370424340691483, global step: 1299
Training loss: 1.3704459442140726, global step: 1299


 32%|███▏      | 6996/21967 [10:33<22:48, 10.94it/s]

Training loss: 1.3679373337934493, global step: 1399
Training loss: 1.3679390377048222, global step: 1399
Training loss: 1.3679227870550261, global step: 1399


 32%|███▏      | 7000/21967 [10:34<22:59, 10.85it/s]

Training loss: 1.3679327103164272, global step: 1399
Training loss: 1.3679330880717901, global step: 1399


 34%|███▍      | 7496/21967 [11:19<22:07, 10.90it/s]

Training loss: 1.3655729533196133, global step: 1499
Training loss: 1.365551265368085, global step: 1499
Training loss: 1.36550479929368, global step: 1499


 34%|███▍      | 7500/21967 [11:19<22:19, 10.80it/s]

Training loss: 1.3654719853789115, global step: 1499
Training loss: 1.3654806734895941, global step: 1499


 36%|███▋      | 7996/21967 [12:04<21:20, 10.91it/s]

Training loss: 1.3633770539657708, global step: 1599
Training loss: 1.36337571064492, global step: 1599
Training loss: 1.3633760818171623, global step: 1599


 36%|███▋      | 8000/21967 [12:05<21:30, 10.82it/s]

Training loss: 1.3633858156997163, global step: 1599
Training loss: 1.3633667084109353, global step: 1599


 39%|███▊      | 8496/21967 [12:49<20:35, 10.91it/s]

Training loss: 1.3617143832914544, global step: 1699
Training loss: 1.3617425926745275, global step: 1699
Training loss: 1.3617916370793428, global step: 1699


 39%|███▊      | 8500/21967 [12:50<20:45, 10.82it/s]

Training loss: 1.3618015937312515, global step: 1699
Training loss: 1.3617926988479376, global step: 1699


 41%|████      | 8996/21967 [13:35<19:47, 10.93it/s]

Training loss: 1.359236841317082, global step: 1799
Training loss: 1.3592240507269286, global step: 1799
Training loss: 1.3592081478904243, global step: 1799


 41%|████      | 9000/21967 [13:35<19:56, 10.84it/s]

Training loss: 1.359222328279993, global step: 1799
Training loss: 1.359193078907798, global step: 1799


 43%|████▎     | 9496/21967 [14:20<19:01, 10.93it/s]

Training loss: 1.3570675633467142, global step: 1899
Training loss: 1.3570513291929385, global step: 1899
Training loss: 1.3570388333350845, global step: 1899


 43%|████▎     | 9500/21967 [14:20<19:09, 10.85it/s]

Training loss: 1.3570387664030015, global step: 1899
Training loss: 1.3570143324360746, global step: 1899


 46%|████▌     | 9996/21967 [15:05<18:15, 10.92it/s]

Training loss: 1.3543751542838947, global step: 1999
Training loss: 1.35438558511755, global step: 1999
Training loss: 1.3544120588719493, global step: 1999


 46%|████▌     | 10000/21967 [15:06<18:24, 10.83it/s]

Training loss: 1.3543975612811505, global step: 1999
Training loss: 1.3543797378027864, global step: 1999


 48%|████▊     | 10496/21967 [15:51<17:30, 10.92it/s]

Training loss: 1.3526572785927262, global step: 2099
Training loss: 1.3526463558181847, global step: 2099
Training loss: 1.3526295949257339, global step: 2099


 48%|████▊     | 10500/21967 [15:51<17:38, 10.83it/s]

Training loss: 1.3526521282585309, global step: 2099
Training loss: 1.3526220935528046, global step: 2099


 50%|█████     | 10996/21967 [16:36<16:44, 10.92it/s]

Training loss: 1.3505768020026192, global step: 2199
Training loss: 1.3505364724584994, global step: 2199
Training loss: 1.3505516643197664, global step: 2199


 50%|█████     | 11000/21967 [16:36<16:52, 10.83it/s]

Training loss: 1.3505282789918591, global step: 2199
Training loss: 1.3505492011669906, global step: 2199


 52%|█████▏    | 11496/21967 [17:21<16:00, 10.90it/s]

Training loss: 1.348508754233891, global step: 2299
Training loss: 1.348538727133367, global step: 2299
Training loss: 1.3485198612834015, global step: 2299


 52%|█████▏    | 11500/21967 [17:21<16:06, 10.83it/s]

Training loss: 1.3485315969652207, global step: 2299
Training loss: 1.3485291399938333, global step: 2299


 55%|█████▍    | 11996/21967 [18:06<15:12, 10.93it/s]

Training loss: 1.3465360065110776, global step: 2399
Training loss: 1.3465411362995026, global step: 2399
Training loss: 1.3465079944501612, global step: 2399


 55%|█████▍    | 12000/21967 [18:07<15:19, 10.84it/s]

Training loss: 1.3465101399120043, global step: 2399
Training loss: 1.3465085418405747, global step: 2399


 57%|█████▋    | 12496/21967 [18:52<14:27, 10.92it/s]

Training loss: 1.3445644701610044, global step: 2499
Training loss: 1.3445482432708697, global step: 2499
Training loss: 1.3445319239116624, global step: 2499


 57%|█████▋    | 12500/21967 [18:52<14:33, 10.84it/s]

Training loss: 1.3445336294376407, global step: 2499
Training loss: 1.3445333061776015, global step: 2499


 59%|█████▉    | 12996/21967 [19:37<13:41, 10.93it/s]

Training loss: 1.342103927944568, global step: 2599
Training loss: 1.3420902547624596, global step: 2599
Training loss: 1.3421108259910381, global step: 2599


 59%|█████▉    | 13000/21967 [19:37<13:46, 10.84it/s]

Training loss: 1.342135421591257, global step: 2599
Training loss: 1.3421370648991118, global step: 2599


 61%|██████▏   | 13496/21967 [20:22<12:55, 10.92it/s]

Training loss: 1.3395970556726275, global step: 2699
Training loss: 1.3395760945741708, global step: 2699
Training loss: 1.3395789739970252, global step: 2699


 61%|██████▏   | 13500/21967 [20:22<13:01, 10.84it/s]

Training loss: 1.3395542964120672, global step: 2699
Training loss: 1.339544572559266, global step: 2699


 64%|██████▎   | 13996/21967 [21:07<12:10, 10.92it/s]

Training loss: 1.3380672806754799, global step: 2799
Training loss: 1.3380674391678313, global step: 2799
Training loss: 1.3380680865599972, global step: 2799


 64%|██████▎   | 14000/21967 [21:08<12:15, 10.83it/s]

Training loss: 1.3380733327936114, global step: 2799
Training loss: 1.3380736085898333, global step: 2799


 66%|██████▌   | 14496/21967 [21:53<11:24, 10.92it/s]

Training loss: 1.3359365316290985, global step: 2899
Training loss: 1.335918296260089, global step: 2899
Training loss: 1.335898671453144, global step: 2899


 66%|██████▌   | 14500/21967 [21:53<11:28, 10.84it/s]

Training loss: 1.335896983559764, global step: 2899
Training loss: 1.3359032226633274, global step: 2899


 68%|██████▊   | 14996/21967 [22:38<10:37, 10.93it/s]

Training loss: 1.333868890772187, global step: 2999
Training loss: 1.3338727382031337, global step: 2999
Training loss: 1.3338591473543095, global step: 2999


 68%|██████▊   | 15000/21967 [22:38<10:42, 10.84it/s]

Training loss: 1.3338832898781543, global step: 2999
Training loss: 1.333876253398341, global step: 2999


 71%|███████   | 15496/21967 [23:23<09:52, 10.92it/s]

Training loss: 1.3323512861558644, global step: 3099
Training loss: 1.332341045522019, global step: 3099
Training loss: 1.3323428780177011, global step: 3099


 71%|███████   | 15500/21967 [23:23<09:56, 10.84it/s]

Training loss: 1.332346666684565, global step: 3099
Training loss: 1.3323514307419773, global step: 3099


 73%|███████▎  | 15996/21967 [24:08<09:06, 10.92it/s]

Training loss: 1.3299376578172992, global step: 3199
Training loss: 1.3299287721816526, global step: 3199
Training loss: 1.3299125397012346, global step: 3199


 73%|███████▎  | 16000/21967 [24:09<09:10, 10.84it/s]

Training loss: 1.329936453980615, global step: 3199
Training loss: 1.3299062597363418, global step: 3199


 75%|███████▌  | 16496/21967 [24:54<08:21, 10.92it/s]

Training loss: 1.3274706293503853, global step: 3299
Training loss: 1.3274704684963166, global step: 3299
Training loss: 1.3274422814074722, global step: 3299


 75%|███████▌  | 16500/21967 [24:54<08:24, 10.83it/s]

Training loss: 1.3274383586866956, global step: 3299
Training loss: 1.3274784574718776, global step: 3299


 77%|███████▋  | 16996/21967 [25:39<07:47, 10.63it/s]

Training loss: 1.3253170793636577, global step: 3399
Training loss: 1.3253248623116602, global step: 3399
Training loss: 1.3253205698322126, global step: 3399


 77%|███████▋  | 16998/21967 [25:39<07:34, 10.93it/s]

Training loss: 1.3253161852984503, global step: 3399
Training loss: 1.3252993950360494, global step: 3399


 80%|███████▉  | 17496/21967 [26:25<06:49, 10.92it/s]

Training loss: 1.3229370577228379, global step: 3499
Training loss: 1.3229834209355962, global step: 3499
Training loss: 1.322963471319864, global step: 3499


 80%|███████▉  | 17500/21967 [26:25<06:52, 10.84it/s]

Training loss: 1.32296913611179, global step: 3499
Training loss: 1.3229535533650356, global step: 3499


 82%|████████▏ | 17996/21967 [27:10<06:03, 10.92it/s]

Training loss: 1.3206146076209813, global step: 3599
Training loss: 1.3206140465353775, global step: 3599
Training loss: 1.32060039804506, global step: 3599


 82%|████████▏ | 18000/21967 [27:10<06:06, 10.83it/s]

Training loss: 1.3205981252193186, global step: 3599
Training loss: 1.3206056949198064, global step: 3599


 84%|████████▍ | 18496/21967 [27:55<05:18, 10.90it/s]

Training loss: 1.318613156204966, global step: 3699
Training loss: 1.3186048967151494, global step: 3699
Training loss: 1.3186363944390171, global step: 3699


 84%|████████▍ | 18500/21967 [27:55<05:20, 10.82it/s]

Training loss: 1.3186201299444535, global step: 3699
Training loss: 1.3186102880517292, global step: 3699


 86%|████████▋ | 18996/21967 [28:40<04:31, 10.92it/s]

Training loss: 1.3168223442733837, global step: 3799
Training loss: 1.3168008570377137, global step: 3799
Training loss: 1.3167785926763025, global step: 3799


 86%|████████▋ | 19000/21967 [28:41<04:33, 10.84it/s]

Training loss: 1.3167838161023544, global step: 3799
Training loss: 1.316792507624399, global step: 3799


 89%|████████▉ | 19496/21967 [29:25<03:46, 10.92it/s]

Training loss: 1.314475639352862, global step: 3899
Training loss: 1.314471568426161, global step: 3899
Training loss: 1.314471872054828, global step: 3899


 89%|████████▉ | 19500/21967 [29:26<03:47, 10.84it/s]

Training loss: 1.314466641695209, global step: 3899
Training loss: 1.3144687539391338, global step: 3899


 91%|█████████ | 19996/21967 [30:11<03:00, 10.92it/s]

Training loss: 1.3122571297498546, global step: 3999
Training loss: 1.3122709955819656, global step: 3999
Training loss: 1.3122623757615104, global step: 3999


 91%|█████████ | 20000/21967 [30:11<03:01, 10.83it/s]

Training loss: 1.3122495737600022, global step: 3999
Training loss: 1.3122312833910685, global step: 3999


 93%|█████████▎| 20496/21967 [30:56<02:14, 10.92it/s]

Training loss: 1.3103547592938427, global step: 4099
Training loss: 1.310352208098625, global step: 4099
Training loss: 1.3103497344228985, global step: 4099


 93%|█████████▎| 20500/21967 [30:56<02:15, 10.84it/s]

Training loss: 1.3103349229724526, global step: 4099
Training loss: 1.3103244549614945, global step: 4099


 96%|█████████▌| 20996/21967 [31:41<01:28, 10.92it/s]

Training loss: 1.308418719109594, global step: 4199
Training loss: 1.308414382200385, global step: 4199
Training loss: 1.3084285155647668, global step: 4199


 96%|█████████▌| 21000/21967 [31:41<01:29, 10.84it/s]

Training loss: 1.3084179450428615, global step: 4199
Training loss: 1.3084313522347881, global step: 4199


 98%|█████████▊| 21496/21967 [32:26<00:43, 10.93it/s]

Training loss: 1.3064139782994202, global step: 4299
Training loss: 1.306417208392302, global step: 4299
Training loss: 1.3064074049831884, global step: 4299


 98%|█████████▊| 21500/21967 [32:27<00:43, 10.84it/s]

Training loss: 1.3064185832656143, global step: 4299
Training loss: 1.306418799286338, global step: 4299


100%|██████████| 21967/21967 [33:09<00:00, 11.04it/s]


begin evaluation!!!


  0%|          | 2/21967 [00:00<30:56, 11.83it/s]

eval_loss: 1.0964473637863785
eval_accuracy: 0.5452646239554317
Training Epoch: 2/5


  2%|▏         | 496/21967 [00:42<31:17, 11.44it/s]

Training loss: 1.162025826808178, global step: 99
Training loss: 1.161285364639855, global step: 99
Training loss: 1.160969176102932, global step: 99


  2%|▏         | 500/21967 [00:43<31:32, 11.34it/s]

Training loss: 1.1606293189597416, global step: 99
Training loss: 1.1593462292679804, global step: 99


  5%|▍         | 996/21967 [01:25<30:32, 11.44it/s]

Training loss: 1.1489913005026142, global step: 199
Training loss: 1.1490753351983776, global step: 199
Training loss: 1.1490651508092642, global step: 199


  5%|▍         | 1000/21967 [01:26<30:47, 11.35it/s]

Training loss: 1.148803382544575, global step: 199
Training loss: 1.148848644218168, global step: 199


  7%|▋         | 1496/21967 [02:09<29:49, 11.44it/s]

Training loss: 1.1455937003411576, global step: 299
Training loss: 1.145746718813232, global step: 299
Training loss: 1.1454034833288542, global step: 299


  7%|▋         | 1500/21967 [02:09<30:04, 11.34it/s]

Training loss: 1.1451989074296085, global step: 299
Training loss: 1.1455941229084479, global step: 299


  9%|▉         | 1996/21967 [02:52<29:56, 11.12it/s]

Training loss: 1.1456881733914666, global step: 399
Training loss: 1.1455590855591522, global step: 399
Training loss: 1.1455269612457732, global step: 399


  9%|▉         | 2000/21967 [02:52<30:15, 11.00it/s]

Training loss: 1.145524741337822, global step: 399
Training loss: 1.1453234451213081, global step: 399


 11%|█▏        | 2496/21967 [03:36<28:25, 11.42it/s]

Training loss: 1.145292707543096, global step: 499
Training loss: 1.1455525746449637, global step: 499
Training loss: 1.1454648799809353, global step: 499


 11%|█▏        | 2500/21967 [03:36<28:39, 11.32it/s]

Training loss: 1.1454711803753344, global step: 499
Training loss: 1.1455671518814474, global step: 499


 14%|█▎        | 2996/21967 [04:19<28:25, 11.12it/s]

Training loss: 1.1461182317371559, global step: 599
Training loss: 1.1461366441384335, global step: 599
Training loss: 1.1461482336511601, global step: 599


 14%|█▎        | 3000/21967 [04:19<28:44, 11.00it/s]

Training loss: 1.1462861883334114, global step: 599
Training loss: 1.1461230070740431, global step: 599


 16%|█▌        | 3496/21967 [05:03<26:55, 11.44it/s]

Training loss: 1.1457232438358285, global step: 699
Training loss: 1.1457522068794996, global step: 699
Training loss: 1.1458037117992022, global step: 699


 16%|█▌        | 3500/21967 [05:03<27:07, 11.35it/s]

Training loss: 1.1459279322348164, global step: 699
Training loss: 1.1458990174843673, global step: 699


 18%|█▊        | 3996/21967 [05:46<26:10, 11.44it/s]

Training loss: 1.1425327185769851, global step: 799
Training loss: 1.1426965276504661, global step: 799
Training loss: 1.1426984316815965, global step: 799


 18%|█▊        | 4000/21967 [05:47<26:23, 11.35it/s]

Training loss: 1.1428178990629687, global step: 799
Training loss: 1.1428952725880384, global step: 799


 20%|██        | 4496/21967 [06:29<25:27, 11.44it/s]

Training loss: 1.1423069245359125, global step: 899
Training loss: 1.1422793931291302, global step: 899
Training loss: 1.1423059263798245, global step: 899


 20%|██        | 4500/21967 [06:30<25:39, 11.35it/s]

Training loss: 1.1423134748712547, global step: 899
Training loss: 1.14220380863234, global step: 899


 23%|██▎       | 4996/21967 [07:12<24:42, 11.45it/s]

Training loss: 1.140906653354118, global step: 999
Training loss: 1.1410087318994027, global step: 999
Training loss: 1.140926877737093, global step: 999


 23%|██▎       | 5000/21967 [07:13<24:54, 11.35it/s]

Training loss: 1.1408728079444745, global step: 999
Training loss: 1.1409053192970442, global step: 999


 25%|██▌       | 5496/21967 [07:56<23:59, 11.44it/s]

Training loss: 1.1396116617107739, global step: 1099
Training loss: 1.1396489951168518, global step: 1099
Training loss: 1.1396087083693784, global step: 1099


 25%|██▌       | 5500/21967 [07:56<24:10, 11.35it/s]

Training loss: 1.1396310212190215, global step: 1099
Training loss: 1.139584914345723, global step: 1099


 27%|██▋       | 5996/21967 [08:39<23:15, 11.44it/s]

Training loss: 1.1355819347304439, global step: 1199
Training loss: 1.1356665001283732, global step: 1199
Training loss: 1.1356537172809051, global step: 1199


 27%|██▋       | 6000/21967 [08:39<23:27, 11.35it/s]

Training loss: 1.13567721005995, global step: 1199
Training loss: 1.1355688613362065, global step: 1199


 30%|██▉       | 6496/21967 [09:22<23:10, 11.13it/s]

Training loss: 1.131697529096664, global step: 1299
Training loss: 1.1317722311328289, global step: 1299
Training loss: 1.1317869515885146, global step: 1299


 30%|██▉       | 6500/21967 [09:22<23:25, 11.01it/s]

Training loss: 1.131788738203777, global step: 1299
Training loss: 1.1318369014519676, global step: 1299


 32%|███▏      | 6996/21967 [10:06<22:26, 11.12it/s]

Training loss: 1.1303235356779164, global step: 1399
Training loss: 1.1302396730323343, global step: 1399
Training loss: 1.130197366801001, global step: 1399


 32%|███▏      | 7000/21967 [10:06<22:39, 11.01it/s]

Training loss: 1.1302140122425917, global step: 1399
Training loss: 1.1303205941309415, global step: 1399


 34%|███▍      | 7496/21967 [10:49<21:04, 11.45it/s]

Training loss: 1.1273206225041552, global step: 1499
Training loss: 1.1272690691254643, global step: 1499
Training loss: 1.1272328624906056, global step: 1499


 34%|███▍      | 7500/21967 [10:50<21:14, 11.35it/s]

Training loss: 1.1272318845450338, global step: 1499
Training loss: 1.1272066931072227, global step: 1499


 36%|███▋      | 7996/21967 [11:33<20:20, 11.45it/s]

Training loss: 1.1250055824549143, global step: 1599
Training loss: 1.1249651067814286, global step: 1599
Training loss: 1.1249987053570412, global step: 1599


 36%|███▋      | 8000/21967 [11:33<20:30, 11.35it/s]

Training loss: 1.1250138777552932, global step: 1599
Training loss: 1.1250017529169773, global step: 1599


 39%|███▊      | 8496/21967 [12:16<19:36, 11.45it/s]

Training loss: 1.1238431210488764, global step: 1699
Training loss: 1.1238389477360005, global step: 1699
Training loss: 1.1238178690814111, global step: 1699


 39%|███▊      | 8500/21967 [12:16<19:44, 11.36it/s]

Training loss: 1.1238462230969624, global step: 1699
Training loss: 1.123839306974603, global step: 1699


 41%|████      | 8996/21967 [12:59<19:25, 11.13it/s]

Training loss: 1.1219761846933913, global step: 1799
Training loss: 1.1221224125855669, global step: 1799
Training loss: 1.1221095450968628, global step: 1799


 41%|████      | 9000/21967 [13:00<19:38, 11.00it/s]

Training loss: 1.122060931612085, global step: 1799
Training loss: 1.1220197626354893, global step: 1799


 43%|████▎     | 9496/21967 [13:43<18:10, 11.44it/s]

Training loss: 1.1200471635984395, global step: 1899
Training loss: 1.1200146415806749, global step: 1899
Training loss: 1.119980711640936, global step: 1899


 43%|████▎     | 9500/21967 [13:43<18:18, 11.35it/s]

Training loss: 1.1199315096551516, global step: 1899
Training loss: 1.1199178575319557, global step: 1899


 46%|████▌     | 9996/21967 [14:26<17:26, 11.44it/s]

Training loss: 1.1189119806918517, global step: 1999
Training loss: 1.1188637063022004, global step: 1999
Training loss: 1.118830823122179, global step: 1999


 46%|████▌     | 10000/21967 [14:27<17:34, 11.35it/s]

Training loss: 1.1188304698797829, global step: 1999
Training loss: 1.118767033081011, global step: 1999


 48%|████▊     | 10496/21967 [15:09<16:43, 11.43it/s]

Training loss: 1.1170266065393362, global step: 2099
Training loss: 1.117067920413896, global step: 2099
Training loss: 1.1170909745400266, global step: 2099


 48%|████▊     | 10500/21967 [15:10<16:52, 11.33it/s]

Training loss: 1.1170859990648927, global step: 2099
Training loss: 1.1171250105740138, global step: 2099


 50%|█████     | 10996/21967 [15:53<15:59, 11.43it/s]

Training loss: 1.1149095689107906, global step: 2199
Training loss: 1.1149183660345854, global step: 2199
Training loss: 1.114864285872358, global step: 2199


 50%|█████     | 11000/21967 [15:53<16:07, 11.34it/s]

Training loss: 1.1148719228600876, global step: 2199
Training loss: 1.1148431541704515, global step: 2199


 52%|█████▏    | 11496/21967 [16:36<15:16, 11.43it/s]

Training loss: 1.1131561598534323, global step: 2299
Training loss: 1.1131658916735585, global step: 2299
Training loss: 1.1131593079415625, global step: 2299


 52%|█████▏    | 11500/21967 [16:36<15:23, 11.33it/s]

Training loss: 1.1131323669471882, global step: 2299
Training loss: 1.1131299004679203, global step: 2299


 55%|█████▍    | 11996/21967 [17:19<14:32, 11.42it/s]

Training loss: 1.1107433929604607, global step: 2399
Training loss: 1.1107177857704593, global step: 2399
Training loss: 1.1107318825635342, global step: 2399


 55%|█████▍    | 12000/21967 [17:19<14:39, 11.33it/s]

Training loss: 1.1107581146247716, global step: 2399
Training loss: 1.110751328543257, global step: 2399


 57%|█████▋    | 12496/21967 [18:02<13:49, 11.42it/s]

Training loss: 1.1098867179161789, global step: 2499
Training loss: 1.1099610628945356, global step: 2499
Training loss: 1.1099407498657612, global step: 2499


 57%|█████▋    | 12500/21967 [18:02<13:55, 11.33it/s]

Training loss: 1.10992643422216, global step: 2499
Training loss: 1.1099319220482182, global step: 2499


 59%|█████▉    | 12996/21967 [18:45<13:05, 11.42it/s]

Training loss: 1.1084834215370634, global step: 2599
Training loss: 1.1084645357310874, global step: 2599
Training loss: 1.108448209483949, global step: 2599


 59%|█████▉    | 13000/21967 [18:46<13:11, 11.33it/s]

Training loss: 1.1084265906755228, global step: 2599
Training loss: 1.1084408843804574, global step: 2599


 61%|██████▏   | 13496/21967 [19:29<12:21, 11.43it/s]

Training loss: 1.1068122781931846, global step: 2699
Training loss: 1.1068165224684798, global step: 2699
Training loss: 1.1068002166440551, global step: 2699


 61%|██████▏   | 13500/21967 [19:29<12:27, 11.33it/s]

Training loss: 1.106806246188172, global step: 2699
Training loss: 1.1067655851814877, global step: 2699


 64%|██████▎   | 13996/21967 [20:12<11:38, 11.42it/s]

Training loss: 1.1048419447781443, global step: 2799
Training loss: 1.1048487502779194, global step: 2799
Training loss: 1.104839336955033, global step: 2799


 64%|██████▎   | 14000/21967 [20:12<11:43, 11.32it/s]

Training loss: 1.104815292458357, global step: 2799
Training loss: 1.1047869982037282, global step: 2799


 66%|██████▌   | 14496/21967 [20:55<10:53, 11.44it/s]

Training loss: 1.103332652767369, global step: 2899
Training loss: 1.1033171162211746, global step: 2899
Training loss: 1.1033189666681744, global step: 2899


 66%|██████▌   | 14500/21967 [20:55<10:57, 11.35it/s]

Training loss: 1.1032921504709807, global step: 2899
Training loss: 1.103291332882551, global step: 2899


 68%|██████▊   | 14996/21967 [21:38<10:09, 11.44it/s]

Training loss: 1.1017183516637887, global step: 2999
Training loss: 1.1017010532052247, global step: 2999
Training loss: 1.1017093632121582, global step: 2999


 68%|██████▊   | 15000/21967 [21:39<10:14, 11.35it/s]

Training loss: 1.1017378136323863, global step: 2999
Training loss: 1.1017246436196309, global step: 2999


 71%|███████   | 15496/21967 [22:21<09:25, 11.44it/s]

Training loss: 1.0999311189698227, global step: 3099
Training loss: 1.0999426910640249, global step: 3099
Training loss: 1.0999169649765776, global step: 3099


 71%|███████   | 15500/21967 [22:22<09:29, 11.36it/s]

Training loss: 1.0999027109265151, global step: 3099
Training loss: 1.0998848741957623, global step: 3099


 73%|███████▎  | 15996/21967 [23:04<08:41, 11.45it/s]

Training loss: 1.0981728281424021, global step: 3199
Training loss: 1.0981713306754044, global step: 3199
Training loss: 1.0981402097716662, global step: 3199


 73%|███████▎  | 16000/21967 [23:05<08:45, 11.36it/s]

Training loss: 1.0981643843712665, global step: 3199
Training loss: 1.0981954340618383, global step: 3199


 75%|███████▌  | 16496/21967 [23:48<07:57, 11.45it/s]

Training loss: 1.0972321829739105, global step: 3299
Training loss: 1.097283266274264, global step: 3299
Training loss: 1.097282085669556, global step: 3299


 75%|███████▌  | 16500/21967 [23:48<08:01, 11.36it/s]

Training loss: 1.0972776081079294, global step: 3299
Training loss: 1.0972897794369285, global step: 3299


 77%|███████▋  | 16996/21967 [24:31<07:14, 11.44it/s]

Training loss: 1.0955176418889934, global step: 3399
Training loss: 1.095519159267996, global step: 3399
Training loss: 1.0955100468144934, global step: 3399


 77%|███████▋  | 17000/21967 [24:31<07:17, 11.36it/s]

Training loss: 1.0955175529647776, global step: 3399
Training loss: 1.0955440118638016, global step: 3399


 80%|███████▉  | 17496/21967 [25:14<06:30, 11.45it/s]

Training loss: 1.094312963769517, global step: 3499
Training loss: 1.0942944544356845, global step: 3499
Training loss: 1.0943643101172467, global step: 3499


 80%|███████▉  | 17500/21967 [25:14<06:33, 11.36it/s]

Training loss: 1.0943312246157846, global step: 3499
Training loss: 1.0943393958259346, global step: 3499


 82%|████████▏ | 17996/21967 [25:57<05:46, 11.45it/s]

Training loss: 1.0929934515841109, global step: 3599
Training loss: 1.093015119809327, global step: 3599
Training loss: 1.0929882905298043, global step: 3599


 82%|████████▏ | 18000/21967 [25:57<05:49, 11.35it/s]

Training loss: 1.0929868629173063, global step: 3599
Training loss: 1.0929756234805381, global step: 3599


 84%|████████▍ | 18496/21967 [26:40<05:03, 11.45it/s]

Training loss: 1.0913318963913248, global step: 3699
Training loss: 1.0913272817267885, global step: 3699
Training loss: 1.0913277037138422, global step: 3699


 84%|████████▍ | 18500/21967 [26:40<05:05, 11.36it/s]

Training loss: 1.0913427018941395, global step: 3699
Training loss: 1.0913095436625766, global step: 3699


 86%|████████▋ | 18996/21967 [27:24<04:27, 11.12it/s]

Training loss: 1.0896824922411654, global step: 3799
Training loss: 1.089673697077731, global step: 3799
Training loss: 1.0896589564697903, global step: 3799


 86%|████████▋ | 19000/21967 [27:24<04:29, 10.99it/s]

Training loss: 1.089725465540924, global step: 3799
Training loss: 1.0897254417558, global step: 3799


 89%|████████▉ | 19496/21967 [28:07<03:36, 11.44it/s]

Training loss: 1.08808143636024, global step: 3899
Training loss: 1.0880770859387647, global step: 3899
Training loss: 1.0880824969524958, global step: 3899


 89%|████████▉ | 19500/21967 [28:08<03:37, 11.34it/s]

Training loss: 1.0881188784468319, global step: 3899
Training loss: 1.0881064649031489, global step: 3899


 91%|█████████ | 19996/21967 [28:50<02:52, 11.44it/s]

Training loss: 1.0865733147293486, global step: 3999
Training loss: 1.086562294400461, global step: 3999
Training loss: 1.0865594486972776, global step: 3999


 91%|█████████ | 20000/21967 [28:51<02:53, 11.35it/s]

Training loss: 1.086529022789631, global step: 3999
Training loss: 1.0865384930461353, global step: 3999


 93%|█████████▎| 20496/21967 [29:33<02:08, 11.45it/s]

Training loss: 1.0855186369272813, global step: 4099
Training loss: 1.0855041334018762, global step: 4099
Training loss: 1.0855204274008878, global step: 4099


 93%|█████████▎| 20500/21967 [29:34<02:09, 11.35it/s]

Training loss: 1.0854974364294274, global step: 4099
Training loss: 1.0854943308673421, global step: 4099


 96%|█████████▌| 20996/21967 [30:17<01:24, 11.45it/s]

Training loss: 1.0838868175277712, global step: 4199
Training loss: 1.0838638731369146, global step: 4199
Training loss: 1.0838452531440421, global step: 4199


 96%|█████████▌| 21000/21967 [30:17<01:25, 11.35it/s]

Training loss: 1.083829206843236, global step: 4199
Training loss: 1.0838095747504588, global step: 4199


 98%|█████████▊| 21496/21967 [31:00<00:41, 11.45it/s]

Training loss: 1.0829145724564326, global step: 4299
Training loss: 1.082912020027571, global step: 4299
Training loss: 1.0829284584355332, global step: 4299


 98%|█████████▊| 21500/21967 [31:00<00:41, 11.36it/s]

Training loss: 1.082919774670769, global step: 4299
Training loss: 1.0829189673452202, global step: 4299


100%|██████████| 21967/21967 [31:40<00:00, 11.56it/s]


begin evaluation!!!


  0%|          | 2/21967 [00:00<32:37, 11.22it/s]

eval_loss: 0.9197493560383911
eval_accuracy: 0.6399721448467967
Training Epoch: 3/5


  2%|▏         | 496/21967 [00:42<32:01, 11.17it/s]

Training loss: 0.8813294225870961, global step: 99
Training loss: 0.8836481559781297, global step: 99
Training loss: 0.8841267783996805, global step: 99


  2%|▏         | 500/21967 [00:43<32:26, 11.03it/s]

Training loss: 0.8853889498485619, global step: 99
Training loss: 0.8852213839132466, global step: 99


  5%|▍         | 996/21967 [01:26<30:31, 11.45it/s]

Training loss: 0.8797054379205009, global step: 199
Training loss: 0.8797068832644138, global step: 199
Training loss: 0.8797813540729978, global step: 199


  5%|▍         | 1000/21967 [01:26<30:46, 11.35it/s]

Training loss: 0.8795911588282408, global step: 199
Training loss: 0.8796500840836817, global step: 199


  7%|▋         | 1496/21967 [02:09<29:48, 11.44it/s]

Training loss: 0.8817089493657434, global step: 299
Training loss: 0.8814707762775893, global step: 299
Training loss: 0.8818553044506129, global step: 299


  7%|▋         | 1500/21967 [02:09<30:04, 11.34it/s]

Training loss: 0.8818739579222071, global step: 299
Training loss: 0.8821187087541584, global step: 299


  9%|▉         | 1996/21967 [02:52<29:05, 11.44it/s]

Training loss: 0.8834155725394573, global step: 399
Training loss: 0.8832381958392017, global step: 399
Training loss: 0.8832406254823828, global step: 399


  9%|▉         | 2000/21967 [02:52<29:21, 11.34it/s]

Training loss: 0.8833402754763165, global step: 399
Training loss: 0.8839704352633663, global step: 399


 11%|█▏        | 2496/21967 [03:35<28:22, 11.44it/s]

Training loss: 0.8859447394156146, global step: 499
Training loss: 0.8859122324546871, global step: 499
Training loss: 0.8858170352225715, global step: 499


 11%|█▏        | 2500/21967 [03:36<28:35, 11.35it/s]

Training loss: 0.885624570840652, global step: 499
Training loss: 0.8857794797413943, global step: 499


 14%|█▎        | 2996/21967 [04:18<27:38, 11.44it/s]

Training loss: 0.8897354746414644, global step: 599
Training loss: 0.889813784317167, global step: 599
Training loss: 0.8897160976729892, global step: 599


 14%|█▎        | 3000/21967 [04:19<27:53, 11.33it/s]

Training loss: 0.8896399118003744, global step: 599
Training loss: 0.8896284299868769, global step: 599


 16%|█▌        | 3496/21967 [05:02<26:54, 11.44it/s]

Training loss: 0.8935316305005329, global step: 699
Training loss: 0.8935430240550971, global step: 699
Training loss: 0.89343516776604, global step: 699


 16%|█▌        | 3500/21967 [05:02<27:06, 11.35it/s]

Training loss: 0.8934167923289684, global step: 699
Training loss: 0.8934218357012387, global step: 699


 18%|█▊        | 3996/21967 [05:45<26:11, 11.43it/s]

Training loss: 0.8934170778575021, global step: 799
Training loss: 0.8932766751595983, global step: 799
Training loss: 0.8932915546274018, global step: 799


 18%|█▊        | 4000/21967 [05:45<26:24, 11.34it/s]

Training loss: 0.8933311136641641, global step: 799
Training loss: 0.8932625769219955, global step: 799


 20%|██        | 4496/21967 [06:28<25:27, 11.44it/s]

Training loss: 0.8912203667161197, global step: 899
Training loss: 0.8912220887998882, global step: 899
Training loss: 0.8911825909677124, global step: 899


 20%|██        | 4500/21967 [06:28<25:38, 11.35it/s]

Training loss: 0.8910738401886833, global step: 899
Training loss: 0.8910439455739708, global step: 899


 23%|██▎       | 4996/21967 [07:11<24:42, 11.44it/s]

Training loss: 0.8920658611693421, global step: 999
Training loss: 0.8920146369347699, global step: 999
Training loss: 0.892070540768792, global step: 999


 23%|██▎       | 5000/21967 [07:11<24:54, 11.35it/s]

Training loss: 0.8920416230333595, global step: 999
Training loss: 0.8920297497143076, global step: 999


 25%|██▌       | 5496/21967 [07:54<23:59, 11.45it/s]

Training loss: 0.8905937378817629, global step: 1099
Training loss: 0.8905975162315968, global step: 1099
Training loss: 0.89062179453713, global step: 1099


 25%|██▌       | 5500/21967 [07:54<24:10, 11.35it/s]

Training loss: 0.8906008218089508, global step: 1099
Training loss: 0.890586966537935, global step: 1099


 27%|██▋       | 5996/21967 [08:37<23:15, 11.45it/s]

Training loss: 0.8917209267491992, global step: 1199
Training loss: 0.8917229238069996, global step: 1199
Training loss: 0.8917438184293643, global step: 1199


 27%|██▋       | 6000/21967 [08:37<23:26, 11.35it/s]

Training loss: 0.8917582365558441, global step: 1199
Training loss: 0.8916852599775856, global step: 1199


 30%|██▉       | 6496/21967 [09:21<23:10, 11.13it/s]

Training loss: 0.8943489876410152, global step: 1299
Training loss: 0.894255601557825, global step: 1299
Training loss: 0.8941868565505297, global step: 1299


 30%|██▉       | 6500/21967 [09:21<23:25, 11.01it/s]

Training loss: 0.8942610113326862, global step: 1299
Training loss: 0.8942868980531932, global step: 1299


 32%|███▏      | 6996/21967 [10:05<22:25, 11.12it/s]

Training loss: 0.8942806083903387, global step: 1399
Training loss: 0.8941872264005631, global step: 1399
Training loss: 0.8941093575471859, global step: 1399


 32%|███▏      | 7000/21967 [10:05<22:39, 11.01it/s]

Training loss: 0.8940416015642018, global step: 1399
Training loss: 0.8940472905412966, global step: 1399


 34%|███▍      | 7496/21967 [10:49<21:39, 11.13it/s]

Training loss: 0.8933674368164872, global step: 1499
Training loss: 0.8933604568751925, global step: 1499
Training loss: 0.8933153686339623, global step: 1499


 34%|███▍      | 7500/21967 [10:49<21:53, 11.01it/s]

Training loss: 0.8934233939927494, global step: 1499
Training loss: 0.8934290351462628, global step: 1499


 36%|███▋      | 7996/21967 [11:33<20:22, 11.43it/s]

Training loss: 0.8916099468032929, global step: 1599
Training loss: 0.8915755005465776, global step: 1599
Training loss: 0.8915340803513284, global step: 1599


 36%|███▋      | 8000/21967 [11:33<20:31, 11.34it/s]

Training loss: 0.8916140936634583, global step: 1599
Training loss: 0.8915951739894679, global step: 1599


 39%|███▊      | 8496/21967 [12:16<19:38, 11.43it/s]

Training loss: 0.8929703378976444, global step: 1699
Training loss: 0.8930583927020073, global step: 1699
Training loss: 0.8930580557919263, global step: 1699


 39%|███▊      | 8500/21967 [12:16<19:47, 11.34it/s]

Training loss: 0.89308565105979, global step: 1699
Training loss: 0.8930679523173585, global step: 1699


 41%|████      | 8996/21967 [12:59<18:53, 11.44it/s]

Training loss: 0.8931294972011783, global step: 1799
Training loss: 0.8931899628767196, global step: 1799
Training loss: 0.8931329969221073, global step: 1799


 41%|████      | 9000/21967 [12:59<19:03, 11.34it/s]

Training loss: 0.8931295915599312, global step: 1799
Training loss: 0.8932284700153026, global step: 1799


 43%|████▎     | 9496/21967 [13:42<18:10, 11.44it/s]

Training loss: 0.8923253355342787, global step: 1899
Training loss: 0.8923065892464647, global step: 1899
Training loss: 0.892342876805338, global step: 1899


 43%|████▎     | 9500/21967 [13:42<18:18, 11.34it/s]

Training loss: 0.8923190442402411, global step: 1899
Training loss: 0.8923689059674714, global step: 1899


 46%|████▌     | 9996/21967 [14:25<17:26, 11.44it/s]

Training loss: 0.8924057591085943, global step: 1999
Training loss: 0.8923612487278416, global step: 1999
Training loss: 0.8923317048666244, global step: 1999


 46%|████▌     | 10000/21967 [14:26<17:35, 11.34it/s]

Training loss: 0.8924172440792297, global step: 1999
Training loss: 0.8923617805244297, global step: 1999


 48%|████▊     | 10496/21967 [15:09<16:42, 11.44it/s]

Training loss: 0.8907843008651627, global step: 2099
Training loss: 0.8907740931118918, global step: 2099
Training loss: 0.8909286574591078, global step: 2099


 48%|████▊     | 10500/21967 [15:09<16:50, 11.35it/s]

Training loss: 0.891005043618865, global step: 2099
Training loss: 0.8909939701665991, global step: 2099


 50%|█████     | 10996/21967 [15:52<15:58, 11.44it/s]

Training loss: 0.8908205958260303, global step: 2199
Training loss: 0.8907975793506838, global step: 2199
Training loss: 0.8907654188296259, global step: 2199


 50%|█████     | 11000/21967 [15:52<16:06, 11.34it/s]

Training loss: 0.890775368470515, global step: 2199
Training loss: 0.8908341770350028, global step: 2199


 52%|█████▏    | 11496/21967 [16:35<15:14, 11.45it/s]

Training loss: 0.8928204837928859, global step: 2299
Training loss: 0.8928139520907565, global step: 2299
Training loss: 0.892813665677804, global step: 2299


 52%|█████▏    | 11500/21967 [16:35<15:21, 11.36it/s]

Training loss: 0.8928024363622217, global step: 2299
Training loss: 0.8928221284465869, global step: 2299


 55%|█████▍    | 11996/21967 [17:18<14:31, 11.45it/s]

Training loss: 0.8908922379546833, global step: 2399
Training loss: 0.8908576943823728, global step: 2399
Training loss: 0.8908282680806641, global step: 2399


 55%|█████▍    | 12000/21967 [17:18<14:38, 11.35it/s]

Training loss: 0.8907944053581633, global step: 2399
Training loss: 0.8909161741545631, global step: 2399


 57%|█████▋    | 12496/21967 [18:01<13:48, 11.43it/s]

Training loss: 0.8913677471430124, global step: 2499
Training loss: 0.8913208105352143, global step: 2499
Training loss: 0.8913036970254888, global step: 2499


 57%|█████▋    | 12500/21967 [18:02<13:54, 11.34it/s]

Training loss: 0.8912774472761047, global step: 2499
Training loss: 0.8912213600409566, global step: 2499


 59%|█████▉    | 12996/21967 [18:44<13:05, 11.43it/s]

Training loss: 0.8903452223616181, global step: 2599
Training loss: 0.8903198883048651, global step: 2599
Training loss: 0.8903864864034956, global step: 2599


 59%|█████▉    | 13000/21967 [18:45<13:10, 11.34it/s]

Training loss: 0.8903937150491652, global step: 2599
Training loss: 0.8903592249093555, global step: 2599


 61%|██████▏   | 13496/21967 [19:28<12:21, 11.42it/s]

Training loss: 0.8899766945163159, global step: 2699
Training loss: 0.8899570511058901, global step: 2699
Training loss: 0.8899515566427885, global step: 2699


 61%|██████▏   | 13500/21967 [19:28<12:27, 11.32it/s]

Training loss: 0.8899064182369669, global step: 2699
Training loss: 0.8899724582396875, global step: 2699


 64%|██████▎   | 13996/21967 [20:11<11:44, 11.32it/s]

Training loss: 0.8910697699311845, global step: 2799
Training loss: 0.8910477292655016, global step: 2799
Training loss: 0.8910688640317032, global step: 2799


 64%|██████▎   | 14000/21967 [20:11<11:46, 11.28it/s]

Training loss: 0.8910664828061198, global step: 2799
Training loss: 0.8910568670604994, global step: 2799


 66%|██████▌   | 14496/21967 [20:54<10:53, 11.44it/s]

Training loss: 0.8903841386584834, global step: 2899
Training loss: 0.8903324697773172, global step: 2899
Training loss: 0.890367466226326, global step: 2899


 66%|██████▌   | 14500/21967 [20:54<10:58, 11.34it/s]

Training loss: 0.890354401728115, global step: 2899
Training loss: 0.8903907941309647, global step: 2899


 68%|██████▊   | 14996/21967 [21:37<10:09, 11.44it/s]

Training loss: 0.8902843691798678, global step: 2999
Training loss: 0.8902783160564677, global step: 2999
Training loss: 0.8902839679670166, global step: 2999


 68%|██████▊   | 15000/21967 [21:38<10:14, 11.34it/s]

Training loss: 0.8902915284158802, global step: 2999
Training loss: 0.8902859121264994, global step: 2999


 71%|███████   | 15496/21967 [22:20<09:25, 11.44it/s]

Training loss: 0.8902392050878257, global step: 3099
Training loss: 0.8902493449570837, global step: 3099
Training loss: 0.8902609034136659, global step: 3099


 71%|███████   | 15500/21967 [22:21<09:29, 11.35it/s]

Training loss: 0.8902627287531344, global step: 3099
Training loss: 0.8902625002945687, global step: 3099


 73%|███████▎  | 15996/21967 [23:04<08:41, 11.44it/s]

Training loss: 0.8907851696321837, global step: 3199
Training loss: 0.890790961365136, global step: 3199
Training loss: 0.8907943173582542, global step: 3199


 73%|███████▎  | 16000/21967 [23:04<08:46, 11.34it/s]

Training loss: 0.8907830098725167, global step: 3199
Training loss: 0.8907922589644409, global step: 3199


 75%|███████▌  | 16496/21967 [23:47<07:58, 11.44it/s]

Training loss: 0.891283755158883, global step: 3299
Training loss: 0.8912860415618465, global step: 3299
Training loss: 0.8912990070608932, global step: 3299


 75%|███████▌  | 16500/21967 [23:47<08:02, 11.34it/s]

Training loss: 0.8913064204510457, global step: 3299
Training loss: 0.8912751253596529, global step: 3299


 77%|███████▋  | 16996/21967 [24:30<07:15, 11.42it/s]

Training loss: 0.8909434310787213, global step: 3399
Training loss: 0.8909672047665825, global step: 3399
Training loss: 0.8909444608368081, global step: 3399


 77%|███████▋  | 17000/21967 [24:30<07:18, 11.33it/s]

Training loss: 0.8909570930672718, global step: 3399
Training loss: 0.8909224517057451, global step: 3399


 80%|███████▉  | 17496/21967 [25:13<06:31, 11.43it/s]

Training loss: 0.8907106256397256, global step: 3499
Training loss: 0.8906881110700103, global step: 3499
Training loss: 0.8906669272078462, global step: 3499


 80%|███████▉  | 17500/21967 [25:13<06:34, 11.34it/s]

Training loss: 0.8906581155327317, global step: 3499
Training loss: 0.8906624752158941, global step: 3499


 82%|████████▏ | 17996/21967 [25:57<05:56, 11.13it/s]

Training loss: 0.8899599344101373, global step: 3599
Training loss: 0.8900055731326891, global step: 3599
Training loss: 0.889987665413651, global step: 3599


 82%|████████▏ | 18000/21967 [25:57<06:00, 11.01it/s]

Training loss: 0.8900007999906051, global step: 3599
Training loss: 0.8899995981794218, global step: 3599


 84%|████████▍ | 18496/21967 [26:41<05:12, 11.12it/s]

Training loss: 0.8892065545563667, global step: 3699
Training loss: 0.8891635889346167, global step: 3699
Training loss: 0.8891883716108046, global step: 3699


 84%|████████▍ | 18500/21967 [26:41<05:15, 11.00it/s]

Training loss: 0.8891716524970694, global step: 3699
Training loss: 0.8891889094936667, global step: 3699


 86%|████████▋ | 18996/21967 [27:24<04:19, 11.45it/s]

Training loss: 0.8896390062220695, global step: 3799
Training loss: 0.8896519185480607, global step: 3799
Training loss: 0.8896412913997707, global step: 3799


 86%|████████▋ | 19000/21967 [27:25<04:21, 11.36it/s]

Training loss: 0.8896403178276169, global step: 3799
Training loss: 0.8896644703619944, global step: 3799


 89%|████████▉ | 19496/21967 [28:07<03:36, 11.44it/s]

Training loss: 0.8897345618466425, global step: 3899
Training loss: 0.8897106662496945, global step: 3899
Training loss: 0.8897120580396258, global step: 3899


 89%|████████▉ | 19500/21967 [28:08<03:37, 11.35it/s]

Training loss: 0.889697768902472, global step: 3899
Training loss: 0.8897018912132489, global step: 3899


 91%|█████████ | 19996/21967 [28:51<02:52, 11.45it/s]

Training loss: 0.8905231990397096, global step: 3999
Training loss: 0.8905665959127412, global step: 3999
Training loss: 0.8905565944782641, global step: 3999


 91%|█████████ | 20000/21967 [28:51<02:53, 11.35it/s]

Training loss: 0.8905340070872665, global step: 3999
Training loss: 0.8905819401257118, global step: 3999


 93%|█████████▎| 20496/21967 [29:34<02:08, 11.43it/s]

Training loss: 0.8905604302928589, global step: 4099
Training loss: 0.8905560495937134, global step: 4099
Training loss: 0.8905329224268009, global step: 4099


 93%|█████████▎| 20500/21967 [29:34<02:09, 11.34it/s]

Training loss: 0.8905395822028324, global step: 4099
Training loss: 0.8905112651671102, global step: 4099


 96%|█████████▌| 20996/21967 [30:17<01:24, 11.44it/s]

Training loss: 0.8905914684789086, global step: 4199
Training loss: 0.8905976503020393, global step: 4199
Training loss: 0.890580108195072, global step: 4199


 96%|█████████▌| 21000/21967 [30:17<01:25, 11.34it/s]

Training loss: 0.8905880023024644, global step: 4199
Training loss: 0.8906070673549534, global step: 4199


 98%|█████████▊| 21496/21967 [31:00<00:41, 11.44it/s]

Training loss: 0.8905650235667468, global step: 4299
Training loss: 0.8905805381273856, global step: 4299
Training loss: 0.8905592576678203, global step: 4299


 98%|█████████▊| 21500/21967 [31:00<00:41, 11.35it/s]

Training loss: 0.8905688981243232, global step: 4299
Training loss: 0.890580710127608, global step: 4299


100%|██████████| 21967/21967 [31:40<00:00, 11.56it/s]


begin evaluation!!!


  0%|          | 2/21967 [00:00<31:28, 11.63it/s]

eval_loss: 0.8819074710016463
eval_accuracy: 0.6483286908077994
Training Epoch: 4/5


  2%|▏         | 496/21967 [00:42<31:14, 11.45it/s]

Training loss: 0.6592722936523984, global step: 99
Training loss: 0.6595651204960661, global step: 99
Training loss: 0.6600545606795151, global step: 99


  2%|▏         | 500/21967 [00:43<31:30, 11.35it/s]

Training loss: 0.6595725273324483, global step: 99
Training loss: 0.661204112775788, global step: 99


  5%|▍         | 996/21967 [01:25<30:31, 11.45it/s]

Training loss: 0.6527325137775747, global step: 199
Training loss: 0.6525719547906075, global step: 199
Training loss: 0.6522316410465011, global step: 199


  5%|▍         | 1000/21967 [01:26<30:46, 11.35it/s]

Training loss: 0.651879647780277, global step: 199
Training loss: 0.6518143132343903, global step: 199


  7%|▋         | 1496/21967 [02:08<29:48, 11.44it/s]

Training loss: 0.6448061370705083, global step: 299
Training loss: 0.6445553781051329, global step: 299
Training loss: 0.6449099183654936, global step: 299


  7%|▋         | 1500/21967 [02:09<30:03, 11.35it/s]

Training loss: 0.6446237985333192, global step: 299
Training loss: 0.6448452063004242, global step: 299


  9%|▉         | 1996/21967 [02:52<29:03, 11.45it/s]

Training loss: 0.6391082035555017, global step: 399
Training loss: 0.6389421889159963, global step: 399
Training loss: 0.63886214651561, global step: 399


  9%|▉         | 2000/21967 [02:52<29:18, 11.35it/s]

Training loss: 0.6387596466863813, global step: 399
Training loss: 0.6389595262986748, global step: 399


 11%|█▏        | 2496/21967 [03:35<28:20, 11.45it/s]

Training loss: 0.6358344676853989, global step: 499
Training loss: 0.635637413835353, global step: 499
Training loss: 0.6354330731988472, global step: 499


 11%|█▏        | 2500/21967 [03:35<28:33, 11.36it/s]

Training loss: 0.6353311653761976, global step: 499
Training loss: 0.6352166402581878, global step: 499


 14%|█▎        | 2996/21967 [04:18<27:36, 11.45it/s]

Training loss: 0.636870241705919, global step: 599
Training loss: 0.6368109241500373, global step: 599
Training loss: 0.6367593620865567, global step: 599


 14%|█▎        | 3000/21967 [04:18<27:50, 11.36it/s]

Training loss: 0.636611939648016, global step: 599
Training loss: 0.6365473908780589, global step: 599


 16%|█▌        | 3496/21967 [05:01<26:53, 11.45it/s]

Training loss: 0.6318627722206951, global step: 699
Training loss: 0.631712714501857, global step: 699
Training loss: 0.6317382754234947, global step: 699


 16%|█▌        | 3500/21967 [05:01<27:05, 11.36it/s]

Training loss: 0.6318009282630566, global step: 699
Training loss: 0.6318902322270826, global step: 699


 18%|█▊        | 3996/21967 [05:44<26:10, 11.44it/s]

Training loss: 0.6372230815208133, global step: 799
Training loss: 0.6374821309547953, global step: 799
Training loss: 0.6375188875986761, global step: 799


 18%|█▊        | 4000/21967 [05:44<26:23, 11.35it/s]

Training loss: 0.637620437900308, global step: 799
Training loss: 0.6374669141383205, global step: 799


 20%|██        | 4496/21967 [06:27<25:27, 11.44it/s]

Training loss: 0.6365577679135742, global step: 899
Training loss: 0.6365380377663644, global step: 899
Training loss: 0.636455723783053, global step: 899


 20%|██        | 4500/21967 [06:27<25:38, 11.35it/s]

Training loss: 0.6364905054235012, global step: 899
Training loss: 0.6365639151652162, global step: 899


 23%|██▎       | 4996/21967 [07:10<24:43, 11.44it/s]

Training loss: 0.6379268124945708, global step: 999
Training loss: 0.638020569408645, global step: 999
Training loss: 0.6380316800519314, global step: 999


 23%|██▎       | 5000/21967 [07:11<24:54, 11.35it/s]

Training loss: 0.6381339882248912, global step: 999
Training loss: 0.6380706563969032, global step: 999


 25%|██▌       | 5496/21967 [07:53<23:58, 11.45it/s]

Training loss: 0.6397222966050646, global step: 1099
Training loss: 0.6396853857601856, global step: 1099
Training loss: 0.6396457967510196, global step: 1099


 25%|██▌       | 5500/21967 [07:54<24:10, 11.35it/s]

Training loss: 0.639540574834206, global step: 1099
Training loss: 0.6395080407421239, global step: 1099


 27%|██▋       | 5996/21967 [08:36<23:14, 11.45it/s]

Training loss: 0.6376438351278135, global step: 1199
Training loss: 0.6376804005360235, global step: 1199
Training loss: 0.6376739292644239, global step: 1199


 27%|██▋       | 6000/21967 [08:37<23:26, 11.35it/s]

Training loss: 0.6375952693983098, global step: 1199
Training loss: 0.6375949752147093, global step: 1199


 30%|██▉       | 6496/21967 [09:20<22:31, 11.45it/s]

Training loss: 0.6379826826616273, global step: 1299
Training loss: 0.637962140114674, global step: 1299
Training loss: 0.6380070092965577, global step: 1299


 30%|██▉       | 6500/21967 [09:20<22:42, 11.35it/s]

Training loss: 0.638022927310382, global step: 1299
Training loss: 0.6379797833383136, global step: 1299


 32%|███▏      | 6996/21967 [10:03<21:47, 11.45it/s]

Training loss: 0.6398119727099012, global step: 1399
Training loss: 0.6398032274947322, global step: 1399
Training loss: 0.6397761204962112, global step: 1399


 32%|███▏      | 7000/21967 [10:03<21:57, 11.36it/s]

Training loss: 0.6397614969662198, global step: 1399
Training loss: 0.6398113743371592, global step: 1399


 34%|███▍      | 7496/21967 [10:46<21:04, 11.45it/s]

Training loss: 0.6396725487868594, global step: 1499
Training loss: 0.6397209582184894, global step: 1499
Training loss: 0.6397765456461936, global step: 1499


 34%|███▍      | 7500/21967 [10:46<21:14, 11.35it/s]

Training loss: 0.6397390075360327, global step: 1499
Training loss: 0.6397567290364643, global step: 1499


 36%|███▋      | 7996/21967 [11:29<20:20, 11.45it/s]

Training loss: 0.6392694622888612, global step: 1599
Training loss: 0.6392970875221462, global step: 1599
Training loss: 0.6393340898551232, global step: 1599


 36%|███▋      | 8000/21967 [11:29<20:30, 11.35it/s]

Training loss: 0.6392789265454359, global step: 1599
Training loss: 0.6392946548793469, global step: 1599


 39%|███▊      | 8496/21967 [12:12<19:36, 11.45it/s]

Training loss: 0.6402861374739794, global step: 1699
Training loss: 0.6403159729042343, global step: 1699
Training loss: 0.6402839986748331, global step: 1699


 39%|███▊      | 8500/21967 [12:12<19:46, 11.35it/s]

Training loss: 0.6402278779142004, global step: 1699
Training loss: 0.6402467943245324, global step: 1699


 41%|████      | 8996/21967 [12:55<18:52, 11.45it/s]

Training loss: 0.6425000856971402, global step: 1799
Training loss: 0.6424611549544282, global step: 1799
Training loss: 0.6425115603204308, global step: 1799


 41%|████      | 9000/21967 [12:55<19:01, 11.36it/s]

Training loss: 0.6424911974896514, global step: 1799
Training loss: 0.6424708113724801, global step: 1799


 43%|████▎     | 9496/21967 [13:38<18:09, 11.45it/s]

Training loss: 0.6449656805930597, global step: 1899
Training loss: 0.6449204938038358, global step: 1899
Training loss: 0.6448794657783372, global step: 1899


 43%|████▎     | 9500/21967 [13:39<18:17, 11.36it/s]

Training loss: 0.6448881982535798, global step: 1899
Training loss: 0.6449270074056783, global step: 1899


 50%|█████     | 10996/21967 [15:47<15:57, 11.45it/s]

Training loss: 0.6464520212449429, global step: 2199
Training loss: 0.6464290313621759, global step: 2199
Training loss: 0.6463876637414674, global step: 2199


 50%|█████     | 11000/21967 [15:48<16:05, 11.36it/s]

Training loss: 0.64633011103065, global step: 2199
Training loss: 0.6463079676077155, global step: 2199


 52%|█████▏    | 11496/21967 [16:31<15:14, 11.45it/s]

Training loss: 0.6465866273774823, global step: 2299
Training loss: 0.6465819661578472, global step: 2299
Training loss: 0.6465504788296856, global step: 2299


 52%|█████▏    | 11500/21967 [16:31<15:21, 11.36it/s]

Training loss: 0.6465619615258928, global step: 2299
Training loss: 0.6465594846043876, global step: 2299


 55%|█████▍    | 11996/21967 [17:14<14:32, 11.43it/s]

Training loss: 0.6463944787898913, global step: 2399
Training loss: 0.6463871490443863, global step: 2399
Training loss: 0.646381046581945, global step: 2399


 55%|█████▍    | 12000/21967 [17:14<14:38, 11.34it/s]

Training loss: 0.6463500293544628, global step: 2399
Training loss: 0.6463264295286958, global step: 2399


 57%|█████▋    | 12496/21967 [17:57<13:47, 11.44it/s]

Training loss: 0.6474533430296739, global step: 2499
Training loss: 0.6474638807522075, global step: 2499
Training loss: 0.64746109306353, global step: 2499


 57%|█████▋    | 12500/21967 [17:57<13:53, 11.35it/s]

Training loss: 0.6474418067421213, global step: 2499
Training loss: 0.6474217668367975, global step: 2499


 59%|█████▉    | 12996/21967 [18:40<13:04, 11.44it/s]

Training loss: 0.6468690403522304, global step: 2599
Training loss: 0.6468577662106595, global step: 2599
Training loss: 0.6469044967820106, global step: 2599


 59%|█████▉    | 13000/21967 [18:40<13:10, 11.35it/s]

Training loss: 0.6469602051206186, global step: 2599
Training loss: 0.6469434517442644, global step: 2599


 61%|██████▏   | 13496/21967 [19:23<12:19, 11.45it/s]

Training loss: 0.6471248931169514, global step: 2699
Training loss: 0.6470954629554333, global step: 2699
Training loss: 0.6470900634761001, global step: 2699


 61%|██████▏   | 13500/21967 [19:23<12:25, 11.35it/s]

Training loss: 0.6471512851749889, global step: 2699
Training loss: 0.6471225158670357, global step: 2699


 64%|██████▎   | 13996/21967 [20:06<11:36, 11.44it/s]

Training loss: 0.6468217901664408, global step: 2799
Training loss: 0.6468273277987063, global step: 2799
Training loss: 0.6467949854610541, global step: 2799


 64%|██████▎   | 14000/21967 [20:06<11:41, 11.35it/s]

Training loss: 0.6468751379738992, global step: 2799
Training loss: 0.6468825942560246, global step: 2799


 66%|██████▌   | 14496/21967 [20:49<10:52, 11.45it/s]

Training loss: 0.6490846905611724, global step: 2899
Training loss: 0.6490574333951268, global step: 2899
Training loss: 0.6491039286906365, global step: 2899


 66%|██████▌   | 14500/21967 [20:50<10:57, 11.36it/s]

Training loss: 0.6490898146749344, global step: 2899
Training loss: 0.6490716744622405, global step: 2899


 68%|██████▊   | 14996/21967 [21:32<10:08, 11.46it/s]

Training loss: 0.64929221859694, global step: 2999
Training loss: 0.6492992213326626, global step: 2999
Training loss: 0.6492725029304909, global step: 2999


 68%|██████▊   | 15000/21967 [21:33<10:13, 11.36it/s]

Training loss: 0.6492620588458073, global step: 2999
Training loss: 0.6492819413252, global step: 2999


 71%|███████   | 15496/21967 [22:15<09:25, 11.45it/s]

Training loss: 0.6510367859840585, global step: 3099
Training loss: 0.6510080653891105, global step: 3099
Training loss: 0.6509703530028318, global step: 3099


 71%|███████   | 15500/21967 [22:16<09:30, 11.34it/s]

Training loss: 0.6509648834624002, global step: 3099
Training loss: 0.650972365537592, global step: 3099


 73%|███████▎  | 15996/21967 [22:59<08:41, 11.44it/s]

Training loss: 0.6514732973181816, global step: 3199
Training loss: 0.6514574641907928, global step: 3199
Training loss: 0.6514641349085264, global step: 3199


 73%|███████▎  | 16000/21967 [22:59<08:45, 11.35it/s]

Training loss: 0.6514431695428761, global step: 3199
Training loss: 0.6514607384663736, global step: 3199


 75%|███████▌  | 16496/21967 [23:42<07:57, 11.45it/s]

Training loss: 0.6521720608139496, global step: 3299
Training loss: 0.6521820461401185, global step: 3299
Training loss: 0.6521790324832504, global step: 3299


 75%|███████▌  | 16500/21967 [23:42<08:01, 11.36it/s]

Training loss: 0.6521966909770658, global step: 3299
Training loss: 0.6521998283845057, global step: 3299


 76%|███████▋  | 16784/21967 [24:06<07:17, 11.85it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

 27%|██▋       | 5996/21967 [08:37<23:14, 11.45it/s]

Training loss: 0.3975631569711085, global step: 1199
Training loss: 0.39759215942264187, global step: 1199
Training loss: 0.3975275014507002, global step: 1199


 27%|██▋       | 6000/21967 [08:37<23:26, 11.36it/s]

Training loss: 0.39784418944716776, global step: 1199
Training loss: 0.39778541619515, global step: 1199


 30%|██▉       | 6496/21967 [09:20<22:33, 11.43it/s]

Training loss: 0.39666879191136384, global step: 1299
Training loss: 0.39669171364089983, global step: 1299
Training loss: 0.3967378827678789, global step: 1299


 30%|██▉       | 6500/21967 [09:20<22:44, 11.33it/s]

Training loss: 0.3966791675100925, global step: 1299
Training loss: 0.39663978024148755, global step: 1299


 32%|███▏      | 6996/21967 [10:03<21:49, 11.44it/s]

Training loss: 0.3968676707036647, global step: 1399
Training loss: 0.3968110003934427, global step: 1399
Training loss: 0.39695131649216436, global step: 1399


 32%|███▏      | 7000/21967 [10:03<21:58, 11.35it/s]

Training loss: 0.3969370843758325, global step: 1399
Training loss: 0.3969038721554909, global step: 1399


 34%|███▍      | 7496/21967 [10:46<21:03, 11.45it/s]

Training loss: 0.39576297235650343, global step: 1499
Training loss: 0.39576478363978296, global step: 1499
Training loss: 0.395736602705567, global step: 1499


 34%|███▍      | 7500/21967 [10:46<21:14, 11.35it/s]

Training loss: 0.39578094548913106, global step: 1499
Training loss: 0.39577749893610165, global step: 1499


 36%|███▋      | 7996/21967 [11:29<20:20, 11.45it/s]

Training loss: 0.3981649060108821, global step: 1599
Training loss: 0.39819814178271046, global step: 1599
Training loss: 0.39820885600839867, global step: 1599


 36%|███▋      | 8000/21967 [11:29<20:31, 11.35it/s]

Training loss: 0.39817639102296465, global step: 1599
Training loss: 0.39812954867735106, global step: 1599


 39%|███▊      | 8496/21967 [12:12<19:37, 11.44it/s]

Training loss: 0.39814250301787163, global step: 1699
Training loss: 0.3982122790958533, global step: 1699
Training loss: 0.3981857051946875, global step: 1699


 39%|███▊      | 8500/21967 [12:13<19:46, 11.35it/s]

Training loss: 0.3981394054556788, global step: 1699
Training loss: 0.39824575487793573, global step: 1699


 41%|████      | 8996/21967 [12:55<18:53, 11.44it/s]

Training loss: 0.39849282418857096, global step: 1799
Training loss: 0.39855694311537737, global step: 1799
Training loss: 0.3985430299727337, global step: 1799


 41%|████      | 9000/21967 [12:56<19:02, 11.35it/s]

Training loss: 0.39849984135591465, global step: 1799
Training loss: 0.39849966274456333, global step: 1799


 43%|████▎     | 9496/21967 [13:38<18:09, 11.44it/s]

Training loss: 0.398262738935395, global step: 1899
Training loss: 0.39825047364778626, global step: 1899
Training loss: 0.39823737427206357, global step: 1899


 43%|████▎     | 9500/21967 [13:39<18:18, 11.35it/s]

Training loss: 0.3981989748048322, global step: 1899
Training loss: 0.39823667985217887, global step: 1899


 46%|████▌     | 9996/21967 [14:22<17:26, 11.44it/s]

Training loss: 0.4005269856173514, global step: 1999
Training loss: 0.40052734163971365, global step: 1999
Training loss: 0.40067389304877543, global step: 1999


 46%|████▌     | 10000/21967 [14:22<17:33, 11.35it/s]

Training loss: 0.4006547784124138, global step: 1999
Training loss: 0.40063848962058185, global step: 1999


 48%|████▊     | 10496/21967 [15:05<16:42, 11.44it/s]

Training loss: 0.4001904287173218, global step: 2099
Training loss: 0.40024915221018925, global step: 2099
Training loss: 0.4002332503900674, global step: 2099


 48%|████▊     | 10500/21967 [15:05<16:50, 11.35it/s]

Training loss: 0.40020115502088704, global step: 2099
Training loss: 0.4001652989067748, global step: 2099


 50%|█████     | 10996/21967 [15:48<15:58, 11.45it/s]

Training loss: 0.40038018025643585, global step: 2199
Training loss: 0.4003976174512604, global step: 2199
Training loss: 0.4004334746650978, global step: 2199


 50%|█████     | 11000/21967 [15:48<16:06, 11.35it/s]

Training loss: 0.40045697612456654, global step: 2199
Training loss: 0.400438964669089, global step: 2199


 52%|█████▏    | 11496/21967 [16:31<15:14, 11.45it/s]

Training loss: 0.40258181204346905, global step: 2299
Training loss: 0.4025759006376255, global step: 2299
Training loss: 0.4026072374914101, global step: 2299


 52%|█████▏    | 11500/21967 [16:31<15:21, 11.35it/s]

Training loss: 0.40261427761701263, global step: 2299
Training loss: 0.40262599390957277, global step: 2299


 55%|█████▍    | 11996/21967 [17:14<14:31, 11.45it/s]

Training loss: 0.4031844820873085, global step: 2399
Training loss: 0.40318044764067684, global step: 2399
Training loss: 0.40320163373850826, global step: 2399


 55%|█████▍    | 12000/21967 [17:14<14:38, 11.35it/s]

Training loss: 0.40318863100039865, global step: 2399
Training loss: 0.4031794564067355, global step: 2399


 57%|█████▋    | 12496/21967 [17:57<13:47, 11.45it/s]

Training loss: 0.4026529907341647, global step: 2499
Training loss: 0.40265488117256204, global step: 2499
Training loss: 0.40263425982002515, global step: 2499


 57%|█████▋    | 12500/21967 [17:58<13:53, 11.36it/s]

Training loss: 0.40264715704355974, global step: 2499
Training loss: 0.40262367452394837, global step: 2499


 59%|█████▉    | 12996/21967 [18:40<13:03, 11.45it/s]

Training loss: 0.4033259397669963, global step: 2599
Training loss: 0.403298526723419, global step: 2599
Training loss: 0.4033015059329943, global step: 2599


 59%|█████▉    | 13000/21967 [18:41<13:09, 11.35it/s]

Training loss: 0.4033252386908425, global step: 2599
Training loss: 0.4033101956337842, global step: 2599


 61%|██████▏   | 13496/21967 [19:23<12:19, 11.46it/s]

Training loss: 0.40408892484399145, global step: 2699
Training loss: 0.40407491551033775, global step: 2699
Training loss: 0.40406425259339535, global step: 2699


 61%|██████▏   | 13500/21967 [19:24<12:25, 11.36it/s]

Training loss: 0.4041002852869268, global step: 2699
Training loss: 0.40408797611081904, global step: 2699


 64%|██████▎   | 13996/21967 [20:06<11:36, 11.45it/s]

Training loss: 0.40461663412684606, global step: 2799
Training loss: 0.40459006066903236, global step: 2799
Training loss: 0.4046147057360921, global step: 2799


 64%|██████▎   | 14000/21967 [20:07<11:41, 11.35it/s]

Training loss: 0.40458772526866754, global step: 2799
Training loss: 0.40456516123229186, global step: 2799


 66%|██████▌   | 14496/21967 [20:50<10:52, 11.45it/s]

Training loss: 0.4044721085752831, global step: 2899
Training loss: 0.40444549136057734, global step: 2899
Training loss: 0.4044347429275172, global step: 2899


 66%|██████▌   | 14500/21967 [20:50<10:57, 11.35it/s]

Training loss: 0.40443587913010176, global step: 2899
Training loss: 0.40441015879895553, global step: 2899


 68%|██████▊   | 14996/21967 [21:33<10:08, 11.45it/s]

Training loss: 0.4052768026581092, global step: 2999
Training loss: 0.40526747645909983, global step: 2999
Training loss: 0.4052714332703789, global step: 2999


 68%|██████▊   | 15000/21967 [21:33<10:13, 11.36it/s]

Training loss: 0.4053454154139964, global step: 2999
Training loss: 0.4053667108966712, global step: 2999


 71%|███████   | 15496/21967 [22:16<09:25, 11.44it/s]

Training loss: 0.4050178677818866, global step: 3099
Training loss: 0.4050144492876047, global step: 3099
Training loss: 0.40498966183039475, global step: 3099


 71%|███████   | 15500/21967 [22:16<09:30, 11.34it/s]

Training loss: 0.40500146488832145, global step: 3099
Training loss: 0.4050032363562977, global step: 3099


 73%|███████▎  | 15996/21967 [22:59<08:41, 11.46it/s]

Training loss: 0.40539235018686903, global step: 3199
Training loss: 0.40540157907322094, global step: 3199
Training loss: 0.4053797855039593, global step: 3199


 73%|███████▎  | 16000/21967 [22:59<08:45, 11.35it/s]

Training loss: 0.4053701958002039, global step: 3199
Training loss: 0.4054100629760412, global step: 3199


 75%|███████▌  | 16496/21967 [23:42<07:58, 11.43it/s]

Training loss: 0.40620500438110196, global step: 3299
Training loss: 0.40619170828942114, global step: 3299
Training loss: 0.4061730998502871, global step: 3299


 75%|███████▌  | 16500/21967 [23:42<08:02, 11.33it/s]

Training loss: 0.4061650740627229, global step: 3299
Training loss: 0.4061493818324725, global step: 3299


 77%|███████▋  | 16996/21967 [24:25<07:14, 11.45it/s]

Training loss: 0.4063329789735167, global step: 3399
Training loss: 0.4063375941947019, global step: 3399
Training loss: 0.4063622582909657, global step: 3399


 77%|███████▋  | 17000/21967 [24:26<07:17, 11.35it/s]

Training loss: 0.40634325203692157, global step: 3399
Training loss: 0.406333132730114, global step: 3399


 80%|███████▉  | 17496/21967 [25:08<06:30, 11.44it/s]

Training loss: 0.40793823619188296, global step: 3499
Training loss: 0.4079195926307021, global step: 3499
Training loss: 0.40790031787269854, global step: 3499


 80%|███████▉  | 17500/21967 [25:09<06:33, 11.35it/s]

Training loss: 0.407897971220872, global step: 3499
Training loss: 0.4079286198707344, global step: 3499


 82%|████████▏ | 17996/21967 [25:51<05:46, 11.45it/s]

Training loss: 0.4088371828228568, global step: 3599
Training loss: 0.4088768269673271, global step: 3599
Training loss: 0.4088586442160846, global step: 3599


 82%|████████▏ | 18000/21967 [25:52<05:49, 11.35it/s]

Training loss: 0.40884996509597304, global step: 3599
Training loss: 0.4088555668027594, global step: 3599


 84%|████████▍ | 18496/21967 [26:35<05:03, 11.43it/s]

Training loss: 0.40983190152179166, global step: 3699
Training loss: 0.4098418796108752, global step: 3699
Training loss: 0.40985511372795413, global step: 3699


 84%|████████▍ | 18500/21967 [26:35<05:05, 11.34it/s]

Training loss: 0.4098427471705927, global step: 3699
Training loss: 0.40984078809265717, global step: 3699


 86%|████████▋ | 18996/21967 [27:18<04:26, 11.14it/s]

Training loss: 0.4101474221173474, global step: 3799
Training loss: 0.4101372982612384, global step: 3799
Training loss: 0.4101162782469616, global step: 3799


 86%|████████▋ | 19000/21967 [27:18<04:29, 11.01it/s]

Training loss: 0.4101202087748788, global step: 3799
Training loss: 0.4101216420163002, global step: 3799


 89%|████████▉ | 19496/21967 [28:02<03:41, 11.13it/s]

Training loss: 0.4100696024432999, global step: 3899
Training loss: 0.41006871084453295, global step: 3899
Training loss: 0.41013476785233904, global step: 3899


 89%|████████▉ | 19500/21967 [28:02<03:44, 11.01it/s]

Training loss: 0.41012804479632314, global step: 3899
Training loss: 0.4101301643612268, global step: 3899


 91%|█████████ | 19996/21967 [28:46<02:57, 11.12it/s]

Training loss: 0.4106850923407118, global step: 3999
Training loss: 0.4106709209086611, global step: 3999
Training loss: 0.41065997059621867, global step: 3999


 91%|█████████ | 20000/21967 [28:46<02:58, 11.00it/s]

Training loss: 0.410775857634612, global step: 3999
Training loss: 0.410768311296478, global step: 3999


 93%|█████████▎| 20496/21967 [29:30<02:12, 11.13it/s]

Training loss: 0.4108771085425295, global step: 4099
Training loss: 0.41087927450075823, global step: 4099
Training loss: 0.4108922491110151, global step: 4099


 93%|█████████▎| 20500/21967 [29:31<02:13, 11.01it/s]

Training loss: 0.41088436378910265, global step: 4099
Training loss: 0.41090093857997734, global step: 4099


 96%|█████████▌| 20996/21967 [30:14<01:27, 11.12it/s]

Training loss: 0.4109100283790165, global step: 4199
Training loss: 0.4109112495523054, global step: 4199
Training loss: 0.41091799314127914, global step: 4199


 96%|█████████▌| 21000/21967 [30:15<01:27, 11.01it/s]

Training loss: 0.4109201523210144, global step: 4199
Training loss: 0.4109011698737688, global step: 4199


 98%|█████████▊| 21496/21967 [30:59<00:42, 11.13it/s]

Training loss: 0.4115273549134466, global step: 4299
Training loss: 0.4115624372343126, global step: 4299
Training loss: 0.411551205558317, global step: 4299


 98%|█████████▊| 21500/21967 [30:59<00:42, 11.01it/s]

Training loss: 0.41154156364206484, global step: 4299
Training loss: 0.41152430292398123, global step: 4299


100%|██████████| 21967/21967 [31:39<00:00, 11.56it/s]


begin evaluation!!!
eval_loss: 1.2231793509367244
eval_accuracy: 0.6455431754874652


In [28]:
test_path='data/test/middle'
is_training_test=False
test_input_cloze,test_input_direct=to_input(test_path)
test_sequences_cloze=input_to_sequence_cloze(test_input_cloze,max_seq_length,tokenizer,is_training_test)
test_sequences_direct=input_to_sequence_question(test_input_direct,max_seq_length,tokenizer,is_training_test)
test_sequences=test_sequences_cloze+test_sequences_direct
test_batch_size=1

test_input_sequence=to_all_tensor(test_sequences,'input_sequence')
test_segment_id=to_all_tensor(test_sequences,'segment_id')
test_mask_id=to_all_tensor(test_sequences,'mask_id')
test_label=torch.tensor([case.label for case in test_sequences],dtype=torch.long)

test_dataset=TensorDataset(test_input_sequence,test_segment_id,test_mask_id,test_label)
test_sampler=RandomSampler(test_dataset)
test_data=DataLoader(test_dataset,sampler=test_sampler, batch_size=batch_size)





In [29]:
test_path2='data/test/high'
is_training_test=False
test_input2_cloze,test_input2_direct=to_input(test_path2)
test_sequences2_cloze=input_to_sequence_cloze(test_input2_cloze,max_seq_length,tokenizer,is_training_test)
test_sequences2_direct=input_to_sequence_question(test_input2_direct,max_seq_length,tokenizer,is_training_test)
test_sequences2=test_sequences2_cloze+test_sequences2_direct
test_batch_size=1

test_input_sequence2=to_all_tensor(test_sequences2,'input_sequence')
test_segment_id2=to_all_tensor(test_sequences2,'segment_id')
test_mask_id2=to_all_tensor(test_sequences2,'mask_id')
test_label2=torch.tensor([case.label for case in test_sequences2],dtype=torch.long)

test_dataset2=TensorDataset(test_input_sequence2,test_segment_id2,test_mask_id2,test_label2)
test_sampler2=RandomSampler(test_dataset2)
test_data2=DataLoader(test_dataset2,sampler=test_sampler2, batch_size=batch_size)

In [30]:
def test_race(test_dataloader,device,model,global_step):
    print("begin testing!!!")
    test_loss, test_accuracy = 0, 0
    nb_test_steps, nb_test_examples = 0, 0
    for step, batch in enumerate(test_dataloader):
        batch = tuple(t.to(device) for t in batch)
        input_ids, segment_ids, input_mask, label_ids = batch

        with torch.no_grad():
            tmp_test_loss = model(input_ids, segment_ids, input_mask, label_ids)
            logits = model(input_ids, segment_ids, input_mask)

        logits = logits.detach().cpu().numpy()
        label_ids = label_ids.to('cpu').numpy()
        tmp_test_accuracy = accuracy(logits, label_ids)

        test_loss += tmp_test_loss.mean().item()
        test_accuracy += tmp_test_accuracy

        nb_test_examples += input_ids.size(0)
        nb_test_steps += 1

    test_loss = test_loss / nb_test_steps
    test_accuracy = test_accuracy / nb_test_examples
    print("test_loss:",test_loss)
    print("test_accuracy:",test_accuracy)

    result = {'dev_test_loss': test_loss,
              'dev_test_accuracy': test_accuracy,
              'global_step': global_step}

    output_test_file = os.path.join('output', "test_results.txt")
    with open(output_test_file, "a+") as writer:
        for key in sorted(result.keys()):
            writer.write("%s = %s\n" % (key, str(result[key])))

In [31]:
model.eval()
test_race(test_data, device, model, global_step)

begin testing!!!
test_loss: 1.1780953962829155
test_accuracy: 0.6622562674094707


In [32]:
model.eval()
test_race(test_data2, device, model, global_step)

begin testing!!!
test_loss: 1.5713081693594078
test_accuracy: 0.5520297312750143


In [33]:
test_cloze=test_sequences_cloze+test_sequences2_cloze
test_direct=test_sequences_direct+test_sequences2_direct

test_input_sequence_cloze=to_all_tensor(test_cloze,'input_sequence')
test_segment_id_cloze=to_all_tensor(test_cloze,'segment_id')
test_mask_id_cloze=to_all_tensor(test_cloze,'mask_id')
test_label_cloze=torch.tensor([case.label for case in test_cloze],dtype=torch.long)

test_dataset_cloze=TensorDataset(test_input_sequence_cloze,test_segment_id_cloze,test_mask_id_cloze,test_label_cloze)
test_sampler_cloze=RandomSampler(test_dataset_cloze)
test_data_cloze=DataLoader(test_dataset_cloze,sampler=test_sampler_cloze, batch_size=batch_size)

test_input_sequence_direct=to_all_tensor(test_direct,'input_sequence')
test_segment_id_direct=to_all_tensor(test_direct,'segment_id')
test_mask_id_direct=to_all_tensor(test_direct,'mask_id')
test_label_direct=torch.tensor([case.label for case in test_direct],dtype=torch.long)

test_dataset_direct=TensorDataset(test_input_sequence_direct,test_segment_id_direct,test_mask_id_direct,test_label_direct)
test_sampler_direct=RandomSampler(test_dataset_direct)
test_data_direct=DataLoader(test_dataset_direct,sampler=test_sampler_direct, batch_size=batch_size)

In [34]:
model.eval()
test_race(test_data_cloze,device,model,global_step)

begin testing!!!
test_loss: 1.447483023742803
test_accuracy: 0.5973484848484848


In [35]:
model.eval()
test_race(test_data_direct,device,model,global_step)

begin testing!!!
test_loss: 1.4676962107669835
test_accuracy: 0.5688753269398431
