<a href="https://colab.research.google.com/github/SOL1archive/ClauseSummary/blob/main/main-model-train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
try:
  import google.colab
  IN_COLAB = True
except:
  IN_COLAB = False

'Process in Colab' if IN_COLAB else 'Process in Local'

'Process in Colab'

In [None]:
if IN_COLAB:
    !pip install transformers
    !pip install datasets
    !pip install evaluate
    !pip install rouge_score
    !pip install torchmetrics
    !pip install rouge
    !pip install --upgrade accelerate

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [None]:
# 깃허브에서는 빼야됨
%cd drive/MyDrive/projects/ClauseSummary

/content/drive/MyDrive/projects/ClauseSummary


In [None]:
import warnings
warnings.filterwarnings('ignore')
import datetime
import os
import gc
from pprint import pprint
from typing import Callable, Dict, List, Optional, Tuple, Union
from tqdm import tqdm

import numpy as np
import pandas as pd

import tensorboard
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau, CosineAnnealingLR, CyclicLR
import torchmetrics

from datasets import load_dataset, load_from_disk, concatenate_datasets, DatasetDict, Dataset
from transformers import get_linear_schedule_with_warmup
from transformers import DataCollatorForSeq2Seq
from transformers import BartConfig, T5Config
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
import evaluate
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge import Rouge

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline

import seaborn as sns

In [None]:
class TokenizeMapWrapper:
    def __init__(self, tokenizer, feature, option=None):
        if option is None:
            option = {
                'max_length': 512,
                'truncation': True,
                'padding': 'max_length',
            }
          
        self.option = option
        self.feature = feature
        self.tokenizer = tokenizer

    def __call__(self, row):
        return self.tokenizer(row[self.feature], **self.option)

    def __repr__(self):
        return f'{self.__class__.__name__}(tokenizer={self.tokenizer})'

class Seq2SeqTokenizeMapWrapper(TokenizeMapWrapper):
    def __init__(self, tokenizer, feature, target, option=None):
        super().__init__(tokenizer, feature, option)
        self.target = target

    def seq2seq_tokenize(self, row):
        form_embeddings = self.tokenizer(row[self.feature], **self.option)
        with self.tokenizer.as_target_tokenizer():
            correct_form_embeddings = self.tokenizer(row[self.target], **self.option)

        return {
            'input_ids': form_embeddings['input_ids'],
            'attention_mask': form_embeddings['attention_mask'],
            'labels': correct_form_embeddings['input_ids'],
        }

    def __call__(self, row):
        return self.seq2seq_tokenize(row)

## Setting

- 학습 환경에 맞게 조정하기 (특히 **경로 설정**)

In [None]:
MANUAL_TRAINING = True
MANUAL_VALIDATION = True
NUM_EPOCHS = 1
MID_CHECKPOINT_NUM = 2
MID_PROCESS_PRINT_NUM = 50

In [None]:
t5_large_summary_checkpoint = 'lcw99/t5-large-korean-text-summary'
t5_base_summary_checkpoint = 'eenzeenee/t5-base-korean-summarization'
kobart_summary_checkpoint = 'gogamza/kobart-summarization'
checkpoint = t5_base_summary_checkpoint
print(f'Using Checkpoint: {checkpoint}')

Using Checkpoint: eenzeenee/t5-base-korean-summarization


In [None]:
original_dataset_path = './data/dataset-term-summary.json'
tokenized_dataset_path = f'./data/{checkpoint.replace("/", "-")}-tokenized-dataset'

In [None]:
SAVE_STR = datetime.datetime.now().strftime('%y%m%d-%H:%M')
model_save_path = f"./model/{SAVE_STR}"

## Load Tokenizer & Model Checkpoint

In [None]:
if 'bart' in checkpoint.lower():
    config = BartConfig.from_pretrained(checkpoint)
    #config['vocab'] = 30000
else:
    config = T5Config.from_pretrained(checkpoint)

Downloading (…)lve/main/config.json:   0%|          | 0.00/782 [00:00<?, ?B/s]

In [None]:
tokenizer = AutoTokenizer.from_pretrained(checkpoint, 
                                          max_length=512, 
                                          truncation=False, 
                                          padding='max_length',
                                          #vocab=config.vocab_size
                                          )
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint, config=config)

Downloading (…)okenizer_config.json:   0%|          | 0.00/2.41k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.92M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.10G [00:00<?, ?B/s]

In [None]:
if len(tokenizer) != model.config.vocab_size:
    raise RuntimeError(f'Tokenizer vocab size and model vocab size do not match(Tokenizer:{len(tokenizer)} Model: {model.config.vocab_size}). Which would lead to further error in training.')

## Load Dataset

In [None]:
if not os.path.exists(tokenized_dataset_path):
    dataset = Dataset.from_pandas(pd.read_json(original_dataset_path, encoding='utf-8')[['text', 'summary']])
    tokenizer_wrapper = Seq2SeqTokenizeMapWrapper(tokenizer, 'text', 'summary')

    tokenized_dataset = (dataset
                         .map(tokenizer_wrapper, 
                              batched=True, 
                              batch_size=128, 
                              num_proc=10
                              )
                         .remove_columns(['text', 'summary'])
                         )
    
    tokenized_dataset_dict = tokenized_dataset.train_test_split(test_size=0.2, shuffle=True)
    tokenized_dataset_dict.save_to_disk(tokenized_dataset_path)
else:
    tokenized_dataset_dict = load_from_disk(tokenized_dataset_path)

Map (num_proc=10):   0%|          | 0/7943 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/6354 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1589 [00:00<?, ? examples/s]

## Training

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors='pt')

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    num_train_epochs=NUM_EPOCHS,
    weight_decay=0.01,
    report_to="tensorboard",
    push_to_hub=False,
)

In [None]:
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=tokenized_dataset_dict['train'],
    data_collator=data_collator,
)

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
# model.train()

# if not MANUAL_TRAINING:
#     trainer.train()
# else:
#     total_loss = []
#     epoch_loss = []
#     batch_loss = []
#     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#     trainset = tokenized_dataset_dict['train'].with_format('torch', device=device)
#     dataloader = DataLoader(trainset, batch_size=1, shuffle=False) # TODO: Batch size 조절
    
#     # TODO: Write a code for **Hyperparameter Tuning**
#     optimizer = AdamW(model.parameters(), lr = training_args.learning_rate, weight_decay = training_args.weight_decay)
#     scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=NUM_EPOCHS * len(dataloader))

#     for epoch in range(NUM_EPOCHS):
#         total_steps = len(dataloader)
#         save_divisor = total_steps // MID_CHECKPOINT_NUM
#         print_divisor = total_steps // MID_PROCESS_PRINT_NUM
#         for i, batch in enumerate(tqdm(dataloader)):
#             X = {
#                     'input_ids': batch['input_ids'],
#                     'attention_mask': batch['attention_mask'],
#                 }
#             y = batch['labels']
            
#             outputs = model(**X, labels=y)
#             loss = outputs.loss
#             loss.backward()
#             optimizer.step()
#             optimizer.zero_grad()
#             scheduler.step()

#             batch_loss.append(loss.item())
#             if i % print_divisor == print_divisor - 1:
#                 epoch_loss += batch_loss
#                 batch_loss_series = pd.Series(batch_loss)
#                 print(f'\tbatch {i}\tloss: {loss.item()}\tmean: {batch_loss_series.mean()}')
#                 batch_loss = []

#             if i % save_divisor == save_divisor - 1:
#                 trainer.create_model_card(
#                     language='Korean',
#                     tags='Grammar',
#                     finetuned_from=checkpoint
#                 )
#                 trainer.save_model(model_save_path + f'-epoch-{epoch + 1}' + '-batch-{i + 1}')

#         total_loss += epoch_loss
#         batch_loss_series = pd.Series(epoch_loss)
#         epoch_loss = []
#         print(f'epoch {epoch + 1} loss: {loss.item()} mean: {batch_loss_series.mean()}')

In [None]:
# total_loss_series = pd.Series(total_loss)
# total_loss_series.plot.line()

In [None]:
trainer.create_model_card(
    language='Korean',
    tags='Grammar',
    #model='KoGrammar',
    finetuned_from=checkpoint
)
trainer.save_model(model_save_path)

