# Merge PIO predictions

In [1]:
import pandas as pd
import os
from collections import Counter
import random
from sklearn.metrics import (accuracy_score, classification_report,
                             confusion_matrix, f1_score, 
                             precision_score, recall_score)

In [2]:
base = '/mnt/nas2/results/Results/systematicReview/SemEval2023/predictions_test/without_Dropout/'

entities = ['participant', 'intervention', 'outcome']
seeds = ['0', '1', '42']
embedding = ['roberta', 'biomedroberta']

In [3]:
def rename_cols(df):
    print( df.columns )
    df.rename(columns={df.columns[0]: 'post_id'},inplace=True)
    df.rename(columns={df.columns[1]: 'subredit_id'},inplace=True)
    df.rename(columns={df.columns[2]: 'words'},inplace=True)
    df.rename(columns={df.columns[3]: 'labels'},inplace=True)
    
    
    columns_titles = ["subredit_id","post_id"]
    df=df.reindex(columns=columns_titles)
    return df

In [4]:
def get_predictions(picos, embed):
    
    file_paths = []

    s1 = dict()
    s2 = dict()
    s3 = dict()

    for i in seeds:

        dir_path = os.path.join(base, picos, i, embed)
        files = os.listdir(dir_path)

        for f in files:
            file_path = os.path.join(base, picos, i, embed, f)
            file_paths.append(file_path)    
            # get the file into a dataframe
            df = pd.read_csv(file_path, sep=',')
            col = df.columns[0]
            #df = rename_cols(df) 
            key = f.split('_')[0].replace('ensemble', '')
            if i == '0':
                s1[ key ] = df
            if i == '1':
                s2[ key ] = df
            if i == '42':
                s3[ key ] = df
                
    return s1, s2, s3

In [5]:
p1, p2, p3 = get_predictions(entities[0], embedding[0])
i1, i2, i3 = get_predictions(entities[1], embedding[0])
o1, o2, o3 = get_predictions(entities[2], embedding[0])

In [6]:
print( len(p1), len(p2), len(p3) )
print( len(i1), len(i2), len(i3) )
print( len(o1), len(o2), len(o3) )

14 14 14
14 14 14
14 14 14


In [7]:
def get_ensembles(d1, d2, d3, ensembles):
    
    en1, en2, en3 = ensembles
    
    x = d1[str(en1)]['labels']
    y = d2[str(en2)]['labels']
    z = d3[str(en3)]['labels']
    
    return x,y,z

p1l, p2l, p3l = get_ensembles(p1, p2, p3, ensembles = [14, 14, 14])
i1l, i2l, i3l = get_ensembles(i1, i2, i3, ensembles = [14, 14, 14])
o1l, o2l, o3l = get_ensembles(o1, o2, o3, ensembles = [14, 14, 14])

In [8]:
def rand_choice(inp):
    
    return entity

In [21]:
def merge_preds(p_s, i_s, o_s):
    
    count_random = 0
    count_not_random = 0

    merged = []
    
    p1labs, p2labs, p3labs = p_s
    i1labs, i2labs, i3labs = i_s
    o1labs, o2labs, o3labs = o_s
    
    for counter, (p_1, p_2, p_3, i_1, i_2, i_3, o_1, o_2, o_3) in enumerate( zip( p1labs, p2labs, p3labs, i1labs, i2labs, i3labs, o1labs, o2labs, o3labs ) ):
        
        all_labels = [p_1, p_2, p_3, i_1, i_2, i_3, o_1, o_2, o_3]
        
        if len( set( all_labels ) ) == 1:
            merged.append( 'O' )
        else:
            filtered_o = list(filter(('O').__ne__, all_labels))
            
            if len( set( filtered_o ) ) == 1:
                #print( filtered_o, ' ------ ', set( filtered_o ) )
                entity_label = list( set( filtered_o ) )[0]
                merged.append( entity_label )
                count_not_random = count_not_random + 1
            else:
                c_p = filtered_o.count('POP')
                c_i = filtered_o.count('INT')
                c_o = filtered_o.count('OUT')
                counter_dict = Counter(filtered_o)
                counter_reverse = dict((v, k) for k, v in counter_dict.items())
                
                if (c_p == c_i == c_o) or len( set(counter_dict.values()) ) == 1:
                    entity_label = random.sample(filtered_o, 1)
                    entity_label = entity_label[0]
                    merged.append( entity_label )
                    count_random = count_random + 1
                else:
                    first_two = dict( counter_dict.most_common( 2 ) )

                    if len(set( first_two.values() )) > 1:
                        entity_label = dict(counter_dict.most_common( 1 ))
                        entity_label = list(entity_label.keys())[0]
                        merged.append( entity_label )
                        #print( counter_dict , ' ----- ',  entity_label)
                        count_not_random = count_not_random + 1
                    else:
                        
                        entity_label = random.sample(list( first_two.keys() ), 1)
                        entity_label = entity_label[0]
                        merged.append( entity_label )
                        #print( counter_dict , ' ----- ',  entity_label)
                        count_random = count_random + 1

        
        #if counter == 900:
        #    break
    
    print( 'randomness involved in: ', count_random )
    print( 'randomness NOT involved in: ', count_not_random )
    
    return merged


predictions_merged = merge_preds(p_s = [p1l, p2l, p3l], i_s = [i1l, i2l, i3l], o_s = [o1l, o2l, o3l])

randomness involved in:  3135
randomness NOT involved in:  2536


In [22]:
def write_preds(df, ensemb, l):
    
    base_path = f"/mnt/nas2/results/Results/systematicReview/SemEval2023/submission/random/{ensemb}.csv"
    df = df['14']
    df = df.assign(labels = pd.Series(l).values)
    
    df.to_csv(base_path, encoding='utf-8')
    
write_preds(p1, '14', predictions_merged)