In [8]:
!pip install torch

Defaulting to user installation because normal site-packages is not writeable


In [1]:
import torch

In [2]:
torch.cuda.is_available()

True

In [3]:
device=torch.device("cuda" if torch.cuda.is_available else "cpu")

In [4]:
device

device(type='cuda')

In [1]:
import logging
import os
import argparse
import random
import json
import numpy as np
import torch
import glob
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
# from utiles import read_race_examples,convert_examples_to_features,accuracy,select_field
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.modeling import BertForMultipleChoice
from pytorch_pretrained_bert.optimization import BertAdam
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from tqdm import tqdm
from transformers import BertModel, BertTokenizer



In [2]:
class RaceSequence(object):
    def __init__(self,article,question,option1,option2,option3,option4,label = None):
        self.article = article
        self.question=question
        self.options=[option1,option2,option3,option4]
        self.label = label
        

class InputSequence(object):
    def __init__(self,sequence_list,label):
        self.choices_sequence = [
            {
                'input_sequence': input_sequence,
                'segment_id': segment_id,
                'mask_id': mask_id
            }
            for input_sequence,segment_id,mask_id in sequence_list
        ]
        self.label = label


In [3]:
def to_input(input_path):
    cloze_input_list=[]
    direct_input_list=[]
    for filename in os.listdir(input_path):
        file_path=input_path+"/"+filename
        # print(file_path)
        with open(file_path,'r',encoding='utf-8') as f:
            json_text=json.loads(f.read())
            correct_answers=json_text['answers']
            options_list=json_text['options']
            questions_list=json_text['questions']
            article=json_text['article']

            for i in range(len(questions_list)):
                correct_answer=correct_answers[i]
                label=ord(correct_answer)-ord('A')
                options=options_list[i]
                question=questions_list[i]
                if '_' in question:
                    case=RaceSequence(article,question,options[0],options[1],options[2],options[3],label)
                    cloze_input_list.append(case)
                else:
                    case=RaceSequence(article,question,options[0],options[1],options[2],options[3],label)
                    direct_input_list.append(case)
    return cloze_input_list,direct_input_list

In [4]:
def input_to_sequence_question(input_list,max_length,tokenizer,is_training):

    choice_list = []
    for case in input_list:
#         print(case.article)
        article_token = tokenizer.tokenize(case.article)
        question_token = tokenizer.tokenize(case.question)
        
        sequence_list = []
        for option in case.options:
            
            question_answer_token = question_token + tokenizer.tokenize(option)
            
            while len(article_token)+len(question_answer_token)>max_length-3:
                if  len(article_token)>len(question_answer_token):
                    article_token.pop()
                else:
                    question_answer_token.pop()
            
            sequence_token = ["[CLS]"] + article_token + ["[SEP]"] + question_answer_token + ["[SEP]"]
            segment_id = [0] * (len(article_token) + 2) + [1] * (len(question_answer_token) + 1)
            input_sequence = tokenizer.convert_tokens_to_ids(sequence_token)
            mask_id = [1] * len(input_sequence)

            # Zero-pad up to the sequence length.
            for i in range(max_length-len(input_sequence)):
                input_sequence+=[0]
                mask_id+=[0]
                segment_id+=[0]

            sequence_list.append((input_sequence,segment_id,mask_id))

        label = case.label

        choice_list.append(InputSequence(sequence_list,label))

    return choice_list






def input_to_sequence_cloze(input_list,max_length,tokenizer,is_training):

    choice_list = []
    for case in input_list:
#         print(case.article)
        article_token = tokenizer.tokenize(case.article)
#         question_token = tokenizer.tokenize(case.question)
        
        sequence_list = []
        for option in case.options:
            question_answer=case.question.replace("_",option)
            
            question_answer_token = tokenizer.tokenize(question_answer)
            
            while len(article_token)+len(question_answer_token)>max_length-3:
                if  len(article_token)>len(question_answer_token):
                    article_token.pop()
                else:
                    question_answer_token.pop()
            
            sequence_token = ["[CLS]"] + article_token + ["[SEP]"] + question_answer_token + ["[SEP]"]
            segment_id = [0] * (len(article_token) + 2) + [1] * (len(question_answer_token) + 1)
            input_sequence = tokenizer.convert_tokens_to_ids(sequence_token)
            mask_id = [1] * len(input_sequence)

            # Zero-pad up to the sequence length.
            for i in range(max_length-len(input_sequence)):
                input_sequence+=[0]
                mask_id+=[0]
                segment_id+=[0]

            sequence_list.append((input_sequence,segment_id,mask_id))

        label = case.label

        choice_list.append(InputSequence(sequence_list,label))

    return choice_list


In [5]:
train_middle_path='data/train/middle'
train_high_path='data/train/high'

max_seq_length=360
is_training=True
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")


train_input_cloze,train_input_direct=to_input(train_middle_path)
train_input_high_cloze,train_input_high_direct=to_input(train_high_path)
train_input_cloze+=train_input_high_cloze
train_input_direct+=train_input_high_direct
print("train_input completed!")
input_sequences_question=input_to_sequence_question(train_input_direct, max_seq_length,tokenizer,is_training)
input_sequences_cloze=input_to_sequence_cloze(train_input_cloze, max_seq_length,tokenizer,is_training)

train_input completed!


In [6]:
train_input=train_input_cloze+train_input_direct

In [7]:
input_sequences=input_sequences_question+input_sequences_cloze

In [8]:
eval_middle_path='data/dev/middle'
max_seq_length=360
is_training=False
# tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

eval_input_middle_cloze,eval_input_middle_direct=to_input(eval_middle_path)

input_sequences_eval_middle_cloze=input_to_sequence_cloze(eval_input_middle_cloze, max_seq_length,tokenizer,is_training)
input_sequences_eval_middle_direct=input_to_sequence_question(eval_input_middle_direct, max_seq_length,tokenizer,is_training)
input_sequences_eval_middle=input_sequences_eval_middle_cloze+input_sequences_eval_middle_direct

In [9]:
def warmup_linear(x, warmup=0.002):
    if x < warmup:
        return x/warmup
    return 1.0 - x


def accuracy(out, labels):
    outputs = np.argmax(out, axis=1)
    return np.sum(outputs == labels)

In [10]:
def to_all_tensor(input_sequences,key):
    all_input_sequences=[]
    for input_sequence in input_sequences:
        lst=[]
        for case in input_sequence.choices_sequence:
            lst.append(case[key])
        all_input_sequences.append(lst)
    return torch.tensor(all_input_sequences,dtype=torch.long)

In [11]:
train_input_sequence=to_all_tensor(input_sequences,'input_sequence')
train_segment_id=to_all_tensor(input_sequences,'segment_id')
train_mask_id=to_all_tensor(input_sequences,'mask_id')
train_label=torch.tensor([case.label for case in input_sequences],dtype=torch.long)

In [12]:
eval_input_sequence_middle=to_all_tensor(input_sequences_eval_middle,'input_sequence')
eval_segment_id_middle=to_all_tensor(input_sequences_eval_middle,'segment_id')
eval_mask_id_middle=to_all_tensor(input_sequences_eval_middle,'mask_id')
eval_label_middle=torch.tensor([case.label for case in input_sequences_eval_middle],dtype=torch.long)

In [13]:
eval_input_sequence_middle.shape

torch.Size([1436, 4, 360])

In [14]:
train_input_sequence.shape

torch.Size([87866, 4, 360])

In [15]:
device=torch.device("cuda" if torch.cuda.is_available else "cpu")

model = BertForMultipleChoice.from_pretrained("MiniLM-L6-H384-distilled-from-BERT-Base",
        cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(-1)),
        num_choices=4).to(device)

In [16]:
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule


learning_rate=2.5e-5
warmup_proportion=0.05
batch_size=4
gradient_accumulation_steps=5 # not used
# train_batch_size=int(batch_size/gradient_accumulation_steps)
train_batch_size=batch_size
eval_batch_size=4

num_train_steps=len(train_input)

n_gpu=torch.cuda.device_count()
epoch_num=5


device=torch.device("cuda" if torch.cuda.is_available else "cpu")

model = BertForMultipleChoice.from_pretrained("MiniLM-L6-H384-distilled-from-BERT-Base",
        cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(-1)),
        num_choices=4).to(device)

param_optimizer=list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

optimizer = BertAdam(optimizer_grouped_parameters,
                    lr=learning_rate,
                    warmup=warmup_proportion,
                    t_total=num_train_steps)

# optimizer = BertAdam(model.parameters(),lr=learning_rate,schedule='warmup_linear',warmup=warmup_proportion,t_total=num_train_steps)

train_dataset=TensorDataset(train_input_sequence,train_segment_id,train_mask_id,train_label)
train_sampler=RandomSampler(train_dataset)
train_data=DataLoader(train_dataset,sampler=train_sampler, batch_size=train_batch_size)

eval_dataset=TensorDataset(eval_input_sequence_middle,eval_segment_id_middle,eval_mask_id_middle,eval_label_middle)
eval_sampler=RandomSampler(eval_dataset)
eval_data=DataLoader(eval_dataset,sampler=eval_sampler, batch_size=eval_batch_size)

In [17]:
num_train_steps

87866

In [18]:
def val_race(eval_dataloader,device,model,global_step):
    print("begin evaluation!!!")
    eval_loss, eval_accuracy = 0, 0
    eval_steps, eval_examples = 0, 0
    for step, batch in enumerate(eval_dataloader):
        batch = tuple(t.to(device) for t in batch)
        input_ids, segment_ids, input_mask, label_ids = batch

        with torch.no_grad():
            tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids)
            logits = model(input_ids, segment_ids, input_mask)

        logits = logits.detach().cpu().numpy()
        label_ids = label_ids.to('cpu').numpy()
        tmp_eval_accuracy = accuracy(logits, label_ids)

        eval_loss += tmp_eval_loss.mean().item()
        eval_accuracy += tmp_eval_accuracy

        eval_examples += input_ids.size(0)
        eval_steps += 1

    eval_loss = eval_loss / eval_steps
    eval_accuracy = eval_accuracy / eval_examples
    print("eval_loss:",eval_loss)
    print("eval_accuracy:",eval_accuracy)

    result = {'dev_eval_loss': eval_loss,
              'dev_eval_accuracy': eval_accuracy,
              'global_step': global_step}

    output_eval_file = os.path.join('output', "eval_results_nowarmpup.txt")
    with open(output_eval_file, "a+") as writer:
        for key in sorted(result.keys()):
            writer.write("%s = %s\n" % (key, str(result[key])))

In [19]:
def train_race(train_dataloader,device,n_gpu,model,optimizer,global_step,t_total,train_loss,train_examples,train_steps):
    for step, data in enumerate(tqdm(train_dataloader)):
        data = tuple(t.to(device) for t in data)
        input_ids, segment_ids,input_mask, label_ids = data