In [None]:
print(checkpoint)

eenzeenee/t5-base-korean-summarization


In [None]:
## finding the best parameters
def mean(A):
    sum = 0
    for a in A:
        sum += a
    return sum / len(A)

learning_rates = [1e-5, 5e-5]
weight_decays = [0.03, 0.05, 0.07]
for learning_rate in learning_rates:
    for decay in weight_decays:
        gc.collect()
        torch.cuda.empty_cache()
        training_args = Seq2SeqTrainingArguments(
            output_dir="./results",
            evaluation_strategy = "epoch",
            learning_rate=learning_rate,
            per_device_train_batch_size=64,
            per_device_eval_batch_size=64,
            num_train_epochs=NUM_EPOCHS,
            weight_decay=decay,
            report_to="tensorboard",
            push_to_hub=False,
        )
        total_loss = []
        epoch_loss = []
        batch_loss = []
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        trainset = tokenized_dataset_dict['train'].with_format('torch', device=device)
        dataloader = DataLoader(trainset, batch_size=1, shuffle=False) # TODO: Batch size 조절
        
        # TODO: Write a code for **Hyperparameter Tuning**
        optimizer = AdamW(model.parameters(), lr = training_args.learning_rate, weight_decay = training_args.weight_decay)
        optimizer_name = "AdamW"
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=NUM_EPOCHS * len(dataloader))
        scheduler_name = "linear_schedule"

        for epoch in range(NUM_EPOCHS):
            total_steps = len(dataloader)
            save_divisor = total_steps // MID_CHECKPOINT_NUM
            print_divisor = total_steps // MID_PROCESS_PRINT_NUM
            for i, batch in enumerate(tqdm(dataloader)):
                X = {
                        'input_ids': batch['input_ids'],
                        'attention_mask': batch['attention_mask'],
                    }
                y = batch['labels']
                
                outputs = model(**X, labels=y)
                loss = outputs.loss
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                scheduler.step()

                batch_loss.append(loss.item())
                if i % print_divisor == print_divisor - 1:
                    epoch_loss += batch_loss
                    batch_loss_series = pd.Series(batch_loss)
                    print(f'\tbatch {i}\tloss: {loss.item()}\tmean: {batch_loss_series.mean()}')
                    batch_loss = []

                if i % save_divisor == save_divisor - 1:
                    trainer.create_model_card(
                        language='Korean',
                        tags='Grammar',
                        finetuned_from=checkpoint
                    )
                    trainer.save_model(model_save_path + f'-epoch-{epoch + 1}' + '-batch-{i + 1}')

            total_loss += epoch_loss
            batch_loss_series = pd.Series(epoch_loss)
            epoch_loss = []


        # for recording
        total_loss.sort()
        top5_loss = mean(total_loss[:5])
        text = "%s, %s, %s, %f, %f, %f\n"%(checkpoint, optimizer_name, scheduler_name, training_args.learning_rate, training_args.weight_decay, top5_loss)
        with open('./results/experiments.csv', 'a') as f:
            f.write(text)

  2%|▏         | 128/6354 [00:17<13:32,  7.66it/s]

	batch 126	loss: 13.751532554626465	mean: 16.080567757914384


  4%|▍         | 255/6354 [00:32<12:24,  8.19it/s]

	batch 253	loss: 6.717682838439941	mean: 10.380187503934845


  6%|▌         | 382/6354 [00:48<12:25,  8.01it/s]

	batch 380	loss: 2.4800286293029785	mean: 4.321600343298725


  8%|▊         | 509/6354 [01:04<12:05,  8.06it/s]

	batch 507	loss: 2.362123489379883	mean: 2.9382144325361477


 10%|█         | 636/6354 [01:19<11:42,  8.14it/s]

	batch 634	loss: 2.1617205142974854	mean: 2.3437385840678777


 12%|█▏        | 763/6354 [01:35<11:24,  8.17it/s]

	batch 761	loss: 1.7215794324874878	mean: 1.793405591502903


 14%|█▍        | 890/6354 [01:50<11:08,  8.17it/s]

	batch 888	loss: 1.020443320274353	mean: 1.4374880389435085


 16%|█▌        | 1017/6354 [02:06<10:48,  8.23it/s]

	batch 1015	loss: 0.7492220997810364	mean: 1.2063487507696227


 18%|█▊        | 1144/6354 [02:21<10:39,  8.15it/s]

	batch 1142	loss: 0.6558165550231934	mean: 1.1270766119788012


 20%|██        | 1271/6354 [02:37<10:25,  8.13it/s]

	batch 1269	loss: 0.5463663339614868	mean: 0.9300842538593322


 22%|██▏       | 1398/6354 [02:53<10:06,  8.17it/s]

	batch 1396	loss: 0.27867409586906433	mean: 0.8312757294478379


 24%|██▍       | 1525/6354 [03:08<09:57,  8.08it/s]

	batch 1523	loss: 0.9976701736450195	mean: 0.8074663216908147


 26%|██▌       | 1652/6354 [03:24<09:35,  8.17it/s]

	batch 1650	loss: 0.7904521822929382	mean: 0.7732847375780578


 28%|██▊       | 1779/6354 [03:39<09:16,  8.23it/s]

	batch 1777	loss: 0.48571449518203735	mean: 0.7711524000083367


 30%|██▉       | 1906/6354 [03:55<09:00,  8.23it/s]

	batch 1904	loss: 0.9187382459640503	mean: 0.7401962201074352


 32%|███▏      | 2033/6354 [04:10<08:49,  8.15it/s]

	batch 2031	loss: 0.49817028641700745	mean: 0.7793671584974123


 34%|███▍      | 2160/6354 [04:26<08:30,  8.21it/s]

	batch 2158	loss: 0.96700519323349	mean: 0.70357345496341


 36%|███▌      | 2287/6354 [04:41<08:16,  8.19it/s]

	batch 2285	loss: 0.9379640817642212	mean: 0.6586447923084883


 38%|███▊      | 2414/6354 [04:57<08:07,  8.09it/s]

	batch 2412	loss: 1.0398740768432617	mean: 0.7197405291119898


 40%|███▉      | 2541/6354 [05:13<07:46,  8.17it/s]

	batch 2539	loss: 0.38692110776901245	mean: 0.696605899437206


 42%|████▏     | 2668/6354 [05:28<07:27,  8.24it/s]

	batch 2666	loss: 0.8394633531570435	mean: 0.6554479582572547


 44%|████▍     | 2795/6354 [05:44<07:17,  8.14it/s]

	batch 2793	loss: 0.36133497953414917	mean: 0.614935295406993


 46%|████▌     | 2922/6354 [05:59<07:01,  8.15it/s]

	batch 2920	loss: 0.8821706771850586	mean: 0.6299187551567874


 48%|████▊     | 3049/6354 [06:15<06:43,  8.19it/s]

	batch 3047	loss: 0.5726723074913025	mean: 0.5855250848002556


 50%|████▉     | 3176/6354 [06:30<06:31,  8.11it/s]

	batch 3174	loss: 0.3263731002807617	mean: 0.5535333219358302


 52%|█████▏    | 3303/6354 [06:51<06:14,  8.14it/s]

	batch 3301	loss: 0.21036015450954437	mean: 0.5350075742508483


 54%|█████▍    | 3430/6354 [07:07<05:57,  8.19it/s]

	batch 3428	loss: 0.48231637477874756	mean: 0.5566219132921593


 56%|█████▌    | 3557/6354 [07:23<05:41,  8.20it/s]

	batch 3555	loss: 0.1752159148454666	mean: 0.542735813669567


 58%|█████▊    | 3684/6354 [07:39<05:30,  8.09it/s]

	batch 3682	loss: 0.4644537568092346	mean: 0.5943337521684451


 60%|█████▉    | 3811/6354 [07:54<05:15,  8.06it/s]

	batch 3809	loss: 1.2437366247177124	mean: 0.6021079229498942


 62%|██████▏   | 3938/6354 [08:10<04:57,  8.11it/s]

	batch 3936	loss: 0.17485378682613373	mean: 0.5488242652765879


 64%|██████▍   | 4065/6354 [08:26<04:42,  8.10it/s]

	batch 4063	loss: 0.9848126173019409	mean: 0.573572233685945


 66%|██████▌   | 4192/6354 [08:41<04:24,  8.17it/s]

	batch 4190	loss: 0.6630785465240479	mean: 0.5659499753589236


 68%|██████▊   | 4319/6354 [08:57<04:09,  8.16it/s]

	batch 4317	loss: 0.3572733998298645	mean: 0.5311779539709486


 70%|██████▉   | 4446/6354 [09:12<03:52,  8.22it/s]

	batch 4444	loss: 0.10711797326803207	mean: 0.555830450119876


 72%|███████▏  | 4573/6354 [09:28<03:36,  8.24it/s]

	batch 4571	loss: 0.04161600023508072	mean: 0.6204931688839643


 74%|███████▍  | 4700/6354 [09:43<03:24,  8.09it/s]

	batch 4698	loss: 0.8523080348968506	mean: 0.569334582990314


 76%|███████▌  | 4827/6354 [09:59<03:09,  8.05it/s]

	batch 4825	loss: 0.34108978509902954	mean: 0.5415310146448415


 78%|███████▊  | 4954/6354 [10:15<02:51,  8.18it/s]

	batch 4952	loss: 0.6875087022781372	mean: 0.5752286333884076


 80%|███████▉  | 5081/6354 [10:30<02:35,  8.18it/s]

	batch 5079	loss: 0.9345957040786743	mean: 0.5767003358598417


 82%|████████▏ | 5208/6354 [10:46<02:19,  8.23it/s]

	batch 5206	loss: 0.12736128270626068	mean: 0.5111968024802489


 84%|████████▍ | 5335/6354 [11:01<02:04,  8.18it/s]

	batch 5333	loss: 0.8130489587783813	mean: 0.5534209752323356


 86%|████████▌ | 5462/6354 [11:17<01:49,  8.13it/s]

	batch 5460	loss: 0.3088817298412323	mean: 0.5978064663033551


 88%|████████▊ | 5589/6354 [11:32<01:34,  8.11it/s]

	batch 5587	loss: 0.5751851797103882	mean: 0.5497966996709428


 90%|████████▉ | 5716/6354 [11:48<01:19,  8.06it/s]

	batch 5714	loss: 1.3224751949310303	mean: 0.5322777555199472


 92%|█████████▏| 5843/6354 [12:04<01:03,  8.08it/s]

	batch 5841	loss: 0.8687528967857361	mean: 0.5385439294705711


 94%|█████████▍| 5970/6354 [12:19<00:47,  8.15it/s]

	batch 5968	loss: 0.3254709541797638	mean: 0.5938929997560546


 96%|█████████▌| 6097/6354 [12:35<00:46,  5.49it/s]

	batch 6095	loss: 0.5801292061805725	mean: 0.5605558182853531


 98%|█████████▊| 6224/6354 [12:51<00:15,  8.17it/s]

	batch 6222	loss: 0.10527849942445755	mean: 0.5294364023455015


