In [1]:
import numpy as np
import os
import json, time, copy
import math

import random
import pickle

from word2number import w2n
import string, re
from collections import Counter, defaultdict
from pprint import pprint
import spacy
nlp = spacy.load("en_core_web_sm", disable=["ner","textcat","parser"])
np.set_printoptions(precision=4)

In [2]:
import sys
sys.path.append("/home/yingshac/CYS/WebQnA/VLP/BARTScore")
from bart_score import BARTScorer
#bart_scorer = BARTScorer(device='cuda:0', checkpoint='facebook/bart-large-cnn')

In [3]:
bart_scorer_ParaBank = BARTScorer(device='cuda:0', checkpoint='facebook/bart-large-cnn')
bart_scorer_ParaBank.load(path='/home/yingshac/CYS/WebQnA/VLP/BARTScore/bart.pth')

In [4]:
TABLE = str.maketrans(dict.fromkeys(string.punctuation)) 
def normalize_text_for_bart(x):
    return " ".join(x.translate(TABLE).split())
#def compute_bartscore(c, a, switch=False):
    #if switch: score = np.exp(bart_scorer.score(c, a))
    #else: score = np.exp(bart_scorer.score(a, c))
    #return score
def compute_bartscore_ParaBank(c, a, switch=False):
    c_removepunc = [normalize_text_for_bart(x) for x in c]
    a_removepunc = [normalize_text_for_bart(x) for x in a]
    if switch: score = np.exp(bart_scorer_ParaBank.score(c_removepunc, a_removepunc))
    else: score = np.exp(bart_scorer_ParaBank.score(a_removepunc, c_removepunc))
    return score

In [5]:
def detectNum(l):
    result = []
    for w in l:
        try: result.append(str(int(w)))
        except: pass
    return result
def toNum(word):
    if word == 'point': return word
    try: return w2n.word_to_num(word)
    except:
        return word

def normalize_text(s):
    def remove_articles(text):
        regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
        return re.sub(regex, " ", text)

    def white_space_fix(text): # additional: converting numbers to digit form
        return " ".join([str(toNum(w)) for w in text.split()])

    def remove_punc(text):
        exclude = set(string.punctuation) - set(['.'])
        text1 = "".join(ch for ch in text if ch not in exclude)
        return re.sub(r"\.(?!\d)", "", text1) # remove '.' if it's not a decimal point

    def lower(text):
        return text.lower()
    
    def lemmatization(text):
        return " ".join([token.lemma_ for token in nlp(text)])

    if len(s.strip()) == 1:
        # accept article and punc if input is a single char
        return white_space_fix(lower(s))
    elif len(s.strip().split()) == 1: 
        # accept article if input is a single word
        return lemmatization(white_space_fix(remove_punc(lower(s))))

    return lemmatization(white_space_fix(remove_articles(remove_punc(lower(s)))))


In [6]:
# VQA Eval (SQuAD style EM, F1)
def compute_vqa_metrics(cands, a, exclude="", domain=None):
    if len(cands) == 0: return (0,0,0)
    bow_a = normalize_text(a).split()
    F1 = []
    EM = 0
    RE = []
    PR = []
    e = normalize_text(exclude).split()
    for c in cands:
        bow_c = [w for w in normalize_text(c).split() if not w in e]
        if domain == {"NUMBER"}: bow_c = detectNum(bow_c)
        elif domain is not None: 
            bow_c = list(domain.intersection(bow_c))
            bow_a = list(domain.intersection(bow_a))
        
        #print(bow_c)
        #print(bow_a)
        if bow_c == bow_a:
            EM = 1
        common = Counter(bow_a) & Counter(bow_c)
        num_same = sum(common.values())
        if num_same == 0:
            return (0,0,0,0,0)
        precision = 1.0 * num_same / len(bow_c)
        recall = 1.0 * num_same / len(bow_a)
        RE.append(recall)
        PR.append(precision)

        f1 = 2*precision*recall / (precision + recall + 1e-5)
        F1.append(f1)
    
    PR_avg = np.mean(PR)
    RE_avg = np.mean(RE)
    F1_avg = np.mean(F1)
    F1_max = np.max(F1)
    return (F1_avg, F1_max, EM, RE_avg, PR_avg)

In [7]:
color_set= {'orangebrown', 'spot', 'yellow', 'blue', 'rainbow', 'ivory', 'brown', 'gray', 'teal', 'bluewhite', 'orangepurple', 'black', 'white', 'gold', 'redorange', 'pink', 'blonde', 'tan', 'turquoise', 'grey', 'beige', 'golden', 'orange', 'bronze', 'maroon', 'purple', 'bluere', 'red', 'rust', 'violet', 'transparent', 'yes', 'silver', 'chrome', 'green', 'aqua'}
shape_set = {'globular', 'octogon', 'ring', 'hoop', 'octagon', 'concave', 'flat', 'wavy', 'shamrock', 'cross', 'cylinder', 'cylindrical', 'pentagon', 'point', 'pyramidal', 'crescent', 'rectangular', 'hook', 'tube', 'cone', 'bell', 'spiral', 'ball', 'convex', 'square', 'arch', 'h', 'cuboid', 'step', 'rectangle', 'dot', 'oval', 'circle', 'star', 'crosse', 'crest', 'octagonal', 'cube', 'triangle', 'semicircle', 'domeshape', 'obelisk', 'corkscrew', 'curve', 'circular', 'xs', 'slope', 'pyramid', 'round', 'bow', 'straight', 'triangular', 'heart', 'fork', 'teardrop', 'fold', 'curl', 'spherical', 'diamond', 'keyhole', 'conical', 'dome', 'sphere', 'bellshaped', 'rounded', 'hexagon', 'flower', 'globe', 'torus'}
yesno_set = {'yes', 'no'}

