<a href="https://colab.research.google.com/github/SOL1archive/ClauseSummary/blob/main/model-train/main-model-train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
try:
  import google.colab
  IN_COLAB = True
except:
  IN_COLAB = False

'Process in Colab' if IN_COLAB else 'Process in Local'

In [None]:
if IN_COLAB:
    !pip install transformers
    !pip install datasets
    !pip install evaluate
    !pip install rouge_score
    !pip install torchmetrics
    !pip install rouge
    !pip install --upgrade accelerate

In [None]:
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive/')

In [None]:
# 깃허브에서는 빼야됨
%cd drive/MyDrive/projects/ClauseSummary

In [None]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import datetime
import os
import gc
from pprint import pprint
from tqdm.notebook import tqdm

import numpy as np
import pandas as pd

import tensorboard
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau, CosineAnnealingLR, CyclicLR
import torchmetrics

from datasets import load_dataset, load_from_disk, concatenate_datasets, DatasetDict, Dataset
from transformers import get_linear_schedule_with_warmup
from transformers import DataCollatorForSeq2Seq
from transformers import BartConfig, T5Config, LongformerConfig
from transformers import AutoTokenizer, LongformerTokenizer, AutoModelForSeq2SeqLM, LongT5ForConditionalGeneration
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
import evaluate
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge import Rouge

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline

import seaborn as sns

In [None]:
class TokenizeMapWrapper:
    def __init__(self, tokenizer, feature, option=None):
        if option is None:
            option = {
                'max_new_tokens': 4096,
                'truncation': True,
                'padding': 'max_length',
            }

        self.option = option
        self.feature = feature
        self.tokenizer = tokenizer

    def __call__(self, row):
        return self.tokenizer(row[self.feature], **self.option)

    def __repr__(self):
        return f'{self.__class__.__name__}(tokenizer={self.tokenizer})'

class Seq2SeqTokenizeMapWrapper(TokenizeMapWrapper):
    def __init__(self, tokenizer, feature, target, option=None):
        super().__init__(tokenizer, feature, option)
        self.target = target

    def seq2seq_tokenize(self, row):
        form_embeddings = self.tokenizer(row[self.feature], **self.option)
        with self.tokenizer.as_target_tokenizer():
            correct_form_embeddings = self.tokenizer(row[self.target], **self.option)

        return {
            'input_ids': form_embeddings['input_ids'],
            'attention_mask': form_embeddings['attention_mask'],
            'labels': correct_form_embeddings['input_ids'],
        }

    def __call__(self, row):
        return self.seq2seq_tokenize(row)

## Setting

- 학습 환경에 맞게 조정하기 (특히 **경로 설정**)

In [None]:
MANUAL_TRAINING = True
MANUAL_VALIDATION = True
MID_CHECKPOINT_NUM = 2
MID_PROCESS_PRINT_NUM = 35

NUM_EPOCHS = 1
learning_rate = 2e-5
decay = 0.01

In [None]:
t5_large_summary_checkpoint = 'lcw99/t5-large-korean-text-summary'
t5_base_summary_checkpoint = 'eenzeenee/t5-base-korean-summarization'
kobart_summary_checkpoint = 'gogamza/kobart-summarization'
kolongformer = "psyche/kolongformer-4096"
longt5_checkpoint = 'KETI-AIR-Downstream/long-ke-t5-base-summarization'
checkpoint = longt5_checkpoint
print(f'Using Checkpoint: {checkpoint}')

In [None]:
original_dataset_path = './data/dataset-term-summary.json'
tokenized_dataset_path = f'./data/{checkpoint.replace("/", "-")}-tokenized-dataset'

In [None]:
SAVE_STR = datetime.datetime.now().strftime('%y%m%d-%H:%M')
model_save_path = f"./model/{SAVE_STR}"

## Load Tokenizer & Model Checkpoint

In [None]:
if 'bart' in checkpoint.lower():
    config = BartConfig.from_pretrained(checkpoint)
    #config['vocab'] = 30000
elif "t5" in checkpoint.lower():
    config = T5Config.from_pretrained(checkpoint)
elif "longformer" in checkpoint.lower():
    config = LongformerConfig.from_pretrained(checkpoint)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(checkpoint,
                                          max_new_tokens=4096,
                                          truncation=False,
                                          padding='max_length',
                                          #vocab=config.vocab_size
                                          )
#tokenizer = LongformerTokenizer(vocab_file, merges_file, errors='replace', bos_token='<s>', eos_token='</s>', sep_token='</s>', cls_token='<s>', unk_token='<unk>', pad_token='<pad>', mask_token='<mask>', add_prefix_space=False, **kwargs)

model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

In [None]:
if len(tokenizer) != model.config.vocab_size:
    raise RuntimeError(f'Tokenizer vocab size and model vocab size do not match(Tokenizer:{len(tokenizer)} Model: {model.config.vocab_size}). Which would lead to further error in training.')

## Load Dataset

In [None]:
if not os.path.exists(tokenized_dataset_path):
    dataset = Dataset.from_pandas(pd.read_json(original_dataset_path, encoding='utf-8')[['text', 'summary']])
    tokenizer_wrapper = Seq2SeqTokenizeMapWrapper(tokenizer, 'text', 'summary')

    tokenized_dataset = (dataset
                         .map(tokenizer_wrapper,
                              batched=True,
                              batch_size=128,
                              num_proc=10
                              )
                         .remove_columns(['text', 'summary'])
                         )

    tokenized_dataset_dict = tokenized_dataset.train_test_split(test_size=0.1, shuffle=True)
    tokenized_dataset_dict.save_to_disk(tokenized_dataset_path)
else:
    tokenized_dataset_dict = load_from_disk(tokenized_dataset_path)

## Training