100%|█████████▉| 6351/6354 [13:06<00:00,  8.19it/s]

	batch 6349	loss: 0.6471816897392273	mean: 0.5086451947439726


100%|██████████| 6354/6354 [13:10<00:00,  8.04it/s]
  2%|▏         | 128/6354 [00:16<13:18,  7.80it/s]

	batch 126	loss: 0.5000577569007874	mean: 0.4913713218878925


  4%|▍         | 255/6354 [00:31<12:42,  8.00it/s]

	batch 253	loss: 0.45761188864707947	mean: 0.4915395588357383


  6%|▌         | 382/6354 [00:47<12:11,  8.16it/s]

	batch 380	loss: 0.02775469794869423	mean: 0.5726354861499992


  8%|▊         | 509/6354 [01:03<11:52,  8.20it/s]

	batch 507	loss: 0.05438016727566719	mean: 0.4522490083701967


 10%|█         | 636/6354 [01:18<11:42,  8.14it/s]

	batch 634	loss: 0.6598687767982483	mean: 0.4937180933344552


 12%|█▏        | 763/6354 [01:34<11:24,  8.17it/s]

	batch 761	loss: 0.5086763501167297	mean: 0.49553589776202334


 14%|█▍        | 890/6354 [01:49<11:13,  8.11it/s]

	batch 888	loss: 0.377505362033844	mean: 0.5127988309694791


 16%|█▌        | 1017/6354 [02:05<11:06,  8.01it/s]

	batch 1015	loss: 0.2770121395587921	mean: 0.48521793312503125


 18%|█▊        | 1144/6354 [02:21<10:38,  8.16it/s]

	batch 1142	loss: 0.34440815448760986	mean: 0.49632825840477224


 20%|██        | 1271/6354 [02:36<10:23,  8.16it/s]

	batch 1269	loss: 0.3728826344013214	mean: 0.47190391560869777


 22%|██▏       | 1398/6354 [02:52<10:10,  8.11it/s]

	batch 1396	loss: 0.022926565259695053	mean: 0.4321971142877216


 24%|██▍       | 1525/6354 [03:08<09:51,  8.16it/s]

	batch 1523	loss: 0.6653112769126892	mean: 0.41346859227220606


 26%|██▌       | 1652/6354 [03:23<09:40,  8.10it/s]

	batch 1650	loss: 0.6258612275123596	mean: 0.42062243434424534


 28%|██▊       | 1779/6354 [03:39<09:32,  8.00it/s]

	batch 1777	loss: 0.17994342744350433	mean: 0.4221279158029617


 30%|██▉       | 1906/6354 [03:55<09:09,  8.10it/s]

	batch 1904	loss: 0.44193586707115173	mean: 0.41759560282496133


 32%|███▏      | 2033/6354 [04:10<08:49,  8.16it/s]

	batch 2031	loss: 0.34206119179725647	mean: 0.46720601120545696


 34%|███▍      | 2160/6354 [04:26<08:35,  8.14it/s]

	batch 2158	loss: 0.5575352311134338	mean: 0.4090633198589556


 36%|███▌      | 2287/6354 [04:41<08:18,  8.15it/s]

	batch 2285	loss: 0.42844846844673157	mean: 0.3984703451695698


 38%|███▊      | 2414/6354 [04:57<08:04,  8.13it/s]

	batch 2412	loss: 0.7459433078765869	mean: 0.43845650272839887


 40%|███▉      | 2541/6354 [05:13<07:51,  8.10it/s]

	batch 2539	loss: 0.29327303171157837	mean: 0.43228151399356235


 42%|████▏     | 2668/6354 [05:29<07:36,  8.07it/s]

	batch 2666	loss: 0.7077004313468933	mean: 0.3820666972452437


 44%|████▍     | 2795/6354 [05:44<07:22,  8.04it/s]

	batch 2793	loss: 0.009950620122253895	mean: 0.3855963000463043


 46%|████▌     | 2922/6354 [06:00<07:00,  8.16it/s]

	batch 2920	loss: 0.20070718228816986	mean: 0.37102592952446795


 48%|████▊     | 3049/6354 [06:15<06:45,  8.15it/s]

	batch 3047	loss: 0.4057093858718872	mean: 0.3506973303623468


 50%|████▉     | 3176/6354 [06:31<06:28,  8.17it/s]

	batch 3174	loss: 0.1507745385169983	mean: 0.33789531404599843


 52%|█████▏    | 3303/6354 [06:50<06:20,  8.03it/s]

	batch 3301	loss: 0.0849456787109375	mean: 0.32766915093025734


 54%|█████▍    | 3430/6354 [07:05<05:57,  8.18it/s]

	batch 3428	loss: 0.448911190032959	mean: 0.3228274609391233


 56%|█████▌    | 3557/6354 [07:21<05:48,  8.03it/s]

	batch 3555	loss: 0.11275146156549454	mean: 0.3349272213128256


 58%|█████▊    | 3684/6354 [07:37<05:27,  8.15it/s]

	batch 3682	loss: 0.30745381116867065	mean: 0.39315003901170054


 60%|█████▉    | 3811/6354 [07:53<05:12,  8.15it/s]

	batch 3809	loss: 1.1128559112548828	mean: 0.3880810772259463


 62%|██████▏   | 3938/6354 [08:08<04:58,  8.09it/s]

	batch 3936	loss: 0.04300559312105179	mean: 0.3429891206232435


 64%|██████▍   | 4065/6354 [08:24<04:41,  8.12it/s]

	batch 4063	loss: 0.6856505274772644	mean: 0.35141687450827813


 66%|██████▌   | 4192/6354 [08:39<04:24,  8.18it/s]

	batch 4190	loss: 0.5952544212341309	mean: 0.3472338649590947


 68%|██████▊   | 4319/6354 [08:55<04:10,  8.11it/s]

	batch 4317	loss: 0.021611355245113373	mean: 0.3461113113224301


 70%|██████▉   | 4446/6354 [09:11<03:57,  8.02it/s]

	batch 4444	loss: 0.0154332984238863	mean: 0.34115646332025706


 72%|███████▏  | 4573/6354 [09:26<03:40,  8.09it/s]

	batch 4571	loss: 0.015256439335644245	mean: 0.43283434290306894


 74%|███████▍  | 4700/6354 [09:42<03:24,  8.09it/s]

	batch 4698	loss: 0.40756121277809143	mean: 0.369297241828749


 76%|███████▌  | 4827/6354 [09:58<03:07,  8.13it/s]

	batch 4825	loss: 0.23866169154644012	mean: 0.3558803781405033


 78%|███████▊  | 4954/6354 [10:13<03:19,  7.03it/s]

	batch 4952	loss: 0.5829697251319885	mean: 0.385908942416633


 80%|███████▉  | 5081/6354 [10:29<02:36,  8.13it/s]

	batch 5079	loss: 0.9197076559066772	mean: 0.3655974678117709


 82%|████████▏ | 5208/6354 [10:45<02:21,  8.11it/s]

	batch 5206	loss: 0.036077626049518585	mean: 0.32955589482018094


 84%|████████▍ | 5335/6354 [11:00<02:06,  8.03it/s]

	batch 5333	loss: 0.6804947853088379	mean: 0.3790291195779335


 86%|████████▌ | 5462/6354 [11:16<01:49,  8.16it/s]

	batch 5460	loss: 0.09772387146949768	mean: 0.37170128804800195


 88%|████████▊ | 5589/6354 [11:31<01:33,  8.17it/s]

	batch 5587	loss: 0.18312197923660278	mean: 0.35334117883563365


 90%|████████▉ | 5716/6354 [11:47<01:18,  8.15it/s]

	batch 5714	loss: 1.2767384052276611	mean: 0.3355013104861485


 92%|█████████▏| 5843/6354 [12:03<01:02,  8.13it/s]

	batch 5841	loss: 0.518688440322876	mean: 0.35854708327083135


 94%|█████████▍| 5970/6354 [12:18<00:47,  8.08it/s]

	batch 5968	loss: 0.04183601215481758	mean: 0.39442688908691953


 96%|█████████▌| 6097/6354 [12:34<00:31,  8.09it/s]

	batch 6095	loss: 0.33555591106414795	mean: 0.38124579851750257


 98%|█████████▊| 6224/6354 [12:50<00:16,  8.03it/s]

	batch 6222	loss: 0.025081630796194077	mean: 0.33934690068864565


