# 短縮モデル 学習ノートブック


## Mount


In [None]:
from google.colab import drive

drive.mount("/content/drive")


## Libraries


In [None]:
!pip install transformers
!pip install wandb
!pip install pytorch-lightning
!pip install rich
!pip install python-box


In [None]:
import datetime
import os
import random
import time

from box import Box
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytorch_lightning as pl
from pytorch_lightning.callbacks import RichProgressBar
from pytorch_lightning.loggers import WandbLogger
import torch
import torchmetrics
import wandb
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
from transformers import AdamW, T5Tokenizer, T5ForConditionalGeneration


## Config


In [None]:
config = dict(
    wandb_project_name="summary",
    pretrained_model_name="sonoisa/t5-base-japanese",
    epoch=4,
    seed=40,
    data_module=dict(
        batch_size=4,
        text_max_length=30,
        summary_max_length=17,  # not to exceed 20words
    ),
    optimizer=dict(
        name="AdamW",
        lr=2e-5,
        eps=1e-8,
    ),
    early_stopping=dict(
        monitor="val/loss",
        patience=3,
        mode="min",
        min_delta=0.02,
    ),
    checkpoint=dict(
        monitor="val/loss",
        mode="min",
        filename="{epoch}",
        verbose=True,
    ),
    data=dict(
        train_rate=0.8,
    ),
)

config = Box(config)


## Wandb Setup


In [None]:
wandb.login()


## Tokenizer


In [None]:
tokenizer = T5Tokenizer.from_pretrained(config.pretrained_model_name, is_fast=True)


## Data


In [None]:
TRAIN_DATASET_PATH = "/content/drive/MyDrive/MurataLab/summary/train_dataset.csv"
TEST_DATASET_PATH = "/content/drive/MyDrive/MurataLab/summary/test_dataset.csv"


In [None]:
train_df, val_df = train_test_split(
    pd.read_csv(TRAIN_DATASET_PATH),
    train_size=config.data.train_rate,
    random_state=config.seed,
)
test_df = pd.read_csv(TEST_DATASET_PATH)
print(len(train_df), len(val_df), test_df.shape)


### Dataset


In [None]:
class CustomDataset(Dataset):
    TEXT_COLUMN = "text"
    SUMMARY_COLUMN = "summary"

    def __init__(self, data, tokenizer, text_max_token_len, summary_max_token_len):
        self.data = data
        self.tokenizer = tokenizer
        self.text_max_token_len = text_max_token_len
        self.summary_max_token_len = summary_max_token_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        data_row = self.data.iloc[index]
        text, summary = data_row[self.TEXT_COLUMN], data_row[self.SUMMARY_COLUMN]

        text_encoding = self.tokenizer.encode_plus(
            text,
            max_length=self.text_max_token_len,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
            add_special_tokens=True,
            return_attention_mask=True,
        )

        summary_encoding = self.tokenizer.encode_plus(
            summary,
            max_length=self.summary_max_token_len,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
            add_special_tokens=True,
            return_attention_mask=True,
        )
        labels = summary_encoding["input_ids"]
        labels[labels==0] = -100 # Note: the input_ids includes padding too, so replace pad tokens(zero value) with value of -100

        return dict(
            text=text,
            text_input_ids=text_encoding["input_ids"].flatten(),
            text_attention_mask=text_encoding["attention_mask"].flatten(),
            summary=summary,
            labels=labels.flatten(),
            lebels_attention_mask=summary_encoding["attention_mask"].flatten(),
        )



### DataModule