In [8]:
img_dataset = json.load(open("/home/yingshac/CYS/WebQnA/WebQnA_data_new/img_dataset_0823_clean_te.json", "r"))

print(Counter([img_dataset[k]['split'] for k in img_dataset]))
print(Counter([img_dataset[k]['Qcate'] for k in img_dataset]))
print(Counter([img_dataset[k]['Qcate'] for k in img_dataset if img_dataset[k]['split'] == 'test']))
print(len(set([img_dataset[k]['Guid'] for k in img_dataset])))

Counter({'train': 16448, 'test': 3464, 'val': 2511})
Counter({'YesNo': 7430, 'Others': 5823, 'choose': 4693, 'number': 2084, 'color': 1832, 'shape': 561})
Counter({'Others': 1058, 'choose': 981, 'YesNo': 935, 'color': 228, 'number': 200, 'shape': 62})
22423


In [9]:
x = []
for k in img_dataset:
    if 'test' in img_dataset[k]['split']:
        x.append(len(img_dataset[k]['A']))
print(Counter(x))

Counter({6: 1579, 5: 780, 4: 703, 3: 402})


In [10]:
### With normalization, img_data
bleu4_img_clean = {}
RE_img_clean = {}
F1_img_clean = {}
mul_img_clean = {}
acc_img_clean = {}

drop = defaultdict(int)
for k in img_dataset:
    if not 'test' in img_dataset[k]['split']: continue
    bleu4_img_clean[k] = []
    RE_img_clean[k] = []
    mul_img_clean[k] = []
    F1_img_clean[k] = []
    acc_img_clean[k] = []
    datum = img_dataset[k]
    Keywords_A = datum['Keywords_A'].replace('"', "")
    all_A = [a.replace('"', "") for a in datum['A']]
    Qcate = img_dataset[k]['Qcate']
    scores = np.zeros((len(all_A), len(all_A)))
    c_list = []
    a_list = []
    for i in range(len(all_A)):
        for j in range(len(all_A)):
            c_list.append(all_A[i])
            a_list.append(all_A[j])
    scores = compute_bartscore_ParaBank(c_list, a_list).reshape((len(all_A), len(all_A)))
    for i in range(len(all_A)):
        Q = datum['Q'].replace('"', "")
        C = [all_A[i]]
        
        
        score = np.max([min(1.0, scores[i][j]/scores[j][j]) for j in range(len(all_A)) if not i==j])
        #print(scores)
        
        if Qcate == 'color': F1_avg, F1_max, EM, RE_avg, PR_avg = compute_vqa_metrics(C, Keywords_A, "", color_set)
        elif Qcate == 'shape': F1_avg, F1_max, EM, RE_avg, PR_avg = compute_vqa_metrics(C, Keywords_A, "", shape_set)
        elif Qcate == 'YesNo': F1_avg, F1_max, EM, RE_avg, PR_avg = compute_vqa_metrics(C, Keywords_A, "", yesno_set)
        elif Qcate == 'number': F1_avg, F1_max, EM, RE_avg, PR_avg = compute_vqa_metrics(C, Keywords_A, "", {"NUMBER"})
        else: F1_avg, F1_max, EM, RE_avg, PR_avg = compute_vqa_metrics(C, Keywords_A)
        bleu4_img_clean[k].append(score)
        RE_img_clean[k].append(RE_avg)
        F1_img_clean[k].append(F1_avg)
        if Qcate in ['color', 'shape', 'number', 'YesNo']: 
            acc_img_clean[k].append(F1_avg)
            mul_img_clean[k].append(F1_avg * score)
        else: 
            acc_img_clean[k].append(RE_avg)
            mul_img_clean[k].append(RE_avg * score)
        
    if len(RE_img_clean) % 500 == 499: print(len(RE_img_clean))
assert len(RE_img_clean) == len(mul_img_clean) == len(bleu4_img_clean) == len(acc_img_clean) == len(F1_img_clean)
print(len(RE_img_clean))
print(drop)
print(np.sum(list(drop.values())))

499
999
1499
1999
2499
2999
3464
defaultdict(<class 'int'>, {})
0.0