#         print(segment_ids)
        loss = model(input_ids, segment_ids, input_mask, label_ids)
        
        train_loss += loss.item()
        train_examples += input_ids.size(0)
        train_steps += 1

        loss.backward()

        
         # modify learning rate with special warm up BERT uses
        if (step + 1) % gradient_accumulation_steps == 0:
#             tmp_lr= learning_rate * warmup_linear(global_step / t_total, warmup_proportion)

            tmp_lr=learning_rate

            for param_group in optimizer.param_groups:
                param_group['lr'] = tmp_lr
            optimizer.step()
            optimizer.zero_grad()
            global_step += 1
#             print("Training loss: {}, global step: {}".format(tr_loss / nb_tr_steps, global_step))

        if (global_step+1) % 100 == 0:
            print("Training loss: {}, global step: {}".format(train_loss / train_steps, global_step))
            
            

In [20]:
n_gpu

1

In [None]:
global_step=0
model.train()

for epoch in range(epoch_num):
    train_loss=0
    training_case_num=0
    training_step_num=0
    
  
  
    print("Training Epoch: {}/{}".format(epoch+1, int(epoch_num)))
    train_race(train_data,device,n_gpu,model,optimizer,global_step,num_train_steps,train_loss,training_case_num,training_step_num)
#     if (global_step+1) % 100 == 0:
    model.eval()
    val_race(eval_data, device, model, global_step)

  

  0%|          | 0/21967 [00:00<?, ?it/s]

Training Epoch: 1/5


	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  ../torch/csrc/utils/python_arg_parser.cpp:1055.)
  next_m.mul_(beta1).add_(1 - beta1, grad)
  2%|▏         | 496/21967 [00:45<32:45, 10.92it/s] 

Training loss: 1.3864957927453399, global step: 99
Training loss: 1.386507928852112, global step: 99
Training loss: 1.3865080532173755, global step: 99


  2%|▏         | 500/21967 [00:45<33:00, 10.84it/s]

Training loss: 1.386503319184943, global step: 99
Training loss: 1.3864960386184508, global step: 99


  5%|▍         | 996/21967 [01:30<32:00, 10.92it/s]

Training loss: 1.3864060409105003, global step: 199
Training loss: 1.3864028445448742, global step: 199
Training loss: 1.386401098494305, global step: 199


  5%|▍         | 1000/21967 [01:31<32:14, 10.84it/s]

Training loss: 1.386402532786788, global step: 199
Training loss: 1.3864049391226247, global step: 199


  7%|▋         | 1496/21967 [02:15<31:17, 10.91it/s]

Training loss: 1.3862965438676917, global step: 299
Training loss: 1.3862937213902804, global step: 299
Training loss: 1.386291637449322, global step: 299


  7%|▋         | 1500/21967 [02:16<31:31, 10.82it/s]

Training loss: 1.3862951956540148, global step: 299
Training loss: 1.3862975228222472, global step: 299


  9%|▉         | 1996/21967 [03:01<30:29, 10.91it/s]

Training loss: 1.3860321865942244, global step: 399
Training loss: 1.386033145602576, global step: 399
Training loss: 1.3860317012579606, global step: 399


  9%|▉         | 2000/21967 [03:01<30:43, 10.83it/s]

Training loss: 1.3860321771036517, global step: 399
Training loss: 1.386034989428556, global step: 399


 11%|█▏        | 2496/21967 [03:46<30:31, 10.63it/s]

Training loss: 1.3853396341174782, global step: 499
Training loss: 1.3853265435840838, global step: 499
Training loss: 1.3853227918416344, global step: 499


 11%|█▏        | 2498/21967 [03:46<29:41, 10.93it/s]

Training loss: 1.3853299430220294, global step: 499
Training loss: 1.3853423830126228, global step: 499


 14%|█▎        | 2996/21967 [04:32<28:57, 10.92it/s]

Training loss: 1.3833083907431474, global step: 599
Training loss: 1.3832932993152909, global step: 599
Training loss: 1.3833195710524264, global step: 599


 14%|█▎        | 3000/21967 [04:32<29:10, 10.83it/s]

Training loss: 1.383319209224785, global step: 599
Training loss: 1.3832780625590406, global step: 599


 16%|█▌        | 3496/21967 [05:17<28:12, 10.91it/s]

Training loss: 1.379528128674443, global step: 699
Training loss: 1.379464959866395, global step: 699
Training loss: 1.3794551795640946, global step: 699


 16%|█▌        | 3500/21967 [05:18<28:25, 10.83it/s]

Training loss: 1.3794589149672485, global step: 699
Training loss: 1.3794702070718086, global step: 699


 18%|█▊        | 3996/21967 [06:03<27:27, 10.91it/s]

Training loss: 1.3749099454086027, global step: 799
Training loss: 1.3748947888075769, global step: 799
Training loss: 1.3748440508816222, global step: 799


 18%|█▊        | 4000/21967 [06:03<27:39, 10.83it/s]

Training loss: 1.3748035750698007, global step: 799
Training loss: 1.3747545134189516, global step: 799


 20%|██        | 4496/21967 [06:48<26:41, 10.91it/s]

Training loss: 1.3719094716401996, global step: 899
Training loss: 1.3719393311075043, global step: 899
Training loss: 1.3719258923145674, global step: 899


 20%|██        | 4500/21967 [06:48<26:53, 10.83it/s]

Training loss: 1.3719170737287743, global step: 899
Training loss: 1.3718960029492249, global step: 899


 23%|██▎       | 4996/21967 [07:33<25:54, 10.92it/s]

Training loss: 1.3688831755110213, global step: 999
Training loss: 1.3688651263188896, global step: 999
Training loss: 1.3688857571134478, global step: 999


 23%|██▎       | 5000/21967 [07:33<26:06, 10.83it/s]

Training loss: 1.3688504761245166, global step: 999
Training loss: 1.3688304660367498, global step: 999


 25%|██▌       | 5496/21967 [08:19<25:10, 10.91it/s]

Training loss: 1.3645467568355, global step: 1099
Training loss: 1.3645412597637163, global step: 1099
Training loss: 1.364537649551955, global step: 1099


 25%|██▌       | 5500/21967 [08:19<25:20, 10.83it/s]

Training loss: 1.3645516550727392, global step: 1099
Training loss: 1.36454870050919, global step: 1099


 27%|██▋       | 5996/21967 [09:05<24:23, 10.91it/s]

Training loss: 1.359968778747037, global step: 1199
Training loss: 1.3599143126076105, global step: 1199
Training loss: 1.359951574006875, global step: 1199


 27%|██▋       | 6000/21967 [09:05<24:34, 10.83it/s]

Training loss: 1.3599526200464622, global step: 1199
Training loss: 1.3599364058656243, global step: 1199


 30%|██▉       | 6496/21967 [09:50<23:38, 10.91it/s]

Training loss: 1.3570229707909511, global step: 1299
Training loss: 1.3570522198481072, global step: 1299
Training loss: 1.357047021260569, global step: 1299


 30%|██▉       | 6500/21967 [09:50<23:48, 10.83it/s]

Training loss: 1.3570414840000966, global step: 1299
Training loss: 1.357021274970557, global step: 1299


 32%|███▏      | 6996/21967 [10:35<22:53, 10.90it/s]

Training loss: 1.3549990655695907, global step: 1399
Training loss: 1.3549740419618193, global step: 1399
Training loss: 1.354982971923597, global step: 1399


 32%|███▏      | 7000/21967 [10:35<23:02, 10.82it/s]

Training loss: 1.3550157738638182, global step: 1399
Training loss: 1.355005533427132, global step: 1399


 34%|███▍      | 7496/21967 [11:21<22:08, 10.89it/s]

Training loss: 1.3518040594695169, global step: 1499
Training loss: 1.3518682268856683, global step: 1499
Training loss: 1.3518565590578475, global step: 1499


 34%|███▍      | 7500/21967 [11:21<22:17, 10.82it/s]

Training loss: 1.35184861502987, global step: 1499
Training loss: 1.3518381654969818, global step: 1499


 36%|███▋      | 7996/21967 [12:06<21:20, 10.91it/s]

Training loss: 1.347391414426132, global step: 1599
Training loss: 1.3473940256791093, global step: 1599
Training loss: 1.347334150211415, global step: 1599


 36%|███▋      | 8000/21967 [12:06<21:28, 10.84it/s]

Training loss: 1.34728753030166, global step: 1599
Training loss: 1.3472608456478699, global step: 1599


 39%|███▊      | 8496/21967 [12:51<20:34, 10.91it/s]

Training loss: 1.3443988224911367, global step: 1699
Training loss: 1.3443995281113856, global step: 1699
Training loss: 1.3443756131224034, global step: 1699


 39%|███▊      | 8500/21967 [12:52<20:44, 10.83it/s]

Training loss: 1.3443737308413428, global step: 1699
Training loss: 1.344350306177605, global step: 1699


 41%|████      | 8996/21967 [13:37<19:50, 10.90it/s]

Training loss: 1.3407817874901025, global step: 1799
Training loss: 1.340755079072441, global step: 1799
Training loss: 1.340726888575951, global step: 1799


 41%|████      | 9000/21967 [13:37<19:58, 10.82it/s]

Training loss: 1.3406840987034865, global step: 1799
Training loss: 1.3406685641479832, global step: 1799


 43%|████▎     | 9496/21967 [14:22<19:02, 10.91it/s]

Training loss: 1.3385567779852379, global step: 1899
Training loss: 1.3385573290131527, global step: 1899
Training loss: 1.3385383282207848, global step: 1899


 43%|████▎     | 9500/21967 [14:22<19:11, 10.82it/s]

Training loss: 1.3385270262899136, global step: 1899
Training loss: 1.3385301356391412, global step: 1899


 46%|████▌     | 9996/21967 [15:08<18:43, 10.65it/s]

Training loss: 1.3353596620228125, global step: 1999
Training loss: 1.3353682194425374, global step: 1999
Training loss: 1.3353584634490976, global step: 1999


 46%|████▌     | 9998/21967 [15:08<18:13, 10.94it/s]

Training loss: 1.3353412201474204, global step: 1999
Training loss: 1.3353044493208648, global step: 1999


 48%|████▊     | 10496/21967 [15:53<17:31, 10.91it/s]

Training loss: 1.3325038853800257, global step: 2099
Training loss: 1.3325264867417879, global step: 2099
Training loss: 1.3325304636060822, global step: 2099


 48%|████▊     | 10500/21967 [15:53<17:39, 10.82it/s]

Training loss: 1.3325107928161781, global step: 2099
Training loss: 1.3325010018412733, global step: 2099


 50%|█████     | 10996/21967 [16:38<17:11, 10.63it/s]