100%|█████████▉| 6351/6354 [13:05<00:00,  8.14it/s]

	batch 6349	loss: 0.5768541693687439	mean: 0.33556912050996796


100%|██████████| 6354/6354 [13:12<00:00,  8.02it/s]
  2%|▏         | 128/6354 [00:15<12:42,  8.17it/s]

	batch 126	loss: 0.24087749421596527	mean: 0.32266246217511974


  4%|▍         | 255/6354 [00:31<12:36,  8.06it/s]

	batch 253	loss: 0.18344524502754211	mean: 0.31486757722954584


  6%|▌         | 382/6354 [00:47<12:20,  8.07it/s]

	batch 380	loss: 0.012552020139992237	mean: 0.3865716504428066


  8%|▊         | 509/6354 [01:03<12:00,  8.11it/s]

	batch 507	loss: 0.005726885516196489	mean: 0.2898724392426794


 10%|█         | 636/6354 [01:18<11:47,  8.09it/s]

	batch 634	loss: 0.5854669213294983	mean: 0.31741898983540967


 12%|█▏        | 763/6354 [01:34<11:26,  8.14it/s]

	batch 761	loss: 0.4514678120613098	mean: 0.3196263312477633


 14%|█▍        | 890/6354 [01:50<11:11,  8.13it/s]

	batch 888	loss: 0.1743694692850113	mean: 0.33158540838090045


 16%|█▌        | 1017/6354 [02:05<10:57,  8.12it/s]

	batch 1015	loss: 0.24498431384563446	mean: 0.3166687012811404


 18%|█▊        | 1144/6354 [02:21<10:46,  8.06it/s]

	batch 1142	loss: 0.2419348508119583	mean: 0.32664099930763596


 20%|██        | 1271/6354 [02:37<10:33,  8.02it/s]

	batch 1269	loss: 0.337546706199646	mean: 0.3189275203094869


 22%|██▏       | 1398/6354 [02:52<10:11,  8.11it/s]

	batch 1396	loss: 0.009956092573702335	mean: 0.28473034638207495


 24%|██▍       | 1525/6354 [03:08<09:54,  8.12it/s]

	batch 1523	loss: 0.5220857858657837	mean: 0.2557221972970219


 26%|██▌       | 1652/6354 [03:24<09:36,  8.16it/s]

	batch 1650	loss: 0.5374595522880554	mean: 0.2698925998076562


 28%|██▊       | 1779/6354 [03:39<09:19,  8.18it/s]

	batch 1777	loss: 0.13938893377780914	mean: 0.281972053009782


 30%|██▉       | 1906/6354 [03:55<09:04,  8.16it/s]

	batch 1904	loss: 0.14077098667621613	mean: 0.2829836768824167


 32%|███▏      | 2033/6354 [04:10<08:57,  8.04it/s]

	batch 2031	loss: 0.19998499751091003	mean: 0.32872389302603666


 34%|███▍      | 2160/6354 [04:26<08:39,  8.07it/s]

	batch 2158	loss: 0.10997865349054337	mean: 0.27930560040720337


 36%|███▌      | 2287/6354 [04:42<08:22,  8.10it/s]

	batch 2285	loss: 0.195010706782341	mean: 0.2897196668035197


 38%|███▊      | 2414/6354 [04:57<08:07,  8.08it/s]

	batch 2412	loss: 0.6311658024787903	mean: 0.31629403917793564


 40%|███▉      | 2541/6354 [05:13<07:49,  8.11it/s]

	batch 2539	loss: 0.26174411177635193	mean: 0.2933462286454178


 42%|████▏     | 2668/6354 [05:29<07:36,  8.08it/s]

	batch 2666	loss: 0.5744233727455139	mean: 0.2704791902769797


 44%|████▍     | 2795/6354 [05:45<07:20,  8.07it/s]

	batch 2793	loss: 0.0038862675428390503	mean: 0.2722079743444186


 46%|████▌     | 2922/6354 [06:00<07:05,  8.06it/s]

	batch 2920	loss: 0.07982796430587769	mean: 0.26264473907447855


 48%|████▊     | 3049/6354 [06:16<06:53,  7.99it/s]

	batch 3047	loss: 0.3024846017360687	mean: 0.2540582374327699


 50%|████▉     | 3176/6354 [06:32<06:32,  8.11it/s]

	batch 3174	loss: 0.045419152826070786	mean: 0.24656769649701296


 52%|█████▏    | 3303/6354 [06:51<06:29,  7.83it/s]

	batch 3301	loss: 0.025819415226578712	mean: 0.24059100770006528


 54%|█████▍    | 3430/6354 [07:07<06:01,  8.08it/s]

	batch 3428	loss: 0.4264802038669586	mean: 0.229314308302334


 56%|█████▌    | 3557/6354 [07:22<05:45,  8.10it/s]

	batch 3555	loss: 0.07772540301084518	mean: 0.24616253752019404


 58%|█████▊    | 3684/6354 [07:38<05:29,  8.12it/s]

	batch 3682	loss: 0.18027415871620178	mean: 0.29652365703432726


 60%|█████▉    | 3811/6354 [07:54<05:14,  8.10it/s]

	batch 3809	loss: 1.0401889085769653	mean: 0.2946665134532084


 62%|██████▏   | 3938/6354 [08:09<04:56,  8.16it/s]

	batch 3936	loss: 0.026074763387441635	mean: 0.25238554529922713


 64%|██████▍   | 4065/6354 [08:25<04:40,  8.16it/s]

	batch 4063	loss: 0.5760518908500671	mean: 0.259233861354979


 66%|██████▌   | 4192/6354 [08:41<04:25,  8.14it/s]

	batch 4190	loss: 0.5580374002456665	mean: 0.26489611251892037


 68%|██████▊   | 4319/6354 [08:56<04:09,  8.15it/s]

	batch 4317	loss: 0.00554668391123414	mean: 0.2732638605887723


 70%|██████▉   | 4446/6354 [09:12<03:54,  8.15it/s]

	batch 4444	loss: 0.010561895556747913	mean: 0.24705632407251218


 72%|███████▏  | 4573/6354 [09:27<03:40,  8.07it/s]

	batch 4571	loss: 0.010918710380792618	mean: 0.3486469954492407


 74%|███████▍  | 4700/6354 [09:43<03:24,  8.10it/s]

	batch 4698	loss: 0.20482486486434937	mean: 0.29059811032883176


 76%|███████▌  | 4827/6354 [09:59<03:06,  8.18it/s]

	batch 4825	loss: 0.15669141709804535	mean: 0.27791749234732943


 78%|███████▊  | 4954/6354 [10:14<02:52,  8.13it/s]

	batch 4952	loss: 0.5284811854362488	mean: 0.30609874541900584


 80%|███████▉  | 5081/6354 [10:30<02:36,  8.15it/s]

	batch 5079	loss: 0.930010199546814	mean: 0.2813653807072372


 82%|████████▏ | 5208/6354 [10:46<02:20,  8.18it/s]

	batch 5206	loss: 0.036176618188619614	mean: 0.26045917600070634


 84%|████████▍ | 5335/6354 [11:02<02:05,  8.13it/s]

	batch 5333	loss: 0.5897416472434998	mean: 0.29930638012432265


 86%|████████▌ | 5462/6354 [11:17<01:50,  8.08it/s]

	batch 5460	loss: 0.034745801240205765	mean: 0.2816359350626452


 88%|████████▊ | 5589/6354 [11:33<01:34,  8.11it/s]

	batch 5587	loss: 0.049054570496082306	mean: 0.27466194444440495


 90%|████████▉ | 5716/6354 [11:48<01:18,  8.13it/s]

	batch 5714	loss: 1.2757110595703125	mean: 0.26013088168202364


 92%|█████████▏| 5843/6354 [12:04<01:02,  8.15it/s]

	batch 5841	loss: 0.20909473299980164	mean: 0.2826757872803503


 94%|█████████▍| 5970/6354 [12:20<00:47,  8.15it/s]

	batch 5968	loss: 0.024171754717826843	mean: 0.3205018640261173


 96%|█████████▌| 6097/6354 [12:35<00:31,  8.16it/s]

	batch 6095	loss: 0.1286323219537735	mean: 0.31020362665244255


 98%|█████████▊| 6224/6354 [12:51<00:16,  8.10it/s]

	batch 6222	loss: 0.02272375486791134	mean: 0.262612447923592


