# 改行・読点挿入モデル 学習ノートブック


## Mount


In [None]:
from google.colab import drive

drive.mount("/content/drive")


## Libraries


In [None]:
!pip install transformers
!pip install torchinfo
!pip install wandb
!pip install pytorch-lightning
!pip install rich
!pip install python-box


In [None]:
import datetime
import os
import random
import time

from box import Box
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytorch_lightning as pl
from pytorch_lightning.callbacks import RichProgressBar
from pytorch_lightning.loggers import WandbLogger
import torch
import torchmetrics
import wandb
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
from transformers import (
    AdamW,
    BertConfig,
    BertForSequenceClassification,
    BertModel,
    BertTokenizer,
    get_linear_schedule_with_warmup,
)


## Config


In [None]:
config = dict(
    wandb_project_name="lf-punctuation",
    pretrained_model_name="cl-tohoku/bert-base-japanese-whole-word-masking",
    epoch=4,
    seed=40,
    data_module=dict(
        batch_size=16,
        max_length=32,
    ),
    optimizer=dict(
        name="AdamW",
        lr=2e-5,
        eps=1e-8,
    ),
    early_stopping=dict(
        monitor="val/loss",
        patience=3,
        mode="min",
        min_delta=0.02,
    ),
    checkpoint=dict(
        monitor="val/loss",
        mode="min",
        filename="{epoch}",
        verbose=True,
    ),
    data=dict(
        train_rate=0.8,
    ),
    model=dict(
        hidden_layer1_output=64,
    ),
)

config = Box(config)


## Wandb Setup


In [None]:
wandb.login()


## Tokenizer


In [None]:
tokenizer = BertTokenizer.from_pretrained(config.pretrained_model_name)
tokenizer.add_tokens(["[ANS]"])


## Dataset


### Load


In [None]:
TRAIN_DATASET_PATH = "/content/drive/MyDrive/MurataLab/newline/train_dataset.csv"
TEST_DATASET_PATH = "/content/drive/MyDrive/MurataLab/newline/test_dataset.csv"


In [None]:
train_df, val_df = train_test_split(
    pd.read_csv(TRAIN_DATASET_PATH),
    train_size=config.data.train_rate,
    random_state=config.seed,
)
test_df = pd.read_csv(TEST_DATASET_PATH)
print(len(train_df), len(val_df), test_df.shape)


### Dataset


In [None]:
class CustomDataset(Dataset):
    """
    DataFrameを下記のitemを保持するDatasetに変換。
    text(原文)、input_ids(tokenizeされた文章)、attention_mask、labels(ラベル)
    """

    INPUT_COLUMN = "input"
    LF_COLUMN = "is_line_feed"
    COMMA_COLUMN = "is_comma"
    PERIOD_COLUMN = "is_period"

    def __init__(self, data, tokenizer, max_token_len):
        self.data = data
        self.tokenizer = tokenizer
        self.max_token_len = max_token_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        data_row = self.data.iloc[index]
        text = data_row[self.INPUT_COLUMN]
        labels = (
            data_row[self.LF_COLUMN],
            data_row[self.COMMA_COLUMN],
            data_row[self.PERIOD_COLUMN],
        )

        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_token_len,
            padding="max_length",
            truncation=True,
            return_attention_mask=True,
            return_tensors="pt",
        )

        return dict(
            text=text,
            input_ids=encoding["input_ids"].flatten(),
            attention_mask=encoding["attention_mask"].flatten(),
            labels=torch.tensor(labels),
        )


### DataModule


In [None]:
class DataModuleGenerator(pl.LightningDataModule):
    """
    DataFrameからモデリング時に使用するDataModuleを作成
    """

    def __init__(
        self,
        train_df,
        valid_df,
        test_df,
        tokenizer,
        batch_size=16,
        max_token_len=512,
    ):
        super().__init__()
        self.train_df = train_df
        self.valid_df = valid_df
        self.test_df = test_df
        self.batch_size = batch_size
        self.max_token_len = max_token_len
        self.tokenizer = tokenizer

    def setup(self, stage=None):
        self.train_dataset = CustomDataset(
            self.train_df, self.tokenizer, self.max_token_len
        )
        self.vaild_dataset = CustomDataset(
            self.valid_df, self.tokenizer, self.max_token_len
        )
        self.test_dataset = CustomDataset(
            self.test_df, self.tokenizer, self.max_token_len
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=os.cpu_count() or 1,
            pin_memory=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.vaild_dataset,
            batch_size=self.batch_size,
            num_workers=os.cpu_count() or 1,
            pin_memory=True,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            num_workers=os.cpu_count() or 1,
            pin_memory=True,
        )


