In [None]:
try:
  import google.colab
  IN_COLAB = True
except:
  IN_COLAB = False

'Process in Colab' if IN_COLAB else 'Process in Local'

In [None]:
if IN_COLAB:
    !pip install transformers
    !pip install datasets
    !pip install --upgrade accelerate
    !pip install torchmetrics

In [None]:
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive/')

In [None]:
# 깃허브에서는 빼야됨
%cd drive/MyDrive/projects/ClauseSummary

In [None]:
from typing import Dict
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import datetime
import re
import os
import gc
from pprint import pprint
from tqdm.notebook import tqdm

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
import torch.optim as optim
from torch.optim import Adam, AdamW, SGD
from torch.optim.lr_scheduler import CosineAnnealingLR, CyclicLR
from torch.nn import MSELoss
from torch.utils.data import DataLoader
from torchmetrics import Accuracy
from torchmetrics.regression import R2Score

from datasets import load_from_disk, load_dataset, Dataset, DatasetDict, concatenate_datasets
from transformers import AutoTokenizer, LongformerTokenizer, AutoModel, AutoModelForMaskedLM
from transformers import get_linear_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup
from transformers import Trainer, TrainingArguments

In [None]:
def add_newline_before_number(text:str) -> str: #숫자. 형태로 되어있는것에 개행문자를 추가.
    text = re.sub(r'(\d+)\.',r'\n\1.', str(text))
    return text

def change_it_to_a_comma(text:str): # (1) (2) 형태를 ,으로
    items = re.split(r'\(\d+\)', text)
    if len(items) > 1:
        items[1] = items[0]+items[1]
        del items[0]
    return ','.join(items)

def remove_whitespace_after_str(text:str):
    text = re.sub(r"\b갑\s", r'갑', text)
    text = re.sub(r"\b을\s", r'을', text)
    text = re.sub(r"\b병\s", r'병', text)
    text = re.sub(r"\b정\s", r'정', text)
    return text

def change_number_point(text:str): # 1. 2. 등을 제 1조 2 항 등으로 바꿔줌
    items = re.split(r'\d+\.', text)
    if len(items) > 1:
        items[1] = items[0]+items[1]
        del items[0]
    return ''.join(items)

def summary_preprocessing_func(text: str):
    text = add_newline_before_number(text)
    text = change_it_to_a_comma(text)
    text = remove_whitespace_after_str(text)
    text = change_number_point(text)
    return text

def text_preprocessing_func(text):
    return re.sub(r'\n[\n ]+', '\n', text)

In [None]:
def preprocessing(row: Dict[str, str]):
    text = row['text']
    summary = row['summary']
    text = text_preprocessing_func(text)
    summary = summary_preprocessing_func(summary)

    return {'text': text,
            'summary': summary
            }

def df_preprocessing(df: pd.DataFrame):
    text_df = df[['text', 'summary']]
    text_df = text_df.apply(preprocessing, axis=1, result_type='expand')

    df[['text', 'summary']] = text_df[['text', 'summary']]

    return df

In [None]:
class TokenizeMapWrapper:
    def __init__(self, tokenizer, feature, option=None):
        if option is None:
            option = {
                'max_length': 4096,
                'truncation': True,
                'padding': 'max_length',
            }

        self.feature = feature
        self.tokenizer = tokenizer

    def __call__(self, row):
        return self.tokenizer(row[self.feature], **self.option)

    def __repr__(self):
        return f'{self.__class__.__name__}(tokenizer={self.tokenizer})'

class RewardTokenizeMapWrapper(TokenizeMapWrapper):
    def __init__(self, tokenizer, text_feature, summary_feature, max_token=4096, option=None):
        if option is None:
            option = {
                'max_length': max_token,
                'truncation': True,
            }

        self.max_token = option['max_length']
        self.option = option
        self.text_feature = text_feature
        self.summary_feature = summary_feature
        self.tokenizer = tokenizer

    def __call__(self, row):
        text = row[self.text_feature]
        summary = row[self.summary_feature]

        tokenized_text = self.tokenizer(text, **self.option)
        tokenized_summary = self.tokenizer(summary, **self.option)
        tokenized_total_text = dict()
        for key in tokenized_text:
            if len(tokenized_text['input_ids']) + len(tokenized_summary['input_ids']) < self.max_token:
                tokenized_total_text[key] = tokenized_text[key] + tokenized_summary[key]
            else:
                tokenized_total_text[key] = (tokenized_text[key][:- len(tokenized_summary['input_ids'])]
                                             + tokenized_summary[key]
                )
            tokenized_total_text[key] = (tokenized_total_text[key]
                                         + ([1] * (self.max_token - len(tokenized_total_text[key])))
            )
        return tokenized_total_text