Training loss: 1.329676198032348, global step: 2199
Training loss: 1.3296827919843546, global step: 2199
Training loss: 1.3297129262032439, global step: 2199


 50%|█████     | 10998/21967 [16:39<16:43, 10.94it/s]

Training loss: 1.3297136065298307, global step: 2199
Training loss: 1.329749769482334, global step: 2199


 52%|█████▏    | 11496/21967 [17:24<15:59, 10.91it/s]

Training loss: 1.3266802729342801, global step: 2299
Training loss: 1.3266754563568612, global step: 2299
Training loss: 1.3266616016768928, global step: 2299


 52%|█████▏    | 11500/21967 [17:24<16:06, 10.83it/s]

Training loss: 1.3266261460615627, global step: 2299
Training loss: 1.3266312478884312, global step: 2299


 55%|█████▍    | 11996/21967 [18:09<15:14, 10.90it/s]

Training loss: 1.324397279302892, global step: 2399
Training loss: 1.3243991210187185, global step: 2399
Training loss: 1.3243842962593637, global step: 2399


 55%|█████▍    | 12000/21967 [18:10<15:20, 10.82it/s]

Training loss: 1.3243707315407907, global step: 2399
Training loss: 1.3243711948573604, global step: 2399


 57%|█████▋    | 12496/21967 [18:55<14:28, 10.91it/s]

Training loss: 1.3215180907882944, global step: 2499
Training loss: 1.3215081093835235, global step: 2499
Training loss: 1.3215161945673526, global step: 2499


 57%|█████▋    | 12500/21967 [18:55<14:34, 10.83it/s]

Training loss: 1.3215078869034222, global step: 2499
Training loss: 1.321503064371851, global step: 2499


 59%|█████▉    | 12996/21967 [19:40<13:43, 10.89it/s]

Training loss: 1.3189259691392519, global step: 2599
Training loss: 1.318917573865064, global step: 2599
Training loss: 1.3188969102044872, global step: 2599


 59%|█████▉    | 13000/21967 [19:41<13:48, 10.82it/s]

Training loss: 1.3188698166497836, global step: 2599
Training loss: 1.3188628879306188, global step: 2599


 61%|██████▏   | 13496/21967 [20:26<12:56, 10.90it/s]

Training loss: 1.315925858814922, global step: 2699
Training loss: 1.3158944573940154, global step: 2699
Training loss: 1.3158877499767556, global step: 2699


 61%|██████▏   | 13500/21967 [20:26<13:02, 10.82it/s]

Training loss: 1.3158910418196172, global step: 2699
Training loss: 1.3158731103358476, global step: 2699


 64%|██████▎   | 13996/21967 [21:11<12:10, 10.91it/s]

Training loss: 1.3130640143920542, global step: 2799
Training loss: 1.3130346050293795, global step: 2799
Training loss: 1.3130374331214372, global step: 2799


 64%|██████▎   | 14000/21967 [21:11<12:16, 10.82it/s]

Training loss: 1.3130137277664602, global step: 2799
Training loss: 1.3130146091057067, global step: 2799


 66%|██████▌   | 14496/21967 [21:56<11:27, 10.87it/s]

Training loss: 1.309954258806749, global step: 2899
Training loss: 1.3099568116110616, global step: 2899
Training loss: 1.309948252107864, global step: 2899


 66%|██████▌   | 14500/21967 [21:57<11:31, 10.79it/s]

Training loss: 1.309949971652225, global step: 2899
Training loss: 1.3099288192360359, global step: 2899


 68%|██████▊   | 14996/21967 [22:42<10:39, 10.91it/s]

Training loss: 1.3069118433095646, global step: 2999
Training loss: 1.3069089416083606, global step: 2999
Training loss: 1.3068718963435009, global step: 2999


 68%|██████▊   | 15000/21967 [22:42<10:43, 10.82it/s]

Training loss: 1.3068743661644142, global step: 2999
Training loss: 1.3068566152918966, global step: 2999


 71%|███████   | 15496/21967 [23:27<09:54, 10.88it/s]

Training loss: 1.3039999244058467, global step: 3099
Training loss: 1.3040071608913149, global step: 3099
Training loss: 1.3040035129585765, global step: 3099


 71%|███████   | 15500/21967 [23:27<09:59, 10.79it/s]

Training loss: 1.303991818723409, global step: 3099
Training loss: 1.3039603859341924, global step: 3099


 73%|███████▎  | 15996/21967 [24:12<09:07, 10.91it/s]

Training loss: 1.3014520812701642, global step: 3199
Training loss: 1.3014304300443653, global step: 3199
Training loss: 1.3014320101415007, global step: 3199


 73%|███████▎  | 16000/21967 [24:13<09:10, 10.83it/s]

Training loss: 1.3013967436255478, global step: 3199
Training loss: 1.301439770369375, global step: 3199


 75%|███████▌  | 16496/21967 [24:58<08:22, 10.89it/s]

Training loss: 1.2991088215583524, global step: 3299
Training loss: 1.299116623677335, global step: 3299
Training loss: 1.2991078390423771, global step: 3299


 75%|███████▌  | 16500/21967 [24:58<08:26, 10.80it/s]

Training loss: 1.2991224252620543, global step: 3299
Training loss: 1.2991363238284743, global step: 3299


 77%|███████▋  | 16996/21967 [25:43<07:35, 10.91it/s]

Training loss: 1.295979218117873, global step: 3399
Training loss: 1.295991658475307, global step: 3399
Training loss: 1.2959749893177226, global step: 3399


 77%|███████▋  | 17000/21967 [25:44<07:38, 10.82it/s]

Training loss: 1.29596802847731, global step: 3399
Training loss: 1.2959551529920383, global step: 3399


 80%|███████▉  | 17496/21967 [26:29<06:50, 10.90it/s]

Training loss: 1.2932860517099811, global step: 3499
Training loss: 1.2932896097303528, global step: 3499
Training loss: 1.2933075037242112, global step: 3499


 80%|███████▉  | 17500/21967 [26:29<06:52, 10.82it/s]

Training loss: 1.2932956657902637, global step: 3499
Training loss: 1.2932828536060335, global step: 3499


 82%|████████▏ | 17996/21967 [27:14<06:03, 10.91it/s]

Training loss: 1.2900372294824498, global step: 3599
Training loss: 1.2900210691408518, global step: 3599
Training loss: 1.2900035970408552, global step: 3599


 82%|████████▏ | 18000/21967 [27:14<06:06, 10.83it/s]

Training loss: 1.2899966882969196, global step: 3599
Training loss: 1.2899722619032594, global step: 3599


 84%|████████▍ | 18496/21967 [27:59<05:18, 10.91it/s]

Training loss: 1.2876068832762146, global step: 3699
Training loss: 1.2875928879960417, global step: 3699
Training loss: 1.2875922027730218, global step: 3699


 84%|████████▍ | 18500/21967 [28:00<05:20, 10.82it/s]

Training loss: 1.287585714855108, global step: 3699
Training loss: 1.2875622876338477, global step: 3699


 86%|████████▋ | 18996/21967 [28:44<04:32, 10.91it/s]

Training loss: 1.2847443408779673, global step: 3799
Training loss: 1.2847514474695043, global step: 3799
Training loss: 1.2847528283113605, global step: 3799


 86%|████████▋ | 19000/21967 [28:45<04:34, 10.83it/s]

Training loss: 1.2847696931199133, global step: 3799
Training loss: 1.2847401657693167, global step: 3799


 89%|████████▉ | 19496/21967 [29:30<03:46, 10.91it/s]

Training loss: 1.2818069665531038, global step: 3899
Training loss: 1.2818213574355946, global step: 3899
Training loss: 1.2817951711412026, global step: 3899


 89%|████████▉ | 19500/21967 [29:30<03:47, 10.83it/s]

Training loss: 1.281785374952287, global step: 3899
Training loss: 1.2818042394934426, global step: 3899


 91%|█████████ | 19996/21967 [30:15<03:00, 10.90it/s]

Training loss: 1.2792499296573974, global step: 3999
Training loss: 1.2792414013392497, global step: 3999
Training loss: 1.279248375159273, global step: 3999


 91%|█████████ | 20000/21967 [30:16<03:01, 10.83it/s]

Training loss: 1.2792395970316837, global step: 3999
Training loss: 1.2792166455774714, global step: 3999


 93%|█████████▎| 20496/21967 [31:00<02:14, 10.92it/s]

Training loss: 1.2767586607181551, global step: 4099
Training loss: 1.276735201882232, global step: 4099
Training loss: 1.2767189318926595, global step: 4099


 93%|█████████▎| 20500/21967 [31:01<02:15, 10.83it/s]

Training loss: 1.276709046846181, global step: 4099
Training loss: 1.2767075950404225, global step: 4099


 96%|█████████▌| 20996/21967 [31:46<01:29, 10.90it/s]

Training loss: 1.2741043745733267, global step: 4199
Training loss: 1.2740983130560417, global step: 4199
Training loss: 1.2740992590811284, global step: 4199


 96%|█████████▌| 21000/21967 [31:46<01:29, 10.82it/s]

Training loss: 1.2741082079232404, global step: 4199
Training loss: 1.27408601676795, global step: 4199


 98%|█████████▊| 21496/21967 [32:31<00:44, 10.63it/s]

Training loss: 1.271347996073286, global step: 4299
Training loss: 1.2713663219285358, global step: 4299
Training loss: 1.2713578867015047, global step: 4299


 98%|█████████▊| 21498/21967 [32:31<00:42, 10.92it/s]

Training loss: 1.271334556216184, global step: 4299
Training loss: 1.271354060820898, global step: 4299


100%|██████████| 21967/21967 [33:14<00:00, 11.01it/s]


begin evaluation!!!


  0%|          | 2/21967 [00:00<30:39, 11.94it/s]

eval_loss: 1.0141123664412326
eval_accuracy: 0.5919220055710307
Training Epoch: 2/5


  2%|▏         | 496/21967 [00:43<32:12, 11.11it/s]

Training loss: 1.0838936944200535, global step: 99
Training loss: 1.0842012951691304, global step: 99
Training loss: 1.0833997739632604, global step: 99


  2%|▏         | 500/21967 [00:43<32:32, 11.00it/s]

Training loss: 1.0831121345121697, global step: 99
Training loss: 1.083428144454956, global step: 99


  5%|▍         | 996/21967 [01:27<30:41, 11.39it/s]

Training loss: 1.0800902007512712, global step: 199
Training loss: 1.0798073067782394, global step: 199
Training loss: 1.0796876639097839, global step: 199


  5%|▍         | 1000/21967 [01:27<30:54, 11.30it/s]

Training loss: 1.0793693914024052, global step: 199
Training loss: 1.0797560618625388, global step: 199


  7%|▋         | 1496/21967 [02:11<30:42, 11.11it/s]

