## Import Libraries

In [2]:
import os
import random
import warnings
import multiprocessing
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from simpletransformers.t5 import T5Model
from sumeval.metrics.rouge import RougeCalculator

pd.set_option('display.max_colwidth', 10000)
warnings.simplefilter('ignore')

In [3]:
class Cfg:
    seed = 42
    cuda = torch.cuda.is_available()
    prefix = 'headlines_generation'
    data = '../data/headlines_generation_corpus/preprocessed_livedoornews.csv' 
    aug_data = ['../data/headlines_generation_corpus/sentence_shuffle_aug_data.csv', 
                '../data/headlines_generation_corpus/back_translation_aug_data.csv']

In [4]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(Cfg.seed)

In [5]:
# SimpleTransformers mT5に入力可能な形へリネーム
data = pd.read_csv(Cfg.data).rename(columns={'title':'target_text','body':'input_text'})

data = data.reindex(columns=['input_text','target_text','media'])
data['prefix'] = Cfg.prefix

In [6]:
print(f'input_textの文字数統計量 \n {data.input_text.apply(lambda x: len(x)).describe()}\n')
print(f'target_textの文字数統計量 \n {data.target_text.apply(lambda x: len(x)).describe()}')

input_textの文字数統計量 
 count    7232.000000
mean      904.344718
std       641.194157
min       101.000000
25%       449.000000
50%       719.000000
75%      1224.250000
max      7878.000000
Name: input_text, dtype: float64

target_textの文字数統計量 
 count    7232.000000
mean       33.079923
std        13.577580
min         5.000000
25%        24.000000
50%        31.000000
75%        39.000000
max       131.000000
Name: target_text, dtype: float64


In [7]:
train, val = train_test_split(data, test_size=0.2, random_state=Cfg.seed, stratify=data['media'])
train = train[['input_text','target_text','prefix']]
val = val[['input_text','target_text','prefix']]

In [8]:
# train.to_csv('../data/headlines_generation_corpus/train.csv', index=False)

In [9]:
print(f'augdataと結合前のtrain: {len(train)}\n')

train = pd.concat(
    [train,                                           # 訓練データ
     pd.read_csv(Cfg.aug_data[0]).dropna(how='any'),  # センテンスシャッフルを適用した訓練データ
     pd.read_csv(Cfg.aug_data[1])]                    # 逆翻訳を適用した訓練データ
     ).reset_index(drop=True)

print(f'augdataと結合後のtrain: {len(train)}')

augdataと結合前のtrain: 5785

augdataと結合後のtrain: 17343


## FineTuning

In [None]:
train_params = {
    'repetition_penalty': 1.5,
    'learning_rate': 1e-4,  
    'max_seq_length': 512,
    'max_length': 256,
    'train_batch_size': 2,
    'eval_batch_size': 2,
    'num_train_epochs': 10,
    'evaluate_during_training': True,
    'evaluate_during_training_steps': 5000,
    'use_multiprocessing': False,
    'fp16': False,
    'save_steps': -1,
    'save_eval_checkpoints': False,
    'save_model_every_epoch': False,
    'no_cache': True,
    'overwrite_output_dir': True,
    'preprocess_inputs': False,
    'dataloader_num_workers': multiprocessing.cpu_count()
}

model = T5Model('mt5', 'google/mt5-small', args=train_params, use_cuda=Cfg.cuda)
model.train_model(train, eval_data=val)

  0%|          | 0/17343 [00:00<?, ?it/s]

Using Adafactor for T5


Epoch:   0%|          | 0/10 [00:00<?, ?it/s]

Running Epoch 0 of 10:   0%|          | 0/8672 [00:00<?, ?it/s]

  0%|          | 0/1447 [00:00<?, ?it/s]

  0%|          | 0/1447 [00:00<?, ?it/s]

Running Epoch 1 of 10:   0%|          | 0/8672 [00:00<?, ?it/s]

  0%|          | 0/1447 [00:00<?, ?it/s]

  0%|          | 0/1447 [00:00<?, ?it/s]

  0%|          | 0/1447 [00:00<?, ?it/s]

Running Epoch 2 of 10:   0%|          | 0/8672 [00:00<?, ?it/s]

  0%|          | 0/1447 [00:00<?, ?it/s]

  0%|          | 0/1447 [00:00<?, ?it/s]

  0%|          | 0/1447 [00:00<?, ?it/s]

Running Epoch 3 of 10:   0%|          | 0/8672 [00:00<?, ?it/s]

  0%|          | 0/1447 [00:00<?, ?it/s]

  0%|          | 0/1447 [00:00<?, ?it/s]

Running Epoch 4 of 10:   0%|          | 0/8672 [00:00<?, ?it/s]

  0%|          | 0/1447 [00:00<?, ?it/s]

  0%|          | 0/1447 [00:00<?, ?it/s]

  0%|          | 0/1447 [00:00<?, ?it/s]

Running Epoch 5 of 10:   0%|          | 0/8672 [00:00<?, ?it/s]

  0%|          | 0/1447 [00:00<?, ?it/s]

  0%|          | 0/1447 [00:00<?, ?it/s]

  0%|          | 0/1447 [00:00<?, ?it/s]