In [None]:
class ModelForRewardGeneration(nn.Module):
    def __init__(self, encoder_path, hidden_size=256):
        super(ModelForRewardGeneration, self).__init__()
        self.encoder = AutoModel.from_pretrained(encoder_path)
        self.hidden_size = hidden_size
        self.head1 = nn.Sequential(
            nn.Linear(768, 1024, bias=False),
            nn.BatchNorm1d(1024),
            nn.GELU(),
            nn.Dropout1d(0.2),
            nn.Linear(1024, 1024, bias=False),
            nn.BatchNorm1d(1024),
            nn.GELU(),
            nn.Dropout1d(0.2),
            nn.Linear(1024, 512, bias=False),
            nn.BatchNorm1d(512),
            nn.GELU(),
            nn.Dropout1d(0.1),
            nn.Linear(512, hidden_size, bias=False),
            nn.BatchNorm1d(hidden_size),
            nn.GELU(),
        )
        self.head2 = nn.Sequential(
            nn.Linear(hidden_size, 1),
        )

    def forward(self, input_ids=None, attention_mask=None):
        x = self.encoder(input_ids, attention_mask).pooler_output
        x = self.head1(x)
        x = self.head2(x)
        return x

    def representation_forward(self, input_ids=None, attention_mask=None):
        x = self.encoder(input_ids, attention_mask).pooler_output
        x = self.head1(x)
        return x
    
    def load(self, model_path):
        self.encoder = AutoModel.from_pretrained(model_path + '-encoder')
        self.head1.load_state_dict(torch.load(model_path + '-head1.pt'))
        self.head2.load_state_dict(torch.load(model_path + '-head2.pt'))

def reference_reward_loss(reward, pred):
    return - torch.log10(1 + torch.exp(-reward * pred))

class AMSoftmaxLoss(nn.Module):
    def __init__(self, in_features, n_classes, scale=30, margin=0.4):
        super(AMSoftmaxLoss, self).__init__()
        self.linear = nn.Linear(in_features, n_classes, bias=False)
        self.scale = scale
        self.margin = margin
        self.cross_entropy = nn.CrossEntropyLoss()

    def forward(self, output, target):
        logits = self.logits_forward(output)
        loss = self.cross_entropy(logits, target)
        return loss

    def logits_forward(self, output):
        x_vector = F.normalize(output, p=2, dim=-1)
        self.linear.weight.data = F.normalize(self.linear.weight.data, p=2, dim=-1)
        logits = self.linear(x_vector)
        scaled_logits = (logits - self.margin)*self.scale

        logits = scaled_logits - self._am_logsumexp(logits)
        return logits

    def _am_logsumexp(self, logits):
        max_x = torch.max(logits, dim=-1)[0].unsqueeze(-1)
        term1 = (self.scale * (logits - (max_x + self.margin))).exp()
        term2 = (self.scale * (logits - max_x)).exp().sum(-1).unsqueeze(-1) \
                - (self.scale * (logits - max_x)).exp()
        return self.scale * max_x + (term2 + term1).log()

## Setting

In [None]:
MANUAL_TRAINING = True
MANUAL_VALIDATION = True
MID_CHECKPOINT_NUM = 2
MID_PROCESS_PRINT_NUM = 100

NUM_EPOCHS = 1
MAX_TOKEN = 4096
learning_rate = 2e-5
decay = 0.01

In [None]:
kolongformer_checkpoint = "psyche/kolongformer-4096"
checkpoint = kolongformer_checkpoint
print(f'Using Checkpoint: {checkpoint}')

In [None]:
original_dataset_path = './data/dataset-term-reward.json'
tokenized_dataset_path = f'./data/{checkpoint.replace("/", "-")}-tokenized-dataset'

In [None]:
SAVE_STR = datetime.datetime.now().strftime('%y%m%d-%H:%M')
model_save_path = f"./model/{SAVE_STR}"