In [None]:
print(checkpoint)

In [None]:
def generate_seq(model, tokenizer, input):
    # 생성 전략
    generated_ids = model.generate(**input, max_new_tokens=300, top_p=0.92, top_k=0, early_stopping=True)
    generated_text = tokenizer.decode(generated_ids.squeeze(0), skip_special_tokens=True)

    return generated_text

def generate_input_target(model, tokenizer, input, label):
    input_text = tokenizer.decode(input['input_ids'].squeeze(0), skip_special_tokens=True)
    generated_text = generate_seq(model, tokenizer, input)
    target_text = tokenizer.decode(label.squeeze(0), skip_special_tokens=True)

    return {
        'input_text': input_text,
        'generated_text': generated_text,
        'target_text': target_text
    }

def generate_from_data(model, tokenizer, data):
    label = data['labels']
    input_data = dict()
    input_data['input_ids'] = data['input_ids']
    input_data['attention_mask'] = data['attention_mask']

    return generate_input_target(model, tokenizer, input_data, label)

def eval_bleu_rouge(model, tokenizer, tokenized_testset):
    rouge = Rouge()
    score_dict = dict()
    score_dict['BLEU'] = []
    score_dict['ROUGE-Precision'] = []
    score_dict['ROUGE-Recall'] = []
    score_dict['ROUGE-F1'] = []
    eval_tqdm_bar = tqdm(tokenized_testset, leave=False, desc='Evaluating')
    for example in eval_tqdm_bar:
        data = dict()
        for key in example:
            data[key] = example[key].unsqueeze(0)
        output = generate_from_data(model, tokenizer, data)
        try:
            bleu_score = sentence_bleu([output['target_text']],
                                       output['generated_text'],
                                       smoothing_function=SmoothingFunction().method1
            )
            rouge_score = rouge.get_scores(output['generated_text'],
                                           output['target_text']
            )
        except ValueError:
            continue

        score_dict['BLEU'].append(bleu_score)
        score_dict['ROUGE-Precision'].append(rouge_score[0]['rouge-2']['p'])
        score_dict['ROUGE-Recall'].append(rouge_score[0]['rouge-2']['r'])
        score_dict['ROUGE-F1'].append(rouge_score[0]['rouge-2']['f'])

    return pd.DataFrame(score_dict)

In [None]:
# Utils
def dict_to_str(d):
    return '\t'.join([f'{k}: {v}' for k, v in d.items()])

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    evaluation_strategy = "epoch",
    learning_rate=learning_rate,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    num_train_epochs=NUM_EPOCHS,
    weight_decay=decay,
    report_to="tensorboard",
    push_to_hub=False,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args
)

In [None]:
## finding the best parameters
gc.collect()
torch.cuda.empty_cache()

total_loss = []
epoch_loss = []
batch_loss = []

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
trainset = tokenized_dataset_dict['train'].with_format('torch', device=device)
testset = tokenized_dataset_dict['test'].with_format('torch', device=device)
dataloader = DataLoader(trainset, batch_size=12, shuffle=False) # TODO: Batch size 조절

# TODO: Minor Hyperparameter Tuning
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=decay)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=NUM_EPOCHS * len(dataloader))
training_stats = []

In [None]:
for epoch in range(NUM_EPOCHS):
    total_steps = len(dataloader)
    save_divisor = total_steps // MID_CHECKPOINT_NUM
    print_divisor = total_steps // MID_PROCESS_PRINT_NUM
    with tqdm(dataloader, leave=False, desc='Batch', position=0, postfix={'Epoch': 1, 'Batch': 1, 'loss': 0, 'loss_mean': 0, 'BLEU': 0, 'ROUGE': 0}) as tqdm_bar:
        for i, batch in enumerate(tqdm_bar):
            tqdm_bar.set_description(f'Batch: {i + 1}')
            X = {
                    'input_ids': batch['input_ids'],
                    'attention_mask': batch['attention_mask'],
                }
            y = batch['labels']

            outputs = model(**X, labels=y)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()
            batch_loss.append(loss.item())

            if i % print_divisor == print_divisor - 1:
                epoch_loss += batch_loss
                batch_loss_series = pd.Series(batch_loss, dtype=np.float64)
                metric = eval_bleu_rouge(model, tokenizer, testset)
                training_stats.append(
                    {
                        'Epoch': epoch + 1,
                        'Batch': i + 1,
                        'loss': loss.item(),
                        'loss_mean': batch_loss_series.mean(),
                        'BLEU': metric['BLEU'].mean(),
                        'ROUGE': metric['ROUGE-F1'].mean()
                    }
                )
                tqdm_bar.set_postfix(training_stats[-1])
                batch_loss = []

            if i % save_divisor == save_divisor - 1:
                trainer.create_model_card(
                    language='Korean',
                    finetuned_from=checkpoint
                )
                trainer.save_model(model_save_path + checkpoint + f'-epoch-{epoch + 1}' + f'-batch-{i + 1}')

            total_loss += epoch_loss
            batch_loss_series = pd.Series(epoch_loss, dtype=np.float64)
            epoch_loss = []

In [None]:
metric = eval_bleu_rouge(model, tokenizer, testset)
metric.to_csv('./metric.csv')

training_stats_df = pd.DataFrame(training_stats)
training_stats_df.to_csv('./training_stats.csv')

trainer.create_model_card(
    language='Korean',
    finetuned_from=checkpoint
)
trainer.save_model(model_save_path + '-final')

## Analysis

In [None]:
training_stats_df.describe()

In [None]:
total_loss = pd.Series(total_loss)
total_loss.plot.line()

In [None]:
training_stats_df['loss_mean'].plot.line()

In [None]:
training_stats_df[['BLEU', 'ROUGE']].plot.line()