In [1]:
import os
import pandas as pd
from os import listdir
from os.path import isfile, join
import numpy as np
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score


def combine_results(root_folder_name, folds = 5, epochs = 5, seed_num = 0, train = False, supcon_svm = True):
    if train:
        folder_name = f'{root_folder_name}_train/{seed_num}/'
        file_name = f"K{folds}_epochs{epochs}_"
    else:
        folder_name = f'{root_folder_name}_apply/{seed_num}/'
        file_name = f"epochs{epochs}_"

    list_of_files = [f.split(file_name)[1] for f in listdir(folder_name) if ((isfile(join(folder_name, f))) and (file_name in f))]
    
    assert(len(list_of_files)>0)

    final_data = pd.DataFrame()

    for item in list_of_files:
        
        if ('SVM' in root_folder_name) \
            or (('SUPCON' in root_folder_name) and (supcon_svm))\
            or (('SUPCON' in root_folder_name) and ('ce' not in item)):
            winner_col = 'winner_svm'
        else:
            winner_col = 'winner'
        
        print(f'processing "{file_name + item}"...')
        dir_name = folder_name + file_name + item
        if ('removed' in item) and not train:
            data = pd.read_csv(
                dir_name, usecols=['sentence', 'pmid', winner_col], 
                converters = {'pmid': str, winner_col: str}).rename(
                columns={winner_col:item[:-4]}, errors="raise")
            if '2to1' in item:
                data['t_label'] = 1
            else:
                data['t_label'] = 4 if '5t' in item else 0
        else:
            data = pd.read_csv(
                dir_name, usecols=['sentence', 'pmid', 'label', winner_col], 
                converters = {'pmid': str, winner_col: str}).rename(
                columns={winner_col:item[:-4], 'label':'t_label'}, errors="raise")
        
        if winner_col=='winner_svm':
            data[item[:-4]] = data[item[:-4]].apply(lambda x: 'c'+str(int(float(x))))
            
        if len(final_data) == 0:
            final_data = data
        else:
            if len(data)> len(final_data):
                merge_how = 'right'
            else:
                merge_how = 'left'
            final_data = final_data.merge(data, on=['sentence', 'pmid'], how='outer')
        if 'label' in final_data.columns:
            final_data['label'] = final_data['label'].fillna(final_data['t_label'])
        else:
            final_data['label'] = final_data['t_label']
        final_data = final_data.drop(columns=['t_label'])

    final_data['pmid_base'] = final_data['pmid'].apply(lambda x: str(x).split('_')[0])
    final_data['source'] = final_data['pmid'].apply(lambda x: 'edits' if (('edt' in x) or ('alt' in x)) else 'base')
    final_data['label'] = [4 if label==0 and source=='edits' else label for label, source in zip(final_data['label'], final_data['source'])]
    final_data['label'] = final_data['label'].apply(lambda x: int(x))

#     assert(len(final_data[final_data['pmid'].duplicated()])==0)
    
    return final_data

In [2]:
def initialise(final_data):
    # get columns to loop
    info_columns = ['pmid', 'sentence', 'label', 'source', 'pmid_base']
    prediction_columns = [i for i in final_data.columns if i not in info_columns]

    # initialise
    results = []
    full_breakdown = None
    return prediction_columns, results, full_breakdown

def format_data(results, full_breakdown, train = False):
    results = pd.DataFrame(results)
    if train:
        results.columns = ['pcat', 'Edit_Source', 'Classifier', 'Base_Type', 'Edit_Type', 'Label', \
                           'n_c0', 'n_c1', 'n_c2', 'n_c3', 'n_c4', 'n',\
                           'P','R','F1', 'Acc',\
                           'P_wo','R_wo','F1_wo', 'Acc_wo']
    else:
        results.columns = ['pcat', 'Edit_Source', 'Classifier', 'Base_Type', 'Edit_Type', 'Label', \
                           'P','R','F1', 'Acc', 
                           'P_wo','R_wo','F1_wo', 'Acc_wo']
    results = results.sort_values(by=['Classifier', 'Base_Type', 'Edit_Type', 'Label'])
    full_breakdown.index.name = None
    full_breakdown.columns.name = None
    return results, full_breakdown