Training loss: 1.0892772012530363, global step: 299
Training loss: 1.089699443389866, global step: 299
Training loss: 1.0898901098914837, global step: 299


  7%|▋         | 1500/21967 [02:11<30:31, 11.18it/s]

Training loss: 1.0900322355439411, global step: 299
Training loss: 1.0901399127756302, global step: 299


  9%|▉         | 1996/21967 [02:54<29:08, 11.42it/s]

Training loss: 1.086550830241134, global step: 399
Training loss: 1.086527549014063, global step: 399
Training loss: 1.086463194670651, global step: 399


  9%|▉         | 2000/21967 [02:54<29:22, 11.33it/s]

Training loss: 1.0863838242935586, global step: 399
Training loss: 1.086382847240175, global step: 399


 11%|█▏        | 2496/21967 [03:37<28:24, 11.42it/s]

Training loss: 1.0788121788439626, global step: 499
Training loss: 1.0786615098134065, global step: 499
Training loss: 1.0785017954525014, global step: 499


 11%|█▏        | 2500/21967 [03:37<28:38, 11.33it/s]

Training loss: 1.0783746385450264, global step: 499
Training loss: 1.0782791210823701, global step: 499


 14%|█▎        | 2996/21967 [04:20<27:44, 11.40it/s]

Training loss: 1.0792047566682945, global step: 599
Training loss: 1.0789776551866723, global step: 599
Training loss: 1.0789403614498274, global step: 599


 14%|█▎        | 3000/21967 [04:21<27:56, 11.31it/s]

Training loss: 1.0789097080913046, global step: 599
Training loss: 1.0788572948747732, global step: 599


 16%|█▌        | 3496/21967 [05:03<26:58, 11.41it/s]

Training loss: 1.0780209097355733, global step: 699
Training loss: 1.0779687676970895, global step: 699
Training loss: 1.0780272726155977, global step: 699


 16%|█▌        | 3500/21967 [05:04<27:11, 11.32it/s]

Training loss: 1.0781627617363183, global step: 699
Training loss: 1.07821405556976, global step: 699


 18%|█▊        | 3996/21967 [05:47<26:57, 11.11it/s]

Training loss: 1.077168748756225, global step: 799
Training loss: 1.077065038945045, global step: 799
Training loss: 1.0770731163242622, global step: 799


 18%|█▊        | 4000/21967 [05:48<27:14, 10.99it/s]

Training loss: 1.0769866200358704, global step: 799
Training loss: 1.0769499586906752, global step: 799


 20%|██        | 4496/21967 [06:31<26:12, 11.11it/s]

Training loss: 1.0739925128234773, global step: 899
Training loss: 1.0740441857502394, global step: 899
Training loss: 1.0741024043478116, global step: 899


 20%|██        | 4500/21967 [06:32<26:29, 10.99it/s]

Training loss: 1.074143007540104, global step: 899
Training loss: 1.074145763936931, global step: 899


 23%|██▎       | 4996/21967 [07:15<24:45, 11.43it/s]

Training loss: 1.0724368129257444, global step: 999
Training loss: 1.072517831123071, global step: 999
Training loss: 1.0725722227969454, global step: 999


 23%|██▎       | 5000/21967 [07:15<24:57, 11.33it/s]

Training loss: 1.0726110222263068, global step: 999
Training loss: 1.0725649545033662, global step: 999


 25%|██▌       | 5496/21967 [07:58<24:01, 11.42it/s]

Training loss: 1.0723417206636008, global step: 1099
Training loss: 1.072246075373311, global step: 1099
Training loss: 1.072402015740446, global step: 1099


 25%|██▌       | 5500/21967 [07:58<24:13, 11.33it/s]

Training loss: 1.07250096955194, global step: 1099
Training loss: 1.0725010469173775, global step: 1099


 27%|██▋       | 5996/21967 [08:41<23:21, 11.40it/s]

Training loss: 1.072742886996647, global step: 1199
Training loss: 1.0727097708992834, global step: 1199
Training loss: 1.0726879650460734, global step: 1199


 27%|██▋       | 6000/21967 [08:42<23:32, 11.30it/s]

Training loss: 1.072654256923233, global step: 1199
Training loss: 1.0725828588038688, global step: 1199


 30%|██▉       | 6496/21967 [09:25<22:34, 11.43it/s]

Training loss: 1.0726756382988818, global step: 1299
Training loss: 1.07263378242462, global step: 1299
Training loss: 1.0726279150903895, global step: 1299


 30%|██▉       | 6500/21967 [09:25<22:45, 11.33it/s]

Training loss: 1.0726670241742986, global step: 1299
Training loss: 1.0727368108555617, global step: 1299


 32%|███▏      | 6996/21967 [10:08<21:52, 11.41it/s]

Training loss: 1.071702695046262, global step: 1399
Training loss: 1.0716567255552323, global step: 1399
Training loss: 1.0716900191035155, global step: 1399


 32%|███▏      | 7000/21967 [10:08<22:02, 11.32it/s]

Training loss: 1.0716949377989353, global step: 1399
Training loss: 1.0716757799986687, global step: 1399


 34%|███▍      | 7496/21967 [10:51<21:08, 11.41it/s]

Training loss: 1.0688309838586525, global step: 1499
Training loss: 1.068841291372858, global step: 1499
Training loss: 1.068854165537144, global step: 1499


 34%|███▍      | 7500/21967 [10:52<21:17, 11.33it/s]

Training loss: 1.0689372780597919, global step: 1499
Training loss: 1.068985089111255, global step: 1499


 36%|███▋      | 7996/21967 [11:35<20:23, 11.42it/s]

Training loss: 1.0661351806134116, global step: 1599
Training loss: 1.0662152682034567, global step: 1599
Training loss: 1.0661806273295669, global step: 1599


 36%|███▋      | 8000/21967 [11:35<20:33, 11.32it/s]

Training loss: 1.0661417224606773, global step: 1599
Training loss: 1.0661360630375518, global step: 1599


 39%|███▊      | 8496/21967 [12:18<19:40, 11.41it/s]

Training loss: 1.0635298680520675, global step: 1699
Training loss: 1.0635327393296061, global step: 1699
Training loss: 1.0635378617761317, global step: 1699


 39%|███▊      | 8500/21967 [12:18<19:50, 11.32it/s]

Training loss: 1.0635439658133765, global step: 1699
Training loss: 1.0634954514628243, global step: 1699


 41%|████      | 8996/21967 [13:01<18:55, 11.42it/s]

Training loss: 1.0613328050074478, global step: 1799
Training loss: 1.0613316290830905, global step: 1799
Training loss: 1.061307694383253, global step: 1799


 41%|████      | 9000/21967 [13:01<19:04, 11.33it/s]

Training loss: 1.061360835915613, global step: 1799
Training loss: 1.0613613041820082, global step: 1799


 43%|████▎     | 9496/21967 [13:44<18:12, 11.41it/s]

Training loss: 1.060869403037509, global step: 1899
Training loss: 1.0609071264937897, global step: 1899
Training loss: 1.060976161136776, global step: 1899


 43%|████▎     | 9500/21967 [13:45<18:21, 11.32it/s]

Training loss: 1.060988954232508, global step: 1899
Training loss: 1.0609520027734176, global step: 1899


 46%|████▌     | 9996/21967 [14:28<17:29, 11.41it/s]

Training loss: 1.0605706760813558, global step: 1999
Training loss: 1.0605375494743858, global step: 1999
Training loss: 1.060493369902672, global step: 1999


 46%|████▌     | 10000/21967 [14:28<17:37, 11.32it/s]

Training loss: 1.060472000580319, global step: 1999
Training loss: 1.060452569872573, global step: 1999


 48%|████▊     | 10496/21967 [15:11<16:44, 11.42it/s]

Training loss: 1.0601014521489431, global step: 2099
Training loss: 1.060113824979209, global step: 2099
Training loss: 1.060063202393683, global step: 2099


 48%|████▊     | 10500/21967 [15:12<16:52, 11.32it/s]

Training loss: 1.0600367183585033, global step: 2099
Training loss: 1.0601034307053503, global step: 2099


 50%|█████     | 10996/21967 [15:54<16:01, 11.41it/s]

Training loss: 1.058334429119597, global step: 2199
Training loss: 1.058339144774364, global step: 2199
Training loss: 1.0584096758256905, global step: 2199


 50%|█████     | 11000/21967 [15:55<16:08, 11.33it/s]

Training loss: 1.0583534388176137, global step: 2199
Training loss: 1.058416676445163, global step: 2199


 52%|█████▏    | 11496/21967 [16:38<15:16, 11.42it/s]

Training loss: 1.0552264397535287, global step: 2299
Training loss: 1.0552681376101163, global step: 2299
Training loss: 1.0552747768765587, global step: 2299


 52%|█████▏    | 11500/21967 [16:38<15:24, 11.32it/s]

Training loss: 1.0552504106404823, global step: 2299
Training loss: 1.0553280420537223, global step: 2299


 55%|█████▍    | 11996/21967 [17:21<14:33, 11.42it/s]

Training loss: 1.053237140981194, global step: 2399
Training loss: 1.0532509739541458, global step: 2399
Training loss: 1.0532456048704282, global step: 2399


 55%|█████▍    | 12000/21967 [17:21<14:39, 11.33it/s]

Training loss: 1.0532199001959075, global step: 2399
Training loss: 1.0531683983589393, global step: 2399


 57%|█████▋    | 12496/21967 [18:04<13:50, 11.40it/s]

Training loss: 1.0506425459905355, global step: 2499
Training loss: 1.050634871329993, global step: 2499
Training loss: 1.0506293628049714, global step: 2499


 57%|█████▋    | 12500/21967 [18:05<13:57, 11.31it/s]

Training loss: 1.0505953692060734, global step: 2499
Training loss: 1.0506036367959735, global step: 2499


 59%|█████▉    | 12996/21967 [18:48<13:06, 11.41it/s]

Training loss: 1.0482085623985165, global step: 2599
Training loss: 1.04819553109931, global step: 2599
Training loss: 1.0481917244421113, global step: 2599


 59%|█████▉    | 13000/21967 [18:48<13:11, 11.32it/s]

Training loss: 1.0481928416005757, global step: 2599
Training loss: 1.0481851156770894, global step: 2599


 61%|██████▏   | 13496/21967 [19:31<12:21, 11.42it/s]

Training loss: 1.046799309052066, global step: 2699
Training loss: 1.0468153298137435, global step: 2699
Training loss: 1.046758562097896, global step: 2699


 61%|██████▏   | 13500/21967 [19:31<12:27, 11.33it/s]

Training loss: 1.0467282826623787, global step: 2699
Training loss: 1.046736349258496, global step: 2699


 64%|██████▎   | 13996/21967 [20:14<11:37, 11.42it/s]

