In [1]:
import json

In [2]:
with open('wino_amrs.jsonl') as f:
    gt_amrs = [json.loads(l) for l in f]

In [13]:
# Change to your output parse file
model_name='blip2-flant5xl'
# model_name='blip2-opt-6.7b-coco'
with open(f'sample_outputs/wino_{model_name}_caps_parse.jsonl') as f:
    cand_amrs = [json.loads(l) for l in f]

In [4]:
gt_amrs = gt_amrs[:len(cand_amrs)]

In [5]:
SAMPLE_SIZE=len(cand_amrs[0]['parses_0'])

In [6]:
def full_score(candidate,ground_truth):
    full_score=!python smatch/smatch.py -f {candidate} {ground_truth}  --ms --significant 10 | cut -d ':' -f2
    return list(map(float,full_score))

def full_score_recall(candidate,ground_truth):
    full_score=!python smatch/smatch.py -f {candidate} {ground_truth}  --ms --pr --significant 10 | grep Recall | cut -d ':' -f2
    return list(map(float,full_score))

def rel_score(candidate,ground_truth):
    just_rel=!python smatch/smatch.py -f {candidate} {ground_truth} --ms --justrelation | cut -d ':' -f2
    return list(map(float,just_rel))

def attr_score(candidate,ground_truth):
    just_attr=!python smatch/smatch.py -f {candidate} {ground_truth} --ms --justattribute | cut -d ':' -f2
    return list(map(float,just_attr))

def inst_score(candidate,ground_truth):
    just_inst=!python smatch/smatch.py -f {candidate} {ground_truth} --ms --justinstance | cut -d ':' -f2
    return list(map(float,just_inst))

# try sema
# def full_score(candidate,ground_truth):
#     full_score=!python sema/sema.py -t {candidate} -g {ground_truth}  --ms | cut -d ':' -f2
#     return list(map(float,full_score))

smatch_score_map = {
    'full':full_score,
    'full_recall':full_score_recall,
    'rel':rel_score,
    'attr':attr_score,
    'inst':inst_score
}

def get_smatch_scores_from_file(ground_truth_path,candidate_path,method='full'):
    global smatch_score_map
    if method not in smatch_score_map:
        raise ValueError("invalid smatch score method")
    return smatch_score_map[method](candidate_path,ground_truth_path)

def save_tmp_amr(f1,f2,ground_truth_path,candidate_path):
    with open(ground_truth_path,'w') as f, open(candidate_path,'w') as g:
        for p0,p1 in zip(f1,f2):
            print(p0,file=f)
            print(p1,file=g)
            
def get_smatch_scores(f1,f2,method='full'):
    candidate_path="cand_test.amr"
    ground_truth_path="gt_test.amr"
    save_tmp_amr(f1,f2,ground_truth_path,candidate_path)
    return get_smatch_scores_from_file(ground_truth_path,candidate_path,method=method)

In [7]:
# C0,I0,C1
# s(C0,I0) vs s(C1,I0)
# s(C,I) => smatch(parse(C),parse(cap(I)))

In [8]:
import statistics

In [9]:
def get_handles(gt_amrs,cand_amrs):
    fc0=[gt_amr['parse_0'] for gt_amr in gt_amrs for _ in range(SAMPLE_SIZE)]
    fi0=[p for cand_amr in cand_amrs for p in cand_amr['parses_0']]
    fc1=[gt_amr['parse_1'] for gt_amr in gt_amrs for _ in range(SAMPLE_SIZE)]
    fi1=[p for cand_amr in cand_amrs for p in cand_amr['parses_1']]
    return fc0,fc1,fi0,fi1

def get_pair_scores(scoring_func,fc0,fc1,fi0,fi1):
    s00 = list(scoring_func(fc0,fi0))
    s01 = list(scoring_func(fc0,fi1))
    s10 = list(scoring_func(fc0,fi1))
    s11 = list(scoring_func(fc1,fi1))
    return s00,s01,s10,s11

def get_text_score(s00,s01,s10,s11):
    # s00 = score(C0,I0)
    # text score == given image, pick caption
    counts = [c0>x0 and c1>x1 for c0,x0,c1,x1 in zip(s00,s10,s11,s01)]
    return sum(counts)/len(counts),counts