def invert_dictionary(csci2sciente):
    sciente2csci = {}
    for k, v in csci2sciente.items():
        if v in sciente2csci.keys():
            sciente2csci[v] = sciente2csci[v] + [k]
        else:
            sciente2csci[v] = [k]
    return sciente2csci

def get_column_names(root_folder_name, pcat):
    classifier_name = root_folder_name.split('_')[1]
    if classifier_name=='SUPCON':
        if 'ce' in pcat:
            classifier_name+='+MLP'
        else:
            classifier_name+='+SVM'
        
    edit_type_name = \
        'Shorten' if 'shorten' in pcat else \
        ('Multiples' if 'multiples' in pcat else \
         ('Mask' if 'mask' in pcat else \
          ('Synonyms' if 'syn' in pcat else \
           ('T5Paraphrase' if 't5para' in pcat else \
            ('Regular' if (('edits' in pcat) or ('neg' in pcat)) else '-')))))
    if ('base' in pcat) and ('SUPCON' not in root_folder_name):
        base_type_name = edit_type_name if edit_type_name!='-' else 'Original'
        edit_type_name = '-'
    else:
        base_type_name = edit_type_name if 'oriedits' in pcat else 'Original'
    label_name = '4t_rs' if '4t_rs' in pcat else ('4t' if '4t' in pcat else ('5t' if (('5t' in pcat) or ('neg' in pcat)) else 'base'))
    edit_source_name = \
        '-' if label_name=='base' else \
        ('2to1' if '2to1' in pcat else \
         ('2to4' if '2to4' in pcat else \
          ('all' if 'all' in pcat else \
           ('mix' if 'mix' in pcat else '1to4'))))
    return classifier_name, edit_type_name, base_type_name, label_name, edit_source_name

def concat_to_main_frame(full_breakdown, breakdown):
    if full_breakdown is not None:
        full_breakdown = pd.concat([full_breakdown, breakdown], ignore_index=True)
    else:
        full_breakdown = breakdown.copy()
    return full_breakdown

##### TRAIN #####
def get_train_results(final_data):

    prediction_columns, results, full_breakdown = initialise(final_data)
    excl_index = final_data[final_data['source']=='edits'].index
    actual = final_data['label'].apply(lambda x: 'c'+str(x))
    label_set = ['c0','c1','c2','c3','c4']
    
    for pcat in prediction_columns:
        _y_true = actual.copy()
        main_index = final_data[[pcat, 'label']].dropna().index
        wo_index = [i for i in main_index if (i not in excl_index)]
        
        if '4t' or 'base' in pcat:
            _y_true = _y_true.apply(lambda x: 'c0' if x=='c4' else x)
        
        # full data
        y_true = np.array(_y_true)[main_index].copy()
        y_pred = np.array(final_data[pcat])[main_index].copy()
        
        # original data
        y_true_2 = np.array(_y_true)[wo_index].copy()
        y_pred_2 = np.array(final_data[pcat])[wo_index].copy()
        
        c0_count, c1_count, c2_count, c3_count, c4_count = \
            sum(y_true=='c0'), sum(y_true=='c1'), sum(y_true=='c2'), sum(y_true=='c3'), sum(y_true=='c4')
        total_count = len(y_true)
        
        # confusion matrix
        breakdown = pd.DataFrame(confusion_matrix(y_true, y_pred, labels=list(label_set)))
        breakdown.columns = list(label_set)
        breakdown['label'] = list(label_set)
        
        # metrics
        acc = accuracy_score(y_true, y_pred)
        acc_wo = accuracy_score(y_true_2, y_pred_2)

        macro = precision_recall_fscore_support(y_true, y_pred, average='macro')
        macro_wo = precision_recall_fscore_support(y_true_2, y_pred_2, average='macro')
        
        # append results
        classifier_name, edit_type_name, base_type_name, label_name, edit_source_name \
            = get_column_names(root_folder_name, pcat)
        breakdown.insert(loc=0, column='pcat', value=pcat)
        full_breakdown = concat_to_main_frame(full_breakdown, breakdown)
        results.append([pcat, edit_source_name, classifier_name, base_type_name, edit_type_name, label_name, 
                        c0_count, c1_count, c2_count, c3_count, c4_count, total_count,
                        macro[0], macro[1], macro[2], acc, 
                        macro_wo[0], macro_wo[1], macro_wo[2], acc_wo])
                
    results, full_breakdown = format_data(results, full_breakdown, train=True)
    return results, full_breakdown

