# 改行・読点挿入モデル 学習ノートブック


## Mount


In [None]:
from google.colab import drive

drive.mount("/content/drive")


## Libraries


In [None]:
!pip install transformers
!pip install wandb
!pip install pytorch-lightning
!pip install rich
!pip install python-box


In [None]:
import datetime
import os
import random
import time

import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import torchmetrics
import wandb
from box import Box
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, RichProgressBar
from pytorch_lightning.loggers import WandbLogger
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
from transformers import (
    BertForSequenceClassification,
    BertModel,
    BertTokenizer,
    get_linear_schedule_with_warmup,
)


## Wandb Setup


In [None]:
wandb.login()


## Dataset


### Define


In [None]:
TRAIN_DATASET_PATH = "/content/drive/MyDrive/MurataLab/newline/train_dataset.csv"
TEST_DATASET_PATH = "/content/drive/MyDrive/MurataLab/newline/test_dataset.csv"


### Dataset


In [None]:
class CustomDataset(Dataset):
    INPUT_COLUMN = "input"
    LF_COLUMN = "is_line_feed"
    COMMA_PERIOD_COLUMN = "comma_period"

    def __init__(self, data, tokenizer, max_token_len):
        self.data = data
        self.tokenizer = tokenizer
        self.max_token_len = max_token_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        data_row = self.data.iloc[index]
        text = data_row[self.INPUT_COLUMN]

        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_token_len,
            padding="max_length",
            truncation=True,
            return_attention_mask=True,
            return_tensors="pt",
        )

        return dict(
            text=text,
            input_ids=encoding["input_ids"].flatten(),
            attention_mask=encoding["attention_mask"].flatten(),
            labels=torch.tensor(
                [data_row[self.LF_COLUMN], data_row[self.COMMA_PERIOD_COLUMN]]
            ),
        )


### DataModule


In [None]:
class DataModuleGenerator(pl.LightningDataModule):
    """
    DataFrameからモデリング時に使用するDataModuleを作成
    """

    def __init__(
        self,
        train_df,
        valid_df,
        test_df,
        tokenizer,
        batch_size,
        max_token_len,
    ):
        super().__init__()
        self.train_df = train_df
        self.valid_df = valid_df
        self.test_df = test_df
        self.batch_size = batch_size
        self.max_token_len = max_token_len
        self.tokenizer = tokenizer

    def setup(self, stage=None):
        self.train_dataset = CustomDataset(
            self.train_df, self.tokenizer, self.max_token_len
        )
        self.valid_dataset = CustomDataset(
            self.valid_df, self.tokenizer, self.max_token_len
        )
        self.test_dataset = CustomDataset(
            self.test_df, self.tokenizer, self.max_token_len
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=1,  # FIXME: os.cpu_count()だとsweepでエラーになる
            pin_memory=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.valid_dataset,
            batch_size=self.batch_size,
            num_workers=1,
            pin_memory=True,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            num_workers=1,
            pin_memory=True,
        )


## Model