def get_image_score(s00,s01,s10,s11):
    counts = [c0>x0 and c1>x1 for c0,x0,c1,x1 in zip(s00,s01,s11,s10)]
    return sum(counts)/len(counts),counts

def get_group_score(s00,s01,s10,s11):
    counts = [c00>x01 and c00>x10 and c11>x01 and c11>x10 for c00,x01,c11,x10 in zip(s00,s01,s11,s10)]
    return sum(counts)/len(counts),counts

In [10]:
fc0,fc1,fi0,fi1 = get_handles(gt_amrs,cand_amrs)
score_func=lambda f1,f2:get_smatch_scores(f1,f2,method='full')
s00,s01,s10,s11 = get_pair_scores(score_func,fc0,fc1,fi0,fi1)
text_score,txt_cnt=get_text_score(s00,s01,s10,s11)
im_score,im_cnt=get_image_score(s00,s01,s10,s11)
gp_score,gp_cnt=get_group_score(s00,s01,s10,s11)

In [11]:
print(text_score,im_score,gp_score)

0.1005 0.102 0.099


In [12]:
s00[0],s10[0],s11[0],s01[0]

(0.5384615385, 0.3571428571, 0.5, 0.3571428571)

In [18]:
len(txt_cnt)

2000

In [19]:
def get_avg_cnt_first(cnt):
    cnt_avg_first = [ sum([cnt[i+idx] for idx in range(SAMPLE_SIZE)])/SAMPLE_SIZE for i in range(0,len(txt_cnt),SAMPLE_SIZE)]
    return cnt_avg_first

In [20]:
def get_filtered_list(cnt,filter_set_ids):
    cnt_filtered = [ cnt[i] for i in range(len(cnt)) if i in filter_set_ids]
    return cnt_filtered

def get_filtered_score(cnt_expanded,filter_set_ids):
    cnt = get_avg_cnt_first(cnt_expanded)
    filtered = get_filtered_list(cnt,filter_set_ids)
    return sum(filtered)/len(filtered) , filtered

In [21]:
compositionality_sample_ids = set([
    0,1,2,5,6,7,8,9,11,12,14,15,17,18,19,20,21,24,26,29,30,32,33,34,35,37,39,43,45,47,48,50,51,52,53,54,56,57,59,60,64,66,67,
    71,79,80,85,87,89,90,91,92,94,98,99,100,101,102,104,105,106,107,108,109,112,115,117,120,122,123,124,125,126,127,129,137,139,140,141,142,145,146,147,151,153,154,
    157,158,160,161,162,165,166,167,168,169,170,171,175,177,178,179,180,181,183,184,185,186,194,195,196,197,202,205,207,212,213,216,225,231,236,240,243,244,248,250,251,252,256,
    259,261,265,266,269,270,271,272,273,278,279,283,285,288,289,290,291,294,297,301,302,306,308,309,317,328,337,341,349,357,360,366,368,369,370,372,378,379,380,389,391,397
])
actor_recipient = set([0,1,2,5,6,8,9,66, 85, 98, 153, 161, 167, 168, 175, 178, 180, 181, 186, 194, 195, 196, 212, 225, 231, 248, 250])

placement_and_positioning = set([79, 91, 137, 154,290, 328, 337, 357, 360, 378, 379, 380, 389,
157, 158, 162, 165, 166, 177, 179, 183, 184, 185, 202, 244, 248,
79, 91, 137, 154, 56
])

action_swaps = set([270, 271, 272, 278, 279, 283, 285, 288, 289, 294, 301, 306, 308, 309, 317, 328, 341, 349, 376, 357, 360, 366, 368, 369, 370, 372, 378, 379, 380, 389, 397, 0, 1, 2, 5, 6, 8, 9, 66,21,24,26,48,142 , 146 , 14, 180,181, 194, 195, 196, 225, 231, 248, 250])

counting = set([265, 285, 15, 17, 18, 19, 20, 59, 60,90, 140, 141, 145,265, 285, 15, 17, 18, 19, 20, 59, 60])

