# Data Preparing

In [1]:
import os
import csv
import random
import numpy as np
import pandas as pd

In [2]:
corpus_list = ['cejc','mpdd']
situation_list = ['apology','request','thanksgiving']
sen_type_list = ['query','res']
src_type = 'original' #'translated'
ver_name = '000_translate_all_both_prefix_rel'
context_len = 2

save_dir = f'outputs/context/{ver_name}/{context_len}/'

In [3]:
def get_data_as_list(path):
    data = []
    with open(path, 'r', encoding='utf-8-sig')as f:
        reader = csv.reader(f)
        for row in reader:
            data.append(row[0])
    return data

def get_df(corpus_list, situation_list, sen_type_list, src_type, context_len, train_type):
    target_text = []
    input_text = []
    prefix = []
    for corpus in corpus_list:
        for situation in situation_list:
            for sen_type in sen_type_list:
                f_path = f'/nfs/nas-7.1/yamashita/LAB/dialogue_data/data/{corpus}/{situation}/{context_len}/rewrited_{sen_type}_{train_type}'
                target_text += get_data_as_list(f_path)
                
                f_path = f'/nfs/nas-7.1/yamashita/LAB/dialogue_data/data/{corpus}/{situation}/{context_len}/{src_type}_{sen_type}_{train_type}'
                input_text += get_data_as_list(f_path)
                
                f_path = f'/nfs/nas-7.1/yamashita/LAB/dialogue_data/data/{corpus}/{situation}/{context_len}/relation_pair_{train_type}'
                rel_list = []
                with open(f_path, 'r', encoding='utf-8-sig')as f:
                    reader = csv.reader(f)
                    for row in reader:
                        if sen_type == 'query':
                            rel_list.append(row[0])
                        elif sen_type == 'res':
                            rel_list.append(row[1])
                tmp_prefix = []
                for rel in rel_list:
                    tmp_prefix.append(f'{corpus} {situation} {sen_type} {rel}')     
                prefix += tmp_prefix
    df = pd.DataFrame([prefix,input_text,target_text], index=['prefix','input_text','target_text']).astype(str).T
    return df

# Finetune

In [None]:
import logging
import pandas as pd
from simpletransformers.t5 import T5Model, T5Args

logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)

train_type = 'train'    
train_df = get_df(corpus_list, situation_list, sen_type_list, src_type, context_len, train_type)

train_type = 'val'  
eval_df = get_df(corpus_list, situation_list, sen_type_list, src_type, context_len, train_type)

# train_df["prefix"] = ""
# eval_df["prefix"] = ""

display(train_df.iloc[:5])
display(eval_df.iloc[:5])

In [None]:
model_args = T5Args()

model_args.max_seq_length = 128
model_args.length_penalty = 20
model_args.train_batch_size = 2
model_args.eval_batch_size = 2
model_args.num_train_epochs = 20
model_args.evaluate_during_training = True
model_args.evaluate_during_training_steps = 500
model_args.use_multiprocessing = False
model_args.fp16 = False
model_args.early_stopping_metric = 'eval_loss'
model_args.early_stopping_metric_minimize = True
model_args.early_stopping_patience = 3
model_args.use_early_stopping = True
model_args.save_eval_checkpoints = True
model_args.save_eval_checkpoints = False
model_args.learning_rate = 3e-5
model_args.best_model_dir = save_dir+'best_model/'
model_args.output_dir = save_dir+'ckpt/'
model_args.save_model_every_epoch = True
model_args.save_steps = -1
model_args.no_cache = True
model_args.reprocess_input_data = True
model_args.overwrite_output_dir = True
model_args.preprocess_inputs = False
model_args.num_return_sequences = 1
model_args.wandb_project = ver_name

model = T5Model("mt5", "google/mt5-base", args=model_args)
# model = T5Model("mt5", "google/mt5-base", args=model_args, cuda_device=1)
# Train the model
os.environ['WANDB_CONSOLE'] = 'off'
model.train_model(train_df, eval_data=eval_df)

# Test

In [None]:

import logging
import sacrebleu
import pandas as pd
from simpletransformers.t5 import T5Model, T5Args


logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)


model_args = T5Args()
model_args.max_length = 128
model_args.length_penalty = 20
model_args.num_beams = 10

model = T5Model("mt5", save_dir+"best_model/", args=model_args)
# model = T5Model("mt5", save_dir+"best_model/", args=model_args, cuda_device=1)

In [None]:
train_type = 'test'  
eval_df = get_df(corpus_list, situation_list, sen_type_list, src_type, context_len, train_type)

to_ja_truth = eval_df.loc[eval_df["prefix"].str.contains("mpdd")]["target_text"].tolist()
to_ja_input = eval_df.loc[eval_df["prefix"].str.contains("mpdd")]["input_text"].tolist()
to_ja_prefix = eval_df.loc[eval_df["prefix"].str.contains("mpdd")]["prefix"].tolist()

to_zh_truth = eval_df.loc[eval_df["prefix"].str.contains("cejc")]["target_text"].tolist()
to_zh_input = eval_df.loc[eval_df["prefix"].str.contains("cejc")]["input_text"].tolist()
to_zh_prefix = eval_df.loc[eval_df["prefix"].str.contains("cejc")]["prefix"].tolist()

# to_ja_input = [": " + input_text for input_text in to_ja_input]
# to_zh_input = [": " + input_text for input_text in to_zh_input]
to_ja_input = [prefix + ": " + input_text for prefix, input_text in zip(to_ja_prefix, to_ja_input)]
to_zh_input = [prefix + ": " + input_text for prefix, input_text in zip(to_zh_prefix, to_zh_input)]
to_zh_input[:5]

In [None]:
# Predict
to_ja_preds = model.predict(to_ja_input)
# to_ja_bleu = sacrebleu.corpus_bleu(to_ja_preds, to_ja_truth)
# print("--------------------------")
# print("to_ja_bleu: ", to_ja_bleu.score)

to_zh_preds = model.predict(to_zh_input)
# to_zh_bleu = sacrebleu.corpus_bleu(to_zh_preds, to_zh_truth)
# print("--------------------------")
# print("to_zh_bleu: ", to_zh_bleu.score)

In [None]:
r_ja_df = pd.DataFrame([to_ja_preds,to_ja_truth],index=[f'{ver_name}', 'truth'])
r_ja_df.T.to_csv(save_dir+'ja_preds_truth.csv',encoding='utf_8_sig')

r_zh_df = pd.DataFrame([to_zh_preds,to_zh_truth],index=[f'{ver_name}', 'truth'])
r_zh_df.T.to_csv(save_dir+'zh_preds_truth.csv',encoding='utf_8_sig')

## For Thesis

In [18]:
import logging
import sacrebleu
import pandas as pd
from simpletransformers.t5 import T5Model, T5Args


logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)


model_args = T5Args()
model_args.max_length = 128
model_args.length_penalty = 0
model_args.num_beams = 10



In [19]:
model = T5Model("mt5", save_dir+"best_model/", args=model_args)
text = ["mpdd request query spouse: query: 請你看在我倆同學的份上吧！你有什麼要求可以向我提提吧！ context: 正鵬，我倆還是…….怎麼，你不願意離婚是嗎？那不行！",
       "cecj apology query friend: query: あっ。えっと。ごめん。わたしもある。 context: すぐ決めるから。智希にゆうことあったらゆっていい?。"]
pred = model.predict(text)
pred

HBox(children=(FloatProgress(value=0.0, description='Generating outputs', max=1.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Decoding outputs', max=2.0, style=ProgressStyle(descripti…




['お願いしたいことがあるなら教えてください。', '啊,不好意思。']