In [None]:
class MyModel(pl.LightningModule):
    THRESHOLD = 0.5

    def __init__(
        self,
        tokenizer,
        config,
    ):
        super().__init__()
        self.save_hyperparameters()

        self.config = config

        self.bert = BertModel.from_pretrained(
            config.pretrained_model_name, return_dict=True
        )
        self.bert.resize_token_embeddings(len(tokenizer))

        # ラインフィードの判定 二値分類
        self.hidden_lf_layer = torch.nn.Linear(
            self.bert.config.hidden_size, config.model.hidden_lf_layer
        )
        self.lf_layer = torch.nn.Linear(config.model.hidden_lf_layer, 1)

        # 挿入なし, comma, periodの判定 三値分類
        self.hidden_comma_period_layer = torch.nn.Linear(
            self.bert.config.hidden_size, config.model.hidden_comma_period_layer
        )
        self.comma_period_layer = torch.nn.Linear(
            config.model.hidden_comma_period_layer, 3
        )

        self.lf_criterion = torch.nn.BCELoss()
        self.comma_period_criterion = torch.nn.CrossEntropyLoss()

        self.lf_metrics = torchmetrics.MetricCollection(
            [
                torchmetrics.Accuracy(task="binary", threshold=self.THRESHOLD),
                torchmetrics.Precision(task="binary", threshold=self.THRESHOLD),
                torchmetrics.Recall(task="binary", threshold=self.THRESHOLD),
                torchmetrics.F1Score(task="binary", threshold=self.THRESHOLD),
                torchmetrics.MatthewsCorrCoef(task="binary", threshold=self.THRESHOLD),
            ]
        )
        self.comma_period_metrics = torchmetrics.MetricCollection(
            [
                torchmetrics.Accuracy(task="multiclass", num_classes=3),
                torchmetrics.Precision(task="multiclass", num_classes=3),
                torchmetrics.Recall(task="multiclass", num_classes=3),
                torchmetrics.F1Score(task="multiclass", num_classes=3),
                torchmetrics.MatthewsCorrCoef(task="multiclass", num_classes=3),
            ]
        )

        # BertLayerモジュールの最後を勾配計算ありに変更
        for param in self.bert.parameters():
            param.requires_grad = False
        for param in self.bert.encoder.layer[-1].parameters():
            param.requires_grad = True

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask)

        lf_outputs = torch.relu(self.hidden_lf_layer(outputs.pooler_output))
        lf_predictions = torch.sigmoid(self.lf_layer(lf_outputs)).flatten()

        comma_period_outputs = torch.relu(
            self.hidden_comma_period_layer(outputs.pooler_output)
        )
        comma_period_predictions = torch.softmax(
            self.comma_period_layer(comma_period_outputs), dim=1  # row
        )

        loss = (
            self.compute_loss(
                lf_predictions,
                labels[:, 0].float(),
                comma_period_predictions,
                labels[:, 1].long(),
            )
            if labels is not None
            else 0
        )
        return loss, [lf_predictions, comma_period_predictions]

    def training_step(self, batch, batch_idx):
        loss, predictions = self.forward(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"],
        )
        self.log("train/loss", loss, on_step=True, prog_bar=True)
        return {
            "loss": loss,
            "batch_preds": predictions,
            "batch_labels": batch["labels"],
        }

    def validation_step(self, batch, batch_idx):
        loss, predictions = self.forward(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"],
        )
        return {
            "loss": loss,
            "batch_preds": predictions,
            "batch_labels": batch["labels"],
        }

    def test_step(self, batch, batch_idx):
        loss, predictions = self.forward(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"],
        )
        return {
            "loss": loss,
            "batch_preds": predictions,
            "batch_labels": batch["labels"],
        }

    def compute_loss(
        self, lf_preds, lf_labels, comma_period_preds, comma_period_labels
    ):
        lf_loss = self.lf_criterion(lf_preds, lf_labels.float())
        comma_period_loss = self.comma_period_criterion(
            comma_period_preds, comma_period_labels.long()
        )
        return lf_loss + comma_period_loss  # weightsをつけることも可能

    def epoch_end(self, outputs, mode):
        epoch_lf_preds = torch.cat([x["batch_preds"][0] for x in outputs])
        epoch_lf_labels = torch.cat([x["batch_labels"][:, 0] for x in outputs])
        epoch_comma_period_preds = torch.cat([x["batch_preds"][1] for x in outputs])
        epoch_comma_period_labels = torch.cat(
            [x["batch_labels"][:, 1] for x in outputs]
        )

        epoch_loss = self.compute_loss(
            epoch_lf_preds,
            epoch_lf_labels,
            epoch_comma_period_preds,
            epoch_comma_period_labels,
        )
        self.log(f"{mode}/loss", epoch_loss, logger=True)

        lf_metrics = self.lf_metrics(epoch_lf_preds, epoch_lf_labels.int())
        for metric in lf_metrics.keys():
            self.log(
                f"{mode}/lf/{metric.lower()}", lf_metrics[metric].item(), logger=True
            )

        comma_period_metrics = self.comma_period_metrics(
            epoch_comma_period_preds, epoch_comma_period_labels.int()
        )
        for metric in comma_period_metrics.keys():
            self.log(
                f"{mode}/comma_period/{metric.lower()}",
                comma_period_metrics[metric].item(),
                logger=True,
            )

        return (
            epoch_lf_preds,
            epoch_lf_labels,
            epoch_comma_period_preds,
            epoch_comma_period_labels,
        )

    def validation_epoch_end(self, outputs):
        self.epoch_end(outputs, "val")

    def test_epoch_end(self, outputs):
        lf_preds, lf_labels, comma_period_preds, comma_period_labels = self.epoch_end(
            outputs, "test"
        )
        lf_preds, lf_labels, comma_period_preds, comma_period_labels = (
            lf_preds.cpu().numpy(),
            lf_labels.cpu().numpy(),
            comma_period_preds.cpu().numpy(),
            comma_period_labels.cpu().numpy(),
        )
        lf_preds, comma_period_preds = (
            np.where(lf_preds > self.THRESHOLD, 1, 0),
            np.argmax(comma_period_preds, axis=1),
        )

        wandb.log(
            {
                "test/lf/confusion_matrix": wandb.plot.confusion_matrix(
                    probs=None,
                    y_true=lf_labels,
                    preds=lf_preds,
                    class_names=["-", "改行"],
                ),
                "test/comma_period/confusion_matrix": wandb.plot.confusion_matrix(
                    probs=None,
                    y_true=comma_period_labels,
                    preds=comma_period_preds,
                    class_names=["挿入なし", "読点", "句点"],
                ),
            }
        )

    def configure_optimizers(self):
        assert self.config.optimizer.name in ["AdamW", "RAdam"]
        if self.config.optimizer.name == "AdamW":
            optimizer = torch.optim.AdamW(
                self.parameters(),
                lr=self.config.optimizer.lr,
            )
        elif self.config.optimizer.name == "RAdam":
            optimizer = torch.optim.RAdam(
                self.parameters(),
                lr=self.config.optimizer.lr,
            )
        return [optimizer]