##### TEST #####
def get_apply_results(final_data):

    prediction_columns, results, full_breakdown = initialise(final_data)
    csci2sciente = {'c0': 0, 'c1': 1, 'c2': 1, 'c3': 0, 'c4': 0}
    sciente2csci = invert_dictionary(csci2sciente)
    actual = final_data['label'].copy()
    actual_2 = final_data['label'].apply(lambda x: 'c'+str(x))
    label_set = set(actual)
    
    for pcat in prediction_columns:

        main_index = final_data[[pcat, 'label']].dropna().index
        
        # 2 labels only
        y_true = np.array(actual)[main_index].copy()
        y_pred = final_data.loc[main_index][pcat].apply(lambda x: csci2sciente[x]).copy()
        
        # all original labels
        y_true_2 = np.array(actual_2)[main_index].copy()
        y_pred_2 = final_data.loc[main_index][pcat].copy()
        
        # confusion matrix
#         print(final_data[pcat])
        breakdown = pd.DataFrame(confusion_matrix(y_true_2, y_pred_2, labels=list(csci2sciente.keys())))
        breakdown.columns = list(csci2sciente.keys())
        breakdown['label'] = list(csci2sciente.keys())
        breakdown = breakdown.iloc[0:2,:]
        
        # metrics
        
        acc = accuracy_score(y_true, y_pred)
        acc_wo = accuracy_score(y_true_2, y_pred_2)

        macro = precision_recall_fscore_support(y_true, y_pred, average='macro')
        macro_wo = precision_recall_fscore_support(y_true_2, y_pred_2, average='macro')

        # append results
        classifier_name, edit_type_name, base_type_name, label_name, edit_source_name = get_column_names(root_folder_name, pcat)
        breakdown.insert(loc=0, column='pcat', value=pcat)
        full_breakdown = concat_to_main_frame(full_breakdown, breakdown)
        results.append([pcat, edit_source_name, classifier_name, base_type_name, edit_type_name, label_name, 
                        macro[0], macro[1], macro[2], acc, 
                        macro_wo[0], macro_wo[1], macro_wo[2], acc_wo])

    results, full_breakdown = format_data(results, full_breakdown)
    return results, full_breakdown

##### SAVE #####
from pandas import ExcelWriter

def save_xls(dict_df, path):
    writer = ExcelWriter(path)
    for key in dict_df:
        dict_df[key].to_excel(writer, key)
    writer.save()

In [11]:
# TRAIN VAL PUBMED & TEST SCITE
supcon_svm = False
root_folder_name = 'pubmed_SVM_biobert'
seed_num, folds, epochs = 0, 5, 5

train_results = pd.DataFrame()

for train in [True, False]:

    final_data = combine_results(root_folder_name, folds, epochs, seed_num, train, supcon_svm)
    ext = '' if train else '_apply'
    final_data.to_csv(f'{root_folder_name}_rowresults{ext}.csv')

    if train:
        train_results, train_breakdown = get_train_results(final_data)
    else:
        test_results, test_breakdown = get_apply_results(final_data)

