In [89]:
import pandas as pd
import numpy as np
import ast
from sklearn.metrics import classification_report, accuracy_score
from sklearn.metrics import precision_score, recall_score, f1_score
from sklearn.metrics import precision_recall_fscore_support
from sklearn.preprocessing import LabelEncoder

In [90]:
import sys
import os
notebook_dir = os.getcwd()
codex_dir = os.path.join(notebook_dir, '..','..','Codex','codex')
sys.path.append(codex_dir)
from codex import Codex
codex = Codex(size='s', code='en')
all_relations = codex.split('train')['relation'].unique()
all_relations = [codex.relation_label(relation) for relation in all_relations]
all_relations.append('Relationship is ambiguous')
all_relations = [relation.lower() for relation in all_relations]

In [91]:
df = pd.read_csv('zeroshot_multi_relation_with_ambiguous/results.csv')

In [92]:
df.verification = df.verification.apply(lambda x: ast.literal_eval(x))
df.pred_relation = df.pred_relation.apply(lambda x: ast.literal_eval(x))

In [93]:
import re
pred_verifications = []
for ind, row in df.iterrows():
    result = []
    for i, ver in enumerate(row['verification']):
        if ver == 'relationship is ambiguous':
            result.append('relationship is ambiguous')
        else:
            match = re.search(r'^\w+', ver)
            if match:
                first_word = match.group()
                if str.lower(first_word) not in ['true','false']:
                    print(row)
                    break
                else:
                    result.append(str.lower(first_word))
            else:
                print(f"faliure in row {ind}")
    pred_verifications.append(result)
df['pred_verification'] = pred_verifications

In [94]:
def process_verification(row):
    result = []
    for i, ver in enumerate(row['pred_verification']):
        if ver == 'relationship is ambiguous':
            result.append('relationship is ambiguous')
        else:
            if ver == 'true':
                result.append(row['pred_relation'][i])
    if not result:
        result.append('relationship is ambiguous')
    row['final_pred'] = result
    return row['final_pred']

df['final_pred'] = None
df['final_pred'] = df.apply(process_verification, axis=1)
df

Unnamed: 0,head,relation,tail,Label,head_id,relation_id,tail_id,prompt,response,pred_relation,true_relation,verification,pred_verification,final_pred
0,Herbert Grönemeyer,country of citizenship,Australia,0,Q275247,P27,Q408,What is the relationship between 'Herbert Grön...,1. Herbert Grönemeyer [place of birth] Austral...,"[place of birth, nationality, tour, concert, p...",relationship is ambiguous,[False. Reason: Herbert Grönemeyer was born in...,"[false, false, false, false, false]",[relationship is ambiguous]
1,"""Weird Al"" Yankovic",record label,Atlantic Records,0,Q8349,P264,Q202440,"What is the relationship between '""Weird Al"" Y...","1. ""Weird Al"" Yankovic [record label] Atlantic...","[record label, employed by, signed to]",relationship is ambiguous,"[True. Reason: ""Weird Al"" Yankovic was signed ...","[true, false, false]",[record label]
2,Walt Disney Records,founded by,Debby Ryan,0,Q54860,P112,Q204005,What is the relationship between 'Walt Disney ...,1. Walt Disney Records [record label] Debby Ry...,"[record label, cast member]",relationship is ambiguous,[False. Reason: Walt Disney Records is a recor...,"[false, true]",[cast member]
3,Johnny Hallyday,country of citizenship,Switzerland,0,Q212015,P27,Q39,What is the relationship between 'Johnny Hally...,1. Johnny Hallyday [residence] Switzerland; 2....,"[residence, place of death, performed in, has ...",relationship is ambiguous,"[False. Reason: Johnny Hallyday, a well-known ...","[false, true, true, false]","[place of death, performed in]"
4,Madonna,"languages spoken, written, or signed",French,1,Q1744,P1412,Q150,What is the relationship between 'Madonna' and...,"1. Madonna [languages spoken, written, or sign...","[languages spoken, written, or signed, residen...","languages spoken, written, or signed",[True. Reason: Madonna is known to speak Frenc...,"[true, false, false, false, false, false, fals...","[languages spoken, written, or signed]"
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,Chord Overstreet,occupation,singer,1,Q53714,P106,Q177220,What is the relationship between 'Chord Overst...,1. Chord Overstreet [occupation] singer; 2. Ch...,"[occupation, genre, influenced by, practiced b...",occupation,[True. Reason: Chord Overstreet is known for b...,"[true, false, false, false, false]",[occupation]
996,Eleanor Roosevelt,residence,"Washington, D.C.",1,Q83396,P551,Q61,What is the relationship between 'Eleanor Roos...,"1. Eleanor Roosevelt [residence] Washington, D...","[residence, place of death, place of birth, fo...",residence,[False. Reason: Eleanor Roosevelt resided in t...,"[false, false, true, false, true, false, relat...","[place of birth, headquarters location, relati..."
997,Harry Connick Jr.,genre,rhythm and blues,0,Q313755,P136,Q45981,What is the relationship between 'Harry Connic...,1. Harry Connick Jr. [genre] Rhythm and Blues.,[genre],relationship is ambiguous,[False. Reason: Harry Connick Jr. is known for...,[false],[relationship is ambiguous]
998,Saint Petersburg,named after,Vladimir Lenin,1,Q656,P138,Q1394,What is the relationship between 'Saint Peters...,1. Saint Petersburg [place of birth] Vladimir ...,"[place of birth, residence, narrative location...",named after,[False. Reason: Vladimir Lenin was born in Sim...,"[false, false, false, false]",[relationship is ambiguous]


<h3>Hit at 1</h3>

In [95]:
hit = 1
hit_at_1_first = 0
hit_at_1_second = 0
for index, row in df.iterrows():
    if len(row['pred_relation']) < hit:
        if row['true_relation'] in row['pred_relation']:
            hit_at_1_first += 1
    else:
        if row['true_relation'] in row['pred_relation'][:hit]:
            hit_at_1_first += 1

    if len(row['final_pred']) < hit:
        if row['true_relation'] in row['final_pred']:
            hit_at_1_second += 1
    else:
        if row['true_relation'] in row['final_pred'][:hit]:
            hit_at_1_second += 1
hit_at_1_first = hit_at_1_first/len(df)
hit_at_1_second = hit_at_1_second/len(df)
print(f"hit at {hit} ask first: {hit_at_1_first}")
print(f"hit at {hit} ask twice: {hit_at_1_second}")

hit at 1 ask first: 0.383
hit at 1 ask twice: 0.56


<h3>The possibility of the right answer included in the prediction</h3>

In [96]:
included_first = 0
included_second = 0
for index, row in df.iterrows():
    if row['true_relation'] in row['pred_relation']:
        included_first += 1

    if row['true_relation'] in row['final_pred']:
        included_second += 1
        
included_first = included_first/len(df)
included_second = included_second/len(df)
print(f"true relation is included in the predictions, ask first: {included_first}")
print(f"true relation is included in the predictions, ask twice: {included_second}")

true relation is included in the predictions, ask first: 0.595
true relation is included in the predictions, ask twice: 0.616
