In [13]:
import pandas as pd
import numpy as np
import ast
from sklearn.metrics import classification_report, accuracy_score
from sklearn.metrics import precision_score, recall_score, f1_score
from sklearn.metrics import precision_recall_fscore_support
from sklearn.preprocessing import LabelEncoder

In [14]:
import sys
import os
notebook_dir = os.getcwd()
codex_dir = os.path.join(notebook_dir, '..','..','Codex','codex')
sys.path.append(codex_dir)
from codex import Codex
codex = Codex(size='s', code='en')
all_relations = codex.split('train')['relation'].unique()
all_relations = [codex.relation_label(relation) for relation in all_relations]
all_relations.append('Relationship is ambiguous')
all_relations = [relation.lower() for relation in all_relations]

In [15]:
df = pd.read_csv('zero_cot_multi_relation_non_ambiguous/result.csv')
df.loc[df['Label'] == 0, 'true_relation'] = 'relationship is ambiguous'

In [16]:
df.verification = df.verification.apply(lambda x: ast.literal_eval(x))
df.pred_relation = df.pred_relation.apply(lambda x: ast.literal_eval(x))

In [17]:
import re
pred_verifications = []
for ind, row in df.iterrows():
    result = []
    for i, ver in enumerate(row['verification']):
        if ver == 'relationship is ambiguous':
            result.append('relationship is ambiguous')
        else:
            match = re.search(r'^\w+', ver)
            if match:
                first_word = match.group()
                if str.lower(first_word) not in ['true','false']:
                    print(row)
                    break
                else:
                    result.append(str.lower(first_word))
            else:
                print(f"faliure in row {ind}")
    pred_verifications.append(result)
df['pred_verification'] = pred_verifications