## Train Runner


In [None]:
class MyTrainer:
    def __init__(self, config):
        self.config = config

    def execute(self):
        current = (datetime.datetime.now() + datetime.timedelta(hours=9)).strftime(
            "%Y%m%d_%H%M%S"
        )
        MODEL_OUTPUT_DIR = "/content/drive/MyDrive/MurataLab/newline/models/" + current
        os.makedirs(MODEL_OUTPUT_DIR, exist_ok=True)

        wandb.init(
            project=self.config.wandb_project_name,
            name=current,
            config=self.config,
            id=current,
            save_code=True,
        )
        config = Box(dict(wandb.config))

        tokenizer = BertTokenizer.from_pretrained(config.pretrained_model_name)
        tokenizer.add_tokens(["[ANS]"])

        train_df, val_df = train_test_split(
            pd.read_csv(TRAIN_DATASET_PATH),
            train_size=config.data.train_rate,
            random_state=config.seed,
        )
        test_df = pd.read_csv(TEST_DATASET_PATH)

        data_module = DataModuleGenerator(
            train_df=train_df,
            valid_df=val_df,
            test_df=test_df,
            tokenizer=tokenizer,
            batch_size=config.data_module.batch_size,
            max_token_len=config.data_module.max_length,
        )
        data_module.setup()

        model = MyModel(
            tokenizer,
            config=config,
        )

        early_stop_callback = EarlyStopping(
            **config.early_stopping,
        )

        wandb_logger = WandbLogger(
            log_model=False,
        )
        wandb_logger.watch(model, log="all")

        checkpoint_callback = ModelCheckpoint(
            dirpath=MODEL_OUTPUT_DIR,
            **config.checkpoint,
        )

        progress_bar = RichProgressBar()

        trainer = pl.Trainer(
            max_epochs=config.epoch,
            accelerator="auto",
            devices="auto",
            callbacks=[checkpoint_callback, early_stop_callback, progress_bar],
            logger=wandb_logger,
        )

        trainer.fit(model, data_module)

        trainer.test(model, data_module)

        wandb.finish()