data_module = DataModuleGenerator(
    train_df=train_df,
    valid_df=val_df,
    test_df=test_df,
    tokenizer=tokenizer,
    batch_size=config.data_module.batch_size,
    max_token_len=config.data_module.max_length,
)
data_module.setup()


## Model


In [None]:
class MyModel(pl.LightningModule):
    THRESHOLD = 0.5

    def __init__(
        self,
        tokenizer,
        epochs=None,
        pretrained_model_name="cl-tohoku/bert-base-japanese-char-whole-word-masking",
        config=None,
    ):
        super().__init__()
        self.config = config

        self.bert = BertModel.from_pretrained(pretrained_model_name, return_dict=True)
        self.bert.resize_token_embeddings(len(tokenizer))

        self.hidden_layer1 = torch.nn.Linear(
            self.bert.config.hidden_size, config.model.hidden_layer1_output
        )
        self.lf_layer = torch.nn.Linear(config.model.hidden_layer1_output, 1)
        self.comma_layer = torch.nn.Linear(config.model.hidden_layer1_output, 1)
        self.period_layer = torch.nn.Linear(config.model.hidden_layer1_output, 1)

        self.n_epochs = epochs

        self.criterion = torch.nn.BCELoss()

        # 評価指標
        self.metrics = torchmetrics.MetricCollection(
            [
                torchmetrics.classification.BinaryAccuracy(threshold=self.THRESHOLD),
                torchmetrics.classification.BinaryPrecision(threshold=self.THRESHOLD),
                torchmetrics.classification.BinaryRecall(threshold=self.THRESHOLD),
                torchmetrics.classification.BinaryF1Score(threshold=self.THRESHOLD),
                torchmetrics.classification.BinaryMatthewsCorrCoef(),
            ]
        )

        # BertLayerモジュールの最後を勾配計算ありに変更
        for param in self.bert.parameters():
            param.requires_grad = False
        for param in self.bert.encoder.layer[-1].parameters():
            param.requires_grad = True

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        outputs = self.hidden_layer1(outputs.pooler_output)

        lf_predictions = torch.sigmoid(self.lf_layer(outputs)).flatten()
        comma_predictions = torch.sigmoid(self.comma_layer(outputs)).flatten()
        period_predictions = torch.sigmoid(self.period_layer(outputs)).flatten()

        predictions = torch.stack(
            [lf_predictions, comma_predictions, period_predictions], dim=1
        )

        loss = 0
        if labels is not None:
            lf_loss = self.criterion(lf_predictions, labels[:, 0].float())
            comma_loss = self.criterion(comma_predictions, labels[:, 1].float())
            period_loss = self.criterion(period_predictions, labels[:, 2].float())
            loss = lf_loss + comma_loss + period_loss  # どれかのlossに重みをつけることも可能
        return loss, predictions

    def training_step(self, batch, batch_idx):
        loss, predictions = self.forward(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"],
        )
        self.log("train/loss", loss, on_step=True, prog_bar=True)
        return {
            "loss": loss,
            "batch_preds": predictions,
            "batch_labels": batch["labels"],
        }

    def validation_step(self, batch, batch_idx):
        loss, preds = self.forward(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"],
        )
        return {"loss": loss, "batch_preds": preds, "batch_labels": batch["labels"]}

    def test_step(self, batch, batch_idx):
        loss, preds = self.forward(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"],
        )
        return {"loss": loss, "batch_preds": preds, "batch_labels": batch["labels"]}

    def calculate_loss(self, preds, labels):
        lf_loss = self.criterion(preds[:, 0], labels[:, 0].float())
        comma_loss = self.criterion(preds[:, 1], labels[:, 1].float())
        period_loss = self.criterion(preds[:, 2], labels[:, 2].float())
        return lf_loss + comma_loss + period_loss

    def epoch_end(self, outputs, mode):
        epoch_preds = torch.cat([x["batch_preds"] for x in outputs])
        epoch_labels = torch.cat([x["batch_labels"] for x in outputs])

        epoch_loss = self.calculate_loss(epoch_preds, epoch_labels)
        self.log(f"{mode}/loss", epoch_loss, logger=True)

        lf_metrics = self.metrics(epoch_preds[:,0], epoch_labels[:,0].int())
        comma_metrics = self.metrics(epoch_preds[:,1], epoch_labels[:,1].int())
        period_metrics = self.metrics(epoch_preds[:,2], epoch_labels[:,2].int())
        for metric in lf_metrics.keys():
            self.log(f"{mode}/lf/{metric.lower()}", lf_metrics[metric].item(), logger=True)
            self.log(f"{mode}/comma/{metric.lower()}", comma_metrics[metric].item(), logger=True)
            self.log(f"{mode}/period/{metric.lower()}", period_metrics[metric].item(), logger=True)
            
        return epoch_preds, epoch_labels

    def validation_epoch_end(self, outputs):
        self.epoch_end(outputs, "val")

    def test_epoch_end(self, outputs):
        preds, labels = self.epoch_end(outputs, "test")

        preds = preds.cpu().numpy()
        labels = labels.cpu().numpy()
        lf_preds, comma_preds, period_preds = (
            (preds[:, 0] > self.THRESHOLD),
            (preds[:, 1] > self.THRESHOLD),
            (preds[:, 2] > self.THRESHOLD),
        )
        lf_labels, comma_labels, period_labels = (
            labels[:, 0],
            labels[:, 1],
            labels[:, 2],
        )

        wandb.log(
            {
                "test/lf/confusion_matrix": wandb.plot.confusion_matrix(
                    probs=None,
                    y_true=lf_labels,
                    preds=lf_preds,
                    class_names=["0", "1"],
                ),
                "test/comma/confusion_matrix": wandb.plot.confusion_matrix(
                    probs=None,
                    y_true=comma_labels,
                    preds=comma_preds,
                    class_names=["0", "1"],
                ),
                "test/period/confusion_matrix": wandb.plot.confusion_matrix(
                    probs=None,
                    y_true=period_labels,
                    preds=period_preds,
                    class_names=["0", "1"],
                ),
            }
        )

    def configure_optimizers(self):
        assert self.config.optimizer.name in ["AdamW"]
        if self.config.optimizer.name == "AdamW":
            optimizer = torch.optim.AdamW(
                self.parameters(),
                lr=self.config.optimizer.lr,
                eps=self.config.optimizer.eps,
            )
        return [optimizer]


