# 短縮モデル 学習ノートブック


## Mount


In [None]:
from google.colab import drive

drive.mount("/content/drive")


## Libraries


In [None]:
!pip install transformers
!pip install wandb
!pip install pytorch-lightning
!pip install rich
!pip install python-box
!pip install sentencepiece
!pip install "sacrebleu[ja]<2.0.0"
!pip install janome
!pip install sumeval
!pip install unidic-lite


In [None]:
import datetime
import os
import random
import time

from box import Box
import numpy as np
import pandas as pd
import pytorch_lightning as pl
from pytorch_lightning.callbacks import RichProgressBar
from pytorch_lightning.loggers import WandbLogger
import torch
import torchmetrics
import wandb
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
from transformers import T5Tokenizer, T5ForConditionalGeneration
from sumeval.metrics.rouge import RougeCalculator
from sumeval.metrics.bleu import BLEUCalculator


## Wandb Setup


In [None]:
wandb.login()


## Data


In [None]:
TRAIN_DATASET_PATH = "/content/drive/MyDrive/MurataLab/summary/train_dataset.csv"
TEST_DATASET_PATH = "/content/drive/MyDrive/MurataLab/summary/test_dataset.csv"


### Dataset


In [None]:
class CustomDataset(Dataset):
    TEXT_COLUMN = "text"
    SUMMARY_COLUMN = "summary"

    def __init__(self, data, tokenizer, text_max_token_len, summary_max_token_len):
        self.data = data
        self.tokenizer = tokenizer
        self.text_max_token_len = text_max_token_len
        self.summary_max_token_len = summary_max_token_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        data_row = self.data.iloc[index]
        text = data_row[self.TEXT_COLUMN]
        summary = data_row[self.SUMMARY_COLUMN]

        text_encoding = self.tokenizer.encode_plus(
            text,
            max_length=self.text_max_token_len,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
            add_special_tokens=True,
            return_attention_mask=True,
        )

        summary_encoding = self.tokenizer.encode_plus(
            summary,
            max_length=self.summary_max_token_len,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
            add_special_tokens=True,
            return_attention_mask=True,
        )
        summary_ids = summary_encoding["input_ids"]
        summary_ids[
            summary_ids == 0
        ] = (
            -100
        )  # Note: the input_ids includes padding too, so replace pad tokens(zero value) with value of -100

        return dict(
            text=text,
            text_ids=text_encoding["input_ids"].flatten(),
            text_attention_mask=text_encoding["attention_mask"].flatten(),
            summary=summary,
            summary_ids=summary_ids.flatten(),
            summary_attention_mask=summary_encoding["attention_mask"].flatten(),
        )


### DataModule


In [None]:
class DataModuleGenerator(pl.LightningDataModule):
    def __init__(
        self,
        train_df,
        valid_df,
        test_df,
        tokenizer,
        batch_size,
        text_max_token_len,
        summary_max_token_len,
    ):
        super().__init__()
        self.train_df = train_df
        self.valid_df = valid_df
        self.test_df = test_df
        self.batch_size = batch_size
        self.text_max_token_len = text_max_token_len
        self.summary_max_token_len = summary_max_token_len
        self.tokenizer = tokenizer

    def setup(self, stage=None):
        self.train_dataset = CustomDataset(
            self.train_df,
            self.tokenizer,
            self.text_max_token_len,
            self.summary_max_token_len,
        )

        self.valid_dataset = CustomDataset(
            self.valid_df,
            self.tokenizer,
            self.text_max_token_len,
            self.summary_max_token_len,
        )

        self.test_dataset = CustomDataset(
            self.test_df,
            self.tokenizer,
            self.text_max_token_len,
            self.summary_max_token_len,
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=1,
            pin_memory=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.valid_dataset,
            batch_size=self.batch_size,
            num_workers=1,
            pin_memory=True,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            num_workers=1,
            pin_memory=True,
        )


## Model


