# RoBERTa를 이용한 한국어 자연어추론(NLI)
- 사전학습 모델 : KLUE-RoBERTa (MODU, CC-100-Kor, NAMUWIKI, NEWSCRAWL, PETITION)
- 데이터 : KLUE-NLI (WIKITREE, POLICY, WIKINEWS, WIKIPEDIA, NSMC and AIRBNB)

# 사전 준비

In [None]:
!pip install transformers
!pip install datasets

**KLUE-NLI 데이터 불러오기**

In [2]:
from datasets import load_dataset

datasets = load_dataset("klue", "nli")

Reusing dataset klue (/root/.cache/huggingface/datasets/klue/nli/1.0.0/e0fc3bc3de3eb03be2c92d72fd04a60ecc71903f821619cb28ca0e1e29e4233e)


  0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
datasets

DatasetDict({
    train: Dataset({
        features: ['guid', 'source', 'premise', 'hypothesis', 'label'],
        num_rows: 24998
    })
    validation: Dataset({
        features: ['guid', 'source', 'premise', 'hypothesis', 'label'],
        num_rows: 3000
    })
})

In [4]:
# label 0: entailment(함의) / 1: neutral(중립) / 2: contradiction(모순)
print(datasets["train"][0])
print(datasets["validation"][0])

{'guid': 'klue-nli-v1_train_00000', 'source': 'NSMC', 'premise': '힛걸 진심 최고다 그 어떤 히어로보다 멋지다', 'hypothesis': '힛걸 진심 최고로 멋지다.', 'label': 0}
{'guid': 'klue-nli-v1_dev_00000', 'source': 'airbnb', 'premise': '흡연자분들은 발코니가 있는 방이면 발코니에서 흡연이 가능합니다.', 'hypothesis': '어떤 방에서도 흡연은 금지됩니다.', 'label': 2}


**KLUE-RoBERTa 모델과 토크나이저 불러오기**

In [5]:
from transformers import AutoModel, AutoTokenizer

roberta_model = AutoModel.from_pretrained("klue/roberta-base")
tokenizer = AutoTokenizer.from_pretrained("klue/roberta-base")

