In [5]:
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score
from sklearn.metrics import confusion_matrix
import pandas as pd
import os
import csv
import numpy as np

def compute_metrics(pred, labels):
#     pred = np.argmax(pred, axis=1)
    if (0 in pred) or (1 in pred) or (2 in pred) or (3 in pred):
        try: # Multi Label
            accuracy = accuracy_score(y_true=labels, y_pred=pred, average='weighted')
        except: # Bin Label
            accuracy = accuracy_score(y_true=labels, y_pred=pred)
        try: # Multi Label
            recall = recall_score(y_true=labels, y_pred=pred, average='weighted')
        except: # Bin Label
            recall = recall_score(y_true=labels, y_pred=pred)
        try: # Multi Label
            precision = precision_score(y_true=labels, y_pred=pred, average='weighted')
        except: # Bin Label
            precision = precision_score(y_true=labels, y_pred=pred)
        try: # Multi Label
            f1_macro = f1_score(y_true=labels, y_pred=pred, average='macro')
            f1_micro = f1_score(y_true=labels, y_pred=pred, average='micro')
        except: # Bin Label
            f1_macro = f1_score(y_true=labels, y_pred=pred, average='macro')
            f1_micro = f1_score(y_true=labels, y_pred=pred, average='micro')
    else:
        accuracy, precision, recall, f1_micro, f1_macro = 0,0,0,0,0
    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1_micro": f1_micro, "f1_macro": f1_macro}

def convert_result(save_dir,ver_name,labels):
    src_type_list = ['original','translated','rewrited']
    tgt_lang_list = ['ja','zh']
    mt_ja_metric,ht_ja_metric,mt_zh_metric,ht_zh_metric = pd.DataFrame(),pd.DataFrame(),pd.DataFrame(),pd.DataFrame()
    mt_ja_confusion,ht_ja_confusion,mt_zh_confusion,ht_zh_confusion = pd.DataFrame(),pd.DataFrame(),pd.DataFrame(),pd.DataFrame()
    for src_type in src_type_list:
        for tgt_lang in tgt_lang_list:
            path = save_dir+f'{src_type}/{tgt_lang}_preds_truth.csv'
            df = pd.read_csv(path)
            preds = df[ver_name].to_list()
            truth = df['truth'].to_list()

            metric = compute_metrics(preds,truth)
            print(preds)
            metric_ = pd.DataFrame.from_dict(metric, orient='index').T
            metric_.to_csv(f'{save_dir}{src_type}/{tgt_lang}_metric.csv',header=True, index=True)
            
            try:
                confusion_matrix_ = confusion_matrix(truth, preds, labels=labels)
                confusion_matrix_ = pd.DataFrame(confusion_matrix_)
                confusion_matrix_.to_csv(f'{save_dir}{src_type}/{tgt_lang}_confusion_matrix.csv',header=True, index=True)
            except:
                pass

In [6]:
situation = 'all'
sentence_type = 'all'
class_type = 'multi' if situation == 'all' else 'binary'

# for multi
# ver_name_list_multi = [ 
#                         '000_multi_all_all',
#                         '010_multi_all_query',
#                         '020_multi_all_res',
#                         '030_multi_all_res_context',
#                         '001_multi_all_all_prefix',
#                         '011_multi_all_query_prefix',
#                         '021_multi_all_res_prefix',
#                         '031_multi_all_res_context_prefix',
#                         '002_multi_all_all_prefix_rel',
#                         '012_multi_all_query_prefix_rel',
#                         '022_multi_all_res_prefix_rel']
# for ver_name in ver_name_list_multi:
#     save_dir = f'outputs/situation_classification/{ver_name}/'
#     convert_result(save_dir,ver_name,['negative','apology','request','thanksgiving'])
    
