<a href="https://colab.research.google.com/github/SOL1archive/ClauseSummary/blob/main/main_model_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
try:
  import google.colab
  IN_COLAB = True
except:
  IN_COLAB = False

'Process in Colab' if IN_COLAB else 'Process in Local'

'Process in Colab'

In [2]:
if IN_COLAB:
    !pip install transformers
    !pip install datasets
    !pip install evaluate
    !pip install rouge_score
    !pip install torchmetrics
    !pip install rouge
    !pip install --upgrade accelerate

Collecting transformers
  Downloading transformers-4.30.2-py3-none-any.whl (7.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.2/7.2 MB[0m [31m101.4 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.14.1 (from transformers)
  Downloading huggingface_hub-0.15.1-py3-none-any.whl (236 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m236.8/236.8 kB[0m [31m27.1 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m84.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors>=0.3.1 (from transformers)
  Downloading safetensors-0.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m67.1 MB/s[0m eta [36m0:00:

In [3]:
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive/')

Mounted at /content/drive/


In [None]:
# 깃허브에서는 빼야됨
%cd drive/MyDrive/projects/ClauseSummary

/content/drive/MyDrive/projects/ClauseSummary


In [4]:
import warnings
warnings.filterwarnings('ignore')
import datetime
import os
import gc
from pprint import pprint
from tqdm import tqdm

import numpy as np
import pandas as pd

import tensorboard
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau, CosineAnnealingLR, CyclicLR
import torchmetrics

from datasets import load_dataset, load_from_disk, concatenate_datasets, DatasetDict, Dataset
from transformers import get_linear_schedule_with_warmup
from transformers import DataCollatorForSeq2Seq
from transformers import BartConfig, T5Config, LongformerConfig
from transformers import AutoTokenizer, LongformerTokenizer, AutoModelForSeq2SeqLM, LongT5ForConditionalGeneration
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
import evaluate
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge import Rouge

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline

import seaborn as sns

In [5]:
class TokenizeMapWrapper:
    def __init__(self, tokenizer, feature, option=None):
        if option is None:
            option = {
                'max_length': 512,
                'truncation': True,
                'padding': 'max_length',
            }

        self.option = option
        self.feature = feature
        self.tokenizer = tokenizer

    def __call__(self, row):
        return self.tokenizer(row[self.feature], **self.option)

    def __repr__(self):
        return f'{self.__class__.__name__}(tokenizer={self.tokenizer})'

class Seq2SeqTokenizeMapWrapper(TokenizeMapWrapper):
    def __init__(self, tokenizer, feature, target, option=None):
        super().__init__(tokenizer, feature, option)
        self.target = target

    def seq2seq_tokenize(self, row):
        form_embeddings = self.tokenizer(row[self.feature], **self.option)
        with self.tokenizer.as_target_tokenizer():
            correct_form_embeddings = self.tokenizer(row[self.target], **self.option)

        return {
            'input_ids': form_embeddings['input_ids'],
            'attention_mask': form_embeddings['attention_mask'],
            'labels': correct_form_embeddings['input_ids'],
        }

    def __call__(self, row):
        return self.seq2seq_tokenize(row)

## Setting

- 학습 환경에 맞게 조정하기 (특히 **경로 설정**)

In [6]:
MANUAL_TRAINING = True
MANUAL_VALIDATION = True
NUM_EPOCHS = 1
MID_CHECKPOINT_NUM = 2
MID_PROCESS_PRINT_NUM = 50

In [7]:
t5_large_summary_checkpoint = 'lcw99/t5-large-korean-text-summary'
t5_base_summary_checkpoint = 'eenzeenee/t5-base-korean-summarization'
kobart_summary_checkpoint = 'gogamza/kobart-summarization'
kolongformer = "psyche/kolongformer-4096"
longt5_checkpoint = 'KETI-AIR-Downstream/long-ke-t5-base-summarization'
checkpoint = longt5_checkpoint
print(f'Using Checkpoint: {checkpoint}')

Using Checkpoint: KETI-AIR-Downstream/long-ke-t5-base-summarization


In [None]:
original_dataset_path = './data/dataset-term-summary.json'
tokenized_dataset_path = f'./data/{checkpoint.replace("/", "-")}-tokenized-dataset'

In [None]:
SAVE_STR = datetime.datetime.now().strftime('%y%m%d-%H:%M')
model_save_path = f"./model/{SAVE_STR}"

## Load Tokenizer & Model Checkpoint

In [None]:
if 'bart' in checkpoint.lower():
    config = BartConfig.from_pretrained(checkpoint)
    #config['vocab'] = 30000
elif "t5" in checkpoint.lower():
    config = T5Config.from_pretrained(checkpoint)
elif "longformer" in checkpoint.lower():
    config = LongformerConfig.from_pretrained(checkpoint)

In [9]:
tokenizer = AutoTokenizer.from_pretrained(checkpoint,
                                          max_length=4096,
                                          truncation=False,
                                          padding='max_length',
                                          #vocab=config.vocab_size
                                          )
#tokenizer = LongformerTokenizer(vocab_file, merges_file, errors='replace', bos_token='<s>', eos_token='</s>', sep_token='</s>', cls_token='<s>', unk_token='<unk>', pad_token='<pad>', mask_token='<mask>', add_prefix_space=False, **kwargs)


model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

Downloading (…)lve/main/config.json:   0%|          | 0.00/924 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.19G [00:00<?, ?B/s]

In [None]:
if len(tokenizer) != model.config.vocab_size:
    raise RuntimeError(f'Tokenizer vocab size and model vocab size do not match(Tokenizer:{len(tokenizer)} Model: {model.config.vocab_size}). Which would lead to further error in training.')

## Load Dataset

In [None]:
if not os.path.exists(tokenized_dataset_path):
    dataset = Dataset.from_pandas(pd.read_json(original_dataset_path, encoding='utf-8')[['text', 'summary']])
    tokenizer_wrapper = Seq2SeqTokenizeMapWrapper(tokenizer, 'text', 'summary')

    tokenized_dataset = (dataset
                         .map(tokenizer_wrapper,
                              batched=True,
                              batch_size=128,
                              num_proc=10
                              )
                         .remove_columns(['text', 'summary'])
                         )

    tokenized_dataset_dict = tokenized_dataset.train_test_split(test_size=0.2, shuffle=True)
    tokenized_dataset_dict.save_to_disk(tokenized_dataset_path)
else:
    tokenized_dataset_dict = load_from_disk(tokenized_dataset_path)

## Training

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors='pt')

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=128,
    per_device_eval_batch_size=128,
    num_train_epochs=NUM_EPOCHS,
    weight_decay=0.01,
    report_to="tensorboard",
    push_to_hub=False,
)

In [None]:
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=tokenized_dataset_dict['train'],
    data_collator=data_collator,
)

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:

# model.train()

# if not MANUAL_TRAINING:
#     trainer.train()
# else:
#     total_loss = []
#     epoch_loss = []
#     batch_loss = []
#     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#     trainset = tokenized_dataset_dict['train'].with_format('torch', device=device)
#     dataloader = DataLoader(trainset, batch_size=1, shuffle=False) # TODO: Batch size 조절

#     # TODO: Write a code for **Hyperparameter Tuning**
#     optimizer = AdamW(model.parameters(), lr = training_args.learning_rate, weight_decay = training_args.weight_decay)
#     scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=NUM_EPOCHS * len(dataloader))

#     for epoch in range(NUM_EPOCHS):
#         total_steps = len(dataloader)
#         save_divisor = total_steps // MID_CHECKPOINT_NUM
#         print_divisor = total_steps // MID_PROCESS_PRINT_NUM
#         for i, batch in enumerate(tqdm(dataloader)):
#             X = {
#                     'input_ids': batch['input_ids'],
#                     'attention_mask': batch['attention_mask'],
#                 }
#             y = batch['labels']

#             outputs = model(**X, labels=y)
#             loss = outputs.loss
#             loss.backward()
#             optimizer.step()
#             optimizer.zero_grad()
#             scheduler.step()

#             batch_loss.append(loss.item())
#             if i % print_divisor == print_divisor - 1:
#                 epoch_loss += batch_loss
#                 batch_loss_series = pd.Series(batch_loss)
#                 print(f'\tbatch {i}\tloss: {loss.item()}\tmean: {batch_loss_series.mean()}')
#                 batch_loss = []

#             if i % save_divisor == save_divisor - 1:
#                 trainer.create_model_card(
#                     language='Korean',
#                     tags='Grammar',
#                     finetuned_from=checkpoint
#                 )
#                 trainer.save_model(model_save_path + f'-epoch-{epoch + 1}' + '-batch-{i + 1}')

#         total_loss += epoch_loss
#         batch_loss_series = pd.Series(epoch_loss)
#         epoch_loss = []
#         print(f'epoch {epoch + 1} loss: {loss.item()} mean: {batch_loss_series.mean()}')

In [None]:
# total_loss_series = pd.Series(total_loss)
# total_loss_series.plot.line()

In [None]:
print(checkpoint)

eenzeenee/t5-base-korean-summarization


In [None]:
def generate_seq(model, tokenizer, input):
    generated_ids = model.generate(**input)
    generated_text = tokenizer.decode(generated_ids.squeeze(0), skip_special_tokens=True)

    return generated_text

def generate_input_target(model, tokenizer, input, label):
    input_text = tokenizer.decode(input['input_ids'].squeeze(0), skip_special_tokens=True)
    generated_text = generate_seq(model, tokenizer, input)
    target_text = tokenizer.decode(label.squeeze(0), skip_special_tokens=True)

    return {
        'input_text': input_text,
        'generated_text': generated_text,
        'target_text': target_text
    }

def generate_from_data(model, tokenizer, data):
    label = data['labels']
    input_data = dict()
    input_data['input_ids'] = data['input_ids']
    input_data['attention_mask'] = data['attention_mask']

    return generate_input_target(model, tokenizer, input_data, label)

def eval_bleu(model, tokenizer, tokenized_testset):
    bleu_score_lt = []
    for example in tqdm(tokenized_testset):
        data = dict()
        for key in example:
            data[key] = example[key].unsqueeze(0)

        output = generate_from_data(model, tokenizer, data)
        try:
            bleu_score = sentence_bleu([output['target_text']],
                                       output['generated_text'],
                                       smoothing_function=SmoothingFunction().method1
            )
        except ValueError:
            continue
        bleu_score_lt.append(bleu_score)

    return pd.DataFrame({'BLEU': bleu_score_lt})

def eval_rogue(model, tokenizer, tokenized_testset):
    rouge = Rouge()
    rouge_score_dict = dict()
    rouge_score_dict['Precision'] = []
    rouge_score_dict['Recall'] = []
    rouge_score_dict['F1'] = []

    for example in tqdm(tokenized_testset):
        data = dict()
        for key in example:
            data[key] = example[key].unsqueeze(0)
        output = generate_from_data(model, tokenizer, data)
        try:
            rouge_score = rouge.get_scores(output['generated_text'],
                                           output['target_text']
            )
        except ValueError:
            continue
        rouge_score_precision = rouge_score[0]['rouge-2']['p']
        rouge_score_recall = rouge_score[0]['rouge-2']['r']
        rouge_score_f = rouge_score[0]['rouge-2']['f']

        rouge_score_dict['Precision'].append(rouge_score_precision)
        rouge_score_dict['Recall'].append(rouge_score_recall)
        rouge_score_dict['F1'].append(rouge_score_f)

    return pd.DataFrame(rouge_score_dict)

In [None]:
## finding the best parameters
def mean(A):
    sum = 0
    for a in A:
        sum += a
    return sum / len(A)

# 개인적으로는 L2 Norm 계수보다는 Learning rate나 learning rate scheduling 최적화에 집중하는 것도 좋을듯??

learning_rate = 2e-5
decay = 0.05

gc.collect()
torch.cuda.empty_cache()

training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    evaluation_strategy = "epoch",
    learning_rate=learning_rate,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    num_train_epochs=NUM_EPOCHS,
    weight_decay=decay,
    report_to="tensorboard",
    push_to_hub=False,
)

total_loss = []
epoch_loss = []
batch_loss = []

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
trainset = tokenized_dataset_dict['train'].with_format('torch', device=device)
testset = tokenized_dataset_dict['test'].with_format('torch', device=device)
dataloader = DataLoader(trainset, batch_size=4, shuffle=False) # TODO: Batch size 조절

# TODO: Write a code for **Hyperparameter Tuning**
optimizer = AdamW(model.parameters(), lr = training_args.learning_rate, weight_decay = training_args.weight_decay)
optimizer_name = "AdamW"
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=NUM_EPOCHS * len(dataloader))
scheduler_name = "linear_schedule"

for epoch in range(NUM_EPOCHS):
    total_steps = len(dataloader)
    save_divisor = total_steps // MID_CHECKPOINT_NUM
    print_divisor = total_steps // MID_PROCESS_PRINT_NUM
    with tqdm(dataloader) as tqdm_bar:
        for i, batch in enumerate(tqdm_bar):
            X = {
                    'input_ids': batch['input_ids'],
                    'attention_mask': batch['attention_mask'],
                }
            y = batch['labels']

            outputs = model(**X, labels=y)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()

            batch_loss.append(loss.item())
            if i % print_divisor == print_divisor - 1:
                epoch_loss += batch_loss
                batch_loss_series = pd.Series(batch_loss)
                tqdm_bar.set_description(f'\tbatch {i}\tloss: {loss.item()}\tmean: {batch_loss_series.mean()}')
                batch_loss = []

            if i % save_divisor == save_divisor - 1:
                trainer.create_model_card(
                    language='Korean',
                    tags='Grammar',
                    finetuned_from=checkpoint
                )
                trainer.save_model(model_save_path + checkpoint + f'-epoch-{epoch + 1}' + f'-batch-{i + 1}')

            total_loss += epoch_loss
            batch_loss_series = pd.Series(epoch_loss)
            epoch_loss = []

    # for recording
    total_loss.sort()
    top5_loss = mean(total_loss[:5])

    bleu_metric = eval_bleu(model, tokenizer, testset)
    rouge_metric = eval_rogue(model, tokenizer, testset)

    Xtext = "%s, %s, %s, %f, %f, %f, %f, %f, %f, %f\n"%(checkpoint, optimizer_name, scheduler_name, training_args.learning_rate, training_args.weight_decay, top5_loss, mean(bleu_metric["BLEU"]), mean(rouge_metric["Precision"]), mean(rouge_metric["Recall"]), mean(rouge_metric["F1"]))
    with open('./results/experiments.csv', 'a') as f:
        f.write(text)
    trainer.create_model_card(
        language='Korean',
        finetuned_from=checkpoint
    )
    trainer.save_model(model_save_path + f'lr={learning_rate}-decay={weight_decays}')

  2%|▏         | 31/1589 [00:11<07:34,  3.43it/s]

	batch 30	loss: 16.92367935180664	mean: 16.911983613044985


  4%|▍         | 62/1589 [00:20<07:24,  3.43it/s]

	batch 61	loss: 14.953643798828125	mean: 15.960837118087277


  6%|▌         | 93/1589 [00:30<07:16,  3.42it/s]

	batch 92	loss: 12.003929138183594	mean: 13.780669150813933


  8%|▊         | 124/1589 [00:39<07:07,  3.43it/s]

	batch 123	loss: 9.859634399414062	mean: 11.247279382521107


 10%|▉         | 155/1589 [00:48<06:56,  3.44it/s]

	batch 154	loss: 6.663162708282471	mean: 9.100438579436272


 12%|█▏        | 186/1589 [00:57<06:48,  3.44it/s]

	batch 185	loss: 4.104167938232422	mean: 5.85140492839198


 14%|█▎        | 217/1589 [01:06<06:38,  3.44it/s]

	batch 216	loss: 3.343761444091797	mean: 3.9187292668127243


 16%|█▌        | 248/1589 [01:15<06:29,  3.44it/s]

	batch 247	loss: 2.902214527130127	mean: 3.271225767750894


 18%|█▊        | 279/1589 [01:24<06:21,  3.43it/s]

	batch 278	loss: 3.134453535079956	mean: 3.3182084560394287


 20%|█▉        | 310/1589 [01:33<06:11,  3.44it/s]

	batch 309	loss: 2.682809829711914	mean: 2.876289621476204


 21%|██▏       | 341/1589 [01:42<06:02,  3.44it/s]

	batch 340	loss: 2.4664788246154785	mean: 2.580979577956661


 23%|██▎       | 372/1589 [01:51<05:53,  3.44it/s]

	batch 371	loss: 2.4505326747894287	mean: 2.547951582939394


 25%|██▌       | 403/1589 [02:00<05:45,  3.44it/s]

	batch 402	loss: 2.380064010620117	mean: 2.3612971613484044


 27%|██▋       | 434/1589 [02:09<05:35,  3.45it/s]

	batch 433	loss: 2.182300567626953	mean: 1.9960220782987532


 29%|██▉       | 465/1589 [02:18<05:27,  3.43it/s]

	batch 464	loss: 1.5539177656173706	mean: 1.7635007173784318


 31%|███       | 496/1589 [02:27<05:16,  3.45it/s]

	batch 495	loss: 1.2173610925674438	mean: 1.613778425801185


 33%|███▎      | 527/1589 [02:36<05:09,  3.44it/s]

	batch 526	loss: 1.3508192300796509	mean: 1.4424317382997083


 35%|███▌      | 558/1589 [02:45<04:59,  3.44it/s]

	batch 557	loss: 1.401520848274231	mean: 1.226391000132407


 37%|███▋      | 589/1589 [02:54<04:51,  3.43it/s]

	batch 588	loss: 1.103131651878357	mean: 1.2682881989786703


 39%|███▉      | 620/1589 [03:03<04:41,  3.45it/s]

	batch 619	loss: 1.1610910892486572	mean: 1.0952721449636644


 41%|████      | 651/1589 [03:12<04:32,  3.44it/s]

	batch 650	loss: 1.0382884740829468	mean: 1.110636863016313


 43%|████▎     | 682/1589 [03:21<04:23,  3.44it/s]

	batch 681	loss: 0.9867775440216064	mean: 0.9565802216529846


 45%|████▍     | 713/1589 [03:30<04:14,  3.44it/s]

	batch 712	loss: 0.9431935548782349	mean: 0.9078159082320428


 47%|████▋     | 744/1589 [03:39<04:06,  3.43it/s]

	batch 743	loss: 0.7224128246307373	mean: 0.8712287852841039


 49%|████▉     | 775/1589 [03:48<03:57,  3.43it/s]

	batch 774	loss: 0.782162070274353	mean: 0.7971179908321749


 51%|█████     | 806/1589 [04:00<03:57,  3.29it/s]

	batch 805	loss: 0.7728719115257263	mean: 0.7768261394193096


 53%|█████▎    | 837/1589 [04:09<03:39,  3.42it/s]

	batch 836	loss: 0.7901128530502319	mean: 0.7264368197610301


 55%|█████▍    | 868/1589 [04:18<03:29,  3.43it/s]

	batch 867	loss: 0.6856735348701477	mean: 0.7418919023006193


 57%|█████▋    | 899/1589 [04:27<03:20,  3.45it/s]

	batch 898	loss: 0.7568831443786621	mean: 0.7367464621220866


 59%|█████▊    | 930/1589 [04:36<03:12,  3.43it/s]

	batch 929	loss: 0.7210502624511719	mean: 0.7613629121934214


 60%|██████    | 961/1589 [04:45<03:02,  3.44it/s]

	batch 960	loss: 0.566251277923584	mean: 0.7412481807893322


 62%|██████▏   | 992/1589 [04:54<02:53,  3.44it/s]

	batch 991	loss: 0.7340598106384277	mean: 0.6679426075950745


 64%|██████▍   | 1023/1589 [05:03<02:44,  3.44it/s]

	batch 1022	loss: 0.5339837074279785	mean: 0.7157699523433563


 66%|██████▋   | 1054/1589 [05:12<02:35,  3.44it/s]

	batch 1053	loss: 0.7593030333518982	mean: 0.7027488110526916


 68%|██████▊   | 1085/1589 [05:21<02:26,  3.44it/s]

	batch 1084	loss: 0.8361758589744568	mean: 0.6688715027224633


 70%|███████   | 1116/1589 [05:30<02:17,  3.43it/s]

	batch 1115	loss: 0.26518434286117554	mean: 0.6669100119221595


 72%|███████▏  | 1147/1589 [05:39<02:08,  3.44it/s]

	batch 1146	loss: 0.694835901260376	mean: 0.727138138586475


 74%|███████▍  | 1178/1589 [05:48<01:59,  3.44it/s]

	batch 1177	loss: 0.7840239405632019	mean: 0.6698144981938023


 76%|███████▌  | 1209/1589 [05:57<01:50,  3.44it/s]

	batch 1208	loss: 1.0093846321105957	mean: 0.6597334305124898


 78%|███████▊  | 1240/1589 [06:06<01:41,  3.44it/s]

	batch 1239	loss: 0.4905136525630951	mean: 0.6543002176669336


 80%|███████▉  | 1271/1589 [06:15<01:32,  3.44it/s]

	batch 1270	loss: 0.24555236101150513	mean: 0.6630955021227559


 82%|████████▏ | 1302/1589 [06:24<01:23,  3.45it/s]

	batch 1301	loss: 0.6724848747253418	mean: 0.6092787226361613


 84%|████████▍ | 1333/1589 [06:33<01:14,  3.44it/s]

	batch 1332	loss: 0.5097284317016602	mean: 0.6307515559657928


 86%|████████▌ | 1364/1589 [06:42<01:05,  3.44it/s]

	batch 1363	loss: 0.769942581653595	mean: 0.6829516185868171


 88%|████████▊ | 1395/1589 [06:51<00:56,  3.44it/s]

	batch 1394	loss: 0.4865851104259491	mean: 0.6082874661491763


 90%|████████▉ | 1426/1589 [07:00<00:47,  3.45it/s]

	batch 1425	loss: 0.5464605689048767	mean: 0.6012102956733396


 92%|█████████▏| 1457/1589 [07:09<00:38,  3.44it/s]

	batch 1456	loss: 0.4918573498725891	mean: 0.614593748123415


 94%|█████████▎| 1488/1589 [07:18<00:29,  3.43it/s]

	batch 1487	loss: 0.6434123516082764	mean: 0.642075894340392


 96%|█████████▌| 1519/1589 [07:27<00:20,  3.44it/s]

	batch 1518	loss: 0.7224025726318359	mean: 0.6112200444744479


 98%|█████████▊| 1550/1589 [07:36<00:11,  3.44it/s]

	batch 1549	loss: 0.6401070952415466	mean: 0.6080994932882248


 99%|█████████▉| 1581/1589 [07:45<00:02,  3.44it/s]

	batch 1580	loss: 0.5171568393707275	mean: 0.5518172368887932


100%|██████████| 1589/1589 [07:50<00:00,  3.38it/s]
  2%|▏         | 31/1589 [00:09<07:35,  3.42it/s]

	batch 30	loss: 0.32601702213287354	mean: 0.5339850596128366


  4%|▍         | 62/1589 [00:18<07:25,  3.43it/s]

	batch 61	loss: 0.4555245041847229	mean: 0.5339438713365986


  6%|▌         | 93/1589 [00:27<07:14,  3.45it/s]

	batch 92	loss: 0.5091833472251892	mean: 0.5970194224388369


  8%|▊         | 124/1589 [00:36<07:06,  3.44it/s]

	batch 123	loss: 0.7219702005386353	mean: 0.504913030372512


 10%|▉         | 155/1589 [00:45<06:56,  3.45it/s]

	batch 154	loss: 0.5782828330993652	mean: 0.5203181245634633


 12%|█▏        | 186/1589 [00:54<06:47,  3.44it/s]

	batch 185	loss: 0.2861728072166443	mean: 0.5337592488335025


 14%|█▎        | 217/1589 [01:03<06:38,  3.44it/s]

	batch 216	loss: 0.6086846590042114	mean: 0.5641077199289876


 16%|█▌        | 248/1589 [01:12<06:29,  3.44it/s]

	batch 247	loss: 0.3990538716316223	mean: 0.5328845833578417


 18%|█▊        | 279/1589 [01:21<06:20,  3.44it/s]

	batch 278	loss: 0.8436741828918457	mean: 0.5494557758492808


 20%|█▉        | 310/1589 [01:30<06:11,  3.45it/s]

	batch 309	loss: 0.4139828681945801	mean: 0.5131621048335107


 21%|██▏       | 341/1589 [01:39<06:02,  3.44it/s]

	batch 340	loss: 0.5421274900436401	mean: 0.48748728969404775


 23%|██▎       | 372/1589 [01:48<05:53,  3.44it/s]

	batch 371	loss: 0.6711204051971436	mean: 0.5251594060851682


 25%|██▌       | 403/1589 [01:57<05:44,  3.44it/s]

	batch 402	loss: 0.33340102434158325	mean: 0.48046646964165474


 27%|██▋       | 434/1589 [02:06<05:35,  3.44it/s]

	batch 433	loss: 0.8019870519638062	mean: 0.5249789559072063


 29%|██▉       | 465/1589 [02:15<05:27,  3.43it/s]

	batch 464	loss: 0.3940044939517975	mean: 0.48272118020442223


 31%|███       | 496/1589 [02:24<05:17,  3.44it/s]

	batch 495	loss: 0.3786020874977112	mean: 0.5374620768331713


 33%|███▎      | 527/1589 [02:33<05:08,  3.44it/s]

	batch 526	loss: 0.4361882209777832	mean: 0.5379798921846575


 35%|███▌      | 558/1589 [02:42<05:01,  3.42it/s]

	batch 557	loss: 0.610393226146698	mean: 0.47312195454874345


 37%|███▋      | 589/1589 [02:51<04:51,  3.43it/s]

	batch 588	loss: 0.437760591506958	mean: 0.5209087479499078


 39%|███▉      | 620/1589 [03:00<04:41,  3.44it/s]

	batch 619	loss: 0.5455291271209717	mean: 0.5081312118038055


 41%|████      | 651/1589 [03:09<04:33,  3.43it/s]

	batch 650	loss: 0.6212541460990906	mean: 0.5295386612415314


 43%|████▎     | 682/1589 [03:18<04:23,  3.44it/s]

	batch 681	loss: 0.5216072201728821	mean: 0.47881382920088306


 45%|████▍     | 713/1589 [03:27<04:14,  3.44it/s]

	batch 712	loss: 0.4994892179965973	mean: 0.4610320719019059


 47%|████▋     | 744/1589 [03:36<04:05,  3.44it/s]

	batch 743	loss: 0.43423739075660706	mean: 0.4486990470078684


 49%|████▉     | 775/1589 [03:45<04:07,  3.28it/s]

	batch 774	loss: 0.4414691627025604	mean: 0.4326236839255979


 51%|█████     | 806/1589 [03:57<03:57,  3.30it/s]

	batch 805	loss: 0.40958264470100403	mean: 0.43413193956498175


 53%|█████▎    | 837/1589 [04:06<03:39,  3.43it/s]

	batch 836	loss: 0.43182387948036194	mean: 0.4105671504812856


 55%|█████▍    | 868/1589 [04:15<03:30,  3.43it/s]

	batch 867	loss: 0.4427337050437927	mean: 0.43448848445569316


 57%|█████▋    | 899/1589 [04:24<03:20,  3.45it/s]

	batch 898	loss: 0.5161148905754089	mean: 0.43996573936554695


 59%|█████▊    | 930/1589 [04:33<03:11,  3.44it/s]

	batch 929	loss: 0.4637368321418762	mean: 0.48601101867614255


 60%|██████    | 961/1589 [04:42<03:02,  3.44it/s]

	batch 960	loss: 0.2927154302597046	mean: 0.4752792992418812


 62%|██████▏   | 992/1589 [04:51<02:53,  3.44it/s]

	batch 991	loss: 0.4024507403373718	mean: 0.41398445108244497


 64%|██████▍   | 1023/1589 [05:00<02:44,  3.44it/s]

	batch 1022	loss: 0.34760019183158875	mean: 0.4629283143628028


 66%|██████▋   | 1054/1589 [05:09<02:35,  3.44it/s]

	batch 1053	loss: 0.6313028931617737	mean: 0.44122645018562195


 68%|██████▊   | 1085/1589 [05:18<02:26,  3.44it/s]

	batch 1084	loss: 0.5726311206817627	mean: 0.4389336094740898


 70%|███████   | 1116/1589 [05:27<02:17,  3.44it/s]

	batch 1115	loss: 0.18817508220672607	mean: 0.43070606935408806


 72%|███████▏  | 1147/1589 [05:36<02:08,  3.44it/s]

	batch 1146	loss: 0.3478165566921234	mean: 0.5105220646627487


 74%|███████▍  | 1178/1589 [05:45<01:59,  3.44it/s]

	batch 1177	loss: 0.615344762802124	mean: 0.4484971430032484


 76%|███████▌  | 1209/1589 [05:54<01:50,  3.44it/s]

	batch 1208	loss: 0.7544234991073608	mean: 0.45052677152618287


 78%|███████▊  | 1240/1589 [06:03<01:41,  3.44it/s]

	batch 1239	loss: 0.25639984011650085	mean: 0.4486138123658396


 80%|███████▉  | 1271/1589 [06:12<01:32,  3.44it/s]

	batch 1270	loss: 0.1349284052848816	mean: 0.4568359414415975


 82%|████████▏ | 1302/1589 [06:21<01:23,  3.43it/s]

	batch 1301	loss: 0.48126092553138733	mean: 0.4168379525503805


 84%|████████▍ | 1333/1589 [06:30<01:14,  3.44it/s]

	batch 1332	loss: 0.4083367586135864	mean: 0.46092002478338057


 86%|████████▌ | 1364/1589 [06:39<01:05,  3.45it/s]

	batch 1363	loss: 0.6620029807090759	mean: 0.48081231934408986


 88%|████████▊ | 1395/1589 [06:48<00:56,  3.44it/s]

	batch 1394	loss: 0.36247360706329346	mean: 0.4265322026706511


 90%|████████▉ | 1426/1589 [06:57<00:47,  3.45it/s]

	batch 1425	loss: 0.3781534731388092	mean: 0.4138075375268536


 92%|█████████▏| 1457/1589 [07:06<00:38,  3.44it/s]

	batch 1456	loss: 0.33838504552841187	mean: 0.44193796957692794


 94%|█████████▎| 1488/1589 [07:15<00:29,  3.44it/s]

	batch 1487	loss: 0.4276389181613922	mean: 0.45757076672969327


 96%|█████████▌| 1519/1589 [07:24<00:20,  3.43it/s]

	batch 1518	loss: 0.5425999164581299	mean: 0.4567383186471078


 98%|█████████▊| 1550/1589 [07:33<00:11,  3.45it/s]

	batch 1549	loss: 0.47950616478919983	mean: 0.4328150653070019


 99%|█████████▉| 1581/1589 [07:42<00:02,  3.43it/s]

	batch 1580	loss: 0.30934351682662964	mean: 0.3976958768983041


100%|██████████| 1589/1589 [07:47<00:00,  3.40it/s]
  2%|▏         | 31/1589 [00:09<07:34,  3.43it/s]

	batch 30	loss: 0.1839081346988678	mean: 0.38575956034354675


  4%|▍         | 62/1589 [00:18<07:25,  3.43it/s]

	batch 61	loss: 0.4132869839668274	mean: 0.38815581990826514


  6%|▌         | 93/1589 [00:27<07:14,  3.44it/s]

	batch 92	loss: 0.34636345505714417	mean: 0.4527426924436323


  8%|▊         | 124/1589 [00:36<07:05,  3.44it/s]

	batch 123	loss: 0.5482177138328552	mean: 0.36123455171623536


 10%|▉         | 155/1589 [00:45<06:56,  3.44it/s]

	batch 154	loss: 0.42890429496765137	mean: 0.38951643147776205


 12%|█▏        | 186/1589 [00:54<06:47,  3.45it/s]

	batch 185	loss: 0.23902621865272522	mean: 0.39590703912319675


 14%|█▎        | 217/1589 [01:03<06:38,  3.44it/s]

	batch 216	loss: 0.5230270624160767	mean: 0.4290880060965015


 16%|█▌        | 248/1589 [01:12<06:29,  3.44it/s]

	batch 247	loss: 0.2603543996810913	mean: 0.3974076134543265


 18%|█▊        | 279/1589 [01:21<06:21,  3.44it/s]

	batch 278	loss: 0.7084031105041504	mean: 0.41698260653403496


 20%|█▉        | 310/1589 [01:30<06:11,  3.44it/s]

	batch 309	loss: 0.24374964833259583	mean: 0.3945306121341644


 21%|██▏       | 341/1589 [01:39<06:02,  3.44it/s]

	batch 340	loss: 0.4409730136394501	mean: 0.36609483582358204


 23%|██▎       | 372/1589 [01:48<05:53,  3.45it/s]

	batch 371	loss: 0.527509331703186	mean: 0.3857167887110864


 25%|██▌       | 403/1589 [01:57<05:46,  3.43it/s]

	batch 402	loss: 0.261179655790329	mean: 0.3548866457516147


 27%|██▋       | 434/1589 [02:06<05:34,  3.45it/s]

	batch 433	loss: 0.6512542963027954	mean: 0.38413309233803905


 29%|██▉       | 465/1589 [02:15<05:26,  3.44it/s]

	batch 464	loss: 0.24251273274421692	mean: 0.36117141405420916


 31%|███       | 496/1589 [02:24<05:16,  3.45it/s]

	batch 495	loss: 0.326738566160202	mean: 0.41869457114127373


 33%|███▎      | 527/1589 [02:33<05:08,  3.44it/s]

	batch 526	loss: 0.3120814263820648	mean: 0.4067671250912451


 35%|███▌      | 558/1589 [02:42<04:59,  3.44it/s]

	batch 557	loss: 0.5306137800216675	mean: 0.3732436712711088


 37%|███▋      | 589/1589 [02:51<04:50,  3.44it/s]

	batch 588	loss: 0.35177862644195557	mean: 0.4057744900065084


 39%|███▉      | 620/1589 [03:00<04:41,  3.44it/s]

	batch 619	loss: 0.41854268312454224	mean: 0.4020659899519336


 41%|████      | 651/1589 [03:09<04:32,  3.44it/s]

	batch 650	loss: 0.5540550351142883	mean: 0.41175798735310953


 43%|████▎     | 682/1589 [03:18<04:23,  3.45it/s]

	batch 681	loss: 0.42199161648750305	mean: 0.37662394873557553


 45%|████▍     | 713/1589 [03:27<04:14,  3.45it/s]

	batch 712	loss: 0.3893264830112457	mean: 0.3615239368331048


 47%|████▋     | 744/1589 [03:36<04:05,  3.44it/s]

	batch 743	loss: 0.39720404148101807	mean: 0.3367120335178991


 49%|████▉     | 775/1589 [03:45<03:56,  3.44it/s]

	batch 774	loss: 0.4161026179790497	mean: 0.3387000491061518


 51%|█████     | 806/1589 [03:56<03:56,  3.31it/s]

	batch 805	loss: 0.32676273584365845	mean: 0.33295984998826056


 53%|█████▎    | 837/1589 [04:05<03:38,  3.44it/s]

	batch 836	loss: 0.32053276896476746	mean: 0.3297981773653338


 55%|█████▍    | 868/1589 [04:14<03:29,  3.44it/s]

	batch 867	loss: 0.37067165970802307	mean: 0.3422891331776496


 57%|█████▋    | 899/1589 [04:23<03:20,  3.45it/s]

	batch 898	loss: 0.45251670479774475	mean: 0.3535865052573143


 59%|█████▊    | 930/1589 [04:32<03:11,  3.45it/s]

	batch 929	loss: 0.38699766993522644	mean: 0.39686151762162486


 60%|██████    | 961/1589 [04:41<03:02,  3.44it/s]

	batch 960	loss: 0.2170812487602234	mean: 0.3914541131065738


 62%|██████▏   | 992/1589 [04:50<02:53,  3.45it/s]

	batch 991	loss: 0.2623901069164276	mean: 0.32822566455410374


 64%|██████▍   | 1023/1589 [04:59<02:44,  3.44it/s]

	batch 1022	loss: 0.2797622084617615	mean: 0.37953066249047557


 66%|██████▋   | 1054/1589 [05:08<02:35,  3.45it/s]

	batch 1053	loss: 0.5941824913024902	mean: 0.3588371255224751


 68%|██████▊   | 1085/1589 [05:17<02:26,  3.45it/s]

	batch 1084	loss: 0.4720971882343292	mean: 0.3624364786571072


 70%|███████   | 1116/1589 [05:26<02:17,  3.44it/s]

	batch 1115	loss: 0.16912366449832916	mean: 0.34991886034127206


 72%|███████▏  | 1147/1589 [05:35<02:08,  3.44it/s]

	batch 1146	loss: 0.2589033842086792	mean: 0.4353662056307639


 74%|███████▍  | 1178/1589 [05:44<01:59,  3.45it/s]

	batch 1177	loss: 0.5493484735488892	mean: 0.3739223316792519


 76%|███████▌  | 1209/1589 [05:53<01:50,  3.44it/s]

	batch 1208	loss: 0.665447473526001	mean: 0.3810908892943013


 78%|███████▊  | 1240/1589 [06:02<01:41,  3.44it/s]

	batch 1239	loss: 0.18401087820529938	mean: 0.37599065515302843


 80%|███████▉  | 1271/1589 [06:11<01:32,  3.45it/s]

	batch 1270	loss: 0.0900358110666275	mean: 0.38391020269163195


 82%|████████▏ | 1302/1589 [06:20<01:23,  3.44it/s]

	batch 1301	loss: 0.4012112021446228	mean: 0.3501630403822468


 84%|████████▍ | 1333/1589 [06:29<01:14,  3.44it/s]

	batch 1332	loss: 0.36213165521621704	mean: 0.39571645303118613


 86%|████████▌ | 1364/1589 [06:38<01:05,  3.45it/s]

	batch 1363	loss: 0.6164274215698242	mean: 0.4013125809930986


 88%|████████▊ | 1395/1589 [06:47<00:56,  3.44it/s]

	batch 1394	loss: 0.3103308379650116	mean: 0.36137945517416925


 90%|████████▉ | 1426/1589 [06:56<00:47,  3.45it/s]

	batch 1425	loss: 0.29222017526626587	mean: 0.3496612815126296


 92%|█████████▏| 1457/1589 [07:05<00:38,  3.44it/s]

	batch 1456	loss: 0.2963719069957733	mean: 0.37894584430802253


 94%|█████████▎| 1488/1589 [07:14<00:29,  3.44it/s]

	batch 1487	loss: 0.38583746552467346	mean: 0.3930140257843079


 96%|█████████▌| 1519/1589 [07:23<00:20,  3.44it/s]

	batch 1518	loss: 0.47697699069976807	mean: 0.4019725289075605


 98%|█████████▊| 1550/1589 [07:32<00:11,  3.44it/s]

	batch 1549	loss: 0.4499339759349823	mean: 0.37141025595126614


 99%|█████████▉| 1581/1589 [07:41<00:02,  3.43it/s]

	batch 1580	loss: 0.25347501039505005	mean: 0.34683261331050624


100%|██████████| 1589/1589 [07:46<00:00,  3.41it/s]
  1%|          | 8/1589 [00:21<1:10:03,  2.66s/it]


## Validation

In [None]:
testset = tokenized_dataset_dict['test'].with_format('torch', device=device)

In [None]:
bleu_df = eval_bleu(model, tokenizer, testset)
bleu_df.to_csv(f"./results/{checkpoint[checkpoint.rfind('/'):]}_bleu.csv", index=False)
bleu_df.describe()

100%|██████████| 1589/1589 [58:08<00:00,  2.20s/it]


Unnamed: 0,BLEU
count,1589.0
mean,0.222803
std,0.130583
min,0.0
25%,0.1185
50%,0.192381
75%,0.301876
max,0.885217


In [None]:
rouge_df = eval_rogue(model, tokenizer, testset)
rouge_df.to_csv(f"./results/{checkpoint[checkpoint.rfind('/'):]}_rouge.csv", index=False)
rouge_df.describe()

NameError: ignored