In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## Install Packages

In [None]:
# installation for tpu use
!pip uninstall torchtext torchsummary -y 
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py 
!python pytorch-xla-env-setup.py --version 1.8 --apt-packages libomp5 libopenblas-dev

# installation for japanese tokenizer use
!pip install fugashi==1.1.0 ipadic==1.0.0 

!pip install pytorch-lightning==1.4.5 torchtext==0.9.0 transformers==4.10.0

In [None]:
cd /content/drive/MyDrive/grammer_correction_pytorch_lightning

/content/drive/MyDrive/grammer_correction_pytorch_lightning


## Import Libraries

In [None]:
import gc
import random
import warnings
import unicodedata
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
import torch_xla.core.xla_model as xm
import pytorch_lightning as pl
from pytorch_lightning.utilities.seed import seed_everything
from transformers import BertJapaneseTokenizer, BertForMaskedLM, get_linear_schedule_with_warmup
from tqdm.notebook import tqdm

warnings.simplefilter('ignore')



In [None]:
class Cfg:
    seed = 42
    lr = 1e-5
    weight_decay = 1e-2
    epochs = 20
    max_length = 64
    train_batch_size = 8*8   
    val_batch_size = 64*8 
    model_name = 'cl-tohoku/bert-base-japanese-char-whole-word-masking'
    task = 'kanji-conversion' # ['kanji-conversion', 'substitution', 'deletion', 'insertion'] 

In [None]:
seed_everything(Cfg.seed, workers=True)

INFO:pytorch_lightning.utilities.seed:Global seed set to 42


42

In [None]:
train = pd.read_csv(f'data/japanese-wikipedia-grammer-correct-dataset/train_{Cfg.task}.csv')
test = pd.read_csv(f'data/japanese-wikipedia-grammer-correct-dataset/test_{Cfg.task}.csv')

In [None]:
train.head(1)