attribute_binding = set([71, 102, 104, 105, 106, 107, 108, 109, 112, 115,  122, 125, 127,259, 261, 266, 269, 273, 291, 297, 391,225, 231, 248])


categories = {
    'attribute_binding':attribute_binding,
    'counting':counting,
    'action_swaps':action_swaps,
    'placement_and_positioning':placement_and_positioning,
    'actor_recipient':actor_recipient
}

In [22]:
txt_score_avg_first = get_avg_cnt_first(txt_cnt)

In [23]:
txt_filtered_score_og, txt_filtered_cnt_og = get_filtered_score(txt_cnt,compositionality_sample_ids)

In [24]:
txt_filtered_score_actor, txt_filtered_cnt_actor = get_filtered_score(txt_cnt,actor_recipient)

In [26]:
for cat,cat_set in categories.items():
    tscore,_=get_filtered_score(txt_cnt,cat_set)
    iscore,_=get_filtered_score(im_cnt,cat_set)
    gscore,_=get_filtered_score(gp_cnt,cat_set)
    print(f"{cat},{tscore:.2%},{iscore:.2%},{gscore:.2%}")

attribute_binding,10.00%,9.17%,9.17%
counting,4.62%,4.62%,4.62%
action_swaps,12.00%,12.00%,12.00%
placement_and_positioning,13.33%,13.33%,13.33%
actor_recipient,12.59%,12.59%,12.59%


In [227]:
# 1. what's the baseline smatch score between C0,C1 (not cap(I0))
# 2. what's the breakdown of errors
# 3. subset category performance

In [228]:
def get_error_ids(cnt):
    # only zero text scores
    assert len(cnt)==400
    return [i for i in range(len(cnt)) if not cnt[i]]

In [229]:
txt_err_idx = get_error_ids(txt_score_avg_first)

In [261]:
txt_corr_idx = set(range(len(txt_score_avg_first)))-set(get_error_ids(txt_score_avg_first))

In [264]:
print(list(sorted(txt_corr_idx)))

[0, 5, 12, 14, 21, 29, 30, 38, 39, 41, 46, 47, 49, 50, 51, 58, 60, 70, 76, 94, 100, 111, 112, 113, 114, 117, 123, 158, 163, 165, 173, 175, 180, 189, 203, 209, 222, 225, 247, 255, 261, 269, 300, 325, 332, 337, 341, 342, 348, 349, 351, 357, 363, 365, 367, 374, 375, 376, 378, 387]


In [230]:
list(set(txt_err_idx)&compositionality_sample_ids)[:10]

[1, 2, 6, 7, 8, 9, 11, 15, 17, 18]

In [231]:
"""
s(Ii0,Ci0) 1
S(Ii0,Ci1) 0
s(Ii0,Ci0)>s(Ii0,Ci1) =>1 else 0


I0,C0,C1
"""

'\ns(Ii0,Ci0) 1\nS(Ii0,Ci1) 0\ns(Ii0,Ci0)>s(Ii0,Ci1) =>1 else 0\n\n\nI0,C0,C1\n'

In [257]:
idx=7*SAMPLE_SIZE
s00[idx],s10[idx],s11[idx],s01[idx]

(0.3333333333, 0.3333333333, 0.1666666667, 0.3333333333)

In [271]:
# s00[35:40]

In [21]:
idx=6
idx_scaled=idx*SAMPLE_SIZE
print(s00[idx_scaled],s10[idx_scaled],s11[idx_scaled],s01[idx_scaled])
print("ground truth")
print(gt_amrs[idx]['parse_0'])
print(gt_amrs[idx]['parse_1'])
print("generated")
print(cand_amrs[idx]['parses_0'][0])
print(cand_amrs[idx]['parses_1'][0])

0.24 0.1666666667 0.0833333333 0.1666666667
ground truth
# ::tok a plant was harmed by another organism , and that organism broke the plant into pieces
(a / and~8
    :op1 (h / harm-01~3
        :ARG0 (o / organism~6
            :mod (a2 / another~5))
        :ARG1 (p2 / plant~1))
    :op2 (b / break-01~11
        :ARG0 o
        :ARG1 p2
        :ARG2 (p / piece~15)))