results = train_results.merge(test_results, how='outer', suffixes=['_Train', '_Test'], 
                              on=['pcat', 'Edit_Source', 'Classifier', 'Base_Type', 'Edit_Type', 'Label'])
results

processing "K5_epochs5_2to1_edits_4t.csv"...
processing "K5_epochs5_2to1_edits_4t_multiples.csv"...
processing "K5_epochs5_2to1_edits_4t_rs.csv"...
processing "K5_epochs5_2to1_edits_4t_rs_multiples.csv"...
processing "K5_epochs5_2to1_edits_4t_rs_shorten.csv"...
processing "K5_epochs5_2to1_edits_4t_rs_synonyms.csv"...
processing "K5_epochs5_2to1_edits_4t_rs_t5para.csv"...
processing "K5_epochs5_2to1_edits_4t_shorten.csv"...
processing "K5_epochs5_2to1_edits_4t_synonyms.csv"...
processing "K5_epochs5_2to1_edits_4t_t5para.csv"...
processing "K5_epochs5_2to1_oriedits_4t_mask.csv"...
processing "K5_epochs5_2to1_oriedits_4t_rs_mask.csv"...
processing "K5_epochs5_2to1_oriedits_4t_rs_shorten.csv"...
processing "K5_epochs5_2to1_oriedits_4t_shorten.csv"...
processing "K5_epochs5_all_edits_4t.csv"...
processing "K5_epochs5_all_edits_4t_rs.csv"...
processing "K5_epochs5_all_edits_5t.csv"...
processing "K5_epochs5_base.csv"...
processing "K5_epochs5_base_mask.csv"...
processing "K5_epochs5_base_sho

  _warn_prf(average, modifier, msg_start, len(result))


processing "epochs5_2to1_edits_4t.csv"...
processing "epochs5_2to1_edits_4t_multiples.csv"...
processing "epochs5_2to1_edits_4t_rs.csv"...
processing "epochs5_2to1_edits_4t_rs_multiples.csv"...
processing "epochs5_2to1_edits_4t_rs_shorten.csv"...
processing "epochs5_2to1_edits_4t_rs_synonyms.csv"...
processing "epochs5_2to1_edits_4t_rs_t5para.csv"...
processing "epochs5_2to1_edits_4t_shorten.csv"...
processing "epochs5_2to1_edits_4t_synonyms.csv"...
processing "epochs5_2to1_edits_4t_t5para.csv"...
processing "epochs5_2to1_oriedits_4t_mask.csv"...
processing "epochs5_2to1_oriedits_4t_rs_mask.csv"...
processing "epochs5_2to1_oriedits_4t_rs_shorten.csv"...
processing "epochs5_2to1_oriedits_4t_shorten.csv"...
processing "epochs5_all_edits_4t.csv"...
processing "epochs5_all_edits_4t_rs.csv"...
processing "epochs5_all_edits_5t.csv"...
processing "epochs5_base.csv"...
processing "epochs5_base_mask.csv"...
processing "epochs5_base_shorten.csv"...
processing "epochs5_edits_4t.csv"...
processing