## Load Tokenizer & Model Checkpoint

In [None]:
tokenizer = LongformerTokenizer.from_pretrained(checkpoint)
#tokenizer = LongformerTokenizer(vocab_file, merges_file, errors='replace', bos_token='<s>', eos_token='</s>', sep_token='</s>', cls_token='<s>', unk_token='<unk>', pad_token='<pad>', mask_token='<mask>', add_prefix_space=False, **kwargs)

model = ModelForRewardGeneration(checkpoint)

In [None]:
if len(tokenizer) != model.encoder.config.vocab_size:
    raise RuntimeError(f'Tokenizer vocab size and model vocab size do not match(Tokenizer:{len(tokenizer)} Model: {model.config.vocab_size}). Which would lead to further error in training.')

## Loading Dataset

In [None]:
df = pd.read_json(original_dataset_path)
df['reward_class'] = df['reward'].apply(lambda x: 0 if x <= 3 else 1 if x <= 7 else 2)
df['reward'] = df['reward'] / 10

if not os.path.exists(tokenized_dataset_path):
    text_df = df_preprocessing(df)

    dataset = Dataset.from_pandas(df[['text', 'summary', 'reward', 'reward_class']])
    tokenizer_wrapper = RewardTokenizeMapWrapper(tokenizer, 'text', 'summary')

    tokenized_dataset = (dataset
                         .map(tokenizer_wrapper)
                         .remove_columns(['text', 'summary'])
                         )

    tokenized_dataset.save_to_disk(tokenized_dataset_path)
else:
    tokenized_dataset = load_from_disk(tokenized_dataset_path)

## Training

In [None]:
print(checkpoint)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
dataset_dict = tokenized_dataset.train_test_split(test_size=0.1, seed=42)
trainset = dataset_dict['train'].with_format('torch', device=device)
testset = dataset_dict['test'].with_format('torch', device=device)
dataloader = DataLoader(trainset, batch_size=5, shuffle=True) # TODO: Batch size 조절

### Training (1)
#### Representation Learing
Loss: AM-Softmax

In [None]:
## finding the best parameters
gc.collect()
torch.cuda.empty_cache()

representation_total_loss = []
representation_epoch_loss = []
representation_batch_loss = []

model.train()

# TODO: Minor Hyperparameter Tuning
am_softmax_criterion = AMSoftmaxLoss(256, 3, margin=0.4).to(device)
am_softmax_optimizer = AdamW(list(model.parameters()) + list(am_softmax_criterion.parameters()), 
                             lr=learning_rate, 
                             weight_decay=decay
)
am_softmax_scheduler = get_linear_schedule_with_warmup(am_softmax_optimizer, 
                                                       num_warmup_steps=0, 
                                                       num_training_steps=NUM_EPOCHS * len(dataloader)
)
representation_training_stats = []
accuracy = Accuracy(task="multiclass", num_classes=3).to(device)

In [None]:
for epoch in range(NUM_EPOCHS):
    total_steps = len(dataloader)
    save_divisor = total_steps // MID_CHECKPOINT_NUM
    print_divisor = total_steps // MID_PROCESS_PRINT_NUM
    with tqdm(dataloader, leave=False, desc='Batch', position=0, postfix={'Epoch': 1, 'Batch': 1, 'train_loss': 0, 'eval accuracy': 0}) as tqdm_bar:
        for i, batch in enumerate(tqdm_bar):
            tqdm_bar.set_description(f'Batch: {i + 1}')
            X = {
                    'input_ids': batch['input_ids'],
                    'attention_mask': batch['attention_mask'],
                }
            labels = batch['reward_class']

            representation = model.representation_forward(**X)
            loss = am_softmax_criterion(representation, labels)
            loss.backward()
            am_softmax_optimizer.step()
            am_softmax_optimizer.zero_grad()
            am_softmax_scheduler.step()
            representation_batch_loss.append(loss.item())

            if i % print_divisor == print_divisor - 1:
                representation_epoch_loss += representation_batch_loss
                representation_batch_loss_series = pd.Series(representation_batch_loss, dtype=np.float64)
                
                model.eval()
                am_softmax_criterion.eval()
                test_samples = testset.shuffle().select(range(50))
                with torch.no_grad():
                    X = {
                        'input_ids': test_samples['input_ids'],
                        'attention_mask': test_samples['attention_mask'],
                    }
                    y = test_samples['reward_class']
                    representation = model.representation_forward(**X)
                    logits = am_softmax_criterion.logits_forward(representation)
                    preds = torch.argmax(logits, dim=1)
                model.train()
                am_softmax_criterion.train()

                representation_training_stats.append(
                    {
                        'Epoch': epoch + 1,
                        'Batch': i + 1,
                        'train_loss': representation_batch_loss_series.mean(),
                        'eval accuracy': accuracy(preds, y).item(),
                    }
                )
                tqdm_bar.set_postfix(representation_training_stats[-1])
                batch_loss = []

            representation_total_loss += representation_epoch_loss
            representation_batch_loss_series = pd.Series(representation_epoch_loss, dtype=np.float64)
            epoch_loss = []