## Config


In [None]:
DO_SWEEP = False


In [None]:
config = dict(
    wandb_project_name="lf-comma-period-v2",
    pretrained_model_name="cl-tohoku/bert-base-japanese-whole-word-masking",
    epoch=4,
    seed=40,
    data_module=dict(
        batch_size=16,
        max_length=32,
    ),
    optimizer=dict(
        name="AdamW",
        lr=2e-5,
    ),
    data=dict(
        train_rate=0.8,
    ),
    model=dict(
        hidden_lf_layer=64,
        hidden_comma_period_layer=64,
    ),
    early_stopping=dict(
        monitor="val/loss",
        patience=3,
        mode="min",
        min_delta=0.02,
    ),
    checkpoint=dict(
        monitor="val/loss",
        mode="min",
        filename="{epoch}",
        verbose=True,
    ),
)
config = Box(config)


In [None]:
sweep_config = dict(
    method="random",
    metric=dict(
        goal="minimize",
        name="val/loss",
    ),
    parameters=dict(
        data_module=dict(
            parameters=dict(
                batch_size=dict(
                    values=[16, 32, 64],
                ),
                max_length=dict(
                    value=32,
                ),
            )
        ),
        optimizer=dict(
            parameters=dict(
                name=dict(
                    values=["AdamW", "RAdam"],
                ),
                lr=dict(
                    values=[1e-5, 5e-5, 9e-5, 1e-6],
                ),
            ),
        ),
        model=dict(
            parameters=dict(
                hidden_lf_layer=dict(
                    values=[128, 256, 512],
                ),
                hidden_comma_period_layer=dict(
                    values=[64, 128, 256],
                ),
            )
        ),
    ),
)


## Execute

In [None]:
if DO_SWEEP:
    sweep_id = wandb.sweep(sweep_config, project=config.wandb_project_name)
    trainer = MyTrainer(config)
    wandb.agent(sweep_id, trainer.execute, count=10)
else:
    trainer = MyTrainer(config)
    trainer.execute()


## Predict


In [None]:
MODEL_DIR = "/content/drive/MyDrive/MurataLab/newline/models"
id = input("id (2022XXXX_XXXXXX) : ")
epoch = input("epoch: ")

tokenizer = BertTokenizer.from_pretrained(config.pretrained_model_name)
tokenizer.add_tokens(["[ANS]"])

model = MyModel(
    tokenizer,
    config=config,
)
model.load_state_dict(
    torch.load(os.path.join(MODEL_DIR, id, f"epoch={epoch}.ckpt"))["state_dict"]
)
model.eval()
model.freeze()


In [None]:
threshold = 0.5

while True:
    text = input("Text [or exit]: ")
    if text == "exit":
        break

    t0 = time.time()
    encoding = tokenizer.encode_plus(
        text,
        add_special_tokens=True,
        max_length=config.data_module.max_length,
        padding="max_length",
        truncation=True,
        return_attention_mask=True,
        return_tensors="pt",
    )
    predictions = model(
        input_ids=encoding["input_ids"],
        attention_mask=encoding["attention_mask"],
    )[1]
    print(f"[Time: {time.time() - t0:.2f} sec]")
    print(predictions)

    print(text.split("[ANS]")[0], end="")
    if np.argmax(predictions[1]) == 1:
        print("、", end="")
    elif np.argmax(predictions[1]) == 2:
        print("。", end="")
    if predictions[0] > threshold:
        print("")
    print(text.split("[ANS]")[1], end="\n\n")