100%|█████████▉| 6351/6354 [13:07<00:00,  8.06it/s]

	batch 6349	loss: 0.5410460829734802	mean: 0.2709995036006091


100%|██████████| 6354/6354 [13:10<00:00,  8.04it/s]
  2%|▏         | 128/6354 [00:16<12:53,  8.05it/s]

	batch 126	loss: 0.09990286827087402	mean: 0.2535267758850507


  4%|▍         | 255/6354 [00:31<12:30,  8.13it/s]

	batch 253	loss: 0.11066937446594238	mean: 0.24341619378835785


  6%|▌         | 382/6354 [00:47<12:09,  8.18it/s]

	batch 380	loss: 0.00760326674208045	mean: 0.30729546701154253


  8%|▊         | 509/6354 [01:03<11:54,  8.18it/s]

	batch 507	loss: 0.006081649102270603	mean: 0.22718283780997134


 10%|█         | 636/6354 [01:18<11:54,  8.00it/s]

	batch 634	loss: 0.578802764415741	mean: 0.2504379965298934


 12%|█▏        | 763/6354 [01:34<11:30,  8.10it/s]

	batch 761	loss: 0.41664403676986694	mean: 0.2486122079215592


 14%|█▍        | 890/6354 [01:49<11:09,  8.16it/s]

	batch 888	loss: 0.09426850080490112	mean: 0.2663577969532504


 16%|█▌        | 1017/6354 [02:05<10:55,  8.14it/s]

	batch 1015	loss: 0.22888337075710297	mean: 0.25210146974555886


 18%|█▊        | 1144/6354 [02:21<10:45,  8.07it/s]

	batch 1142	loss: 0.13579557836055756	mean: 0.2601000748118239


 20%|██        | 1271/6354 [02:37<10:23,  8.15it/s]

	batch 1269	loss: 0.2938583791255951	mean: 0.2572113723726943


 22%|██▏       | 1398/6354 [02:52<10:03,  8.21it/s]

	batch 1396	loss: 0.008995811454951763	mean: 0.22801279067795166


 24%|██▍       | 1525/6354 [03:08<10:04,  7.98it/s]

	batch 1523	loss: 0.386629194021225	mean: 0.2028767409089102


 26%|██▌       | 1652/6354 [03:24<09:48,  7.98it/s]

	batch 1650	loss: 0.5225715637207031	mean: 0.21065476341732212


 28%|██▊       | 1779/6354 [03:39<09:22,  8.14it/s]

	batch 1777	loss: 0.11672435700893402	mean: 0.24019490431455207


 30%|██▉       | 1906/6354 [03:55<09:08,  8.10it/s]

	batch 1904	loss: 0.08429321646690369	mean: 0.23645819397093215


 32%|███▏      | 2033/6354 [04:11<09:02,  7.96it/s]

	batch 2031	loss: 0.18063214421272278	mean: 0.2632570172376806


 34%|███▍      | 2160/6354 [04:26<08:37,  8.11it/s]

	batch 2158	loss: 0.056448813527822495	mean: 0.21351807512064266


 36%|███▌      | 2287/6354 [04:42<08:18,  8.15it/s]

	batch 2285	loss: 0.12801675498485565	mean: 0.22769676269156727


 38%|███▊      | 2414/6354 [04:58<08:06,  8.10it/s]

	batch 2412	loss: 0.6154978275299072	mean: 0.24618429597100522


 40%|███▉      | 2541/6354 [05:13<08:06,  7.83it/s]

	batch 2539	loss: 0.2377723902463913	mean: 0.21689550313313938


 42%|████▏     | 2668/6354 [05:29<07:36,  8.07it/s]

	batch 2666	loss: 0.43676435947418213	mean: 0.20213077900065915


 44%|████▍     | 2795/6354 [05:45<07:19,  8.09it/s]

	batch 2793	loss: 0.004333846271038055	mean: 0.21147378717959808


 46%|████▌     | 2922/6354 [06:00<07:02,  8.12it/s]

	batch 2920	loss: 0.0383325032889843	mean: 0.20812376631940824


 48%|████▊     | 3049/6354 [06:16<06:46,  8.14it/s]

	batch 3047	loss: 0.17889314889907837	mean: 0.19325475152957758


 50%|████▉     | 3176/6354 [06:32<06:29,  8.16it/s]

	batch 3174	loss: 0.03532997518777847	mean: 0.19127180485199682


 52%|█████▏    | 3303/6354 [06:51<06:28,  7.86it/s]

	batch 3301	loss: 0.00868442002683878	mean: 0.1827440853258155


 54%|█████▍    | 3430/6354 [07:06<06:00,  8.12it/s]

	batch 3428	loss: 0.4171626567840576	mean: 0.17382236332201773


 56%|█████▌    | 3557/6354 [07:22<05:43,  8.14it/s]

	batch 3555	loss: 0.050122324377298355	mean: 0.18724346290045513


 58%|█████▊    | 3684/6354 [07:38<05:26,  8.18it/s]

	batch 3682	loss: 0.07369337975978851	mean: 0.21382545203329834


 60%|█████▉    | 3811/6354 [07:54<05:13,  8.12it/s]

	batch 3809	loss: 0.8838328719139099	mean: 0.22253910412665412


 62%|██████▏   | 3938/6354 [08:09<04:58,  8.09it/s]

	batch 3936	loss: 0.014994729310274124	mean: 0.1726062069793679


 64%|██████▍   | 4065/6354 [08:25<04:41,  8.13it/s]

	batch 4063	loss: 0.46568581461906433	mean: 0.18975540913465455


 66%|██████▌   | 4192/6354 [08:41<04:28,  8.05it/s]

	batch 4190	loss: 0.512062132358551	mean: 0.1894521399787957


 68%|██████▊   | 4319/6354 [08:56<04:10,  8.13it/s]

	batch 4317	loss: 0.0022404957562685013	mean: 0.20421693318956127


 70%|██████▉   | 4446/6354 [09:12<03:53,  8.18it/s]

	batch 4444	loss: 0.010508371517062187	mean: 0.17029601717679432


 72%|███████▏  | 4573/6354 [09:27<03:39,  8.12it/s]

	batch 4571	loss: 0.0090936329215765	mean: 0.25686039412476663


 74%|███████▍  | 4700/6354 [09:43<03:24,  8.07it/s]

	batch 4698	loss: 0.09042289853096008	mean: 0.21543080896949468


 76%|███████▌  | 4827/6354 [09:59<03:07,  8.12it/s]

	batch 4825	loss: 0.049213748425245285	mean: 0.19660453862825952


 78%|███████▊  | 4954/6354 [10:14<02:51,  8.14it/s]

	batch 4952	loss: 0.382320374250412	mean: 0.22020676120287028


 80%|███████▉  | 5081/6354 [10:30<02:38,  8.01it/s]

	batch 5079	loss: 1.0159562826156616	mean: 0.1920445129251489


 82%|████████▏ | 5208/6354 [10:46<02:20,  8.17it/s]

	batch 5206	loss: 0.038934629410505295	mean: 0.18597908100428603


 84%|████████▍ | 5335/6354 [11:02<02:05,  8.14it/s]

	batch 5333	loss: 0.36707866191864014	mean: 0.20603003877391085


 86%|████████▌ | 5462/6354 [11:17<01:49,  8.11it/s]

	batch 5460	loss: 0.013089396990835667	mean: 0.1848646920159694


 88%|████████▊ | 5589/6354 [11:33<01:33,  8.15it/s]

	batch 5587	loss: 0.006270775105804205	mean: 0.19187508482237595


 90%|████████▉ | 5716/6354 [11:49<01:18,  8.15it/s]

	batch 5714	loss: 1.2485097646713257	mean: 0.17787227332069516


 92%|█████████▏| 5843/6354 [12:04<01:03,  8.06it/s]

	batch 5841	loss: 0.022288497537374496	mean: 0.1934718679857631


 94%|█████████▍| 5970/6354 [12:20<00:47,  8.04it/s]

	batch 5968	loss: 0.017645033076405525	mean: 0.22803335148459813


 96%|█████████▌| 6097/6354 [12:35<00:31,  8.07it/s]

	batch 6095	loss: 0.03171816095709801	mean: 0.21432917160105774


 98%|█████████▊| 6224/6354 [12:51<00:16,  7.75it/s]

	batch 6222	loss: 0.017868755385279655	mean: 0.1638234952807529


