In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="1" 

In [2]:
import numpy as np
import json, time, copy
import math

import random
import pickle

from word2number import w2n
import string, re
from collections import Counter, defaultdict
from pprint import pprint
import spacy
nlp = spacy.load("en_core_web_sm", disable=["ner","textcat","parser"])
np.set_printoptions(precision=4)

In [3]:
import sys
sys.path.append("/home/yingshac/CYS/WebQnA/VLP/BARTScore")
from bart_score import BARTScorer
#bart_scorer = BARTScorer(device='cuda:0', checkpoint='facebook/bart-large-cnn')

In [4]:
bart_scorer_ParaBank = BARTScorer(device='cuda:0', checkpoint='facebook/bart-large-cnn')
bart_scorer_ParaBank.load(path='/home/yingshac/CYS/WebQnA/VLP/BARTScore/bart.pth')

In [5]:
TABLE = str.maketrans(dict.fromkeys(string.punctuation)) 
def normalize_text_for_bart(x):
    return " ".join(x.translate(TABLE).split())
#def compute_bartscore(c, a, switch=False):
    #if switch: score = np.exp(bart_scorer.score(c, a))
    #else: score = np.exp(bart_scorer.score(a, c))
    #return score
def compute_bartscore_ParaBank(c, a, switch=False):
    c_removepunc = [normalize_text_for_bart(x) for x in c]
    a_removepunc = [normalize_text_for_bart(x) for x in a]
    if switch: score = np.exp(bart_scorer_ParaBank.score(c_removepunc, a_removepunc))
    else: score = np.exp(bart_scorer_ParaBank.score(a_removepunc, c_removepunc))
    return score

In [6]:
def detectNum(l):
    result = []
    for w in l:
        try: result.append(str(int(w)))
        except: pass
    return result
def toNum(word):
    if word == 'point': return word
    try: return w2n.word_to_num(word)
    except:
        return word

def normalize_text(s):
    def remove_articles(text):
        regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
        return re.sub(regex, " ", text)

    def white_space_fix(text): # additional: converting numbers to digit form
        return " ".join([str(toNum(w)) for w in text.split()])

    def remove_punc(text):
        exclude = set(string.punctuation) - set(['.'])
        text1 = "".join(ch for ch in text if ch not in exclude)
        return re.sub(r"\.(?!\d)", "", text1) # remove '.' if it's not a decimal point

    def lower(text):
        return text.lower()
    
    def lemmatization(text):
        return " ".join([token.lemma_ for token in nlp(text)])

    if len(s.strip()) == 1:
        # accept article and punc if input is a single char
        return white_space_fix(lower(s))
    elif len(s.strip().split()) == 1: 
        # accept article if input is a single word
        return lemmatization(white_space_fix(remove_punc(lower(s))))

    return lemmatization(white_space_fix(remove_articles(remove_punc(lower(s)))))


In [7]:
# VQA Eval (SQuAD style EM, F1)
def compute_vqa_metrics(cands, a, exclude="", domain=None):
    if len(cands) == 0: return (0,0,0)
    bow_a = normalize_text(a).split()
    F1 = []
    EM = 0
    RE = []
    PR = []
    e = normalize_text(exclude).split()
    for c in cands:
        bow_c = [w for w in normalize_text(c).split() if not w in e]
        if domain == {"NUMBER"}: bow_c = detectNum(bow_c)
        elif domain is not None: 
            bow_c = list(domain.intersection(bow_c))
            bow_a = list(domain.intersection(bow_a))
        
        #print(bow_c)
        #print(bow_a)
        if bow_c == bow_a:
            EM = 1
        common = Counter(bow_a) & Counter(bow_c)
        num_same = sum(common.values())
        if num_same == 0:
            return (0,0,0,0,0)
        precision = 1.0 * num_same / len(bow_c)
        recall = 1.0 * num_same / len(bow_a)
        RE.append(recall)
        PR.append(precision)

        f1 = 2*precision*recall / (precision + recall + 1e-5)
        F1.append(f1)
    
    PR_avg = np.mean(PR)
    RE_avg = np.mean(RE)
    F1_avg = np.mean(F1)
    F1_max = np.max(F1)
    return (F1_avg, F1_max, EM, RE_avg, PR_avg)

In [8]:
txt_dataset = json.load(open("/home/yingshac/CYS/WebQnA/WebQnA_data_new/txt_dataset_0823_clean_te.json", "r"))

print(Counter([txt_dataset[k]['split'] for k in txt_dataset]))
print(len(set([txt_dataset[k]['Guid'] for k in txt_dataset])))

Counter({'train': 17812, 'test': 4076, 'val': 2455})
24343


In [9]:
x = []
for k in txt_dataset:
    if txt_dataset[k]['split'] == 'test':
        x.append(len(txt_dataset[k]['A']))
print(Counter(x))

Counter({6: 1366, 5: 1343, 4: 874, 3: 493})


In [10]:
### Normalization, txt data
bleu4_txt_clean = {}
RE_txt_clean = {}
F1_txt_clean = {}
mul_txt_clean = {}