Running Epoch 6 of 10:   0%|          | 0/8672 [00:00<?, ?it/s]

  0%|          | 0/1447 [00:00<?, ?it/s]

  0%|          | 0/1447 [00:00<?, ?it/s]

  0%|          | 0/1447 [00:00<?, ?it/s]

Running Epoch 7 of 10:   0%|          | 0/8672 [00:00<?, ?it/s]

  0%|          | 0/1447 [00:00<?, ?it/s]

  0%|          | 0/1447 [00:00<?, ?it/s]

Running Epoch 8 of 10:   0%|          | 0/8672 [00:00<?, ?it/s]

  0%|          | 0/1447 [00:00<?, ?it/s]

  0%|          | 0/1447 [00:00<?, ?it/s]

  0%|          | 0/1447 [00:00<?, ?it/s]

Running Epoch 9 of 10:   0%|          | 0/8672 [00:00<?, ?it/s]

  0%|          | 0/1447 [00:00<?, ?it/s]

  0%|          | 0/1447 [00:00<?, ?it/s]

  0%|          | 0/1447 [00:00<?, ?it/s]

(86720,
 {'global_step': [5000,
   8672,
   10000,
   15000,
   17344,
   20000,
   25000,
   26016,
   30000,
   34688,
   35000,
   40000,
   43360,
   45000,
   50000,
   52032,
   55000,
   60000,
   60704,
   65000,
   69376,
   70000,
   75000,
   78048,
   80000,
   85000,
   86720],
  'eval_loss': [3.3763466170314924,
   3.1116890624906475,
   3.044511700162242,
   2.939817726817908,
   2.9213592198095926,
   2.919580186969338,
   2.9108467747081708,
   2.896219839172139,
   2.938028741689677,
   2.971024427336553,
   2.984746217439517,
   3.0310392794688106,
   3.0562639526240734,
   3.1030994726149417,
   3.154257061402442,
   3.2114194389535577,
   3.2941372157048785,
   3.3068574933715946,
   3.3553057588760367,
   3.4678541043186715,
   3.4316690981552744,
   3.528649355290015,
   3.6848636344075203,
   3.6694261651095106,
   3.7863412937406675,
   3.840989904509065,
   3.8847450858897927],
  'train_loss': [4.813566207885742,
   3.6370949745178223,
   3.585890769958496,
  

In [None]:
pred_params = {
    'max_seq_length': 512,
    'use_multiprocessed_decoding': False,
    'num_beams': 4
}

model = T5Model('mt5', 'outputs/best_model', args=pred_params, use_cuda=Cfg.cuda)
preds = model.predict(list(val['input_text']))

Generating outputs:   0%|          | 0/724 [00:00<?, ?it/s]

## Evaluate

In [None]:
def rouge_calc(preds, targets):
    rouge = RougeCalculator(stopwords=True, lang='ja')
    rouge_1 = [rouge.rouge_n(summary=preds[i],references=targets[i],n=1) for i in range(len(preds))]
    rouge_2 = [rouge.rouge_n(summary=preds[i],references=targets[i],n=2) for i in range(len(preds))]
    rouge_l = [rouge.rouge_l(summary=preds[i],references=targets[i]) for i in range(len(preds))]

    return {'Rouge_1': np.array(rouge_1).mean(),
            'Rouge_2': np.array(rouge_2).mean(),
            'Rouge_L': np.array(rouge_l).mean()}

In [None]:
rouge_calc(preds, list(val['target_text']))

{'Rouge_1': 0.3446141958018871,
 'Rouge_2': 0.16201712695139747,
 'Rouge_L': 0.3064431681368891}

In [None]:
# 訓練データのみ
# {'Rouge_1': 0.3366569193887046,
#  'Rouge_2': 0.15542767293242557,
#  'Rouge_L': 0.3003578772701939}
# 
# 訓練データ + 逆翻訳データ
# {'Rouge_1': 0.34069566637291754,
#  'Rouge_2': 0.16059334773532494,
#  'Rouge_L': 0.30380623914367033}

In [None]:
pd.concat([val['target_text'].reset_index(drop=True),pd.Series(preds, name='preds')], axis=1).sample(5, random_state=Cfg.seed)

Unnamed: 0,target_text,preds
1023,壇れいがミラ・ジョヴォヴィッチとの初対面に感激,ミラ・ジョヴォヴィッチがセクシーなフェラガモのドレスに身を包んで登場
381,マット・デイモンとスカーレット・ヨハンソンが共演、世界を感動で包んだ実話を映画化,マット・デイモンとスカーレット・ヨハンソンが共演、『幸せへのキセキ』予告映像が公開
843,KDDI、au向け「MIRACHIS11PT」「EIS01PT」に音声着信できない不具合などでソフトウェア更新を提供開始,パンテック、auスマートフォン「minichikas11PLと「IS01PR」に不具合でソフトウェア更新を提供開始
427,なでしこ初戦のスタジアムに意外な事実、みのもんたは「なんで?なんでないの?」,なでしこジャパン、ロンドン五輪・女子サッカー1次リーグ初戦に「時計が設置されていない」
192,「男はアクション、女は本音トーク」どちらも破壊力抜群『Black&White』,恋愛映画で女子は楽しめない!『恋人編』