100%|█████████▉| 6351/6354 [13:07<00:00,  8.21it/s]

	batch 6349	loss: 0.43520626425743103	mean: 0.18746485313108213


100%|██████████| 6354/6354 [13:10<00:00,  8.03it/s]
  2%|▏         | 128/6354 [00:16<13:48,  7.52it/s]

	batch 126	loss: 0.017446987330913544	mean: 0.1555910002445714


  4%|▍         | 255/6354 [00:31<12:35,  8.07it/s]

	batch 253	loss: 0.05093598738312721	mean: 0.14655199310445322


  6%|▌         | 382/6354 [00:47<12:20,  8.07it/s]

	batch 380	loss: 0.004964746534824371	mean: 0.17795002444804917


  8%|▊         | 509/6354 [01:03<11:57,  8.15it/s]

	batch 507	loss: 0.003943790215998888	mean: 0.12539718168609826


 10%|█         | 636/6354 [01:18<11:38,  8.18it/s]

	batch 634	loss: 0.433058500289917	mean: 0.13399801206844675


 12%|█▏        | 763/6354 [01:34<11:24,  8.16it/s]

	batch 761	loss: 0.306403785943985	mean: 0.11759680469967718


 14%|█▍        | 890/6354 [01:50<11:14,  8.11it/s]

	batch 888	loss: 0.012045982293784618	mean: 0.13885933194468972


 16%|█▌        | 1017/6354 [02:05<10:59,  8.09it/s]

	batch 1015	loss: 0.17901356518268585	mean: 0.12461212296180947


 18%|█▊        | 1144/6354 [02:21<10:41,  8.12it/s]

	batch 1142	loss: 0.02944967709481716	mean: 0.1293412920473259


 20%|██        | 1271/6354 [02:37<10:23,  8.15it/s]

	batch 1269	loss: 0.11314880847930908	mean: 0.13993282517242725


 22%|██▏       | 1398/6354 [02:52<10:11,  8.11it/s]

	batch 1396	loss: 0.004363238345831633	mean: 0.11282531234186406


 24%|██▍       | 1525/6354 [03:08<09:58,  8.07it/s]

	batch 1523	loss: 0.17880268394947052	mean: 0.10973043712479334


 26%|██▌       | 1652/6354 [03:24<09:36,  8.15it/s]

	batch 1650	loss: 0.37380316853523254	mean: 0.11141078571461721


 28%|██▊       | 1779/6354 [03:39<09:28,  8.05it/s]

	batch 1777	loss: 0.05980861186981201	mean: 0.13441624769970306


 30%|██▉       | 1906/6354 [03:55<09:10,  8.07it/s]

	batch 1904	loss: 0.015630120411515236	mean: 0.1331665471565587


 32%|███▏      | 2033/6354 [04:11<08:57,  8.03it/s]

	batch 2031	loss: 0.05506349354982376	mean: 0.13877735682167333


 34%|███▍      | 2160/6354 [04:26<08:42,  8.02it/s]

	batch 2158	loss: 0.035583607852458954	mean: 0.1178395485965623


 36%|███▌      | 2287/6354 [04:42<08:19,  8.14it/s]

	batch 2285	loss: 0.08942911028862	mean: 0.127043817371064


 38%|███▊      | 2414/6354 [04:58<08:04,  8.13it/s]

	batch 2412	loss: 0.3955729603767395	mean: 0.13116575078488746


 40%|███▉      | 2541/6354 [05:14<07:51,  8.09it/s]

	batch 2539	loss: 0.1266201138496399	mean: 0.11927172495143884


 42%|████▏     | 2668/6354 [05:29<07:36,  8.08it/s]

	batch 2666	loss: 0.1233818531036377	mean: 0.11438720799718093


 44%|████▍     | 2795/6354 [05:45<07:23,  8.02it/s]

	batch 2793	loss: 0.003214568132534623	mean: 0.1286789606873947


 46%|████▌     | 2922/6354 [06:01<07:05,  8.06it/s]

	batch 2920	loss: 0.01857970841228962	mean: 0.12918960257647852


 48%|████▊     | 3049/6354 [06:16<06:48,  8.09it/s]

	batch 3047	loss: 0.11073777079582214	mean: 0.11278969284232274


 50%|████▉     | 3176/6354 [06:32<06:32,  8.09it/s]

	batch 3174	loss: 0.007878158241510391	mean: 0.11546611052987764


 52%|█████▏    | 3303/6354 [06:51<06:20,  8.02it/s]

	batch 3301	loss: 0.002812072867527604	mean: 0.11112228546641752


 54%|█████▍    | 3430/6354 [07:07<06:11,  7.87it/s]

	batch 3428	loss: 0.33567196130752563	mean: 0.10771503829599842


 56%|█████▌    | 3557/6354 [07:22<05:46,  8.07it/s]

	batch 3555	loss: 0.031801801174879074	mean: 0.11703497807630045


 58%|█████▊    | 3684/6354 [07:38<05:32,  8.03it/s]

	batch 3682	loss: 0.028558844700455666	mean: 0.12473280015415292


 60%|█████▉    | 3811/6354 [07:54<05:11,  8.17it/s]

	batch 3809	loss: 0.6027608513832092	mean: 0.13981706030917992


 62%|██████▏   | 3938/6354 [08:10<04:56,  8.14it/s]

	batch 3936	loss: 0.009574683383107185	mean: 0.09846792116213048


 64%|██████▍   | 4065/6354 [08:25<04:42,  8.11it/s]

	batch 4063	loss: 0.33745551109313965	mean: 0.12438971573563923


 66%|██████▌   | 4192/6354 [08:41<04:25,  8.15it/s]

	batch 4190	loss: 0.41060155630111694	mean: 0.11700251036851601


 68%|██████▊   | 4319/6354 [08:57<04:09,  8.16it/s]

	batch 4317	loss: 0.0013806659262627363	mean: 0.1343055499924299


 70%|██████▉   | 4446/6354 [09:12<03:54,  8.12it/s]

	batch 4444	loss: 0.006936390418559313	mean: 0.10859458729204004


 72%|███████▏  | 4573/6354 [09:28<03:42,  8.01it/s]

	batch 4571	loss: 0.004850903060287237	mean: 0.17255034053805432


 74%|███████▍  | 4700/6354 [09:44<03:27,  7.97it/s]

	batch 4698	loss: 0.06837040930986404	mean: 0.14748093015770977


 76%|███████▌  | 4827/6354 [09:59<03:08,  8.08it/s]

	batch 4825	loss: 0.010362484492361546	mean: 0.12934783785892687


 78%|███████▊  | 4954/6354 [10:15<02:52,  8.11it/s]

	batch 4952	loss: 0.21699805557727814	mean: 0.15709425433168525


 80%|███████▉  | 5081/6354 [10:31<02:35,  8.20it/s]

	batch 5079	loss: 0.9478895664215088	mean: 0.13073589051759701


 82%|████████▏ | 5208/6354 [10:46<02:21,  8.12it/s]

	batch 5206	loss: 0.02726845256984234	mean: 0.13431404816013334


 84%|████████▍ | 5335/6354 [11:02<02:06,  8.06it/s]

	batch 5333	loss: 0.14846083521842957	mean: 0.1456722066865814


 86%|████████▌ | 5462/6354 [11:18<01:53,  7.87it/s]

	batch 5460	loss: 0.00642323587089777	mean: 0.13142642107156038


 88%|████████▊ | 5589/6354 [11:33<01:34,  8.13it/s]

	batch 5587	loss: 0.00408911844715476	mean: 0.1449972476926403


 90%|████████▉ | 5716/6354 [11:49<01:17,  8.18it/s]

	batch 5714	loss: 1.2489014863967896	mean: 0.13649863971195159


 92%|█████████▏| 5843/6354 [12:05<01:02,  8.18it/s]

	batch 5841	loss: 0.019725147634744644	mean: 0.14659682392416792


 94%|█████████▍| 5970/6354 [12:20<00:47,  8.16it/s]

	batch 5968	loss: 0.016608664765954018	mean: 0.1812812038896675


 96%|█████████▌| 6097/6354 [12:36<00:31,  8.12it/s]

	batch 6095	loss: 0.017175965011119843	mean: 0.16537951142618829


 98%|█████████▊| 6224/6354 [12:52<00:15,  8.13it/s]

	batch 6222	loss: 0.016849368810653687	mean: 0.12035265183727088


