In [1]:
import argparse
from datetime import datetime
from __future__ import unicode_literals
import logging
import os
import random
import re
import tarfile
import unicodedata

import pytorch_lightning as PL
import sentencepiece
import torch
from torch.utils.data import Dataset, DataLoader
import transformers as T
from tqdm import tqdm

logger = logging.getLogger(__name__)

def report(message):
    print(message)
    
os.environ['CUDA_VISIBLE_DEVICES'] = "0"



In [2]:
class args:
    # model
    # model_name_or_path = "fnakamura/t5-base-en2ja"
    model_name_or_path = None
    tokenizer_name_or_path = "fnakamura/t5-base-en2ja"
    max_input_length = 512
    max_target_length = 4
    train_batch_size = 8
    eval_batch_size = 8
    num_train_epochs = 4

    # data
    data_dir = "../inputs"
    data_file = "../inputs/ldcc-20140209.tar.gz"
    
    # training
    learning_rate = 3e-4
    weight_decay = 0.
    adam_epsilon = 1e-8
    warmup_steps = 0
    gradient_accumulation_steps = 1
    
    num_gpus = 1
    early_stopping_callback = False
    fp_16 = False
    opt_level = 'O1'
    max_grad_norm = 1.0
    
    # experiments
    datetime_id = datetime.now().strftime("%Y-%m-%d_%H%M%S")
    expid = f"en2ja_t5-base_classification_{datetime_id}"
    output_dir = f"../outputs/{expid}"
    random_state = 42
    
    # HF
    cache_dir = os.getenv("HF_CACHE_DIR")

PL.seed_everything(args.random_state)

Global seed set to 42


42

In [3]:
if os.path.exists(args.output_dir):
    report(f"Output dir exists: {args.output_dir}")
else:
    os.makedirs(args.output_dir)
    report(f"Output dir created: {args.output_dir}")

Output dir created: ../outputs/en2ja_t5-base_classification_2022-05-09_133156


In [4]:
def unicode_normalize(cls, s):
    pt = re.compile('([{}]+)'.format(cls))

    def norm(c):
        return unicodedata.normalize('NFKC', c) if pt.match(c) else c

    s = ''.join(norm(x) for x in re.split(pt, s))
    s = re.sub('－', '-', s)
    return s

def remove_extra_spaces(s):
    s = re.sub('[ 　]+', ' ', s)
    blocks = ''.join(('\u4E00-\u9FFF',  # CJK UNIFIED IDEOGRAPHS
                      '\u3040-\u309F',  # HIRAGANA
                      '\u30A0-\u30FF',  # KATAKANA
                      '\u3000-\u303F',  # CJK SYMBOLS AND PUNCTUATION
                      '\uFF00-\uFFEF'   # HALFWIDTH AND FULLWIDTH FORMS
                      ))
    basic_latin = '\u0000-\u007F'

    def remove_space_between(cls1, cls2, s):
        p = re.compile('([{}]) ([{}])'.format(cls1, cls2))
        while p.search(s):
            s = p.sub(r'\1\2', s)
        return s

    s = remove_space_between(blocks, blocks, s)
    s = remove_space_between(blocks, basic_latin, s)
    s = remove_space_between(basic_latin, blocks, s)
    return s

def normalize_neologd(s):
    s = s.strip()
    s = unicode_normalize('０-９Ａ-Ｚａ-ｚ｡-ﾟ', s)

    def maketrans(f, t):
        return {ord(x): ord(y) for x, y in zip(f, t)}

    s = re.sub('[˗֊‐‑‒–⁃⁻₋−]+', '-', s)  # normalize hyphens
    s = re.sub('[﹣－ｰ—―─━ー]+', 'ー', s)  # normalize choonpus
    s = re.sub('[~∼∾〜〰～]+', '〜', s)  # normalize tildes (modified by Isao Sonobe)
    s = s.translate(
        maketrans('!"#$%&\'()*+,-./:;<=>?@[¥]^_`{|}~｡､･｢｣',
              '！”＃＄％＆’（）＊＋，－．／：；＜＝＞？＠［￥］＾＿｀｛｜｝〜。、・「」'))

    s = remove_extra_spaces(s)
    s = unicode_normalize('！”＃＄％＆’（）＊＋，－．／：；＜＞？＠［￥］＾＿｀｛｜｝〜', s)  # keep ＝,・,「,」
    s = re.sub('[’]', '\'', s)
    s = re.sub('[”]', '"', s)
    return s
    

In [5]:
target_genres = [
    "dokujo-tsushin",
     "it-life-hack",
     "kaden-channel",
     "livedoor-homme",
     "movie-enter",
     "peachy",
     "smax",
     "sports-watch",
     "topic-news"
]