# ::tok another organism was harmed by a plant , and that plant broke the organism into pieces
(a / and~8
    :op1 (h / harm-01~3
        :ARG0 (p2 / plant~6)
        :ARG1 (o / organism~1
            :mod (a2 / another~0)))
    :op2 (b / break-01~11
        :ARG0 p2
        :ARG1 o
        :ARG2 (p / piece~15)))

generated
# ::tok a man is removing a tree from the top of a tree
(r / remove-01~3
    :ARG0 (m / man~1)
    :ARG1 (t2 / tree~5)
    :ARG2 (t / top~8
        :part-of t2))

# ::tok a carnivorous plant on a black background
(p / plant~2
    :location (b / background~6
        :ARG1-of (b2 / black-04~5))
    :mo

In [None]:
# Eval metric sensitivity
# ground truth
# # ::tok a bottle is in water
# (b / be-located-at-91~3
#     :ARG1 (b2 / bottle~1)
#     :ARG2 (w / water~4))

# # ::tok water is in a bottle
# (b / be-located-at-91~2
#     :ARG1 (w / water~0)
#     :ARG2 (b2 / bottle~4))

# generated
# # ::tok a message in a bottle floating in the water
# (m / message-01~1
#     :location (b / bottle~4
#         :ARG1-of (f / float-01~5
#             :ARG2 (w / water~8))))

# # ::tok a bottle of water on a white background
# (w / water~3
#     :quant (b2 / bottle~1)
#     :prep-on (b / background~7
#         :ARG1-of (w2 / white-03~6)))


# ground truth
# # ::tok there is a table below someone
# (t / table~3
#     :location (b / below~4
#         :op1 (s / someone~5)))

# # ::tok there is someone below a table
# (s / someone~2
#     :location (b / below~3
#         :op1 (t / table~5)))

# generated
# # ::tok a woman standing on top of a table in an office
# (s / stand-01~2
#     :ARG1 (w / woman~1)
#     :ARG2 (t2 / top~4
#         :part-of (t / table~7
#             :location (o / office~10))))

# # ::tok a woman in a pink dress crouching under a table
# (w / woman~1
#     :ARG0-of (c / crouch-01~6
#         :location (u / under~7
#             :op1 (t / table~9)))
#     :mod (d / dress~5
#         :ARG1-of (p / pink-04~4)))


## Complete caption failure

# ground truth
# # ::tok a tree smashed into a car
# (s / smash-01~2
#     :ARG1 (t / tree~1)
#     :destination (c / car~5))

# # ::tok a car smashed into a tree
# (s / smash-01~2
#     :ARG1 (c / car~1)
#     :destination (t / tree~5))

# generated
# # ::tok a red car is parked under a fallen tree
# (p / park-01~4
#     :ARG1 (c / car~2
#         :ARG1-of (r / red-02~1))
#     :ARG2 (u / under~5
#         :op1 (t / tree~8
#             :ARG1-of (f / fall-01~7))))

# # ::tok a white car is parked next to a tree
# (p / park-01~4
#     :ARG1 (c / car~2
#         :ARG1-of (w / white-03~1))
#     :ARG2 (n / next-to~5
#         :op1 (t / tree~8)))

# idx 15
# ground truth
# # ::tok there are more skiers than snowboarders
# (h / have-quant-91~2
#     :ARG1 (p / person~3
#         :ARG0-of (s / ski-01~3))
#     :ARG3 (m / more~2)
#     :ARG4 (p2 / person~5
#         :ARG0-of (s2 / ski-01~5)))

# # ::tok there are more snowboarders than skiers
# (h / have-quant-91~2
#     :ARG1 (p / person~3
#         :ARG0-of (s / ski-01~3))
#     :ARG3 (m / more~2)
#     :ARG4 (p2 / person~5
#         :ARG0-of (s2 / ski-01~5)))

# generated
# # ::tok a group of people on skis
# (g / group~1
#     :consist-of (p / person~3
#         :ARG0-of (s / ski-01~5)))

# # ::tok a group of people on skis
# (g / group~1
#     :consist-of (p / person~3
#         :ARG0-of (s / ski-01~5)))