100%|█████████▉| 6351/6354 [13:07<00:00,  8.00it/s]

	batch 6349	loss: 0.31103816628456116	mean: 0.15689795177157667


100%|██████████| 6354/6354 [13:11<00:00,  8.03it/s]
  2%|▏         | 128/6354 [00:15<12:49,  8.10it/s]

	batch 126	loss: 0.012808226980268955	mean: 0.11494166506972259


  4%|▍         | 255/6354 [00:31<12:29,  8.14it/s]

	batch 253	loss: 0.03565111756324768	mean: 0.10358994549523497


  6%|▌         | 382/6354 [00:47<12:19,  8.08it/s]

	batch 380	loss: 0.005247966852039099	mean: 0.10591963909641568


  8%|▊         | 509/6354 [01:03<12:00,  8.11it/s]

	batch 507	loss: 0.003318089758977294	mean: 0.07160516952642966


 10%|█         | 636/6354 [01:18<11:47,  8.08it/s]

	batch 634	loss: 0.34404340386390686	mean: 0.0685896236948141


 12%|█▏        | 763/6354 [01:34<11:29,  8.11it/s]

	batch 761	loss: 0.185492604970932	mean: 0.05369147724622318


 14%|█▍        | 890/6354 [01:50<11:17,  8.07it/s]

	batch 888	loss: 0.00970473513007164	mean: 0.07010683384481048


 16%|█▌        | 1017/6354 [02:05<10:58,  8.10it/s]

	batch 1015	loss: 0.10918033123016357	mean: 0.06107157335433775


 18%|█▊        | 1144/6354 [02:21<10:38,  8.16it/s]

	batch 1142	loss: 0.013695722445845604	mean: 0.06928268284743146


 20%|██        | 1271/6354 [02:37<10:28,  8.08it/s]

	batch 1269	loss: 0.06353049725294113	mean: 0.06822382345361136


 22%|██▏       | 1398/6354 [02:53<10:14,  8.06it/s]

	batch 1396	loss: 0.004358469508588314	mean: 0.05516730010196453


 24%|██▍       | 1525/6354 [03:08<10:00,  8.04it/s]

	batch 1523	loss: 0.08072049170732498	mean: 0.059175470832911296


 26%|██▌       | 1652/6354 [03:24<09:40,  8.09it/s]

	batch 1650	loss: 0.19649861752986908	mean: 0.06087583209253668


 28%|██▊       | 1779/6354 [03:40<09:24,  8.10it/s]

	batch 1777	loss: 0.03410150483250618	mean: 0.07044130904915354


 30%|██▉       | 1906/6354 [03:55<09:04,  8.16it/s]

	batch 1904	loss: 0.0067740874364972115	mean: 0.07006674842527649


 32%|███▏      | 2033/6354 [04:11<08:51,  8.13it/s]

	batch 2031	loss: 0.03378874436020851	mean: 0.07281110030149777


 34%|███▍      | 2160/6354 [04:27<08:36,  8.11it/s]

	batch 2158	loss: 0.027744410559535027	mean: 0.06500848793464999


 36%|███▌      | 2287/6354 [04:42<08:35,  7.88it/s]

	batch 2285	loss: 0.07497526705265045	mean: 0.06883898610936653


 38%|███▊      | 2414/6354 [04:58<08:09,  8.05it/s]

	batch 2412	loss: 0.20008760690689087	mean: 0.06857231695307639


 40%|███▉      | 2541/6354 [05:14<07:48,  8.14it/s]

	batch 2539	loss: 0.057229556143283844	mean: 0.06440976476406782


 42%|████▏     | 2668/6354 [05:29<07:32,  8.15it/s]

	batch 2666	loss: 0.04394172877073288	mean: 0.06550038693182621


 44%|████▍     | 2795/6354 [05:45<07:18,  8.12it/s]

	batch 2793	loss: 0.0037758450489491224	mean: 0.07643125499951595


 46%|████▌     | 2922/6354 [06:01<07:00,  8.16it/s]

	batch 2920	loss: 0.013420362025499344	mean: 0.07654551707238406


 48%|████▊     | 3049/6354 [06:16<06:42,  8.20it/s]

	batch 3047	loss: 0.057479217648506165	mean: 0.06351529908738664


 50%|████▉     | 3176/6354 [06:32<06:32,  8.10it/s]

	batch 3174	loss: 0.004910241346806288	mean: 0.06920958431272614


 52%|█████▏    | 3303/6354 [06:51<06:28,  7.85it/s]

	batch 3301	loss: 0.002110116183757782	mean: 0.06626397384033633


 54%|█████▍    | 3430/6354 [07:07<05:58,  8.15it/s]

	batch 3428	loss: 0.20083582401275635	mean: 0.0688181861877115


 56%|█████▌    | 3557/6354 [07:23<05:48,  8.04it/s]

	batch 3555	loss: 0.014966574497520924	mean: 0.07197001131432952


 58%|█████▊    | 3684/6354 [07:38<05:29,  8.10it/s]

	batch 3682	loss: 0.02576354146003723	mean: 0.07530091690540307


 60%|█████▉    | 3811/6354 [07:54<05:18,  8.00it/s]

	batch 3809	loss: 0.40471184253692627	mean: 0.0872272333872961


 62%|██████▏   | 3938/6354 [08:10<05:09,  7.80it/s]

	batch 3936	loss: 0.012646346352994442	mean: 0.059619575411660644


 64%|██████▍   | 4065/6354 [08:26<04:43,  8.06it/s]

	batch 4063	loss: 0.21302959322929382	mean: 0.07718165591427888


 66%|██████▌   | 4192/6354 [08:42<04:24,  8.18it/s]

	batch 4190	loss: 0.23980040848255157	mean: 0.07575891037676924


 68%|██████▊   | 4319/6354 [08:57<04:09,  8.15it/s]

	batch 4317	loss: 0.0007679835543967783	mean: 0.09130595504055551


 70%|██████▉   | 4446/6354 [09:13<03:53,  8.16it/s]

	batch 4444	loss: 0.0031602552626281977	mean: 0.0752370756726316


 72%|███████▏  | 4573/6354 [09:29<03:38,  8.14it/s]

	batch 4571	loss: 0.0061734821647405624	mean: 0.11757850424752307


 74%|███████▍  | 4700/6354 [09:44<03:26,  7.99it/s]

	batch 4698	loss: 0.052563995122909546	mean: 0.10007464137995307


 76%|███████▌  | 4827/6354 [10:00<03:13,  7.88it/s]

	batch 4825	loss: 0.007183557841926813	mean: 0.0877063639510577


 78%|███████▊  | 4954/6354 [10:16<02:51,  8.17it/s]

	batch 4952	loss: 0.1372106373310089	mean: 0.11593388972252124


 80%|███████▉  | 5081/6354 [10:32<02:37,  8.07it/s]

	batch 5079	loss: 0.8126410841941833	mean: 0.09464949173896342


 82%|████████▏ | 5208/6354 [10:47<02:21,  8.09it/s]

	batch 5206	loss: 0.016151590272784233	mean: 0.1024579423178156


 84%|████████▍ | 5335/6354 [11:03<02:04,  8.15it/s]

	batch 5333	loss: 0.0678819864988327	mean: 0.11057709535152109


 86%|████████▌ | 5462/6354 [11:19<01:49,  8.16it/s]

	batch 5460	loss: 0.004362063482403755	mean: 0.09971778611854218


 88%|████████▊ | 5589/6354 [11:34<01:35,  8.03it/s]

	batch 5587	loss: 0.0027560449671000242	mean: 0.11908913147074306


 90%|████████▉ | 5716/6354 [11:50<01:20,  7.90it/s]

	batch 5714	loss: 1.1492063999176025	mean: 0.10987313517094792


 92%|█████████▏| 5843/6354 [12:06<01:03,  8.08it/s]

	batch 5841	loss: 0.010394172742962837	mean: 0.11872000493270558


 94%|█████████▍| 5970/6354 [12:21<00:47,  8.10it/s]

	batch 5968	loss: 0.010570787824690342	mean: 0.1550660570301977


 96%|█████████▌| 6097/6354 [12:37<00:31,  8.10it/s]

	batch 6095	loss: 0.01418267097324133	mean: 0.1401692632756262


 98%|█████████▊| 6224/6354 [12:53<00:15,  8.15it/s]

	batch 6222	loss: 0.02029329724609852	mean: 0.09868330685522657