drop = defaultdict(int)
for k in txt_dataset:
    if not 'test' in txt_dataset[k]['split']: continue
    bleu4_txt_clean[k] = []
    RE_txt_clean[k] = []
    mul_txt_clean[k] = []
    F1_txt_clean[k] = []
    datum = txt_dataset[k]
    Keywords_A = datum['Keywords_A'].replace('"', "")
    all_A = [a.replace('"', "") for a in datum['A']]
    scores = np.zeros((len(all_A), len(all_A)))
    c_list = []
    a_list = []
    for i in range(len(all_A)):
        for j in range(len(all_A)):
            c_list.append(all_A[i])
            a_list.append(all_A[j])
    scores = compute_bartscore_ParaBank(c_list, a_list).reshape((len(all_A), len(all_A)))
    for i in range(len(all_A)):
        Q = datum['Q'].replace('"', "")
        C = [all_A[i]]
        
        score = np.max([min(1, scores[i][j]/scores[j][j]) for j in range(len(all_A)) if not i==j])
        
        F1_avg, F1_max, EM, RE_avg, PR_avg = compute_vqa_metrics(C, Keywords_A)
        bleu4_txt_clean[k].append(score)
        RE_txt_clean[k].append(RE_avg)
        F1_txt_clean[k].append(F1_avg)
        mul_txt_clean[k].append(RE_avg * score)
        
    if len(RE_txt_clean) % 500 == 499: print(len(RE_txt_clean))
assert len(RE_txt_clean) == len(mul_txt_clean) == len(bleu4_txt_clean) == len(F1_txt_clean)
print(len(RE_txt_clean))
print(drop)
print(np.sum(list(drop.values())))

499
999
1499
1999
2499
2999
3499
3999
4076
defaultdict(<class 'int'>, {})
0.0


In [11]:
#  Normalization
print(Counter([len(RE_txt_clean[k]) for k in RE_txt_clean]))
print("mean RE: ", np.mean([np.mean(RE_txt_clean[k]) for k in RE_txt_clean]))
print("mean bleu4: ", np.mean([np.mean(bleu4_txt_clean[k]) for k in bleu4_txt_clean]))
print("mean mul: ", np.mean([np.mean(mul_txt_clean[k]) for k in mul_txt_clean]))

Counter({6: 1366, 5: 1343, 4: 874, 3: 493})
mean RE:  0.9400107298311824
mean bleu4:  0.4834490687569886
mean mul:  0.4562010155711136


In [12]:
# Cache normalization, txt data
txt_dataset_0825_bartscorePB_te = copy.deepcopy(txt_dataset)
count = 0
for k in txt_dataset_0825_bartscorePB_te:
    if not txt_dataset_0825_bartscorePB_te[k]['split'] == 'test': continue
    a_list = [a.replace('"', "") for a in txt_dataset_0825_bartscorePB_te[k]['A']]
    scores = compute_bartscore_ParaBank(a_list, a_list)
    txt_dataset_0825_bartscorePB_te[k]['bartscore_normalizer'] = json.dumps(list(scores))
    count += 1
    if count % 500 == 499: print(count)
json.dump(txt_dataset_0825_bartscorePB_te, open("/home/yingshac/CYS/WebQnA/WebQnA_data_new/txt_dataset_0825_bartscorePB_te.json", "w"),indent=4)

499
999
1499
1999
2499
2999
3499
3999


### Draft Zone

In [16]:
'''
from datasets import load_metric
metric = load_metric("bertscore")
def compute_bertscore(cands, a, model_type):
    metric.add_batch(predictions = cands, references = [a]*len(cands))
    score = metric.compute(model_type=model_type)
    return np.mean(score['f1']), np.max(score['f1'])
'''

In [17]:
cands = ["Meter"]*5
a = ["Iambic foot is a sub part of a meter, the meter is the overall patterns in a verse, foot helps make that up.",
    "Iambic foot can make up meter more than meter can make up iambic foot.",
    "Meter makes up lambic foot more than lambic foot makes up meter.",
    "An Iambic foot helps make up a Meter more than a meter helps make up an Iambic foot.",
    "An iambic foot helps to make up a meter in poetry."]
deno = compute_bartscore_ParaBank(a, a)
print("deno = ", deno)
print(compute_bartscore_ParaBank(a, cands)/deno)
print(compute_bartscore_ParaBank(cands, a)/deno)

deno =  [0.3545 0.3307 0.3259 0.4168 0.35  ]
[0.0041 0.0136 0.0089 0.0207 0.0121]
[0.0108 0.014  0.0217 0.014  0.0091]


In [16]:
cands = ["They are native to lakes ."]*5
a = ["The Janaria and Labidochirus splendescens are both native to the Pacific Ocean",
    "Janaria and Labidochirus splendescens are both native to the Pacific Ocean.",
    "Janaria and Labidochirus splendescens are native to the Pacific Ocean.",
    "The Janaria and Labidochirus splendescens native to the Pacific Ocean",
    "The Janaria and Labidochirus splendescens are native to the Pacific Ocean."]
deno = compute_bartscore_ParaBank(a, a)
print("deno = ", deno)
print(compute_bartscore_ParaBank(a, cands)/deno)
print(compute_bartscore_ParaBank(cands, a)/deno)

deno =  [0.5157 0.5782 0.5817 0.4787 0.5187]
[0.0038 0.0033 0.0033 0.0026 0.0039]
[0.018  0.0142 0.022  0.0119 0.0246]