situation_dict = {1:'apology',2:'request',3:'thanksgiving'}
for i,situation in situation_dict.items():
    ver_name_list_binary = [
                            f'{i}00_multi_{situation}_all',
                            f'{i}10_multi_{situation}_query',
                            f'{i}20_multi_{situation}_res',
#                             f'{i}30_multi_{situation}_res_context',
                            f'{i}01_multi_{situation}_all_prefix',
                            f'{i}11_multi_{situation}_query_prefix',
                            f'{i}21_multi_{situation}_res_prefix',
#                             f'{i}31_multi_{situation}_res_context_prefix',
                            f'{i}02_multi_{situation}_all_prefix_rel',
                            f'{i}12_multi_{situation}_query_prefix_rel',
                            f'{i}22_multi_{situation}_res_prefix_rel',
                           ]
    for ver_name in ver_name_list_binary:
#         if ver_name[:3] in ['100','210','220','320']:
#             continue
        save_dir = f'outputs/situation_classification/{ver_name}/'
        convert_result(save_dir,ver_name,['negative',situation])

[1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1]
[1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0]
[0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0]
[1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 

  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):


[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1]
[0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1]
[0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  _warn_prf(average, modifier, msg_start, len(result))
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  _warn_prf(average, modifier, msg_start, len(result))
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  _warn_prf(average, modifier, msg_start, len(result))
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  _warn_prf(average, modifier, msg_start, len(result))
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  _warn_prf(average, modifier, msg_start, len(result))
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y

[1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

  elif np.all([l not in y_true for l in labels]):
  _warn_prf(average, modifier, msg_start, len(result))
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  _warn_prf(average, modifier, msg_start, len(result))
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  _warn_prf(average, modifier, msg_start, len(result))
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  _warn_prf(average, modifier, msg_start, len(result))
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  _warn_prf(average, modifier, msg_start, len(result))
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  _warn_prf(average, modifier, msg_start, len(result))
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not

  _warn_prf(average, modifier, msg_start, len(result))
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  _warn_prf(average, modifier, msg_start, len(result))
  elif np.all([l not in y_true for l in labels]):
  _warn_prf(average, modifier, msg_start, len(result))
  elif np.all([l not in y_true for l in labels]):
  _warn_prf(average, modifier, msg_start, len(result))
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  _warn_prf(average, modifier, msg_start, len(result))
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  _warn_prf(average, modi

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]
[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

  elif np.all([l not in y_true for l in labels]):
  _warn_prf(average, modifier, msg_start, len(result))
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  _warn_prf(average, modifier, msg_start, len(result))
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  _warn_prf(average, modifier, msg_start, len(result))
  elif np.all([l not in y_true for l in labels]):
  elif np.all([l not in y_true for l in labels]):
  _warn_prf(average, modifier, msg_start, len(result))
  elif np.all([l not in y_true for l in labels]):
  _warn_prf(average, modifier, msg_start, len(result))
  elif np.all([l not in y_true for l in labels]):
  _warn_prf(average, modifier, msg_start, len(result))
  elif np.all([l not in y_true for l in labels]):


In [7]:
dir_path = 'outputs/situation_classification/'

df = pd.DataFrame()
# situation_dict = {0:'all',1:'apology',2:'request',3:'thanksgiving'}
situation_dict = {1:'apology',2:'request',3:'thanksgiving'}
# sen_type_dict = {0:'all',1:'query',2:'res',3:'res_context'}
sen_type_dict = {0:'all',1:'query',2:'res'}
prefix_dict = {0:"",1:"_prefix",2:f'_prefix_rel'}

src_type_list = ['original','translated','rewrited']
tgt_lang_list = ['ja','zh']
index_names = []
scores = []
for s, situation in situation_dict.items():
    for t, sen_type in sen_type_dict.items():
        for p, prefix in prefix_dict.items():
            # "_prefix_rel"は、現状all_allのみしかやってないので、それ以外の場合はcontinue
            ver_name  = f'{s}{t}{p}_multi_{situation}_{sen_type}{prefix}'
#             if (p==2) and (s==0) and (t==0):
#                 pass
#             elif (p==2):
#                 continue
# #             else:
# #                 continue
            for src_type in src_type_list:
                for tgt_lang in tgt_lang_list:
                    score = []
                    with open(f'{dir_path}{ver_name}/{src_type}/{tgt_lang}_metric.csv')as f:
                        reader = csv.reader(f)
                        for i,row in enumerate(reader):
                            if i == 1:
                                score = row[1:]
                        scores.append(score)
                    index_names.append([ver_name,tgt_lang,situation,sen_type,prefix,src_type])

for_table = []
for score, index in zip(scores,index_names):
    for_table.append(index+score)
df = pd.DataFrame(for_table, columns=['ver_name','tgt_lang','situation','sentence_type','prefix','src_type','accuracy','precision','recall','f1_micro','f1_macro'])

save_metric_dir = 'for_thesis/situation_classification_t5/'
os.makedirs(save_metric_dir,exist_ok=True)
df.to_csv(save_metric_dir+'scores.csv',encoding='utf-8-sig')
df

Unnamed: 0,ver_name,tgt_lang,situation,sentence_type,prefix,src_type,accuracy,precision,recall,f1_micro,f1_macro
0,100_multi_apology_all,ja,apology,all,,original,0.4787234042553192,0.606428376865107,0.4787234042553192,0.47872340425531923,0.47390062821245005
1,100_multi_apology_all,zh,apology,all,,original,0.5769230769230769,0.6305170239596469,0.5769230769230769,0.5769230769230769,0.5712143928035982
2,100_multi_apology_all,ja,apology,all,,translated,0.5212765957446809,0.5691216584833606,0.5212765957446809,0.5212765957446809,0.4782286912544715
3,100_multi_apology_all,zh,apology,all,,translated,0.46153846153846156,0.5325443786982248,0.46153846153846156,0.46153846153846156,0.4603889943074004
4,100_multi_apology_all,ja,apology,all,,rewrited,0.6595744680851063,0.6814142678347935,0.6595744680851063,0.6595744680851063,0.6149513568868408
...,...,...,...,...,...,...,...,...,...,...,...
157,322_multi_thanksgiving_res_prefix_rel,zh,thanksgiving,res,_prefix_rel,original,0.8,0.64,0.8,0.8000000000000002,0.4444444444444445
158,322_multi_thanksgiving_res_prefix_rel,ja,thanksgiving,res,_prefix_rel,translated,0.3114754098360656,0.8016115587663241,0.3114754098360656,0.3114754098360656,0.267162471395881
159,322_multi_thanksgiving_res_prefix_rel,zh,thanksgiving,res,_prefix_rel,translated,0.8,0.64,0.8,0.8000000000000002,0.4444444444444445
160,322_multi_thanksgiving_res_prefix_rel,ja,thanksgiving,res,_prefix_rel,rewrited,0.2786885245901639,0.07766729373824241,0.2786885245901639,0.2786885245901639,0.21794871794871792


In [29]:
scores = pd.read_csv(save_metric_dir+'scores.csv')
scores=scores.fillna("_")
pd.options.display.max_rows = 150
scores=scores.groupby(['tgt_lang','sentence_type','prefix','src_type']).mean().reset_index().drop('Unnamed: 0', axis=1)
scores.to_csv(save_metric_dir+'overall_scores.csv',encoding='utf-8-sig')
scores

Unnamed: 0,tgt_lang,sentence_type,prefix,src_type,accuracy,precision,recall,f1_micro,f1_macro
0,ja,all,_,original,0.619345,0.602193,0.619345,0.619345,0.454765
1,ja,all,_,rewrited,0.41526,0.62477,0.41526,0.41526,0.39595
2,ja,all,_,translated,0.449621,0.635532,0.449621,0.449621,0.433911
3,ja,all,_prefix,original,0.654683,0.58633,0.654683,0.654683,0.436434
4,ja,all,_prefix,rewrited,0.625092,0.674924,0.625092,0.625092,0.50388
5,ja,all,_prefix,translated,0.645207,0.613365,0.645207,0.645207,0.489507
6,ja,all,_prefix_rel,original,0.648936,0.624605,0.648936,0.648936,0.43253
7,ja,all,_prefix_rel,rewrited,0.621729,0.631512,0.621729,0.621729,0.46635
8,ja,all,_prefix_rel,translated,0.635241,0.61071,0.635241,0.635241,0.461965
9,ja,query,_,original,0.491196,0.541703,0.491196,0.491196,0.432519