Training loss: 1.0458678000545707, global step: 2799
Training loss: 1.0458667621178332, global step: 2799
Training loss: 1.0458616849170495, global step: 2799


 64%|██████▎   | 14000/21967 [20:14<11:43, 11.33it/s]

Training loss: 1.045840534847163, global step: 2799
Training loss: 1.0458741208507942, global step: 2799


 66%|██████▌   | 14496/21967 [20:57<10:54, 11.42it/s]

Training loss: 1.044142604541392, global step: 2899
Training loss: 1.0441142557565097, global step: 2899
Training loss: 1.0441281389705377, global step: 2899


 66%|██████▌   | 14500/21967 [20:58<10:59, 11.33it/s]

Training loss: 1.0441535878461432, global step: 2899
Training loss: 1.0441396175341258, global step: 2899


 68%|██████▊   | 14996/21967 [21:40<10:10, 11.42it/s]

Training loss: 1.0431288781474932, global step: 2999
Training loss: 1.0431270234934662, global step: 2999
Training loss: 1.0431499180385906, global step: 2999


 68%|██████▊   | 15000/21967 [21:41<10:15, 11.32it/s]

Training loss: 1.0431462371179638, global step: 2999
Training loss: 1.0431764555969703, global step: 2999


 71%|███████   | 15496/21967 [22:24<09:27, 11.41it/s]

Training loss: 1.0425192586305874, global step: 3099
Training loss: 1.0424875693427467, global step: 3099
Training loss: 1.042465810185326, global step: 3099


 71%|███████   | 15500/21967 [22:24<09:31, 11.33it/s]

Training loss: 1.0424953222530537, global step: 3099
Training loss: 1.0424793407314246, global step: 3099


 73%|███████▎  | 15996/21967 [23:07<08:56, 11.12it/s]

Training loss: 1.0417011570738568, global step: 3199
Training loss: 1.0416620057583914, global step: 3199
Training loss: 1.0416657047172393, global step: 3199


 73%|███████▎  | 16000/21967 [23:08<09:02, 11.00it/s]

Training loss: 1.0416401557234067, global step: 3199
Training loss: 1.0416117689623907, global step: 3199


 75%|███████▌  | 16496/21967 [23:52<08:12, 11.12it/s]

Training loss: 1.0401342860557123, global step: 3299
Training loss: 1.0401653982436954, global step: 3299
Training loss: 1.040152647526595, global step: 3299


 75%|███████▌  | 16500/21967 [23:52<08:16, 11.00it/s]

Training loss: 1.0401676754347626, global step: 3299
Training loss: 1.0401643777189489, global step: 3299


 77%|███████▋  | 16996/21967 [24:36<07:27, 11.12it/s]

Training loss: 1.0394953216532112, global step: 3399
Training loss: 1.0394785151088988, global step: 3399
Training loss: 1.0394936498648393, global step: 3399


 77%|███████▋  | 17000/21967 [24:36<07:31, 11.00it/s]

Training loss: 1.0395042330328192, global step: 3399
Training loss: 1.0395445536165295, global step: 3399


 80%|███████▉  | 17496/21967 [25:19<06:30, 11.44it/s]

Training loss: 1.0379362760568966, global step: 3499
Training loss: 1.0379555722937894, global step: 3499
Training loss: 1.0379543160865217, global step: 3499


 80%|███████▉  | 17500/21967 [25:19<06:33, 11.34it/s]

Training loss: 1.037920425801955, global step: 3499
Training loss: 1.0379190021124531, global step: 3499


 82%|████████▏ | 17996/21967 [26:02<05:47, 11.43it/s]

Training loss: 1.0373050577226415, global step: 3599
Training loss: 1.0373111825767978, global step: 3599
Training loss: 1.0373128032595484, global step: 3599


 82%|████████▏ | 18000/21967 [26:03<05:49, 11.34it/s]

Training loss: 1.0373186524729736, global step: 3599
Training loss: 1.0373099927976583, global step: 3599


 84%|████████▍ | 18496/21967 [26:45<05:03, 11.44it/s]

Training loss: 1.036086225019585, global step: 3699
Training loss: 1.0360750872768691, global step: 3699
Training loss: 1.0360875122672872, global step: 3699


 84%|████████▍ | 18500/21967 [26:46<05:05, 11.34it/s]

Training loss: 1.0361394170218174, global step: 3699
Training loss: 1.0361336324301722, global step: 3699


 86%|████████▋ | 18996/21967 [27:29<04:20, 11.42it/s]

Training loss: 1.0342455891450555, global step: 3799
Training loss: 1.0342204065791043, global step: 3799
Training loss: 1.0341856894976307, global step: 3799


 86%|████████▋ | 19000/21967 [27:29<04:21, 11.33it/s]

Training loss: 1.0341850479627277, global step: 3799
Training loss: 1.0341803643724434, global step: 3799


 89%|████████▉ | 19496/21967 [28:12<03:36, 11.42it/s]

Training loss: 1.033582728793336, global step: 3899
Training loss: 1.0335686878128467, global step: 3899
Training loss: 1.0335473995077789, global step: 3899


 89%|████████▉ | 19500/21967 [28:12<03:37, 11.33it/s]

Training loss: 1.0335171566008436, global step: 3899
Training loss: 1.033532980805405, global step: 3899


 91%|█████████ | 19996/21967 [28:55<02:52, 11.42it/s]

Training loss: 1.0328444987440457, global step: 3999
Training loss: 1.0328737360943772, global step: 3999
Training loss: 1.0328597057751028, global step: 3999


 91%|█████████ | 20000/21967 [28:56<02:53, 11.33it/s]

Training loss: 1.0328607459373103, global step: 3999
Training loss: 1.032835829723985, global step: 3999


 93%|█████████▎| 20496/21967 [29:39<02:08, 11.42it/s]

Training loss: 1.031471170480112, global step: 4099
Training loss: 1.0314758264862807, global step: 4099
Training loss: 1.0314952924997727, global step: 4099


 93%|█████████▎| 20500/21967 [29:39<02:09, 11.34it/s]

Training loss: 1.0315058448908314, global step: 4099
Training loss: 1.0314746846986143, global step: 4099


 96%|█████████▌| 20996/21967 [30:22<01:24, 11.43it/s]

Training loss: 1.0307130432710545, global step: 4199
Training loss: 1.0307370412452113, global step: 4199
Training loss: 1.0307495639094693, global step: 4199


 96%|█████████▌| 21000/21967 [30:22<01:25, 11.34it/s]

Training loss: 1.0307962005498996, global step: 4199
Training loss: 1.0308038228071201, global step: 4199


 98%|█████████▊| 21496/21967 [31:05<00:41, 11.41it/s]

Training loss: 1.0300991007682543, global step: 4299
Training loss: 1.030085592122196, global step: 4299
Training loss: 1.030067841088512, global step: 4299


 98%|█████████▊| 21500/21967 [31:05<00:41, 11.33it/s]

Training loss: 1.030053050933512, global step: 4299
Training loss: 1.0300450183945362, global step: 4299


100%|██████████| 21967/21967 [31:46<00:00, 11.52it/s]


begin evaluation!!!


  0%|          | 2/21967 [00:00<31:42, 11.54it/s]

eval_loss: 0.8692831839940675
eval_accuracy: 0.649025069637883
Training Epoch: 3/5


  2%|▏         | 496/21967 [00:42<31:17, 11.43it/s]

Training loss: 0.8104287292018081, global step: 99
Training loss: 0.8103547534875332, global step: 99
Training loss: 0.8102035669973439, global step: 99


  2%|▏         | 500/21967 [00:43<31:33, 11.33it/s]

Training loss: 0.8096296887799918, global step: 99
Training loss: 0.8094937943504426, global step: 99


  5%|▍         | 996/21967 [01:26<30:35, 11.42it/s]

Training loss: 0.8096932424297884, global step: 199
Training loss: 0.8095931736266159, global step: 199
Training loss: 0.8093352333279168, global step: 199


  5%|▍         | 1000/21967 [01:26<30:51, 11.33it/s]

Training loss: 0.8097634861294277, global step: 199
Training loss: 0.8095124473010337, global step: 199


  7%|▋         | 1496/21967 [02:09<29:52, 11.42it/s]

Training loss: 0.816641020675168, global step: 299
Training loss: 0.8166517328251811, global step: 299
Training loss: 0.8165085771797336, global step: 299


  7%|▋         | 1500/21967 [02:09<30:07, 11.33it/s]

Training loss: 0.8166134201278674, global step: 299
Training loss: 0.8166996276601304, global step: 299


  9%|▉         | 1996/21967 [02:52<29:08, 11.42it/s]

Training loss: 0.8028947139574322, global step: 399
Training loss: 0.8025955129582771, global step: 399
Training loss: 0.8025382134557875, global step: 399


  9%|▉         | 2000/21967 [02:53<29:22, 11.33it/s]

Training loss: 0.8025371119346436, global step: 399
Training loss: 0.8031835428692061, global step: 399


 11%|█▏        | 2496/21967 [03:36<29:12, 11.11it/s]

Training loss: 0.7995683599117852, global step: 499
Training loss: 0.7996843531515109, global step: 499
Training loss: 0.7998008768872011, global step: 499


 11%|█▏        | 2500/21967 [03:36<29:31, 10.99it/s]

Training loss: 0.799597193283052, global step: 499
Training loss: 0.7998784799132498, global step: 499


 14%|█▎        | 2996/21967 [04:19<27:42, 11.41it/s]

Training loss: 0.7986479297351832, global step: 599
Training loss: 0.7987787915970036, global step: 599
Training loss: 0.7986585864738813, global step: 599


 14%|█▎        | 3000/21967 [04:20<27:55, 11.32it/s]

Training loss: 0.7985459244878387, global step: 599
Training loss: 0.7984252920512156, global step: 599


 16%|█▌        | 3496/21967 [05:03<27:42, 11.11it/s]

Training loss: 0.800928757265903, global step: 699
Training loss: 0.8010268172255058, global step: 699
Training loss: 0.8010662168342687, global step: 699


 16%|█▌        | 3500/21967 [05:03<28:01, 10.98it/s]

Training loss: 0.8010347961051961, global step: 699
Training loss: 0.8012200073129039, global step: 699


 18%|█▊        | 3996/21967 [05:46<26:14, 11.41it/s]

Training loss: 0.8010696269695335, global step: 799
Training loss: 0.8009859184560658, global step: 799
Training loss: 0.8010953306065524, global step: 799


 18%|█▊        | 4000/21967 [05:46<26:27, 11.32it/s]

Training loss: 0.80106062149506, global step: 799
Training loss: 0.8009752807141963, global step: 799


 20%|██        | 4496/21967 [06:30<25:30, 11.42it/s]

