In [1]:
import pandas as pd
import numpy as np
import csv
import os


def save_csv(path,data):
    with open(path, 'w', encoding='utf_8_sig') as f:
        writer = csv.writer(f)
        writer.writerows(data)
        

def save_as_dataset(save_dir,fname,data):
    pureidx = np.arange(len(data))
    val_idx = pureidx[5::10]
    test_idx = pureidx[::10]

    ind = np.ones(len(data), dtype=bool)
    ind[val_idx] = False
    ind[test_idx] = False
    train_idx = pureidx[ind]
    
    data = np.array(data)
    
    train_data = list(data[train_idx])
    val_data = list(data[val_idx])
    test_data = list(data[test_idx])
    
    save_csv(save_dir+fname+'_train',train_data)
    save_csv(save_dir+fname+'_val',val_data)
    save_csv(save_dir+fname+'_test',test_data)    

In [2]:
corpus = 'mpdd'
integrated_path = f'/nfs/nas-7.1/yamashita/LAB/dialogue_data/integrated_data_relative_relation/integrated_{corpus}.csv'
data_df = pd.read_csv(integrated_path, 
                names=['original','translated','speakerid','conversationid',
                'req_query_tag','req_query_unreadable', 'req_query_natural','req_query_rewrited', 'req_res_unreadable',   'req_res_natural',  'req_res_rewrited',
                'apo_query_tag','apo_query_unreadable', 'apo_query_natural','apo_query_rewrited', 'apo_res_unreadable',   'apo_res_natural',  'apo_res_rewrited',
                'tha_query_tag','tha_query_unreadable', 'tha_query_natural','tha_query_rewrited', 'tha_res_unreadable',   'tha_res_natural',  'tha_res_rewrited',
                'utteranceid'], 
                encoding='utf-8-sig')
data_df['original'] = data_df['original'].str.replace('　','')
data_df['original'] = data_df['original'].str.replace(' ','')


data_df[10:20]

Unnamed: 0,original,translated,speakerid,conversationid,req_query_tag,req_query_unreadable,req_query_natural,req_query_rewrited,req_res_unreadable,req_res_natural,...,apo_res_natural,apo_res_rewrited,tha_query_tag,tha_query_unreadable,tha_query_natural,tha_query_rewrited,tha_res_unreadable,tha_res_natural,tha_res_rewrited,utteranceid
10,えー矢張のやつか。,誒，八佰里的那個。,brothers and sisters,T010_003,,,,,,,...,,,,,,,,,,11
11,映画も出てくるけど。矢張に。矢張じゃねえよ。あれはナツミだよ。ナツミのカメラ。,不過有一部電影要上映了。。給亞哈里。。這不是Yabari。。那是+Natsumi。。夏美的相機。,brothers and sisters,T010_003,,,,,,,...,,,,,,,,,,12
12,矢張のが犯人。,Yahari是殺手。,child,T010_003,,,,,,,...,,,,,,,,,,13
13,ほら。,給你,parent,T010_003,,,,,,,...,,,,,,,,,,14
14,犯人じゃねえよ。,我不是兇手,child,T010_003,,,,,,,...,,,,,,,,,,15
15,これさ。ね。,給你。嘿嘿,parent,T010_003,,,,,,,...,,,,,,,,,,16
16,矢張は。,亞華里是。,brothers and sisters,T010_003,,,,,,,...,,,,,,,,,,17
17,あ。違うちゃう。容疑者。,啊。。不，它不是。。嫌疑人。,brothers and sisters,T010_003,,,,,,,...,,,,,,,,,,18
18,なんでお前矢張が犯人て。,爲什麼+你+亞華利是罪魁禍首。,brothers and sisters,T010_003,,,,,,,...,,,,,,,,,,19
19,矢張が矢張が容疑者。,Yahari是嫌疑人。,brothers and sisters,T010_003,,,,,,,...,,,,,,,,,,20


In [3]:
eos_token = ' '
sit_list = ['request','apology','thanksgiving']
for sit in sit_list:
    tag_col = sit[:3] + '_query_tag'
    rewrited_query_col = sit[:3] + '_query_rewrited'
    rewrited_res_col = sit[:3] + '_res_rewrited'

    filtered_df = data_df[['original','translated','speakerid','conversationid',tag_col,rewrited_query_col,rewrited_res_col]]
    filtered_df = filtered_df.fillna('')
    filtered_list = filtered_df.values.tolist()
    
    ori_query, mt_query, ht_query = [],[],[]
    ori_res, mt_res, ht_res = [],[],[]
    relation_pair = []
    for i,row in enumerate(reversed(filtered_list)):
        if row[4]=='rewrite':
            q_idx = -i-1
            r_idx = -i
            #######
            #query#
            #######
            ori_text = filtered_list[q_idx][0]
            mt_text = filtered_list[q_idx][1]
            ht_text = filtered_list[q_idx][5]
            kaiwa_id = filtered_list[q_idx][3]
            
            ori_query.append([ori_text])
            mt_query.append([mt_text])
            ht_query.append([ht_text])

            #######
            # res #
            #######

            ori_text = filtered_list[r_idx][0]
            mt_text = filtered_list[r_idx][1]
            ht_text = filtered_list[q_idx][6]
            kaiwa_id = filtered_list[r_idx][3] 
            ori_res.append([ori_text])
            mt_res.append([mt_text])
            ht_res.append([ht_text])
            
            relation_pair.append([filtered_list[q_idx][2],filtered_list[r_idx][2]])

    save_dir = f'data/{corpus}/{sit}/'
    os.makedirs(save_dir, exist_ok=True)

    save_as_dataset(save_dir,'original_query',ori_query)
    save_as_dataset(save_dir,'translated_query',mt_query)
    save_as_dataset(save_dir,'rewrited_query',ht_query)
    save_as_dataset(save_dir,'original_res',ori_res)
    save_as_dataset(save_dir,'translated_res',mt_res)
    save_as_dataset(save_dir,'rewrited_res',ht_res)
    save_as_dataset(save_dir,'relation_pair',relation_pair)

In [5]:
filtered_for_neg_df = data_df[['original','translated','speakerid','conversationid','req_query_tag','apo_query_tag','tha_query_tag']]
filtered_for_neg_df = filtered_for_neg_df.fillna('')
filtered_for_neg_df = filtered_for_neg_df.values.tolist()

ori_neg_query, ori_neg_res = [],[]
relation_pair = []
for i,row in enumerate(reversed(filtered_list)):
    if (row[4]!='rewrite') and (row[5]!='rewrite') and (row[6]!='rewrite'):
        q_idx = -i-1
        r_idx = -i
        if filtered_for_neg_df[q_idx][3] == filtered_for_neg_df[r_idx][3]:
            tmp_query = filtered_for_neg_df[q_idx][0]
            tmp_res = filtered_for_neg_df[r_idx][0]
            ori_neg_query.append([tmp_query])
            ori_neg_res.append([tmp_res])
            
            relation_pair.append([filtered_for_neg_df[q_idx][2],filtered_for_neg_df[r_idx][2]])
save_dir = f'data/{corpus}/negative/'
os.makedirs(save_dir, exist_ok=True)

save_as_dataset(save_dir,'original_query',ori_neg_query)
save_as_dataset(save_dir,'original_res',ori_neg_res)
save_as_dataset(save_dir,'relation_pair',relation_pair)