def remove_brackets(text):
    text = re.sub(r"(^【[^】]*】)|(【[^】]*】$)", "", text)
    return text

def normalize_text(text):
    assert "\n" not in text and "\r" not in text
    text = text.replace("\t", " ")
    text = text.strip()
    text = normalize_neologd(text)
    text = text.lower()
    return text

def read_title_body(file):
    next(file)
    next(file)
    title = next(file).decode("utf-8").strip()
    title = normalize_text(remove_brackets(title))
    body = normalize_text(" ".join([line.decode("utf-8").strip() for line in file.readlines()]))
    return title, body

genre_files_list = [[] for genre in target_genres]

all_data = []

with tarfile.open(args.data_file) as archive_file:
    for archive_item in archive_file:
        for i, genre in enumerate(target_genres):
            if genre in archive_item.name and archive_item.name.endswith(".txt"):
                genre_files_list[i].append(archive_item.name)

    for i, genre_files in enumerate(genre_files_list):
        for name in genre_files:
            file = archive_file.extractfile(name)
            title, body = read_title_body(file)
            title = normalize_text(title)
            body = normalize_text(body)

            if len(title) > 0 and len(body) > 0:
                all_data.append({
                    "title": title,
                    "body": body,
                    "genre_id": i
                    })

In [6]:
random.seed(1234)
random.shuffle(all_data)

def to_line(data):
    title = data["title"]
    body = data["body"]
    genre_id = data["genre_id"]

    assert len(title) > 0 and len(body) > 0
    return f"{title}\t{body}\t{genre_id}\n"

data_size = len(all_data)
train_ratio, dev_ratio, test_ratio = 0.7, 0.15, 0.15

with open(f"../inputs/train.tsv", "w", encoding="utf-8") as f_train, \
    open(f"../inputs/dev.tsv", "w", encoding="utf-8") as f_dev, \
    open(f"../inputs/test.tsv", "w", encoding="utf-8") as f_test:
    
    for i, data in tqdm(enumerate(all_data)):
        line = to_line(data)
        if i < train_ratio * data_size:
            f_train.write(line)
        elif i < (train_ratio + dev_ratio) * data_size:
            f_dev.write(line)
        else:
            f_test.write(line)

7334it [00:00, 64440.43it/s]


In [7]:
# !head -1 ../inputs/test.tsv