In [11]:
print(Counter([len(RE_img_clean[k]) for k in RE_img_clean]))
Qcate = ['choose', 'YesNo', 'Others', 'color', 'shape', 'number']
print("mean RE: ", np.mean([np.mean(RE_img_clean[k]) for k in RE_img_clean if img_dataset[k]['Qcate'] in Qcate]))
print("mean F1: ", np.mean([np.mean(F1_img_clean[k]) for k in F1_img_clean if img_dataset[k]['Qcate'] in Qcate]))
print("mean accuracy: ", np.mean([np.mean(acc_img_clean[k]) for k in acc_img_clean if img_dataset[k]['Qcate'] in Qcate]))
print("mean fluency: ", np.mean([np.mean(bleu4_img_clean[k]) for k in bleu4_img_clean if img_dataset[k]['Qcate'] in Qcate]))
print("mean mul: ", np.mean([np.mean(mul_img_clean[k]) for k in mul_img_clean if img_dataset[k]['Qcate'] in Qcate]))

Qcate = ['choose', 'YesNo', 'Others', 'color', 'shape', 'number']
for cate in Qcate:
    print("\n", cate)
    print("mean RE: ", np.mean([np.mean(RE_img_clean[k]) for k in RE_img_clean if img_dataset[k]['Qcate'] == cate]))
    print("mean F1: ", np.mean([np.mean(F1_img_clean[k]) for k in F1_img_clean if img_dataset[k]['Qcate'] == cate]))
    print("mean accuracy: ", np.mean([np.mean(acc_img_clean[k]) for k in acc_img_clean if img_dataset[k]['Qcate'] == cate]))
    print("mean fluency: ", np.mean([np.mean(bleu4_img_clean[k]) for k in bleu4_img_clean if img_dataset[k]['Qcate'] == cate]))
    print("mean mul: ", np.mean([np.mean(mul_img_clean[k]) for k in mul_img_clean if img_dataset[k]['Qcate'] == cate]))

Counter({6: 1579, 5: 780, 4: 703, 3: 402})
mean RE:  0.9520481520712423
mean F1:  0.5651965763804441
mean accuracy:  0.9468847284155235
mean fluency:  0.630391549716049
mean mul:  0.6025058408240133

 choose
mean RE:  0.9681807354975492
mean F1:  0.30618078530459153
mean accuracy:  0.9681807354975492
mean fluency:  0.6839234503683429
mean mul:  0.6665757875949492

 YesNo
mean RE:  1.0
mean F1:  0.9996384943803229
mean accuracy:  0.9996384943803229
mean fluency:  0.5743393083148296
mean mul:  0.5741045999547485

 Others
mean RE:  0.8780125472868293
mean F1:  0.24214694472518142
mean accuracy:  0.8780125472868293
mean fluency:  0.6177855181773121
mean mul:  0.5554930805609485

 color
mean RE:  0.9836866471734893
mean F1:  0.958286424484597
mean accuracy:  0.958286424484597
mean fluency:  0.6244174217447634
mean mul:  0.6059464311760433

 shape
mean RE:  0.9821236559139784
mean F1:  0.9475757166031306
mean accuracy:  0.9475757166031306
mean fluency:  0.5919534839302913
mean mul:  0.563981

In [146]:
# Cache normalization, img data
img_dataset_0825_bartscorePB_te = copy.deepcopy(img_dataset)
count = 0
for k in img_dataset_0825_bartscorePB_te:
    if not img_dataset_0825_bartscorePB_te[k]['split'] == 'test': continue
    a_list = [a.replace('"', "") for a in img_dataset_0825_bartscorePB_te[k]['A']]
    scores = compute_bartscore_ParaBank(a_list, a_list)
    img_dataset_0825_bartscorePB_te[k]['bartscore_normalizer'] = json.dumps(list(scores))
    count += 1
    if count % 500 == 499: print(count)
#json.dump(img_dataset_0825_bartscorePB_te, open("/home/yingshac/CYS/WebQnA/WebQnA_data_new/img_dataset_0825_bartscorePB_te.json", "w"),indent=4)

499
999
1499
1999
2499
2999


### Draft Zone

In [16]:
'''
from datasets import load_metric
metric = load_metric("bertscore")
def compute_bertscore(cands, a, model_type):
    metric.add_batch(predictions = cands, references = [a]*len(cands))
    score = metric.compute(model_type=model_type)
    return np.mean(score['f1']), np.max(score['f1'])
'''

In [9]:
cands = ["5 years."]*5
a = ["The last direct Capetian died some 577 years before the coat of arms of France was created.",
    "The last direct Capetian died 577 years before the coat of arms of France was created.",
    "The coat of arms of France was created 577 years after the last direct Capetian died.",
    "The last Capetian died 572 years before the coat of arms was created.",
    "The last direct Capetian died about 577 years before the coat of arms of France was created."]
deno = compute_bartscore_ParaBank(a, a)
print("deno = ", deno)
print(compute_bartscore_ParaBank(a, cands)/deno)
print(compute_bartscore_ParaBank(cands, a)/deno)

deno =  [0.5113 0.512  0.5519 0.5394 0.5117]
[0.0049 0.0054 0.0074 0.0063 0.0052]
[0.0005 0.0007 0.001  0.0008 0.0006]