representation_total_loss = pd.Series(representation_total_loss)
representation_total_loss.plot.line()

### Training (2)
#### Reward Generation
Loss: MSE

In [None]:
## finding the best parameters
gc.collect()
torch.cuda.empty_cache()

total_loss = []
epoch_loss = []
batch_loss = []

# TODO: Minor Hyperparameter Tuning
criterion = nn.MSELoss()
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=decay)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=30, num_training_steps=NUM_EPOCHS * len(dataloader))
r2_score = R2Score().to(device)
training_stats = []

In [None]:
for epoch in range(NUM_EPOCHS):
    total_steps = len(dataloader)
    save_divisor = total_steps // MID_CHECKPOINT_NUM
    print_divisor = total_steps // MID_PROCESS_PRINT_NUM
    with tqdm(dataloader, leave=False, desc='Batch', position=0, postfix={'epoch': 1, 'batch': 1, 'tain_loss': 0, 'eval R2': 0}) as tqdm_bar:
        for i, batch in enumerate(tqdm_bar):
            tqdm_bar.set_description(f'Batch: {i + 1}')
            X = {
                    'input_ids': batch['input_ids'],
                    'attention_mask': batch['attention_mask'],
                }
            y = batch['reward'].unsqueeze(1)

            output = model(**X)
            loss = criterion(output, y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()
            batch_loss.append(loss.item())

            if i % print_divisor == print_divisor - 1:
                epoch_loss += batch_loss
                batch_loss_series = pd.Series(batch_loss, dtype=np.float64)

                model.eval()
                test_samples = testset.shuffle().select(range(30))
                with torch.no_grad():
                    X = {
                        'input_ids': test_samples['input_ids'],
                        'attention_mask': test_samples['attention_mask'],
                    }
                    y = test_samples['reward'].unsqueeze(1)
                    output = model(**X)
                model.train()

                training_stats.append(
                    {
                        'epoch': epoch + 1,
                        'batch': i + 1,
                        'train_loss': batch_loss_series.mean(),
                        'eval R2': r2_score(output, y).item(),
                    }
                )
                tqdm_bar.set_postfix(training_stats[-1])
                batch_loss = []

            total_loss += epoch_loss
            batch_loss_series = pd.Series(epoch_loss, dtype=np.float64)
            epoch_loss = []

total_loss = pd.Series(total_loss)
total_loss.plot.line()

In [None]:
representation_training_stats_df = pd.DataFrame(representation_training_stats)
representation_training_stats_df.to_csv('./representation_training_stats.csv', index=False)
training_stats_df = pd.DataFrame(training_stats)
training_stats_df.to_csv('./training_stats.csv', index=False)

model.encoder.save_pretrained(f'{model_save_path}-encoder')
torch.save(model.head1.state_dict(), f'{model_save_path}-head1.pt')
torch.save(model.head2.state_dict(), f'{model_save_path}-head2.pt')

## Analysis

In [None]:
representation_training_stats_df.describe()

In [None]:
training_stats_df.describe()

In [None]:
representation_total_loss = pd.Series(representation_total_loss)
representation_total_loss.plot.line()

In [None]:
total_loss = pd.Series(total_loss)
total_loss.plot.line()

In [None]:
representation_total_loss.to_csv('./representation_total_loss.csv', index=False)

In [None]:
total_loss.to_csv('./total_loss.csv', index=False)