Unnamed: 0,pre_text,post_text,category,match
0,個人の言語能力は、全体的な知的能力とは乖離することがあり(例として読字障害、ウィリアムズ症候...,個人の言語能力は、全体的な知的能力とは乖離することがあり(例として読字障害、ウィリアムズ症候...,kanji-conversion_a,False


## Create Dataset

In [None]:
# 'kanji-conversion' と 'inserstion' は '_a' と '_b' が混在している為、層別にvalidationを切る
if Cfg.task == 'kanji-conversion' or Cfg.task == 'inserstion':
    train, val = train_test_split(
        train, test_size=0.2, random_state=Cfg.seed, stratify=train['category']
        )
else:
    train, val = train_test_split(train, test_size=0.2, random_state=Cfg.seed)

In [None]:
# Data Aug function
# 'deletion'タスクで使用

def random_char_deletion(text, random_state=42):
    random.seed(random_state)
    chars = list(text)
    while True:
        del_char = random.choice(chars)
        if del_char not in ['、', '。']:
            break
    chars.remove(del_char)
    new_text = ''.join(chars)
    return new_text

def str_comparison(input, target):
    if input == target:
        return True
    else:
        return False

def apply_randomCharDeletion(df):
    new_df = df.copy()
    new_df['pre_text'] = df['post_text'].map(random_char_deletion)
    # 生成されたテキストと元データに重複がないか確認し、あればdropする
    new_df['match'] = [str_comparison(new_sample, sample) for new_sample, sample in zip(df['pre_text'], new_df['pre_text'])]
    new_df.query('match == False', inplace=True)
    return new_df


# aug_data = apply_randomCharDeletion(train)
# train = pd.concat([train,aug_data])

In [None]:
# reference: https://github.com/stockmarkteam/bert-book/blob/master/Chapter9.ipynb
class GrammerTokenizer(BertJapaneseTokenizer):
    def encode_plus_tagged(self, wrong_text, correct_text, max_length=128):
        encoding = self(
            wrong_text, 
            add_special_tokens=True,
            max_length=max_length, 
            padding='max_length', 
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        encoding_correct = self(
            correct_text,
            add_special_tokens=True,
            max_length=max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        ) 
        encoding['labels'] = encoding_correct['input_ids'] 

        return encoding

    def encode_plus_untagged(self, text, max_length=None, return_tensors=None):
        tokens = [] 
        tokens_original = [] 
        words = self.word_tokenizer.tokenize(text)
        for word in words:
            tokens_word = self.subword_tokenizer.tokenize(word) 
            tokens.extend(tokens_word)
            if tokens_word[0] == '[UNK]': 
                tokens_original.append(word)
            else:
                tokens_original.extend([
                    token.replace('##','') for token in tokens_word
                ])

        position = 0
        spans = [] 
        for token in tokens_original:
            l = len(token)
            while 1:
                if token != text[position:position+l]:
                    position += 1
                else:
                    spans.append([position, position+l])
                    position += l
                    break

        input_ids = self.convert_tokens_to_ids(tokens) 
        encoding = self.prepare_for_model(
            input_ids, 
            max_length=max_length, 
            padding='max_length' if max_length else False, 
            truncation=True if max_length else False
        )
        sequence_length = len(encoding['input_ids'])
        spans = [[-1, -1]] + spans[:sequence_length-2] 
        spans = spans + [[-1, -1]] * ( sequence_length - len(spans) ) 

        if return_tensors == 'pt':
            encoding = { k: torch.tensor([v]) for k, v in encoding.items() }

        return encoding, spans

    def convert_bert_output_to_text(self, text, labels, spans):
        assert len(spans) == len(labels)

        labels = [label for label, span in zip(labels, spans) if span[0]!=-1]
        spans = [span for span in spans if span[0]!=-1]

        predicted_text = ''
        position = 0
        for label, span in zip(labels, spans):
            start, end = span
            if position != start: 
                predicted_text += text[position:start]
            predicted_token = self.convert_ids_to_tokens(label)
            predicted_token = predicted_token.replace('##', '')
            predicted_token = unicodedata.normalize(
                'NFKC', predicted_token
            ) 
            predicted_text += predicted_token
            position = end
        
        return predicted_text

In [None]:
class GrammerDataset(Dataset):
    def __init__(self, data, tokenizer, config):
        super().__init__()
        self.data = data
        self.tokenizer = tokenizer
        self.config = config

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        data_row = self.data.iloc[index]
        pre_text = data_row['pre_text']
        post_text = data_row['post_text']

        encoded = self.tokenizer.encode_plus_tagged(
            pre_text, 
            post_text, 
            max_length=self.config.max_length
        )
        encoded = {k: torch.tensor(v) for k, v in encoded.items()}
        return {
            'input_ids': encoded['input_ids'].flatten(),
            'token_type_ids': encoded['token_type_ids'].flatten(),
            'attention_mask': encoded['attention_mask'].flatten(),
            'labels': encoded['labels'].flatten()
        } 


class GrammerDataModule(pl.LightningDataModule):
    def __init__(self, train_data, val_data, test_data, config):
        super().__init__()
        self.train_data = train_data
        self.val_data = val_data
        self.test_data = test_data
        self.config = config
        self.tokenizer = GrammerTokenizer.from_pretrained(self.config.model_name)

    def create_dataset(self, mode):
        return (
            GrammerDataset(self.train_data, self.tokenizer, self.config) 
            if mode == 'train'
            else GrammerDataset(self.val_data, self.tokenizer, self.config)
            if mode == 'val' 
            else GrammerDataset(self.test_data, self.tokenizer, self.config)
            )
    
    def train_dataloader(self):
        train_ds = self.create_dataset(mode='train')
        train_sampler = DistributedSampler(train_ds, 
                                           num_replicas=xm.xrt_world_size(),
                                           rank=xm.get_ordinal(),
                                           seed=self.config.seed,
                                           shuffle=True)
        train_loader = DataLoader(train_ds, 
                                  sampler=train_sampler, 
                                  batch_size=self.config.train_batch_size,
                                  num_workers=0, 
                                  pin_memory=True,
                                  drop_last=True)
        return train_loader

    def val_dataloader(self):
        val_ds = self.create_dataset(mode='val')
        val_sampler = DistributedSampler(val_ds, 
                                         num_replicas=xm.xrt_world_size(),
                                         rank=xm.get_ordinal(), 
                                         shuffle=False)
        val_loader = DataLoader(val_ds, 
                                sampler=val_sampler, 
                                batch_size=self.config.val_batch_size,
                                num_workers=0, 
                                pin_memory=True,
                                drop_last=False)
        return val_loader

    def test_dataloader(self):
        test_ds = self.create_dataset(mode='test')
        test_sampler = DistributedSampler(test_ds, 
                                          num_replicas=xm.xrt_world_size(),
                                          rank=xm.get_ordinal(), 
                                          shuffle=False)
        test_loader = DataLoader(test_ds, 
                                 sampler=test_sampler, 
                                 batch_size=self.config.val_batch_size,
                                 num_workers=0, 
                                 pin_memory=True,
                                 drop_last=False)
        return test_loader

## Create Model

In [None]:
def get_num_training_steps(config):
    return ((len(train)) // (config.train_batch_size)) * float(config.epochs)

In [None]:
class GrammerModel(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.bert_mlm = BertForMaskedLM.from_pretrained(self.config.model_name)

    def training_step(self, batch, batch_idx):
        output = self.bert_mlm(**batch)
        loss = output.loss
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        output = self.bert_mlm(**batch)
        val_loss = output.loss
        self.log('val_loss', val_loss)

    def test_step(self, batch, batch_idx):
        return self.validation_step(batch, batch_idx)

    def configure_optimizers(self):
        # BERTの2つのパラメータグループを除いて、WeightDecayを適用
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {
                'params': [p for n, p in self.bert_mlm.named_parameters() 
                            if not any(nd in n for nd in no_decay)],
                'weight_decay': self.config.weight_decay,
            },
            {
                'params': [p for n, p in self.bert_mlm.named_parameters() 
                            if any(nd in n for nd in no_decay)],
                'weight_decay': 0.0,
            },
        ]
        optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=self.config.lr)
        
        num_training_steps = get_num_training_steps(self.config)
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=0,
            num_training_steps=num_training_steps
            )
        return [optimizer], [scheduler]

## FineTuning

In [None]:
checkpoint = pl.callbacks.ModelCheckpoint(monitor='val_loss', 
                                          mode='min', 
                                          save_top_k=1,  
                                          save_weights_only=True, 
                                          dirpath=f'{Cfg.task}_model_seed_{Cfg.seed}/')
es_callback = pl.callbacks.EarlyStopping(monitor='val_loss', 
                                         patience=1)
tb_logger = pl.loggers.TensorBoardLogger(f'{Cfg.task}_model_seed_{Cfg.seed}_logs/')

trainer = pl.Trainer(tpu_cores=8, 
                     max_epochs=Cfg.epochs, 
                     logger=tb_logger,
                     callbacks=[checkpoint,es_callback])

model = GrammerModel(Cfg)
datamodule = GrammerDataModule(train, val, test, Cfg)
trainer.fit(model, datamodule=datamodule)

del model
gc.collect()

INFO:pytorch_lightning.utilities.distributed:GPU available: False, used: False
INFO:pytorch_lightning.utilities.distributed:TPU available: True, using: 8 TPU cores
INFO:pytorch_lightning.utilities.distributed:IPU available: False, using: 0 IPUs
Some weights of the model checkpoint at cl-tohoku/bert-base-japanese-char-whole-word-masking were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
The tokenizer class you load from this checkpoint is not the same type as the class thi

Validation sanity check: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.seed:Global seed set to 42
INFO:pytorch_lightning.utilities.seed:Global seed set to 42
INFO:pytorch_lightning.utilities.seed:Global seed set to 42
INFO:pytorch_lightning.utilities.seed:Global seed set to 42
INFO:pytorch_lightning.utilities.seed:Global seed set to 42
INFO:pytorch_lightning.utilities.seed:Global seed set to 42
INFO:pytorch_lightning.utilities.seed:Global seed set to 42
INFO:pytorch_lightning.utilities.seed:Global seed set to 42


Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

30

In [None]:
%load_ext tensorboard
%tensorboard --logdir ./

## Evaluate

In [None]:
model = GrammerModel(Cfg)
model.load_state_dict(torch.load(checkpoint.best_model_path, map_location='cpu')['state_dict'])
trainer.test(model, datamodule=datamodule)

Some weights of the model checkpoint at cl-tohoku/bert-base-japanese-char-whole-word-masking were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
INFO:pytorch_lightning.utilities.seed:Global seed set to 42
INFO:pytorch_lightning.utilities.seed:Global seed set to 42
INFO:pytorch_lightning.utilities.seed:Global seed set to 42
INFO:pytorch_lightning.utilities.seed:Global seed set to 42
INFO:pytorch_lightning.utilities.seed:Global seed set to 42
INFO:pytorch_lightning.utilities

Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'val_loss': 0.11656306684017181}
--------------------------------------------------------------------------------


[{'val_loss': 0.11656306684017181}]

In [None]:
# substitution: 'val_loss': 0.019384516403079033
# deletion:     'val_loss': 0.2072163075208664
# insertion:    'val_loss': 0.138065904378891