In [None]:
class MyModel(pl.LightningModule):
    def __init__(self, tokenizer, config):
        super().__init__()
        self.tokenizer = tokenizer
        self.config = config

        self.model = T5ForConditionalGeneration.from_pretrained(
            config.pretrained_model_name,
            return_dict=True,
        )

        # metrics
        self.rouge_ja = RougeCalculator(stopwords=True, lang="ja")
        # self.bleu_ja = BLEUCalculator(lang="ja")

    def forward(
        self,
        text_ids,
        text_attention_mask,
        summary_ids=None,
        summary_attention_mask=None,
    ):
        output = self.model(
            text_ids,
            attention_mask=text_attention_mask,
            labels=summary_ids,
            decoder_attention_mask=summary_attention_mask,
        )  # loss func is cross entropy
        return output.loss, output.logits

    def predict(self, text_ids, text_attention_mask):
        output = self.model.generate(
            text_ids,
            attention_mask=text_attention_mask,
            max_length=self.config.data_module.summary_max_length,
            num_beams=1,
            repetition_penalty=2.5,
            length_penalty=1.0,
            early_stopping=True,
        )
        return [
            self.tokenizer.decode(
                ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
            )
            for ids in output
        ]

    def _step(self, batch, return_text=False):
        loss, logits = self(
            text_ids=batch["text_ids"],
            text_attention_mask=batch["text_attention_mask"],
            summary_ids=batch["summary_ids"],
            summary_attention_mask=batch["summary_attention_mask"],
        )
        return {
            "loss": loss,
            "logits": logits,
            "text_ids": batch["text_ids"],
            "summary_ids": batch["summary_ids"],
        }

    def training_step(self, batch, batch_size):
        results = self._step(batch)
        self.log("train/loss", results["loss"], prog_bar=True)
        return results

    def _val_test_step(self, batch, batch_size, mode="val"):
        results = self._step(batch)

        predicted_texts = self.predict(batch["text_ids"], batch["text_attention_mask"])

        metrics = {
            "rouge_1": [],
            "rouge_2": [],
            "rouge_l": [],
            # "bleu": [],
        }
        for text, summary, predicted_text in zip(
            batch["text"], batch["summary"], predicted_texts
        ):
            metrics["rouge_1"].append(
                self.rouge_ja.rouge_n(summary, predicted_text, n=1)
            )
            metrics["rouge_2"].append(
                self.rouge_ja.rouge_n(summary, predicted_text, n=2)
            )
            metrics["rouge_l"].append(self.rouge_ja.rouge_l(summary, predicted_text))
            # metrics["bleu"].append(self.bleu_ja.bleu(summary, predicted_text))

        return {
            "loss": results["loss"],
            "rouge_1": np.mean(metrics["rouge_1"]),
            "rouge_2": np.mean(metrics["rouge_2"]),
            "rouge_l": np.mean(metrics["rouge_l"]),
            # "bleu": np.mean(metrics["bleu"]),
            "text": batch["text"],
            "summary": batch["summary"],
            "predicted_text": predicted_texts,
        }

    def validation_step(self, batch, batch_size):
        return self._val_test_step(batch, batch_size, mode="val")

    def test_step(self, batch, batch_size):
        return self._val_test_step(batch, batch_size, mode="test")

    def _epoch_end(self, outputs, mode):
        avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
        self.log(f"{mode}/loss", avg_loss)

        avg_rouge_1 = np.mean([x["rouge_1"] for x in outputs])
        avg_rouge_2 = np.mean([x["rouge_2"] for x in outputs])
        avg_rouge_l = np.mean([x["rouge_l"] for x in outputs])
        # avg_bleu = np.mean([x["bleu"] for x in outputs])
        self.log(f"{mode}/rouge_1", avg_rouge_1)
        self.log(f"{mode}/rouge_2", avg_rouge_2)
        self.log(f"{mode}/rouge_l", avg_rouge_l)
        # self.log(f"{mode}/bleu", avg_bleu)

    def validation_epoch_end(self, outputs):
        self._epoch_end(outputs, mode="val")

    def test_epoch_end(self, outputs):
        self._epoch_end(outputs, mode="test")
        results = []
        for step_output in outputs:
            for text, summary, predicted_text in zip(
                step_output["text"],
                step_output["summary"],
                step_output["predicted_text"],
            ):
                results.append(
                    [
                        text,
                        summary,
                        predicted_text,
                    ]
                )
        wandb.log(
            {
                "test/results": wandb.Table(
                    data=results, columns=["text", "summary", "predicted_text"]
                )
            }
        )

    def configure_optimizers(self):
        assert self.config.optimizer.name in ["AdamW", "RAdam"]
        if self.config.optimizer.name == "AdamW":
            optimizer = torch.optim.AdamW(
                self.parameters(),
                lr=self.config.optimizer.lr,
            )
        elif self.config.optimizer.name == "RAdam":
            optimizer = torch.optim.RAdam(
                self.parameters(),
                lr=self.config.optimizer.lr,
            )
        return [optimizer]


## Trainer