Unnamed: 0,pcat,Edit_Source,Classifier,Base_Type,Edit_Type,Label,n_c0,n_c1,n_c2,n_c3,...,F1_wo_Train,Acc_wo_Train,P_Test,R_Test,F1_Test,Acc_Test,P_wo_Test,R_wo_Test,F1_wo_Test,Acc_wo_Test
0,base_mask,-,SVM,Mask,-,base,1232,462,207,962,...,0.81966,0.841425,0.809949,0.715495,0.744051,0.838252,0.406558,0.353612,0.370566,0.828114
1,2to1_oriedits_4t_mask,2to1,SVM,Mask,Mask,4t,1210,642,200,947,...,0.825801,0.842573,0.808537,0.758849,0.778373,0.849516,0.405794,0.37606,0.388262,0.840279
2,oriedits_4t_mask,1to4,SVM,Mask,Mask,4t,1591,454,203,947,...,0.810623,0.82978,0.780901,0.706201,0.730125,0.826087,0.391846,0.349123,0.363627,0.81595
3,2to1_oriedits_4t_rs_mask,2to1,SVM,Mask,Mask,4t_rs,1210,462,200,947,...,0.816475,0.841163,0.804495,0.733015,0.757707,0.841406,0.404153,0.362819,0.377849,0.83262
4,oriedits_4t_rs_mask,1to4,SVM,Mask,Mask,4t_rs,1232,454,203,947,...,0.798831,0.814684,0.753593,0.719023,0.732894,0.8164,0.377507,0.355079,0.364571,0.805362
5,oriedits_5t_mask,1to4,SVM,Mask,Mask,5t,1591,454,203,947,...,0.647761,0.828003,0.792051,0.714629,0.739672,0.832169,0.318111,0.282901,0.294879,0.823384
6,base,-,SVM,Original,-,base,1356,494,213,998,...,0.869519,0.888598,0.837578,0.728884,0.761227,0.850417,0.420071,0.362027,0.379755,0.84501
7,2to1_edits_4t_multiples,2to1,SVM,Original,Multiples,4t,1353,665,209,995,...,0.875785,0.897083,0.826587,0.738651,0.767491,0.850417,0.414551,0.36705,0.382908,0.845911
8,edits_4t_multiples,1to4,SVM,Original,Multiples,4t,1706,491,212,995,...,0.868081,0.891511,0.859494,0.699496,0.735681,0.84456,0.430712,0.347547,0.366671,0.840279
9,2to1_edits_4t_rs_multiples,2to1,SVM,Original,Multiples,4t_rs,1353,494,209,995,...,0.86455,0.892429,0.820219,0.716002,0.746397,0.841406,0.410936,0.355586,0.37219,0.835999


In [5]:
save_xls({
    'pubmed_train': train_breakdown,
    'pubmed_test': test_breakdown,
    'summary_'+str(seed_num): results
}, f'{root_folder_name}_results.xlsx')

In [12]:
# TEST ALTLEX
supcon_svm = False
root_folder_name = 'pubmedxaltlex_SVM_biobert'
seed_num, folds, epochs = 0, 5, 5

final_data = combine_results(root_folder_name, folds, epochs, seed_num, train, supcon_svm)
test_results, test_breakdown = get_apply_results(final_data)
save_xls({
    'pubmed_test': test_breakdown,
    'summary_'+str(seed_num): test_results
}, f'{root_folder_name}_results.xlsx')

processing "epochs5_2to1_edits_4t_altlex.csv"...
processing "epochs5_2to1_edits_4t_multiples_altlex.csv"...
processing "epochs5_2to1_edits_4t_rs_altlex.csv"...
processing "epochs5_2to1_edits_4t_rs_multiples_altlex.csv"...
processing "epochs5_2to1_edits_4t_rs_shorten_altlex.csv"...
processing "epochs5_2to1_edits_4t_shorten_altlex.csv"...
processing "epochs5_all_edits_4t_altlex.csv"...
processing "epochs5_all_edits_4t_rs_altlex.csv"...
processing "epochs5_all_edits_5t_altlex.csv"...
processing "epochs5_base_altlex.csv"...
processing "epochs5_edits_4t_altlex.csv"...
processing "epochs5_edits_4t_multiples_altlex.csv"...
processing "epochs5_edits_4t_rs_altlex.csv"...
processing "epochs5_edits_4t_rs_multiples_altlex.csv"...
processing "epochs5_edits_4t_rs_shorten_altlex.csv"...
processing "epochs5_edits_4t_shorten_altlex.csv"...
processing "epochs5_edits_5t_altlex.csv"...
processing "epochs5_edits_5t_multiples_altlex.csv"...
processing "epochs5_edits_5t_shorten_altlex.csv"...
processing "epo

  _warn_prf(average, modifier, msg_start, len(result))