Some weights of the model checkpoint at klue/roberta-base were not used when initializing RobertaModel: ['lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'lm_head.decoder.bias', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaModel were not initialized from the model checkpoint at klue/roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for

In [6]:
roberta_model.config

RobertaConfig {
  "_name_or_path": "klue/roberta-base",
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "roberta",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "tokenizer_class": "BertTokenizer",
  "transformers_version": "4.20.1",
  "type_vocab_size": 1,
  "use_cache": true,
  "vocab_size": 32000
}

# 토크나이징, 데이터 구축

**스페셜 토큰 확인**

In [7]:
for i in range (10):
    print("index : ",i," =  tokens : ",tokenizer.decode(i))

index :  0  =  tokens :  [CLS]
index :  1  =  tokens :  [PAD]
index :  2  =  tokens :  [SEP]
index :  3  =  tokens :  [UNK]
index :  4  =  tokens :  [MASK]
index :  5  =  tokens :  !
index :  6  =  tokens :  "
index :  7  =  tokens :  #
index :  8  =  tokens :  $
index :  9  =  tokens :  %


**[CLS] 전제 [SEP] 가설 [SEP] [PAD]...**

In [8]:
import torch
from torch.utils.data import Dataset

In [69]:
class NLIDataset(Dataset):
    def __init__(self, data, max_len=128):  # 데이터셋의 전처리를 해주는 부분
        self._data = data
        self.max_len = max_len
        self.bos = tokenizer.bos_token      # [CLS]
        self.eos = tokenizer.eos_token      # [SEP]
        self.pad = tokenizer.pad_token      # [PAD]
        self.sep = tokenizer.sep_token      # [SEP]
        self.tokenizer = tokenizer
    
    def __len__(self):
        return len(self._data)

    def __getitem__(self, idx):  # 로드한 데이터를 차례차례 DataLoader로 넘겨주는 메서드
        index = self._data[idx]

        p = index["premise"]  # 전제
        p_toked = self.tokenizer.tokenize(self.bos + p + self.sep)      # [CLS] 전제 [SEP]
        p_len = len(p_toked)

        h = index["hypothesis"]  # 가설
        h_toked = self.tokenizer.tokenize(h + self.eos)      # 가설 [SEP]
        h_len = len(p_toked)

            # 전제 + 가설 길이가 최대길이보다 클때
        while (p_len + h_len) > self.max_len:    
            h_len = self.max_len - p_len        # 가설의 길이 = 최대길이 - 전제길이

            if h_len <= 0:       # 전제의 길이가 너무 길어 전제만으로 최대 길이를 초과 한다면
                p_toked = p_toked[-(int(self.max_len / 2)) :]   # 전제길이를 최대길이의 반으로 
                p_len = len(p_toked)
                h_len = self.max_len - p_len              # 가설의 길이를 최대길이 - 전제길이
                        
            h_toked = h_toked[:h_len]
            h_len = len(h_toked)

        # 전제 + 가설 토큰을 index로 변환   
        token_ids = self.tokenizer.convert_tokens_to_ids(p_toked + h_toked)

        # 최대 길이만큼 padding
        while len(token_ids) < self.max_len:
            token_ids += [self.tokenizer.pad_token_id]

        # attention_mask(어텐션마스크) = 전제 + 가설 길이 1 + 나머지(패딩) 0
        attention_mask = [1]*(p_len + h_len) + [0]*(self.max_len - p_len - h_len)

        # label = 0: entailment(함의) / 1: neutral(중립) / 2: contradiction(모순)
        label = index["label"]

        # 전제+가설 + 답변, 어텐션마스크, label
        return (token_ids, attention_mask, label)

**데이터셋 구축** <br>
구성 : (token_ids, attention_mask, token_type_ids, label)

In [70]:
# 훈련 데이터셋
train_dataset = NLIDataset(datasets["train"])

for n in range(5):
    print("train_dataset[",n,"]")
    print("token_ids      : ", train_dataset[n][0])
    print("attention_mask : ", train_dataset[n][1])
    print("label          : ", train_dataset[n][2], "\n")

train_dataset[ 0 ]
token_ids      :  [0, 3, 7254, 3841, 2062, 636, 3711, 12717, 2178, 2062, 11980, 2062, 2, 3, 7254, 3841, 2200, 11980, 2062, 18, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
attention_mask :  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
label          :  0 

train_dataset[ 1 ]
token_ids      :  [0, 3911, 2377, 2366, 1521, 3061, 4785, 1282, 2955, 3308, 3515, 2170

In [71]:
# 검증 데이터셋
val_dataset = NLIDataset(datasets["validation"])

for n in range(5):
    print("val_dataset[",n,"]")
    print("token_ids      : ", val_dataset[n][0])
    print("attention_mask : ", val_dataset[n][1])
    print("label          : ", val_dataset[n][2],"\n")

val_dataset[ 0 ]
token_ids      :  [0, 25313, 2377, 2031, 2073, 20812, 2116, 1513, 2259, 1129, 24094, 20812, 27135, 9753, 2052, 3662, 11800, 18, 2, 3711, 1129, 27135, 2119, 9753, 2073, 5040, 3598, 3606, 18, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
attention_mask :  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
label          :  2 

val_dataset[ 1 ]
token_ids      :  [0, 3633, 2211, 2052, 3655, 3704, 31

**데이터로더 구축**

In [72]:
# collate_fn 구성
def collate_batch(batch):
    token_ids = [item[:][0] for item in batch]
    attention_mask = [item[:][1] for item in batch]
    label_ids = [item[:][2] for item in batch]

    return torch.cuda.LongTensor(token_ids), torch.cuda.LongTensor(attention_mask), torch.cuda.LongTensor(label_ids)

In [73]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn = collate_batch, batch_size=8)
val_dataloader = DataLoader(val_dataset, collate_fn = collate_batch, batch_size=16)

In [74]:
# 데이터로더 확인
sample_data = iter(train_dataloader)
sample_ids = next(sample_data)

token_ids, attention_mask, label_ids = sample_ids

print("first item of batch (train_dataloader)")
print("token_ids \n", token_ids[:][0],"batch size : ", token_ids.size(),"\n")
print("attention_mask \n", attention_mask[:][0], "batch size : ", attention_mask.size(),"\n")
print("label_ids \n", label_ids[:][0], "batch size : ", label_ids.size())

first item of batch (train_dataloader)
token_ids 
 tensor([    0,  3717,  2052,  8451, 31129,  2259,  6636,  2020,  2170,  2318,
        17940,  8481,  2170,  1176,  2069,  9265,  2259,  1590,  2069,  3655,
        20651,  2088,  4735,  4538,    18,     2, 31129,  2522,  6636,  2234,
         2073,  8481,  2170,  1176,  2069,  1583,  4280,    18,     2,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1

# 모델 학습

**모델 정의**

In [75]:
# RoBERTa를 포함한 신경망 모형
class NLIModel(torch.nn.Module):
    def __init__(self, pretrained_model, token_size, num_labels): 
        super(NLIModel, self).__init__()
        
        self.token_size = token_size
        self.num_labels = num_labels
        self.pretrained_model = pretrained_model

        # 분류기 정의
        self.classifier = torch.nn.Linear(self.token_size, self.num_labels)

    def forward(self, input_ids, attention_mask):
        # BERT 모형에 입력을 넣고 출력을 받음
        outputs = self.pretrained_model(input_ids, attention_mask)
        # BERT 출력에서 CLS 토큰에 해당하는 부분만 가져옴
        bert_clf_token = outputs.last_hidden_state[:,0,:]
        # 3개의 라벨로 분류
        outputs = self.classifier(bert_clf_token)

        return outputs

# token_size는 BERT 토큰과 동일
model = NLIModel(roberta_model, token_size=roberta_model.config.hidden_size, num_labels=3)

**파라미터 설정**

In [76]:
from transformers import get_linear_schedule_with_warmup
import torch.nn.functional as F
import time

# GPU 가속을 사용할 수 있으면 device를 cuda로 설정하고, 아니면 cpu로 설정
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# 옵티마이저 AdamW로 설정
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.01) # 가중치 감쇠 설정
criterion = torch.nn.CrossEntropyLoss()    # 멀티클래스이므로 크로스 엔트로피를 손실함수로 사용 -> RoBERTa 코드 내 포함되어있음

num_epochs = 3      # 학습 epoch를 3회로 설정

total_training_steps = num_epochs * len(train_dataloader)
# 학습 스케줄러 설정
scheduler = get_linear_schedule_with_warmup(optimizer=optimizer,
                                            num_training_steps=total_training_steps,
                                            num_warmup_steps=200)

step = 0
eval_steps = 500

In [77]:
model.to(device)  
model.train()     # 학습모드

NLIModel(
  (pretrained_model): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(32000, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm

**학습 진행**

In [78]:
# GPU 캐시 비우기 (GPU 메모리 확보)
torch.cuda.empty_cache()

In [79]:
from tqdm import tqdm

for epoch in range(num_epochs):
    loss = 0
    train_loss = 0.0
    
    for batch_idx, samples in enumerate(tqdm(train_dataloader, mininterval=0.01)):
        optimizer.zero_grad()       # optimizer 초기화(Gradient)

        # 모델 입력 텐서 GPU에 올리기
        token_ids, attention_mask, label_ids = samples

        token_ids = token_ids.to(device)
        attention_mask = attention_mask.to(device)
        label_ids = label_ids.to(device)

        out = model(
            input_ids=token_ids,
            attention_mask=attention_mask,
            )

        out.argmax(dim=1)

        loss = criterion(out, label_ids)
        loss.backward()
        optimizer.step()
        scheduler.step()

        train_loss += loss.item()
        
        step += 1
        if step % eval_steps == 0:  # eval_steps 마다 loss를 출력
            with torch.no_grad():   # 학습 X (그래디언트 계산 X)
                val_loss = 0
                model.eval()        # 평가모드로 전환

                for val_batch_idx, val_samples in enumerate(tqdm(val_dataloader, mininterval=0.01)):

                    token_ids, attention_mask, label_ids = val_samples

                    token_ids = token_ids.to(device)
                    attention_mask = attention_mask.to(device)
                    label_ids = label_ids.to(device)
                    
                    out = model(
                        input_ids=token_ids,
                        attention_mask=attention_mask,
                        )

                    out.argmax(dim=1)

                    loss = criterion(out, label_ids)  
                    val_loss += loss

                avg_val_loss = val_loss / len(val_dataloader)

            avg_train_loss = train_loss / eval_steps    # eval_steps의 평균 loss 계산
            
            print('Step %d, train loss: %.4f, validation loss: %.4f' % (step, avg_train_loss, avg_val_loss))

 16%|█▌        | 499/3125 [01:39<08:53,  4.92it/s]
  0%|          | 0/188 [00:00<?, ?it/s][A
  1%|          | 1/188 [00:00<00:04, 44.15it/s][A
  1%|          | 2/188 [00:00<00:11, 16.28it/s][A
  2%|▏         | 3/188 [00:00<00:14, 12.61it/s][A
  2%|▏         | 4/188 [00:00<00:16, 11.41it/s][A
  3%|▎         | 5/188 [00:00<00:16, 10.85it/s][A
  3%|▎         | 6/188 [00:00<00:17, 10.39it/s][A
  4%|▎         | 7/188 [00:00<00:17, 10.33it/s][A
  4%|▍         | 8/188 [00:00<00:17, 10.19it/s][A
  5%|▍         | 9/188 [00:00<00:17, 10.14it/s][A
  5%|▌         | 10/188 [00:00<00:17, 10.00it/s][A
  6%|▌         | 11/188 [00:01<00:17,  9.94it/s][A
  6%|▋         | 12/188 [00:01<00:17,  9.99it/s][A
  7%|▋         | 13/188 [00:01<00:17, 10.02it/s][A
  7%|▋         | 14/188 [00:01<00:17,  9.98it/s][A
  8%|▊         | 15/188 [00:01<00:17, 10.04it/s][A
  9%|▊         | 16/188 [00:01<00:17, 10.04it/s][A
  9%|▉         | 17/188 [00:01<00:17, 10.02it/s][A
 10%|▉         | 18/188 [00:01<

Step 500, train loss: 1.0827, validation loss: 1.1162


 32%|███▏      | 999/3125 [03:38<07:06,  4.99it/s]
  0%|          | 0/188 [00:00<?, ?it/s][A
  1%|          | 1/188 [00:00<00:04, 44.09it/s][A
  1%|          | 2/188 [00:00<00:12, 15.38it/s][A
  2%|▏         | 3/188 [00:00<00:15, 12.27it/s][A
  2%|▏         | 4/188 [00:00<00:16, 11.06it/s][A
  3%|▎         | 5/188 [00:00<00:17, 10.54it/s][A
  3%|▎         | 6/188 [00:00<00:17, 10.17it/s][A
  4%|▎         | 7/188 [00:00<00:17, 10.06it/s][A
  4%|▍         | 8/188 [00:00<00:18,  9.92it/s][A
  5%|▍         | 9/188 [00:00<00:18,  9.91it/s][A
  5%|▌         | 10/188 [00:00<00:18,  9.65it/s][A
  6%|▌         | 11/188 [00:01<00:18,  9.67it/s][A
  6%|▋         | 12/188 [00:01<00:18,  9.64it/s][A
  7%|▋         | 13/188 [00:01<00:17,  9.78it/s][A
  7%|▋         | 14/188 [00:01<00:17,  9.80it/s][A
  8%|▊         | 15/188 [00:01<00:17,  9.82it/s][A
  9%|▊         | 16/188 [00:01<00:17,  9.78it/s][A
  9%|▉         | 17/188 [00:01<00:17,  9.74it/s][A
 10%|▉         | 18/188 [00:01<

Step 1000, train loss: 2.1858, validation loss: 1.0993


 48%|████▊     | 1499/3125 [05:39<05:31,  4.91it/s]
  0%|          | 0/188 [00:00<?, ?it/s][A
  1%|          | 1/188 [00:00<00:04, 44.54it/s][A
  1%|          | 2/188 [00:00<00:11, 15.51it/s][A
  2%|▏         | 3/188 [00:00<00:15, 12.02it/s][A
  2%|▏         | 4/188 [00:00<00:16, 10.97it/s][A
  3%|▎         | 5/188 [00:00<00:17, 10.51it/s][A
  3%|▎         | 6/188 [00:00<00:18, 10.06it/s][A
  4%|▎         | 7/188 [00:00<00:18,  9.98it/s][A
  4%|▍         | 8/188 [00:00<00:18,  9.79it/s][A
  5%|▍         | 9/188 [00:00<00:18,  9.77it/s][A
  5%|▌         | 10/188 [00:00<00:18,  9.50it/s][A
  6%|▌         | 11/188 [00:01<00:18,  9.51it/s][A
  6%|▋         | 12/188 [00:01<00:18,  9.60it/s][A
  7%|▋         | 13/188 [00:01<00:18,  9.59it/s][A
  7%|▋         | 14/188 [00:01<00:18,  9.65it/s][A
  8%|▊         | 15/188 [00:01<00:17,  9.68it/s][A
  9%|▊         | 16/188 [00:01<00:17,  9.61it/s][A
  9%|▉         | 17/188 [00:01<00:17,  9.74it/s][A
 10%|▉         | 18/188 [00:01

Step 1500, train loss: 3.2914, validation loss: 1.1079


 64%|██████▍   | 1999/3125 [07:40<03:49,  4.92it/s]
  0%|          | 0/188 [00:00<?, ?it/s][A
  1%|          | 1/188 [00:00<00:05, 33.96it/s][A
  1%|          | 2/188 [00:00<00:12, 15.40it/s][A
  2%|▏         | 3/188 [00:00<00:15, 12.10it/s][A
  2%|▏         | 4/188 [00:00<00:16, 10.91it/s][A
  3%|▎         | 5/188 [00:00<00:17, 10.39it/s][A
  3%|▎         | 6/188 [00:00<00:18, 10.02it/s][A
  4%|▎         | 7/188 [00:00<00:18,  9.93it/s][A
  4%|▍         | 8/188 [00:00<00:18,  9.65it/s][A
  5%|▍         | 9/188 [00:00<00:18,  9.72it/s][A
  5%|▌         | 10/188 [00:00<00:18,  9.38it/s][A
  6%|▌         | 11/188 [00:01<00:18,  9.46it/s][A
  6%|▋         | 12/188 [00:01<00:18,  9.51it/s][A
  7%|▋         | 13/188 [00:01<00:18,  9.58it/s][A
  7%|▋         | 14/188 [00:01<00:18,  9.63it/s][A
  8%|▊         | 15/188 [00:01<00:18,  9.60it/s][A
  9%|▊         | 16/188 [00:01<00:17,  9.62it/s][A
  9%|▉         | 17/188 [00:01<00:17,  9.63it/s][A
 10%|▉         | 18/188 [00:01

Step 2000, train loss: 4.3759, validation loss: 1.1017


 80%|███████▉  | 2499/3125 [09:41<02:07,  4.91it/s]
  0%|          | 0/188 [00:00<?, ?it/s][A
  1%|          | 1/188 [00:00<00:04, 38.41it/s][A
  1%|          | 2/188 [00:00<00:12, 15.46it/s][A
  2%|▏         | 3/188 [00:00<00:15, 12.12it/s][A
  2%|▏         | 4/188 [00:00<00:16, 10.88it/s][A
  3%|▎         | 5/188 [00:00<00:17, 10.41it/s][A
  3%|▎         | 6/188 [00:00<00:18,  9.95it/s][A
  4%|▎         | 7/188 [00:00<00:18,  9.87it/s][A
  4%|▍         | 8/188 [00:00<00:18,  9.62it/s][A
  5%|▍         | 9/188 [00:00<00:18,  9.68it/s][A
  5%|▌         | 10/188 [00:00<00:18,  9.39it/s][A
  6%|▌         | 11/188 [00:01<00:18,  9.56it/s][A
  6%|▋         | 12/188 [00:01<00:18,  9.59it/s][A
  7%|▋         | 13/188 [00:01<00:18,  9.59it/s][A
  7%|▋         | 14/188 [00:01<00:18,  9.65it/s][A
  8%|▊         | 15/188 [00:01<00:17,  9.71it/s][A
  9%|▊         | 16/188 [00:01<00:17,  9.66it/s][A
  9%|▉         | 17/188 [00:01<00:17,  9.65it/s][A
 10%|▉         | 18/188 [00:01

Step 2500, train loss: 5.4768, validation loss: 1.1169


 96%|█████████▌| 2999/3125 [11:43<00:25,  4.92it/s]
  0%|          | 0/188 [00:00<?, ?it/s][A
  1%|          | 1/188 [00:00<00:03, 48.74it/s][A
  1%|          | 2/188 [00:00<00:11, 15.53it/s][A
  2%|▏         | 3/188 [00:00<00:15, 12.28it/s][A
  2%|▏         | 4/188 [00:00<00:16, 11.03it/s][A
  3%|▎         | 5/188 [00:00<00:17, 10.47it/s][A
  3%|▎         | 6/188 [00:00<00:18, 10.11it/s][A
  4%|▎         | 7/188 [00:00<00:18,  9.94it/s][A
  4%|▍         | 8/188 [00:00<00:18,  9.73it/s][A
  5%|▍         | 9/188 [00:00<00:18,  9.75it/s][A
  5%|▌         | 10/188 [00:00<00:18,  9.50it/s][A
  6%|▌         | 11/188 [00:01<00:18,  9.55it/s][A
  6%|▋         | 12/188 [00:01<00:18,  9.56it/s][A
  7%|▋         | 13/188 [00:01<00:18,  9.64it/s][A
  7%|▋         | 14/188 [00:01<00:18,  9.66it/s][A
  8%|▊         | 15/188 [00:01<00:18,  9.56it/s][A
  9%|▊         | 16/188 [00:01<00:17,  9.70it/s][A
  9%|▉         | 17/188 [00:01<00:17,  9.70it/s][A
 10%|▉         | 18/188 [00:01

Step 3000, train loss: 6.5799, validation loss: 1.0991


100%|██████████| 3125/3125 [12:28<00:00,  4.18it/s]
 12%|█▏        | 374/3125 [01:16<09:19,  4.92it/s]
  0%|          | 0/188 [00:00<?, ?it/s][A
  1%|          | 1/188 [00:00<00:04, 44.10it/s][A
  1%|          | 2/188 [00:00<00:12, 15.47it/s][A
  2%|▏         | 3/188 [00:00<00:15, 12.14it/s][A
  2%|▏         | 4/188 [00:00<00:16, 10.89it/s][A
  3%|▎         | 5/188 [00:00<00:17, 10.43it/s][A
  3%|▎         | 6/188 [00:00<00:18,  9.92it/s][A
  4%|▎         | 7/188 [00:00<00:18,  9.73it/s][A
  4%|▍         | 8/188 [00:00<00:18,  9.69it/s][A
  5%|▍         | 9/188 [00:00<00:18,  9.66it/s][A
  5%|▌         | 10/188 [00:00<00:18,  9.38it/s][A
  6%|▌         | 11/188 [00:01<00:18,  9.54it/s][A
  6%|▋         | 12/188 [00:01<00:18,  9.63it/s][A
  7%|▋         | 13/188 [00:01<00:18,  9.65it/s][A
  7%|▋         | 14/188 [00:01<00:18,  9.65it/s][A
  8%|▊         | 15/188 [00:01<00:18,  9.59it/s][A
  9%|▊         | 16/188 [00:01<00:17,  9.67it/s][A
  9%|▉         | 17/188 [00:01<

Step 3500, train loss: 0.8275, validation loss: 1.0994


 28%|██▊       | 874/3125 [03:17<07:38,  4.91it/s]
  0%|          | 0/188 [00:00<?, ?it/s][A
  1%|          | 1/188 [00:00<00:04, 43.92it/s][A
  1%|          | 2/188 [00:00<00:12, 15.03it/s][A
  2%|▏         | 3/188 [00:00<00:15, 11.92it/s][A
  2%|▏         | 4/188 [00:00<00:16, 10.88it/s][A
  3%|▎         | 5/188 [00:00<00:17, 10.19it/s][A
  3%|▎         | 6/188 [00:00<00:18, 10.05it/s][A
  4%|▎         | 7/188 [00:00<00:18,  9.93it/s][A
  4%|▍         | 8/188 [00:00<00:18,  9.66it/s][A
  5%|▍         | 9/188 [00:00<00:18,  9.67it/s][A
  5%|▌         | 10/188 [00:00<00:18,  9.44it/s][A
  6%|▌         | 11/188 [00:01<00:18,  9.51it/s][A
  6%|▋         | 12/188 [00:01<00:18,  9.57it/s][A
  7%|▋         | 13/188 [00:01<00:18,  9.60it/s][A
  7%|▋         | 14/188 [00:01<00:18,  9.65it/s][A
  8%|▊         | 15/188 [00:01<00:17,  9.68it/s][A
  9%|▊         | 16/188 [00:01<00:17,  9.68it/s][A
  9%|▉         | 17/188 [00:01<00:17,  9.65it/s][A
 10%|▉         | 18/188 [00:01<

Step 4000, train loss: 1.9293, validation loss: 1.1028


 44%|████▍     | 1374/3125 [05:18<05:55,  4.92it/s]
  0%|          | 0/188 [00:00<?, ?it/s][A
  1%|          | 1/188 [00:00<00:04, 41.77it/s][A
  1%|          | 2/188 [00:00<00:12, 15.10it/s][A
  2%|▏         | 3/188 [00:00<00:15, 12.11it/s][A
  2%|▏         | 4/188 [00:00<00:16, 10.86it/s][A
  3%|▎         | 5/188 [00:00<00:17, 10.41it/s][A
  3%|▎         | 6/188 [00:00<00:18, 10.03it/s][A
  4%|▎         | 7/188 [00:00<00:18,  9.91it/s][A
  4%|▍         | 8/188 [00:00<00:18,  9.64it/s][A
  5%|▍         | 9/188 [00:00<00:18,  9.67it/s][A
  5%|▌         | 10/188 [00:00<00:18,  9.45it/s][A
  6%|▌         | 11/188 [00:01<00:18,  9.55it/s][A
  6%|▋         | 12/188 [00:01<00:18,  9.59it/s][A
  7%|▋         | 13/188 [00:01<00:18,  9.54it/s][A
  7%|▋         | 14/188 [00:01<00:17,  9.68it/s][A
  8%|▊         | 15/188 [00:01<00:17,  9.70it/s][A
  9%|▊         | 16/188 [00:01<00:17,  9.66it/s][A
  9%|▉         | 17/188 [00:01<00:17,  9.65it/s][A
 10%|▉         | 18/188 [00:01

Step 4500, train loss: 3.0312, validation loss: 1.1012


 60%|█████▉    | 1874/3125 [07:20<04:14,  4.91it/s]
  0%|          | 0/188 [00:00<?, ?it/s][A
  1%|          | 1/188 [00:00<00:04, 45.59it/s][A
  1%|          | 2/188 [00:00<00:12, 15.17it/s][A
  2%|▏         | 3/188 [00:00<00:15, 12.02it/s][A
  2%|▏         | 4/188 [00:00<00:16, 10.95it/s][A
  3%|▎         | 5/188 [00:00<00:17, 10.47it/s][A
  3%|▎         | 6/188 [00:00<00:18, 10.01it/s][A
  4%|▎         | 7/188 [00:00<00:18,  9.89it/s][A
  4%|▍         | 8/188 [00:00<00:18,  9.68it/s][A
  5%|▍         | 9/188 [00:00<00:18,  9.69it/s][A
  5%|▌         | 10/188 [00:00<00:18,  9.42it/s][A
  6%|▌         | 11/188 [00:01<00:18,  9.35it/s][A
  6%|▋         | 12/188 [00:01<00:18,  9.61it/s][A
  7%|▋         | 13/188 [00:01<00:18,  9.65it/s][A
  7%|▋         | 14/188 [00:01<00:18,  9.65it/s][A
  8%|▊         | 15/188 [00:01<00:17,  9.67it/s][A
  9%|▊         | 16/188 [00:01<00:17,  9.69it/s][A
  9%|▉         | 17/188 [00:01<00:17,  9.64it/s][A
 10%|▉         | 18/188 [00:01

Step 5000, train loss: 4.1307, validation loss: 1.0988


 76%|███████▌  | 2374/3125 [09:21<02:33,  4.90it/s]
  0%|          | 0/188 [00:00<?, ?it/s][A
  1%|          | 1/188 [00:00<00:05, 32.44it/s][A
  1%|          | 2/188 [00:00<00:12, 14.92it/s][A
  2%|▏         | 3/188 [00:00<00:15, 11.92it/s][A
  2%|▏         | 4/188 [00:00<00:17, 10.82it/s][A
  3%|▎         | 5/188 [00:00<00:17, 10.44it/s][A
  3%|▎         | 6/188 [00:00<00:18, 10.05it/s][A
  4%|▎         | 7/188 [00:00<00:18,  9.89it/s][A
  4%|▍         | 8/188 [00:00<00:18,  9.70it/s][A
  5%|▍         | 9/188 [00:00<00:18,  9.64it/s][A
  5%|▌         | 10/188 [00:00<00:18,  9.40it/s][A
  6%|▌         | 11/188 [00:01<00:18,  9.50it/s][A
  6%|▋         | 12/188 [00:01<00:18,  9.56it/s][A
  7%|▋         | 13/188 [00:01<00:18,  9.57it/s][A
  7%|▋         | 14/188 [00:01<00:18,  9.65it/s][A
  8%|▊         | 15/188 [00:01<00:17,  9.63it/s][A
  9%|▊         | 16/188 [00:01<00:17,  9.69it/s][A
  9%|▉         | 17/188 [00:01<00:17,  9.68it/s][A
 10%|▉         | 18/188 [00:01

Step 5500, train loss: 5.2309, validation loss: 1.0994


 92%|█████████▏| 2874/3125 [11:23<00:50,  4.94it/s]
  0%|          | 0/188 [00:00<?, ?it/s][A
  1%|          | 1/188 [00:00<00:04, 42.49it/s][A
  1%|          | 2/188 [00:00<00:12, 15.41it/s][A
  2%|▏         | 3/188 [00:00<00:15, 11.90it/s][A
  2%|▏         | 4/188 [00:00<00:16, 10.89it/s][A
  3%|▎         | 5/188 [00:00<00:17, 10.40it/s][A
  3%|▎         | 6/188 [00:00<00:18, 10.01it/s][A
  4%|▎         | 7/188 [00:00<00:18,  9.91it/s][A
  4%|▍         | 8/188 [00:00<00:18,  9.63it/s][A
  5%|▍         | 9/188 [00:00<00:18,  9.64it/s][A
  5%|▌         | 10/188 [00:00<00:18,  9.39it/s][A
  6%|▌         | 11/188 [00:01<00:18,  9.50it/s][A
  6%|▋         | 12/188 [00:01<00:18,  9.55it/s][A
  7%|▋         | 13/188 [00:01<00:18,  9.60it/s][A
  7%|▋         | 14/188 [00:01<00:18,  9.58it/s][A
  8%|▊         | 15/188 [00:01<00:17,  9.64it/s][A
  9%|▊         | 16/188 [00:01<00:17,  9.66it/s][A
  9%|▉         | 17/188 [00:01<00:18,  9.47it/s][A
 10%|▉         | 18/188 [00:01

Step 6000, train loss: 6.3312, validation loss: 1.0986


100%|██████████| 3125/3125 [12:34<00:00,  4.14it/s]
  8%|▊         | 249/3125 [00:50<09:46,  4.90it/s]
  0%|          | 0/188 [00:00<?, ?it/s][A
  1%|          | 1/188 [00:00<00:03, 46.95it/s][A
  1%|          | 2/188 [00:00<00:11, 15.63it/s][A
  2%|▏         | 3/188 [00:00<00:15, 12.21it/s][A
  2%|▏         | 4/188 [00:00<00:16, 10.87it/s][A
  3%|▎         | 5/188 [00:00<00:17, 10.40it/s][A
  3%|▎         | 6/188 [00:00<00:18, 10.07it/s][A
  4%|▎         | 7/188 [00:00<00:18,  9.95it/s][A
  4%|▍         | 8/188 [00:00<00:18,  9.68it/s][A
  5%|▍         | 9/188 [00:00<00:18,  9.72it/s][A
  5%|▌         | 10/188 [00:00<00:18,  9.50it/s][A
  6%|▌         | 11/188 [00:01<00:18,  9.56it/s][A
  6%|▋         | 12/188 [00:01<00:18,  9.51it/s][A
  7%|▋         | 13/188 [00:01<00:18,  9.64it/s][A
  7%|▋         | 14/188 [00:01<00:18,  9.51it/s][A
  8%|▊         | 15/188 [00:01<00:17,  9.71it/s][A
  9%|▊         | 16/188 [00:01<00:17,  9.72it/s][A
  9%|▉         | 17/188 [00:01<

Step 6500, train loss: 0.5507, validation loss: 1.0988


 24%|██▍       | 749/3125 [02:52<08:03,  4.92it/s]
  0%|          | 0/188 [00:00<?, ?it/s][A
  1%|          | 1/188 [00:00<00:04, 43.44it/s][A
  1%|          | 2/188 [00:00<00:12, 14.62it/s][A
  2%|▏         | 3/188 [00:00<00:15, 12.27it/s][A
  2%|▏         | 4/188 [00:00<00:16, 10.99it/s][A
  3%|▎         | 5/188 [00:00<00:17, 10.49it/s][A
  3%|▎         | 6/188 [00:00<00:18, 10.10it/s][A
  4%|▎         | 7/188 [00:00<00:18,  9.94it/s][A
  4%|▍         | 8/188 [00:00<00:18,  9.70it/s][A
  5%|▍         | 9/188 [00:00<00:18,  9.74it/s][A
  5%|▌         | 10/188 [00:00<00:18,  9.41it/s][A
  6%|▌         | 11/188 [00:01<00:18,  9.49it/s][A
  6%|▋         | 12/188 [00:01<00:18,  9.57it/s][A
  7%|▋         | 13/188 [00:01<00:18,  9.60it/s][A
  7%|▋         | 14/188 [00:01<00:18,  9.60it/s][A
  8%|▊         | 15/188 [00:01<00:17,  9.68it/s][A
  9%|▊         | 16/188 [00:01<00:17,  9.72it/s][A
  9%|▉         | 17/188 [00:01<00:17,  9.73it/s][A
 10%|▉         | 18/188 [00:01<

Step 7000, train loss: 1.6501, validation loss: 1.1009


 40%|███▉      | 1249/3125 [04:54<06:25,  4.87it/s]
  0%|          | 0/188 [00:00<?, ?it/s][A
  1%|          | 1/188 [00:00<00:05, 32.48it/s][A
  1%|          | 2/188 [00:00<00:12, 14.91it/s][A
  2%|▏         | 3/188 [00:00<00:15, 12.01it/s][A
  2%|▏         | 4/188 [00:00<00:17, 10.80it/s][A
  3%|▎         | 5/188 [00:00<00:17, 10.45it/s][A
  3%|▎         | 6/188 [00:00<00:17, 10.16it/s][A
  4%|▎         | 7/188 [00:00<00:18, 10.02it/s][A
  4%|▍         | 8/188 [00:00<00:18,  9.69it/s][A
  5%|▍         | 9/188 [00:00<00:18,  9.68it/s][A
  5%|▌         | 10/188 [00:00<00:19,  9.32it/s][A
  6%|▌         | 11/188 [00:01<00:18,  9.45it/s][A
  6%|▋         | 12/188 [00:01<00:18,  9.52it/s][A
  7%|▋         | 13/188 [00:01<00:18,  9.55it/s][A
  7%|▋         | 14/188 [00:01<00:18,  9.58it/s][A
  8%|▊         | 15/188 [00:01<00:17,  9.62it/s][A
  9%|▊         | 16/188 [00:01<00:17,  9.65it/s][A
  9%|▉         | 17/188 [00:01<00:17,  9.69it/s][A
 10%|▉         | 18/188 [00:01

Step 7500, train loss: 2.7509, validation loss: 1.0989


 56%|█████▌    | 1749/3125 [06:56<04:39,  4.92it/s]
  0%|          | 0/188 [00:00<?, ?it/s][A
  1%|          | 1/188 [00:00<00:04, 42.66it/s][A
  1%|          | 2/188 [00:00<00:12, 15.31it/s][A
  2%|▏         | 3/188 [00:00<00:15, 12.02it/s][A
  2%|▏         | 4/188 [00:00<00:17, 10.76it/s][A
  3%|▎         | 5/188 [00:00<00:17, 10.30it/s][A
  3%|▎         | 6/188 [00:00<00:18, 10.00it/s][A
  4%|▎         | 7/188 [00:00<00:18,  9.70it/s][A
  4%|▍         | 8/188 [00:00<00:18,  9.77it/s][A
  5%|▍         | 9/188 [00:00<00:18,  9.78it/s][A
  5%|▌         | 10/188 [00:00<00:18,  9.47it/s][A
  6%|▌         | 11/188 [00:01<00:18,  9.54it/s][A
  6%|▋         | 12/188 [00:01<00:18,  9.58it/s][A
  7%|▋         | 13/188 [00:01<00:18,  9.62it/s][A
  7%|▋         | 14/188 [00:01<00:17,  9.67it/s][A
  8%|▊         | 15/188 [00:01<00:17,  9.66it/s][A
  9%|▊         | 16/188 [00:01<00:17,  9.69it/s][A
  9%|▉         | 17/188 [00:01<00:17,  9.68it/s][A
 10%|▉         | 18/188 [00:01

Step 8000, train loss: 3.8506, validation loss: 1.1013


 72%|███████▏  | 2249/3125 [08:57<02:59,  4.88it/s]
  0%|          | 0/188 [00:00<?, ?it/s][A
  1%|          | 1/188 [00:00<00:04, 44.40it/s][A
  1%|          | 2/188 [00:00<00:11, 15.58it/s][A
  2%|▏         | 3/188 [00:00<00:15, 12.17it/s][A
  2%|▏         | 4/188 [00:00<00:16, 10.98it/s][A
  3%|▎         | 5/188 [00:00<00:17, 10.52it/s][A
  3%|▎         | 6/188 [00:00<00:18, 10.09it/s][A
  4%|▎         | 7/188 [00:00<00:18,  9.97it/s][A
  4%|▍         | 8/188 [00:00<00:18,  9.72it/s][A
  5%|▍         | 9/188 [00:00<00:18,  9.69it/s][A
  5%|▌         | 10/188 [00:00<00:19,  9.36it/s][A
  6%|▌         | 11/188 [00:01<00:18,  9.46it/s][A
  6%|▋         | 12/188 [00:01<00:18,  9.54it/s][A
  7%|▋         | 13/188 [00:01<00:18,  9.59it/s][A
  7%|▋         | 14/188 [00:01<00:18,  9.55it/s][A
  8%|▊         | 15/188 [00:01<00:17,  9.65it/s][A
  9%|▊         | 16/188 [00:01<00:17,  9.67it/s][A
  9%|▉         | 17/188 [00:01<00:17,  9.70it/s][A
 10%|▉         | 18/188 [00:01

Step 8500, train loss: 4.9490, validation loss: 1.1040


 88%|████████▊ | 2749/3125 [10:59<01:16,  4.90it/s]
  0%|          | 0/188 [00:00<?, ?it/s][A
  1%|          | 1/188 [00:00<00:08, 21.17it/s][A
  1%|          | 2/188 [00:00<00:13, 13.76it/s][A
  2%|▏         | 3/188 [00:00<00:16, 11.52it/s][A
  2%|▏         | 4/188 [00:00<00:17, 10.69it/s][A
  3%|▎         | 5/188 [00:00<00:17, 10.31it/s][A
  3%|▎         | 6/188 [00:00<00:18, 10.05it/s][A
  4%|▎         | 7/188 [00:00<00:18,  9.96it/s][A
  4%|▍         | 8/188 [00:00<00:18,  9.82it/s][A
  5%|▍         | 9/188 [00:00<00:18,  9.74it/s][A
  5%|▌         | 10/188 [00:00<00:19,  9.25it/s][A
  6%|▌         | 11/188 [00:01<00:18,  9.40it/s][A
  6%|▋         | 12/188 [00:01<00:18,  9.50it/s][A
  7%|▋         | 13/188 [00:01<00:18,  9.51it/s][A
  7%|▋         | 14/188 [00:01<00:18,  9.55it/s][A
  8%|▊         | 15/188 [00:01<00:18,  9.58it/s][A
  9%|▊         | 16/188 [00:01<00:17,  9.65it/s][A
  9%|▉         | 17/188 [00:01<00:17,  9.68it/s][A
 10%|▉         | 18/188 [00:01

Step 9000, train loss: 6.0481, validation loss: 1.0987


100%|██████████| 3125/3125 [12:35<00:00,  4.13it/s]