In [8]:
class TsvDataset(Dataset):
    def __init__(self, tokenizer, data_dir, type_path, input_max_len=512, target_max_len=512):
        self.file_path = os.path.join(data_dir, type_path)
        
        self.input_max_len = input_max_len
        self.target_max_len = target_max_len
        self.tokenizer = tokenizer
        self.inputs = []
        self.targets = []

        self._build()
  
    def __len__(self):
        return len(self.inputs)
  
    def __getitem__(self, index):
        source_ids = self.inputs[index]["input_ids"].squeeze()
        target_ids = self.targets[index]["input_ids"].squeeze()

        source_mask = self.inputs[index]["attention_mask"].squeeze()
        target_mask = self.targets[index]["attention_mask"].squeeze()

        return {
            "source_ids": source_ids, "source_mask": source_mask, 
            "target_ids": target_ids, "target_mask": target_mask
        }

    def _make_record(self, title, body, genre_id):
        # ニュース分類タスク用の入出力形式に変換する。
        input = f"{title} {body}"
        target = f"{genre_id}"
        return input, target
  
    def _build(self):
        with open(self.file_path, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip().split("\t")
                assert len(line) == 3
                assert len(line[0]) > 0
                assert len(line[1]) > 0
                assert len(line[2]) > 0

                title = line[0]
                body = line[1]
                genre_id = line[2]

                input, target = self._make_record(title, body, genre_id)

                tokenized_inputs = self.tokenizer.batch_encode_plus(
                    [input], max_length=self.input_max_len, truncation=True, 
                    padding="max_length", return_tensors="pt"
                )

                tokenized_targets = self.tokenizer.batch_encode_plus(
                    [target], max_length=self.target_max_len, truncation=True, 
                    padding="max_length", return_tensors="pt"
                )

                self.inputs.append(tokenized_inputs)
                self.targets.append(tokenized_targets)

In [9]:
tokenizer = T.T5Tokenizer.from_pretrained(args.tokenizer_name_or_path, cache_dir=args.cache_dir)

In [10]:
train_dataset = TsvDataset(
    tokenizer, args.data_dir, "train.tsv",  input_max_len=512, target_max_len=4)

In [11]:
'''
for data in train_dataset:
    print("A. 入力データの元になる文字列")
    print(tokenizer.decode(data["source_ids"]))
    print()
    print("B. 入力データ（Aの文字列がトークナイズされたトークンID列）")
    print(data["source_ids"])
    print()
    print("C. 出力データの元になる文字列")
    print(tokenizer.decode(data["target_ids"]))
    print()
    print("D. 出力データ（Cの文字列がトークナイズされたトークンID列）")
    print(data["target_ids"])
    break
'''

'\nfor data in train_dataset:\n    print("A. 入力データの元になる文字列")\n    print(tokenizer.decode(data["source_ids"]))\n    print()\n    print("B. 入力データ（Aの文字列がトークナイズされたトークンID列）")\n    print(data["source_ids"])\n    print()\n    print("C. 出力データの元になる文字列")\n    print(tokenizer.decode(data["target_ids"]))\n    print()\n    print("D. 出力データ（Cの文字列がトークナイズされたトークンID列）")\n    print(data["target_ids"])\n    break\n'

In [12]:
class T5FineTuner(PL.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        # self.hparams = hparams
        # self.hparams.update(hparams)
        self.save_hyperparameters(hparams)

        # 事前学習済みモデルの読み込み
        if hparams.model_name_or_path is None:
            self.config = T.T5Config.from_pretrained(hparams.tokenizer_name_or_path, vocab_size=32100)
            self.model = T.T5ForConditionalGeneration(config=self.config)
        else:
            self.config = T.T5Config.from_pretrained(hparams.model_name_or_path, vocab_size=32100)
            self.model = T.T5ForConditionalGeneration.from_pretrained(hparams.model_name_or_path, config=self.config)
            report(f"From pretrained: {hparams.model_name_or_path}")

        # トークナイザーの読み込み
        self.tokenizer = T.T5Tokenizer.from_pretrained(hparams.tokenizer_name_or_path, is_fast=True, config=self.config)

    def forward(
        self, input_ids, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, labels=None):
        """順伝搬"""
        return self.model(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            labels=labels
        )

    def _step(self, batch):
        """ロス計算"""
        labels = batch["target_ids"]

        # All labels set to -100 are ignored (masked), 
        # the loss is only computed for labels in [0, ..., config.vocab_size]
        labels[labels[:, :] == self.tokenizer.pad_token_id] = -100

        outputs = self(
            input_ids=batch["source_ids"],
            attention_mask=batch["source_mask"],
            decoder_attention_mask=batch['target_mask'],
            labels=labels
        )

        loss = outputs[0]
        return loss

    def training_step(self, batch, batch_idx):
        """訓練ステップ処理"""
        loss = self._step(batch)
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        """バリデーションステップ処理"""
        loss = self._step(batch)
        self.log("val_loss", loss)
        return {"val_loss": loss}

    # def validation_epoch_end(self, outputs):
    #     """バリデーション完了処理"""
    #     avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
    #     self.log("val_loss", avg_loss, prog_bar=True)

    def test_step(self, batch, batch_idx):
        """テストステップ処理"""
        loss = self._step(batch)
        self.log("test_loss", loss)
        return {"test_loss": loss}

    # def test_epoch_end(self, outputs):
    #     """テスト完了処理"""
    #     avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean()
    #     self.log("test_loss", avg_loss, prog_bar=True)

    def configure_optimizers(self):
        """オプティマイザーとスケジューラーを作成する"""
        model = self.model
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() 
                            if not any(nd in n for nd in no_decay)],
                "weight_decay": self.hparams.weight_decay,
            },
            {
                "params": [p for n, p in model.named_parameters() 
                            if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        optimizer = T.AdamW(optimizer_grouped_parameters, 
                          lr=self.hparams.learning_rate, 
                          eps=self.hparams.adam_epsilon)

        scheduler = T.get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=self.hparams.warmup_steps, 
            num_training_steps=self.t_total
        )

        return [optimizer], [{"scheduler": scheduler, "interval": "step", "frequency": 1}]

    def get_dataset(self, tokenizer, type_path, args):
        """データセットを作成する"""
        return TsvDataset(
            tokenizer=tokenizer, 
            data_dir=args.data_dir, 
            type_path=type_path, 
            input_max_len=args.max_input_length,
            target_max_len=args.max_target_length)
    
    def setup(self, stage=None):
        """初期設定（データセットの読み込み）"""
        if stage == 'fit' or stage is None:
            train_dataset = self.get_dataset(tokenizer=self.tokenizer, type_path="train.tsv", args=self.hparams)
            self.train_dataset = train_dataset

            val_dataset = self.get_dataset(tokenizer=self.tokenizer, type_path="dev.tsv", args=self.hparams)
            self.val_dataset = val_dataset

            self.t_total = (
                (len(train_dataset) // (self.hparams.train_batch_size * max(1, self.hparams.num_gpus)))
                // self.hparams.gradient_accumulation_steps
                * float(self.hparams.num_train_epochs)
            )

    def train_dataloader(self):
        """訓練データローダーを作成する"""
        return DataLoader(
            self.train_dataset,  batch_size=self.hparams.train_batch_size, drop_last=True, shuffle=True, num_workers=4)

    def val_dataloader(self):
        """バリデーションデータローダーを作成する"""
        return DataLoader(
            self.val_dataset, batch_size=self.hparams.eval_batch_size, num_workers=4)

In [13]:
args_dict = {k: v for k, v in args.__dict__.items() if not k.startswith("__")}

In [14]:
_args = argparse.Namespace(**args_dict)

In [15]:
model = T5FineTuner(_args)

In [16]:
# main
train_params = dict(
    accumulate_grad_batches=args.gradient_accumulation_steps,
    gpus=args.num_gpus,
    max_epochs=args.num_train_epochs,
    precision= 16 if args.fp_16 else 32,
    amp_level=args.opt_level,
    gradient_clip_val=args.max_grad_norm,
    # checkpoint_callback=checkpoint_callback,
)
trainer = PL.Trainer(**train_params)
trainer.fit(model)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores

  | Name  | Type                       | Params
-----------------------------------------------------
0 | model | T5ForConditionalGeneration | 222 M 
-----------------------------------------------------
222 M     Trainable params
0         Non-trainable params
222 M     Total params
891.528   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

1

In [17]:
args.output_dir

'../outputs/en2ja_t5-base_classification_2022-05-09_133156'

In [19]:
!mv ../outputs/en2ja_t5-base_classification_2022-05-09_133156 ../outputs/t5-base-ja_no-PT_classification_2022-05-09_133156

In [20]:
args.output_dir = "../outputs/t5-base-ja_no-PT_classification_2022-05-09_133156"

In [21]:
model.model.save_pretrained(args.output_dir)

In [22]:
# testing
from sklearn import metrics
import textwrap
from tqdm.auto import tqdm

In [23]:
# testing
config = T.T5Config.from_pretrained(args.tokenizer_name_or_path, vocab_size=32100)
# tokenizer = T.T5Tokenizer.from_pretrained(args.output_dir, is_fast=True, config=config)
tokenizer = T.T5Tokenizer.from_pretrained(args.tokenizer_name_or_path, is_fast=True, config=config)
trained_model = T.T5ForConditionalGeneration.from_pretrained(args.output_dir, config=config)

if torch.cuda.is_available():
    trained_model.cuda()

In [24]:
test_dataset = TsvDataset(
    tokenizer, args_dict["data_dir"], "test.tsv", 
    input_max_len=args.max_input_length, target_max_len=args.max_target_length)

test_loader = DataLoader(test_dataset, batch_size=32, num_workers=4)

trained_model.eval()

outputs = []
confidences = []
targets = []

for batch in tqdm(test_loader):
    input_ids = batch['source_ids']
    input_mask = batch['source_mask']
    if torch.cuda.is_available():
        input_ids = input_ids.cuda()
        input_mask = input_mask.cuda()

    outs = trained_model.generate(input_ids=input_ids, 
        attention_mask=input_mask, 
        max_length=args.max_target_length,
        return_dict_in_generate=True,
        output_scores=True)

    dec = [tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) for ids in outs.sequences]
    # conf = [s.cpu().item() for s in torch.exp(outs.sequences_scores)]
    target = [tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) for ids in batch["target_ids"]]

    outputs.extend(dec)
    # confidences.extend(conf)
    targets.extend(target)

  0%|          | 0/35 [00:00<?, ?it/s]

In [25]:
metrics.accuracy_score(targets, outputs)

0.28454545454545455

In [26]:
print(metrics.classification_report(targets, outputs))

              precision    recall  f1-score   support

           0       0.20      0.79      0.32       130
           1       0.18      0.50      0.26       121
           2       0.61      0.35      0.44       123
           3       0.00      0.00      0.00        82
           4       0.39      0.12      0.19       129
           5       0.00      0.00      0.00       141
           6       0.00      0.00      0.00       127
           7       0.00      0.00      0.00       127
           8       0.66      0.76      0.71       120

    accuracy                           0.28      1100
   macro avg       0.23      0.28      0.21      1100
weighted avg       0.23      0.28      0.22      1100



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