Training loss: 0.805171064092872, global step: 899
Training loss: 0.8052467671506773, global step: 899
Training loss: 0.8053541973994104, global step: 899


 20%|██        | 4500/21967 [06:30<25:42, 11.32it/s]

Training loss: 0.8055022442413153, global step: 899
Training loss: 0.8056591963055986, global step: 899


 23%|██▎       | 4996/21967 [07:13<24:45, 11.42it/s]

Training loss: 0.8077677097371048, global step: 999
Training loss: 0.8076888134293303, global step: 999
Training loss: 0.8076292595304049, global step: 999


 23%|██▎       | 5000/21967 [07:13<24:57, 11.33it/s]

Training loss: 0.807735433237215, global step: 999
Training loss: 0.8077885965999438, global step: 999


 25%|██▌       | 5496/21967 [07:56<24:03, 11.41it/s]

Training loss: 0.8062497440906914, global step: 1099
Training loss: 0.8062966999909983, global step: 1099
Training loss: 0.806225160680161, global step: 1099


 25%|██▌       | 5500/21967 [07:56<24:14, 11.32it/s]

Training loss: 0.8062103657441864, global step: 1099
Training loss: 0.8061878478437714, global step: 1099


 27%|██▋       | 5996/21967 [08:39<23:19, 11.41it/s]

Training loss: 0.8036008735503489, global step: 1199
Training loss: 0.8035434412400613, global step: 1199
Training loss: 0.8035553619118023, global step: 1199


 27%|██▋       | 6000/21967 [08:40<23:29, 11.32it/s]

Training loss: 0.8034705722189958, global step: 1199
Training loss: 0.8034653495610627, global step: 1199


 30%|██▉       | 6496/21967 [09:23<22:34, 11.42it/s]

Training loss: 0.8044361163950234, global step: 1299
Training loss: 0.8043866573537303, global step: 1299
Training loss: 0.8043938339318767, global step: 1299


 30%|██▉       | 6500/21967 [09:23<22:45, 11.33it/s]

Training loss: 0.8044506617863031, global step: 1299
Training loss: 0.8044613294171865, global step: 1299


 32%|███▏      | 6996/21967 [10:06<21:51, 11.42it/s]

Training loss: 0.8063804116482188, global step: 1399
Training loss: 0.8064644232334004, global step: 1399
Training loss: 0.8063935911220604, global step: 1399


 32%|███▏      | 7000/21967 [10:06<22:02, 11.32it/s]

Training loss: 0.8064969532463511, global step: 1399
Training loss: 0.8065293414038138, global step: 1399


 34%|███▍      | 7496/21967 [10:49<21:08, 11.41it/s]

Training loss: 0.8069067335124063, global step: 1499
Training loss: 0.8071375997269152, global step: 1499
Training loss: 0.807153836543746, global step: 1499


 34%|███▍      | 7500/21967 [10:50<21:18, 11.32it/s]

Training loss: 0.8070790222967696, global step: 1499
Training loss: 0.8070985149304779, global step: 1499


 36%|███▋      | 7996/21967 [11:33<20:23, 11.42it/s]

Training loss: 0.8085085504082001, global step: 1599
Training loss: 0.808533250833985, global step: 1599
Training loss: 0.8085773271235186, global step: 1599


 36%|███▋      | 8000/21967 [11:33<20:33, 11.33it/s]

Training loss: 0.8086403157406215, global step: 1599
Training loss: 0.808611963734013, global step: 1599


 39%|███▊      | 8496/21967 [12:16<19:40, 11.41it/s]

Training loss: 0.8063653475991404, global step: 1699
Training loss: 0.806317824939956, global step: 1699
Training loss: 0.8063178153064532, global step: 1699


 39%|███▊      | 8500/21967 [12:17<19:50, 11.31it/s]

Training loss: 0.8063943230923573, global step: 1699
Training loss: 0.8063795405251226, global step: 1699


 41%|████      | 8996/21967 [13:00<18:55, 11.42it/s]

Training loss: 0.8058297185948934, global step: 1799
Training loss: 0.8058202012037058, global step: 1799
Training loss: 0.8058397370611391, global step: 1799


 41%|████      | 9000/21967 [13:00<19:04, 11.33it/s]

Training loss: 0.8058228170901627, global step: 1799
Training loss: 0.8057923927343906, global step: 1799


 43%|████▎     | 9496/21967 [13:43<18:12, 11.42it/s]

Training loss: 0.8062946694584788, global step: 1899
Training loss: 0.8062527681851529, global step: 1899
Training loss: 0.8061905157194051, global step: 1899


 43%|████▎     | 9500/21967 [13:43<18:21, 11.32it/s]

Training loss: 0.8062121550770032, global step: 1899
Training loss: 0.8062078586891992, global step: 1899


 46%|████▌     | 9996/21967 [14:26<17:28, 11.42it/s]

Training loss: 0.803909046318067, global step: 1999
Training loss: 0.8038812930412105, global step: 1999
Training loss: 0.8038994903381174, global step: 1999


 46%|████▌     | 10000/21967 [14:26<17:36, 11.33it/s]

Training loss: 0.8039427578529732, global step: 1999
Training loss: 0.8038791639499157, global step: 1999


 48%|████▊     | 10496/21967 [15:10<16:44, 11.42it/s]

Training loss: 0.802614205367881, global step: 2099
Training loss: 0.8026126305304365, global step: 2099
Training loss: 0.802666225642281, global step: 2099


 48%|████▊     | 10500/21967 [15:10<16:52, 11.32it/s]

Training loss: 0.8027156851277292, global step: 2099
Training loss: 0.8026861439727183, global step: 2099


 50%|█████     | 10996/21967 [15:53<16:01, 11.41it/s]

Training loss: 0.8032708255678294, global step: 2199
Training loss: 0.8032420659464878, global step: 2199
Training loss: 0.8032305275309467, global step: 2199


 50%|█████     | 11000/21967 [15:53<16:08, 11.32it/s]

Training loss: 0.8032217888111748, global step: 2199
Training loss: 0.8032831373567857, global step: 2199


 52%|█████▏    | 11496/21967 [16:36<15:16, 11.42it/s]

Training loss: 0.8038168242981429, global step: 2299
Training loss: 0.8037797831820043, global step: 2299
Training loss: 0.8037668504306931, global step: 2299


 52%|█████▏    | 11500/21967 [16:37<15:24, 11.33it/s]

Training loss: 0.8037414228183033, global step: 2299
Training loss: 0.8037312654774311, global step: 2299


 55%|█████▍    | 11996/21967 [17:20<14:33, 11.42it/s]

Training loss: 0.8030990498203345, global step: 2399
Training loss: 0.8030685539419162, global step: 2399
Training loss: 0.8030858543398948, global step: 2399


 55%|█████▍    | 12000/21967 [17:20<14:40, 11.32it/s]

Training loss: 0.803068178741733, global step: 2399
Training loss: 0.8031071373971316, global step: 2399


 57%|█████▋    | 12496/21967 [18:03<13:49, 11.42it/s]

Training loss: 0.8031744187129062, global step: 2499
Training loss: 0.8031842701228589, global step: 2499
Training loss: 0.8031971916820768, global step: 2499


 57%|█████▋    | 12500/21967 [18:03<13:56, 11.32it/s]

Training loss: 0.8032079619725256, global step: 2499
Training loss: 0.8032111608021599, global step: 2499


 59%|█████▉    | 12996/21967 [18:46<13:06, 11.41it/s]

Training loss: 0.8029099540193696, global step: 2599
Training loss: 0.8028690475010519, global step: 2599
Training loss: 0.8028384612726587, global step: 2599


 59%|█████▉    | 13000/21967 [18:46<13:12, 11.32it/s]

Training loss: 0.8028147194615076, global step: 2599
Training loss: 0.8028219988400104, global step: 2599


 61%|██████▏   | 13496/21967 [19:29<12:23, 11.39it/s]

Training loss: 0.8034319153197024, global step: 2699
Training loss: 0.8033900874144898, global step: 2699
Training loss: 0.8034200407025043, global step: 2699


 61%|██████▏   | 13500/21967 [19:30<12:29, 11.30it/s]

Training loss: 0.8034191712043525, global step: 2699
Training loss: 0.8034454691278649, global step: 2699


 64%|██████▎   | 13996/21967 [20:13<11:37, 11.43it/s]

Training loss: 0.8038489294938793, global step: 2799
Training loss: 0.8038447494557376, global step: 2799
Training loss: 0.8038241084869463, global step: 2799


 64%|██████▎   | 14000/21967 [20:13<11:43, 11.33it/s]

Training loss: 0.8037983145718351, global step: 2799
Training loss: 0.8037665904735698, global step: 2799


 66%|██████▌   | 14496/21967 [20:56<10:54, 11.42it/s]

Training loss: 0.8035599146763048, global step: 2899
Training loss: 0.8035465356816485, global step: 2899
Training loss: 0.8035411103646884, global step: 2899


 66%|██████▌   | 14500/21967 [20:56<10:59, 11.33it/s]

Training loss: 0.803527682123261, global step: 2899
Training loss: 0.8035149554669275, global step: 2899


 68%|██████▊   | 14996/21967 [21:39<10:12, 11.38it/s]

Training loss: 0.8029301096741439, global step: 2999
Training loss: 0.8029010037698545, global step: 2999
Training loss: 0.802899376982487, global step: 2999


 68%|██████▊   | 15000/21967 [21:40<10:16, 11.30it/s]

Training loss: 0.8028880619230465, global step: 2999
Training loss: 0.8028417093320043, global step: 2999


 71%|███████   | 15496/21967 [22:23<09:42, 11.12it/s]

Training loss: 0.8039612426562881, global step: 3099
Training loss: 0.8039628273550715, global step: 3099
Training loss: 0.8039335477684936, global step: 3099


 71%|███████   | 15500/21967 [22:23<09:48, 10.99it/s]

Training loss: 0.8039076710270077, global step: 3099
Training loss: 0.8038933560224789, global step: 3099


 73%|███████▎  | 15996/21967 [23:07<08:42, 11.43it/s]

Training loss: 0.8043352818651641, global step: 3199
Training loss: 0.8042977641181283, global step: 3199
Training loss: 0.8042854438364998, global step: 3199


 73%|███████▎  | 16000/21967 [23:07<08:46, 11.33it/s]

Training loss: 0.8042705700051537, global step: 3199
Training loss: 0.8042854446813225, global step: 3199


 75%|███████▌  | 16496/21967 [23:50<07:58, 11.43it/s]

Training loss: 0.8045524091650931, global step: 3299
Training loss: 0.8046227921928242, global step: 3299
Training loss: 0.8046272692203453, global step: 3299


 75%|███████▌  | 16500/21967 [23:50<08:02, 11.34it/s]