100%|█████████▉| 6351/6354 [13:09<00:00,  8.09it/s]

	batch 6349	loss: 0.2271687090396881	mean: 0.14466812614123264


100%|██████████| 6354/6354 [13:15<00:00,  7.99it/s]


In [None]:
# text = "Checkpoint, optimizer, scheduler, learning_rate, weight_decay, top5_loss\n"
# with open('./results/experiments.csv', 'w') as f:
#     f.write(text)

## Validation

In [None]:
def generate_seq(model, tokenizer, input):
    generated_ids = model.generate(**input)
    generated_text = tokenizer.decode(generated_ids.squeeze(0), skip_special_tokens=True)
    
    return generated_text

def generate_input_target(model, tokenizer, input, label):
    input_text = tokenizer.decode(input['input_ids'].squeeze(0), skip_special_tokens=True)
    generated_text = generate_seq(model, tokenizer, input)
    target_text = tokenizer.decode(label.squeeze(0), skip_special_tokens=True)
    
    return {
        'input_text': input_text,
        'generated_text': generated_text, 
        'target_text': target_text
    }

def generate_from_data(model, tokenizer, data):
    label = data['labels']
    input_data = dict()
    input_data['input_ids'] = data['input_ids']
    input_data['attention_mask'] = data['attention_mask']

    return generate_input_target(model, tokenizer, input_data, label)

def eval(model, tokenizer, input_seq, label, metric: Callable, options = dict()):
    generated_input_target = generate_input_target(model, tokenizer, input_seq, label)
    score = metric(
        generated_input_target['generated_text'], 
        generated_input_target['target_text'],
        **options
    )

    return score

def eval_from_data(model, tokenizer, dataset, metric: Callable, options = dict()):
    result = []
    for data in dataset:
        label = data['labels']
        input_data = {
            'input_ids': data['input_ids'],
            'attention_mask': data['attention_mask'],
        }

        result.append(eval(model, tokenizer, input_data, label, metric, options))

    return pd.Series(result)

def eval_bleu(model, tokenizer, tokenized_testset):
    bleu_score_lt = []
    for example in tqdm(tokenized_testset):
        output = generate_from_data(model, tokenizer, example)
        try:
            bleu_score = sentence_bleu([output['target_text']], 
                                       output['generated_text'], 
                                       smoothing_function=SmoothingFunction().method1
            )
        except ValueError:
            continue
        bleu_score_lt.append(bleu_score)
    
    return pd.DataFrame({'BLEU': bleu_score_lt})

def eval_rogue(model, tokenizer, tokenized_testset):
    rouge = Rouge()
    rouge_score_dict = dict()
    rouge_score_dict['Precision'] = []
    rouge_score_dict['Recall'] = []
    rouge_score_dict['F1'] = []

    for example in tqdm(tokenized_testset):
        output = generate_from_data(model, tokenizer, example)
        try:
            rouge_score = rouge.get_scores(output['generated_text'], 
                                           output['target_text']
            )
        except ValueError:
            continue
        rouge_score_precision = rouge_score[0]['rouge-2']['p']
        rouge_score_recall = rouge_score[0]['rouge-2']['r']
        rouge_score_f = rouge_score[0]['rouge-2']['f']
        
        rouge_score_dict['Precision'].append(rouge_score_precision)
        rouge_score_dict['Recall'].append(rouge_score_recall)
        rouge_score_dict['F1'].append(rouge_score_f)
    
    return pd.DataFrame(rouge_score_dict)

In [None]:
bleu_df = eval_bleu(model, tokenizer, tokenized_dataset['test'])
bleu_df.to_csv(f"./results/{checkpoint[checkpoint.rfind('/'):]}bleu.csv", index=False)
bleu_df.describe()

In [None]:
rouge_df = eval_rogue(model, tokenizer, tokenized_dataset['test'])
rouge_df.to_csv(f"./results/{checkpoint[checkpoint.rfind('/'):]}rouge.csv", index=False)
rouge_df.describe()