In [None]:
class DataModuleGenerator(pl.LightningDataModule):
    def __init__(
        self,
        train_df,
        valid_df,
        test_df,
        tokenizer,
        batch_size,
        text_max_token_len,
        summary_max_token_len,
    ):
        super().__init__()
        self.train_df = train_df
        self.valid_df = valid_df
        self.test_df = test_df
        self.batch_size = batch_size
        self.text_max_token_len = text_max_token_len
        self.summary_max_token_len = summary_max_token_len
        self.tokenizer = tokenizer

    def setup(self, stage=None):
        self.train_dataset = CustomDataset(
            self.train_df,
            self.tokenizer,
            self.text_max_token_len,
            self.summary_max_token_len,
        )

        self.valid_dataset = CustomDataset(
            self.valid_df,
            self.tokenizer,
            self.text_max_token_len,
            self.summary_max_token_len,
        )

        self.test_dataset = CustomDataset(
            self.test_df,
            self.tokenizer,
            self.text_max_token_len,
            self.summary_max_token_len,
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=os.cpu_count() or 1,
            pin_memory=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.valid_dataset,
            batch_size=self.batch_size,
            num_workers=os.cpu_count() or 1,
            pin_memory=True,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            num_workers=os.cpu_count() or 1,
            pin_memory=True,
        )


data_module = DataModuleGenerator(
    train_df=train_df,
    valid_df=val_df,
    test_df=test_df,
    tokenizer=tokenizer,
    batch_size=config.data_module.batch_size,
    text_max_token_len=config.data_module.text_max_length,
    summary_max_token_len=config.data_module.summary_max_length,
)
data_module.setup()


## Model


In [None]:
class MyModel(pl.LightningModule):
    def __init__(self, tokenizer, config):
        super().__init__()
        self.tokenizer = tokenizer
        self.config = config

        self.model = T5ForConditionalGeneration.from_pretrained(
            config.pretrained_model_name,
            return_dict=True,
        )

    def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
        output = self.model(
            input_ids,
            attention_mask=attention_mask,
            labels=labels,
            decoder_attention_mask=decoder_attention_mask,
        )
        return output.loss, output.logits

    def training_step(self, batch, batch_size):
        loss, logits = self(
            input_ids=batch["text_input_ids"],
            attention_mask=batch["text_attention_mask"],
            decoder_attention_mask=batch["labels_attention_mask"],
            labels=batch["labels"],
        )
        self.log("train/loss", loss, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_size):
        loss, logits = self(
            input_ids=batch["text_input_ids"],
            attention_mask=batch["text_attention_mask"],
            decoder_attention_mask=batch["labels_attention_mask"],
            labels=batch["labels"],
        )
        return loss

    def test_step(self, batch, batch_size):
        loss, logits = self(
            input_ids=batch["text_input_ids"],
            attention_mask=batch["text_attention_mask"],
            decoder_attention_mask=batch["labels_attention_mask"],
            labels=batch["labels"],
        )
        return loss

    def configure_optimizers(self):
        assert self.config.optimizer.name in ["AdamW"]
        if self.config.optimizer.name == "AdamW":
            optimizer = torch.optim.AdamW(
                self.parameters(),
                lr=self.config.optimizer.lr,
                eps=self.config.optimizer.eps,
            )
        return [optimizer]


In [None]:
model = MyModel(
    tokenizer,
    config=config,
)


## Train


In [None]:
early_stop_callback = EarlyStopping(
    **config.early_stopping,
)

current = (datetime.datetime.now() + datetime.timedelta(hours=9)).strftime(
    "%Y%m%d_%H%M%S"
)
MODEL_OUTPUT_DIR = "/content/drive/MyDrive/MurataLab/summary/models/" + current
os.makedirs(MODEL_OUTPUT_DIR, exist_ok=True)

checkpoint_callback = ModelCheckpoint(
    dirpath=MODEL_OUTPUT_DIR,
    **config.checkpoint,
)


progress_bar = RichProgressBar()

wandb.init(
    project=config.wandb_project_name,
    name=current,
    config=config,
    id=current,
    save_code=True,
)

trainer = pl.Trainer(
    max_epochs=config.epoch,
    accelerator="auto",
    devices="auto",
    callbacks=[checkpoint_callback, early_stop_callback, progress_bar],
    logger=WandbLogger(
        log_model=False,
    ),
)

trainer.fit(model, data_module)


## Test


In [None]:
trainer.test(datamodule=data_module, ckpt_path="best")


In [None]:
wandb.finish()