head                                                       Thomas Mann
relation                                                    occupation
tail                                                          essayist
Label                                                                1
head_id                                                         Q37030
relation_id                                                       P106
tail_id                                                      Q11774202
prompt               What is the relationship between 'Thomas Mann'...
response             Reasoning process: Thomas Mann is a notable Ge...
pred_relation                                             [occupation]
true_relation                                               occupation
verification         [Solution: False. Reason: Thomas Mann was a re...
pred_verification                                        ['not found']
Name: 676, dtype: object


In [18]:
def process_verification(row):
    result = []
    for i, ver in enumerate(row['pred_verification']):
        if ver == 'relationship is ambiguous':
            result.append('relationship is ambiguous')
        else:
            if ver == 'true':
                result.append(row['pred_relation'][i])
    if not result:
        result.append('relationship is ambiguous')
    row['final_pred'] = result
    return row['final_pred']

df['final_pred'] = None
df['final_pred'] = df.apply(process_verification, axis=1)
df

Unnamed: 0,head,relation,tail,Label,head_id,relation_id,tail_id,prompt,response,pred_relation,true_relation,verification,pred_verification,final_pred
0,Herbert Grönemeyer,country of citizenship,Australia,0,Q275247,P27,Q408,What is the relationship between 'Herbert Grön...,Reasoning process: Herbert Grönemeyer is a Ger...,"[performed in, has toured in, performed in]",relationship is ambiguous,[False. Reason: Herbert Grönemeyer is a German...,"[false, true, false]",[has toured in]
1,"""Weird Al"" Yankovic",record label,Atlantic Records,0,Q8349,P264,Q202440,"What is the relationship between '""Weird Al"" Y...","Reasoning process: ""Weird Al"" Yankovic is a mu...","[record label, record label]",relationship is ambiguous,"[True. Reason: ""Weird Al"" Yankovic has been si...","[true, false]",[record label]
2,Walt Disney Records,founded by,Debby Ryan,0,Q54860,P112,Q204005,What is the relationship between 'Walt Disney ...,Reasoning process: Walt Disney Records is a re...,[record label],relationship is ambiguous,[False. Reason: Walt Disney Records is a recor...,[false],[relationship is ambiguous]
3,Johnny Hallyday,country of citizenship,Switzerland,0,Q212015,P27,Q39,What is the relationship between 'Johnny Hally...,Reasoning process: Johnny Hallyday is a promin...,[residence],relationship is ambiguous,"[False. Reason: Johnny Hallyday, the French si...",[false],[relationship is ambiguous]
4,Madonna,"languages spoken, written, or signed",French,1,Q1744,P1412,Q150,What is the relationship between 'Madonna' and...,Reasoning process: Madonna is a singer who has...,[influenced by],"languages spoken, written, or signed",[False. Reason: Madonna is a pop singer known ...,[false],[relationship is ambiguous]
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,Chord Overstreet,occupation,singer,1,Q53714,P106,Q177220,What is the relationship between 'Chord Overst...,Reasoning process: Chord Overstreet is an indi...,[occupation],occupation,[False. Reason: Chord Overstreet is an actor a...,[false],[relationship is ambiguous]
996,Eleanor Roosevelt,residence,"Washington, D.C.",1,Q83396,P551,Q61,What is the relationship between 'Eleanor Roos...,Reasoning process: Eleanor Roosevelt was a pro...,[residence],residence,[True. Reason: Eleanor Roosevelt lived in the ...,[true],[residence]
997,Harry Connick Jr.,genre,rhythm and blues,0,Q313755,P136,Q45981,What is the relationship between 'Harry Connic...,Reasoning process: Harry Connick Jr. is a musi...,[genre],relationship is ambiguous,[False. Reason: Harry Connick Jr. is known for...,[false],[relationship is ambiguous]
998,Saint Petersburg,named after,Vladimir Lenin,1,Q656,P138,Q1394,What is the relationship between 'Saint Peters...,Reasoning process: Saint Petersburg is a city ...,[place of birth],named after,"[True. Reason: Vladimir Lenin, the Russian rev...",[true],[place of birth]


<h3>Hit at 1</h3>

In [19]:
hit = 1
hit_at_1_first = 0
hit_at_1_second = 0
for index, row in df.iterrows():
    if len(row['pred_relation']) < hit:
        if row['true_relation'] in row['pred_relation']:
            hit_at_1_first += 1
    else:
        if row['true_relation'] in row['pred_relation'][:hit]:
            hit_at_1_first += 1

    if len(row['final_pred']) < hit:
        if row['true_relation'] in row['final_pred']:
            hit_at_1_second += 1
    else:
        if row['true_relation'] in row['final_pred'][:hit]:
            hit_at_1_second += 1
hit_at_1_first = hit_at_1_first/len(df)
hit_at_1_second = hit_at_1_second/len(df)
print(f"hit at {hit} ask first: {hit_at_1_first}")
print(f"hit at {hit} ask twice: {hit_at_1_second}")

hit at 1 ask first: 0.289
hit at 1 ask twice: 0.588


<h3>The possibility of the right answer included in the prediction</h3>

In [20]:
included_first = 0
included_second = 0
for index, row in df.iterrows():
    if row['true_relation'] in row['pred_relation']:
        included_first += 1

    if row['true_relation'] in row['final_pred']:
        included_second += 1
        
included_first = included_first/len(df)
included_second = included_second/len(df)
print(f"true relation is included in the predictions, ask first: {included_first}")
print(f"true relation is included in the predictions, ask twice: {included_second}")

true relation is included in the predictions, ask first: 0.303
true relation is included in the predictions, ask twice: 0.594


<h3>Negative sample correction rate</h3>

In [21]:
hit = 1
hit_at_1_first_neg= 0
hit_at_1_second_neg = 0
for index, row in df.iterrows():
    if row['true_relation'] != 'relationship is ambiguous':
        continue
    
    if len(row['pred_relation']) < hit:
        if row['true_relation'] in row['pred_relation']:
            hit_at_1_first += 1
    else:
        if row['true_relation'] in row['pred_relation'][:hit]:
            hit_at_1_first += 1

    if len(row['final_pred']) < hit:
        if row['true_relation'] in row['final_pred']:
            hit_at_1_second += 1
    else:
        if row['true_relation'] in row['final_pred'][:hit]:
            hit_at_1_second += 1
hit_at_1_first = hit_at_1_first/(len(df)/2)
hit_at_1_second = hit_at_1_second/(len(df)/2)
print(f"hit at {hit} ask first: {hit_at_1_first}")
print(f"hit at {hit} ask twice: {hit_at_1_second}")

hit at 1 ask first: 0.000578
hit at 1 ask twice: 0.8411759999999999