In [None]:
model = MyModel(
    tokenizer,
    epochs=config.epoch,
    pretrained_model_name=config.pretrained_model_name,
    config=config,
)


## Train


In [None]:
early_stop_callback = EarlyStopping(
    **config.early_stopping,
)

current = (datetime.datetime.now() + datetime.timedelta(hours=9)).strftime(
    "%Y%m%d_%H%M%S"
)
MODEL_OUTPUT_DIR = "/content/drive/MyDrive/MurataLab/newline/models/" + current
os.makedirs(MODEL_OUTPUT_DIR, exist_ok=True)

checkpoint_callback = ModelCheckpoint(
    dirpath=MODEL_OUTPUT_DIR,
    **config.checkpoint,
)


progress_bar = RichProgressBar()

wandb.init(
    project=config.wandb_project_name,
    name=current,
    config=config,
    id=current,
    save_code=True,
)

trainer = pl.Trainer(
    max_epochs=config.epoch,
    accelerator="auto",
    devices="auto",
    callbacks=[checkpoint_callback, early_stop_callback, progress_bar],
    logger=WandbLogger(
        log_model=False,
    ),
)

trainer.fit(model, data_module)


## Test


In [None]:
trainer.test(datamodule=data_module, ckpt_path="best")


In [None]:
wandb.finish()


## Predict

In [None]:
MODEL_DIR = "/content/drive/MyDrive/MurataLab/newline/models"
id = input("id (2022XXXX_XXXXXX) : ")
epoch = input("epoch: ")

model = MyModel(
    tokenizer,
    pretrained_model_name=config.pretrained_model_name,
    config=config,
)
model.load_state_dict(torch.load(os.path.join(MODEL_DIR, id, f"epoch={epoch}.ckpt"))["state_dict"])
model.eval()
model.freeze()


In [None]:
while True:
    text = input("Text [or exit]: ")
    if text == "exit":
        break

    t0 = time.time()
    encoding = tokenizer.encode_plus(
                text,
                add_special_tokens=True,
                max_length=config.data_module.max_length,
                padding="max_length",
                truncation=True,
                return_attention_mask=True,
                return_tensors="pt",
            )
    predictions = model(
        input_ids=encoding["input_ids"],
        attention_mask=encoding["attention_mask"],
        labels=None,
    )[1]
    print(f"[Time: {time.time() - t0:.2f} sec]")

    print(text.split("[ANS]")[0] , end="")
    threshold = 0.5
    if predictions[0][1] > threshold:
        print(",", end="")
    if predictions[0][2] > threshold:
        print("." ,end="")
    if predictions[0][0] > threshold:
        print("")
    print(text.split("[ANS]")[1], end="\n\n")
    