In [None]:
class MyTrainer:
    def __init__(self, config):
        self.config = config

    def execute(self):
        current = (datetime.datetime.now() + datetime.timedelta(hours=9)).strftime(
            "%Y%m%d_%H%M%S"
        )
        MODEL_OUTPUT_DIR = "/content/drive/MyDrive/MurataLab/summary/models/" + current
        os.makedirs(MODEL_OUTPUT_DIR, exist_ok=True)

        wandb.init(
            project=self.config.wandb_project_name,
            name=current,
            config=self.config,
            id=current,
            save_code=True,
        )
        config = Box(dict(wandb.config))

        tokenizer = T5Tokenizer.from_pretrained(config.pretrained_model_name)

        train_df, val_df = train_test_split(
            pd.read_csv(TRAIN_DATASET_PATH),
            train_size=config.data.train_rate,
            random_state=config.seed,
        )
        test_df = pd.read_csv(TEST_DATASET_PATH)

        data_module = DataModuleGenerator(
            train_df=train_df,
            valid_df=val_df,
            test_df=test_df,
            tokenizer=tokenizer,
            batch_size=config.data_module.batch_size,
            text_max_token_len=config.data_module.text_max_length,
            summary_max_token_len=config.data_module.summary_max_length,
        )
        data_module.setup()

        model = MyModel(
            tokenizer,
            config=config,
        )

        early_stop_callback = EarlyStopping(
            **config.early_stopping,
        )

        wandb_logger = WandbLogger(
            log_model=False,
        )
        wandb_logger.watch(model, log="all")

        checkpoint_callback = ModelCheckpoint(
            dirpath=MODEL_OUTPUT_DIR,
            **config.checkpoint,
        )

        progress_bar = RichProgressBar()

        trainer = pl.Trainer(
            max_epochs=config.epoch,
            accelerator="auto",
            devices="auto",
            callbacks=[checkpoint_callback, early_stop_callback, progress_bar],
            logger=wandb_logger,
            deterministic=True,
            # precision=16,
            # accumulate_grad_batches=config.accumulate_grad_batches,
        )

        trainer.fit(model, data_module)

        trainer.test(model, data_module)

        wandb.finish()


## Config


In [None]:
DO_SWEEP = False


In [None]:
config = dict(
    wandb_project_name="summary",
    pretrained_model_name="sonoisa/t5-base-japanese",
    epoch=10,
    seed=40,
    data_module=dict(
        batch_size=2,
        text_max_length=30,  # データセットの入力テキストは21~25字
        summary_max_length=17,  # 20字を超えないようにxトークンとする
    ),
    optimizer=dict(
        name="RAdam",
        lr=1e-5,
    ),
    early_stopping=dict(
        monitor="val/loss",
        patience=3,
        mode="min",
        min_delta=0.02,
    ),
    checkpoint=dict(
        monitor="val/loss",
        mode="min",
        filename="{epoch}",
        verbose=True,
    ),
    data=dict(
        train_rate=0.9,
    ),
    accumulate_grad_batches=4,
)

config = Box(config)


In [None]:
sweep_config = dict(
    method="random",
    metric=dict(
        goal="minimize",
        name="val/loss",
    ),
    parameters=dict(
        data_module=dict(
            parameters=dict(
                batch_size=dict(
                    values=[1, 2, 3, 4],
                ),
                text_max_length=25,  # データセットの入力テキストは21~25字
                summary_max_length=17,
            )
        ),
        optimizer=dict(
            parameters=dict(
                name=dict(
                    values=["AdamW", "RAdam"],
                ),
                lr=dict(
                    values=[1e-5, 5e-5, 9e-5, 1e-6],
                ),
            ),
        ),
    ),
)


## Execute


In [None]:
if DO_SWEEP:
    sweep_id = wandb.sweep(sweep_config, project=config.wandb_project_name)
    trainer = MyTrainer(config)
    wandb.agent(sweep_id, trainer.execute, count=10)
else:
    trainer = MyTrainer(config)
    trainer.execute()


## Predict


In [None]:
MODEL_DIR = "/content/drive/MyDrive/MurataLab/summary/models"
id = input("id (2023XXXX_XXXXXX) : ")
epoch = input("epoch: ")

tokenizer = T5Tokenizer.from_pretrained(config.pretrained_model_name)

trained_model = MyModel(
    tokenizer,
    config=config,
)
trained_model.load_state_dict(
    torch.load(
        os.path.join(MODEL_DIR, id, f"epoch={epoch}.ckpt"),
        map_location=torch.device("cpu"),
    )["state_dict"]
)
trained_model.eval()
trained_model.freeze()


In [None]:
while True:
    text = input("Text (exit): ")
    if text == "exit":
        break

    t0 = time.time()
    encoding = tokenizer.encode_plus(
        text,
        add_special_tokens=True,
        max_length=config.data_module.text_max_length,
        padding="max_length",
        truncation=True,
        return_attention_mask=True,
        return_tensors="pt",
    )
    generated_ids = trained_model.model.generate(
        input_ids=encoding["input_ids"],
        attention_mask=encoding["attention_mask"],
        max_length=config.data_module.summary_max_length,
        num_beams=4,
        repetition_penalty=2.5,
        # length_penalty=1.0,
        # early_stopping=True,
    )
    print("    Time: ", time.time() - t0)
    print(f"    {tokenizer.batch_decode(generated_ids, skip_special_tokens=True)}")