Training loss: 0.8046325251971744, global step: 3299
Training loss: 0.8046186503179381, global step: 3299


 77%|███████▋  | 16996/21967 [24:33<07:22, 11.23it/s]

Training loss: 0.8045325372831703, global step: 3399
Training loss: 0.8045461648697273, global step: 3399
Training loss: 0.8045729626761998, global step: 3399


 77%|███████▋  | 17000/21967 [24:33<07:29, 11.05it/s]

Training loss: 0.8045379096726075, global step: 3399
Training loss: 0.8045069609652377, global step: 3399


 80%|███████▉  | 17496/21967 [25:17<06:31, 11.41it/s]

Training loss: 0.8037238264020408, global step: 3499
Training loss: 0.8037352676381568, global step: 3499
Training loss: 0.8037189658335848, global step: 3499


 80%|███████▉  | 17500/21967 [25:17<06:34, 11.32it/s]

Training loss: 0.8037031526624192, global step: 3499
Training loss: 0.8037129579674698, global step: 3499


 82%|████████▏ | 17996/21967 [26:00<05:47, 11.41it/s]

Training loss: 0.804868738170352, global step: 3599
Training loss: 0.8048969639748191, global step: 3599
Training loss: 0.8049252259020762, global step: 3599


 82%|████████▏ | 18000/21967 [26:00<05:50, 11.33it/s]

Training loss: 0.8048971147931667, global step: 3599
Training loss: 0.8049531239563781, global step: 3599


 84%|████████▍ | 18496/21967 [26:43<05:03, 11.42it/s]

Training loss: 0.8053380473486417, global step: 3699
Training loss: 0.8053152324644096, global step: 3699
Training loss: 0.8053135581961258, global step: 3699


 84%|████████▍ | 18500/21967 [26:44<05:06, 11.33it/s]

Training loss: 0.8053094675171956, global step: 3699
Training loss: 0.8052951752587921, global step: 3699


 86%|████████▋ | 18996/21967 [27:26<04:20, 11.42it/s]

Training loss: 0.8050855069369015, global step: 3799
Training loss: 0.8050581383847162, global step: 3799
Training loss: 0.805036864059458, global step: 3799


 86%|████████▋ | 19000/21967 [27:27<04:21, 11.33it/s]

Training loss: 0.8050160392941791, global step: 3799
Training loss: 0.8049977622431922, global step: 3799


 89%|████████▉ | 19496/21967 [28:10<03:36, 11.41it/s]

Training loss: 0.8056521426331205, global step: 3899
Training loss: 0.8056239997861073, global step: 3899
Training loss: 0.8056118009502004, global step: 3899


 89%|████████▉ | 19500/21967 [28:10<03:38, 11.30it/s]

Training loss: 0.8055869291645128, global step: 3899
Training loss: 0.80559934461144, global step: 3899


 91%|█████████ | 19996/21967 [28:53<02:52, 11.42it/s]

Training loss: 0.8056261837284963, global step: 3999
Training loss: 0.8056272506072244, global step: 3999
Training loss: 0.8056035770223774, global step: 3999


 91%|█████████ | 20000/21967 [28:53<02:53, 11.33it/s]

Training loss: 0.8055707634494553, global step: 3999
Training loss: 0.8055518028593682, global step: 3999


 93%|█████████▎| 20496/21967 [29:36<02:08, 11.40it/s]

Training loss: 0.8050719444525696, global step: 4099
Training loss: 0.8050418909285421, global step: 4099
Training loss: 0.8050501748176865, global step: 4099


 93%|█████████▎| 20500/21967 [29:37<02:09, 11.31it/s]

Training loss: 0.8050242497743191, global step: 4099
Training loss: 0.8050156750687782, global step: 4099


 96%|█████████▌| 20996/21967 [30:19<01:25, 11.42it/s]

Training loss: 0.8051061276640326, global step: 4199
Training loss: 0.8051078341566894, global step: 4199
Training loss: 0.8051113632386799, global step: 4199


 96%|█████████▌| 21000/21967 [30:20<01:25, 11.32it/s]

Training loss: 0.8051279668841658, global step: 4199
Training loss: 0.8051415056970062, global step: 4199


 98%|█████████▊| 21496/21967 [31:03<00:41, 11.42it/s]

Training loss: 0.8048511686545081, global step: 4299
Training loss: 0.8048574665276801, global step: 4299
Training loss: 0.8048529376433707, global step: 4299


 98%|█████████▊| 21500/21967 [31:03<00:41, 11.32it/s]

Training loss: 0.8048411937190197, global step: 4299
Training loss: 0.8048165082358014, global step: 4299


100%|██████████| 21967/21967 [31:43<00:00, 11.54it/s]


begin evaluation!!!


  0%|          | 2/21967 [00:00<32:40, 11.20it/s]

eval_loss: 0.9019797664436623
eval_accuracy: 0.6608635097493036
Training Epoch: 4/5


  2%|▏         | 496/21967 [00:42<31:22, 11.41it/s]

Training loss: 0.5759044480829905, global step: 99
Training loss: 0.5751140156184368, global step: 99
Training loss: 0.5749174207833107, global step: 99


  2%|▏         | 500/21967 [00:43<31:36, 11.32it/s]

Training loss: 0.573934441971888, global step: 99
Training loss: 0.5738435174032434, global step: 99


  5%|▍         | 996/21967 [01:26<30:37, 11.42it/s]

Training loss: 0.5608805350879741, global step: 199
Training loss: 0.5604209195271725, global step: 199
Training loss: 0.5603847128550224, global step: 199


  5%|▍         | 1000/21967 [01:26<30:51, 11.32it/s]

Training loss: 0.5600586681396239, global step: 199
Training loss: 0.5607704120805604, global step: 199


  7%|▋         | 1496/21967 [02:09<29:56, 11.40it/s]

Training loss: 0.5560381052383288, global step: 299
Training loss: 0.5563271459547509, global step: 299
Training loss: 0.5565656188551309, global step: 299


  7%|▋         | 1500/21967 [02:09<30:11, 11.30it/s]

Training loss: 0.5563572868403266, global step: 299
Training loss: 0.5562040916354513, global step: 299


  9%|▉         | 1996/21967 [02:52<29:08, 11.42it/s]

Training loss: 0.5562175229268806, global step: 399
Training loss: 0.5563440674839855, global step: 399
Training loss: 0.5562488858328744, global step: 399


  9%|▉         | 2000/21967 [02:52<29:21, 11.34it/s]

Training loss: 0.556038860627147, global step: 399
Training loss: 0.5562796032784405, global step: 399


 11%|█▏        | 2496/21967 [03:36<29:12, 11.11it/s]

Training loss: 0.5490636726727481, global step: 499
Training loss: 0.5488531312493694, global step: 499
Training loss: 0.5487654575185524, global step: 499


 11%|█▏        | 2500/21967 [03:36<29:32, 10.98it/s]

Training loss: 0.5485922935703251, global step: 499
Training loss: 0.5486571956010369, global step: 499


 14%|█▎        | 2996/21967 [04:20<28:28, 11.10it/s]

Training loss: 0.5510401433643158, global step: 599
Training loss: 0.5509554213797301, global step: 599
Training loss: 0.550799783666189, global step: 599


 14%|█▎        | 3000/21967 [04:20<28:45, 10.99it/s]

Training loss: 0.5506971100790412, global step: 599
Training loss: 0.5505660522970451, global step: 599


 16%|█▌        | 3496/21967 [05:03<26:57, 11.42it/s]

Training loss: 0.546422093036386, global step: 699
Training loss: 0.5464826124405661, global step: 699
Training loss: 0.5466774678677224, global step: 699


 16%|█▌        | 3500/21967 [05:04<27:10, 11.32it/s]

Training loss: 0.546545639550931, global step: 699
Training loss: 0.5464032603824093, global step: 699


 18%|█▊        | 3996/21967 [05:46<26:15, 11.41it/s]

Training loss: 0.5469509089207661, global step: 799
Training loss: 0.546884933961856, global step: 799
Training loss: 0.5473167192741827, global step: 799


 18%|█▊        | 4000/21967 [05:47<26:27, 11.32it/s]

Training loss: 0.5473980394979052, global step: 799
Training loss: 0.5476650433088172, global step: 799


 20%|██        | 4496/21967 [06:30<25:32, 11.40it/s]

Training loss: 0.5538570875493091, global step: 899
Training loss: 0.5538217114304624, global step: 899
Training loss: 0.5538184165875167, global step: 899


 20%|██        | 4500/21967 [06:30<25:44, 11.31it/s]

Training loss: 0.5537423445866868, global step: 899
Training loss: 0.5536366582079206, global step: 899


 23%|██▎       | 4996/21967 [07:14<24:45, 11.42it/s]

Training loss: 0.5553472417719509, global step: 999
Training loss: 0.5552648736214398, global step: 999
Training loss: 0.5552247121654006, global step: 999


 23%|██▎       | 5000/21967 [07:14<24:57, 11.33it/s]

Training loss: 0.5553544298692119, global step: 999
Training loss: 0.5552465583072111, global step: 999


 25%|██▌       | 5496/21967 [07:57<24:02, 11.42it/s]

Training loss: 0.5558018298909064, global step: 1099
Training loss: 0.5557256224368217, global step: 1099
Training loss: 0.5557570797254426, global step: 1099


 25%|██▌       | 5500/21967 [07:57<24:14, 11.32it/s]

Training loss: 0.555684560868493, global step: 1099
Training loss: 0.5555851185945586, global step: 1099


 27%|██▋       | 5996/21967 [08:40<23:19, 11.41it/s]

Training loss: 0.5527323851635578, global step: 1199
Training loss: 0.5528499825568658, global step: 1199
Training loss: 0.5527942728716473, global step: 1199


 27%|██▋       | 6000/21967 [08:40<23:30, 11.32it/s]

Training loss: 0.5528516949517444, global step: 1199
Training loss: 0.5529140511839683, global step: 1199


 30%|██▉       | 6496/21967 [09:23<22:35, 11.41it/s]

Training loss: 0.5550814693455521, global step: 1299
Training loss: 0.5550701054832146, global step: 1299
Training loss: 0.5549938932139732, global step: 1299


 30%|██▉       | 6500/21967 [09:24<22:46, 11.32it/s]

Training loss: 0.5551599256997574, global step: 1299
Training loss: 0.5552097868332503, global step: 1299


 32%|███▏      | 6996/21967 [10:06<21:51, 11.42it/s]

Training loss: 0.5547302712428649, global step: 1399
Training loss: 0.5547665513566201, global step: 1399
Training loss: 0.554794755314808, global step: 1399


 32%|███▏      | 7000/21967 [10:07<22:02, 11.32it/s]

Training loss: 0.5547722199940708, global step: 1399
Training loss: 0.5547118059314565, global step: 1399


 34%|███▍      | 7496/21967 [10:50<21:07, 11.42it/s]

Training loss: 0.5552179007766674, global step: 1499
Training loss: 0.5552149756751596, global step: 1499
Training loss: 0.5552382186516979, global step: 1499


 34%|███▍      | 7500/21967 [10:50<21:17, 11.33it/s]

Training loss: 0.5552238701409405, global step: 1499
Training loss: 0.5551525421771125, global step: 1499


 36%|███▋      | 7996/21967 [11:33<20:23, 11.41it/s]

Training loss: 0.5564662321867248, global step: 1599
Training loss: 0.5564324406231576, global step: 1599
Training loss: 0.5564339866303091, global step: 1599


 36%|███▋      | 8000/21967 [11:33<20:33, 11.32it/s]

Training loss: 0.5563650021597737, global step: 1599
Training loss: 0.5563263574163068, global step: 1599


 39%|███▊      | 8496/21967 [12:17<20:12, 11.11it/s]

Training loss: 0.5559451256842975, global step: 1699
Training loss: 0.5559304120787877, global step: 1699
Training loss: 0.5558911244554436, global step: 1699


 39%|███▊      | 8500/21967 [12:17<20:25, 10.99it/s]

Training loss: 0.5558691026471227, global step: 1699
Training loss: 0.5558355080028606, global step: 1699


 41%|████      | 8996/21967 [13:00<18:56, 11.41it/s]

Training loss: 0.5544170998902891, global step: 1799
Training loss: 0.5543866656537229, global step: 1799
Training loss: 0.5543730548079617, global step: 1799


 41%|████      | 9000/21967 [13:00<19:05, 11.32it/s]

Training loss: 0.5543934691902335, global step: 1799
Training loss: 0.5543611423903686, global step: 1799


 43%|████▎     | 9496/21967 [13:43<18:12, 11.42it/s]

Training loss: 0.5546478416629883, global step: 1899
Training loss: 0.5546369983682881, global step: 1899
Training loss: 0.5546590748145496, global step: 1899


 43%|████▎     | 9500/21967 [13:44<18:19, 11.34it/s]

Training loss: 0.5546195809397422, global step: 1899
Training loss: 0.5546324616138696, global step: 1899


 46%|████▌     | 9996/21967 [14:27<17:56, 11.12it/s]

Training loss: 0.555222361371439, global step: 1999
Training loss: 0.555325683784198, global step: 1999
Training loss: 0.5553526281485494, global step: 1999


 46%|████▌     | 10000/21967 [14:27<17:51, 11.17it/s]

Training loss: 0.555298753767452, global step: 1999
Training loss: 0.5552942977307868, global step: 1999


 48%|████▊     | 10496/21967 [15:10<16:45, 11.40it/s]

Training loss: 0.5570324169407468, global step: 2099
Training loss: 0.5570392725869291, global step: 2099
Training loss: 0.5570408016150146, global step: 2099


 48%|████▊     | 10500/21967 [15:10<16:53, 11.32it/s]

Training loss: 0.557051077330583, global step: 2099
Training loss: 0.5570509256237477, global step: 2099


 50%|█████     | 10996/21967 [15:53<16:01, 11.40it/s]

Training loss: 0.5569176853035082, global step: 2199
Training loss: 0.5568916299900345, global step: 2199
Training loss: 0.5569028622858025, global step: 2199


 50%|█████     | 11000/21967 [15:54<16:08, 11.32it/s]

Training loss: 0.5569516253149669, global step: 2199
Training loss: 0.5569484732338533, global step: 2199


 52%|█████▏    | 11496/21967 [16:37<15:15, 11.44it/s]

Training loss: 0.5588993639806288, global step: 2299
Training loss: 0.5589177984365606, global step: 2299
Training loss: 0.5589066941973597, global step: 2299


 52%|█████▏    | 11500/21967 [16:37<15:23, 11.34it/s]

Training loss: 0.5589333305565097, global step: 2299
Training loss: 0.5589722200954501, global step: 2299


 55%|█████▍    | 11996/21967 [17:20<14:33, 11.42it/s]

Training loss: 0.5592238387624945, global step: 2399
Training loss: 0.5592003220513239, global step: 2399
Training loss: 0.5592854334071277, global step: 2399


 55%|█████▍    | 12000/21967 [17:21<14:39, 11.33it/s]

Training loss: 0.5593065910364701, global step: 2399
Training loss: 0.5595132608630585, global step: 2399


 57%|█████▋    | 12496/21967 [18:03<13:49, 11.42it/s]

Training loss: 0.5602668838199217, global step: 2499
Training loss: 0.5603038450220511, global step: 2499
Training loss: 0.5602980814497222, global step: 2499


 57%|█████▋    | 12500/21967 [18:04<13:56, 11.32it/s]

Training loss: 0.560279730656898, global step: 2499
Training loss: 0.5603142740971498, global step: 2499


 59%|█████▉    | 12996/21967 [18:47<13:05, 11.42it/s]

Training loss: 0.5607151099142551, global step: 2599
Training loss: 0.5607147132309728, global step: 2599
Training loss: 0.5607049779706422, global step: 2599


 59%|█████▉    | 13000/21967 [18:47<13:11, 11.33it/s]

Training loss: 0.5606727951801751, global step: 2599
Training loss: 0.5606598330724335, global step: 2599


 60%|██████    | 13182/21967 [19:03<12:36, 11.62it/s]

In [None]:
test_path='data/test/middle'
is_training_test=False
test_input_cloze,test_input_direct=to_input(test_path)
test_sequences_cloze=input_to_sequence_cloze(test_input_cloze,max_seq_length,tokenizer,is_training_test)
test_sequences_direct=input_to_sequence_question(test_input_direct,max_seq_length,tokenizer,is_training_test)
test_sequences=test_sequences_cloze+test_sequences_direct
test_batch_size=1

test_input_sequence=to_all_tensor(test_sequences,'input_sequence')
test_segment_id=to_all_tensor(test_sequences,'segment_id')
test_mask_id=to_all_tensor(test_sequences,'mask_id')
test_label=torch.tensor([case.label for case in test_sequences],dtype=torch.long)

test_dataset=TensorDataset(test_input_sequence,test_segment_id,test_mask_id,test_label)
test_sampler=RandomSampler(test_dataset)
test_data=DataLoader(test_dataset,sampler=test_sampler, batch_size=batch_size)





In [None]:
test_path2='data/test/high'
is_training_test=False
test_input2_cloze,test_input2_direct=to_input(test_path2)
test_sequences2_cloze=input_to_sequence_cloze(test_input2_cloze,max_seq_length,tokenizer,is_training_test)
test_sequences2_direct=input_to_sequence_question(test_input2_direct,max_seq_length,tokenizer,is_training_test)
test_sequences2=test_sequences2_cloze+test_sequences2_direct
test_batch_size=1

test_input_sequence2=to_all_tensor(test_sequences2,'input_sequence')
test_segment_id2=to_all_tensor(test_sequences2,'segment_id')
test_mask_id2=to_all_tensor(test_sequences2,'mask_id')
test_label2=torch.tensor([case.label for case in test_sequences2],dtype=torch.long)

test_dataset2=TensorDataset(test_input_sequence2,test_segment_id2,test_mask_id2,test_label2)
test_sampler2=RandomSampler(test_dataset2)
test_data2=DataLoader(test_dataset2,sampler=test_sampler2, batch_size=batch_size)

In [None]:
def test_race(test_dataloader,device,model,global_step):
    print("begin testing!!!")
    test_loss, test_accuracy = 0, 0
    nb_test_steps, nb_test_examples = 0, 0
    for step, batch in enumerate(test_dataloader):
        batch = tuple(t.to(device) for t in batch)
        input_ids, segment_ids, input_mask, label_ids = batch

        with torch.no_grad():
            tmp_test_loss = model(input_ids, segment_ids, input_mask, label_ids)
            logits = model(input_ids, segment_ids, input_mask)

        logits = logits.detach().cpu().numpy()
        label_ids = label_ids.to('cpu').numpy()
        tmp_test_accuracy = accuracy(logits, label_ids)

        test_loss += tmp_test_loss.mean().item()
        test_accuracy += tmp_test_accuracy

        nb_test_examples += input_ids.size(0)
        nb_test_steps += 1

    test_loss = test_loss / nb_test_steps
    test_accuracy = test_accuracy / nb_test_examples
    print("test_loss:",test_loss)
    print("test_accuracy:",test_accuracy)

    result = {'dev_test_loss': test_loss,
              'dev_test_accuracy': test_accuracy,
              'global_step': global_step}

    output_test_file = os.path.join('output', "test_results.txt")
    with open(output_test_file, "a+") as writer:
        for key in sorted(result.keys()):
            writer.write("%s = %s\n" % (key, str(result[key])))

In [None]:
model.eval()
test_race(test_data, device, model, global_step)

In [None]:
model.eval()
test_race(test_data2, device, model, global_step)

In [None]:
test_cloze=test_sequences_cloze+test_sequences2_cloze
test_direct=test_sequences_direct+test_sequences2_direct

test_input_sequence_cloze=to_all_tensor(test_cloze,'input_sequence')
test_segment_id_cloze=to_all_tensor(test_cloze,'segment_id')
test_mask_id_cloze=to_all_tensor(test_cloze,'mask_id')
test_label_cloze=torch.tensor([case.label for case in test_cloze],dtype=torch.long)

test_dataset_cloze=TensorDataset(test_input_sequence_cloze,test_segment_id_cloze,test_mask_id_cloze,test_label_cloze)
test_sampler_cloze=RandomSampler(test_dataset_cloze)
test_data_cloze=DataLoader(test_dataset_cloze,sampler=test_sampler_cloze, batch_size=batch_size)

test_input_sequence_direct=to_all_tensor(test_direct,'input_sequence')
test_segment_id_direct=to_all_tensor(test_direct,'segment_id')
test_mask_id_direct=to_all_tensor(test_direct,'mask_id')
test_label_direct=torch.tensor([case.label for case in test_direct],dtype=torch.long)

test_dataset_direct=TensorDataset(test_input_sequence_direct,test_segment_id_direct,test_mask_id_direct,test_label_direct)
test_sampler_direct=RandomSampler(test_dataset_direct)
test_data_direct=DataLoader(test_dataset_direct,sampler=test_sampler_direct, batch_size=batch_size)

In [34]:
model.eval()
test_race(test_data_cloze,device,model,global_step)

begin testing!!!
test_loss: 1.447483023742803
test_accuracy: 0.5973484848484848


In [35]:
model.eval()
test_race(test_data_direct,device,model,global_step)

begin testing!!!
test_loss: 1.4676962107669835
test_accuracy: 0.5688753269398431
