In [274]:
import json
import nltk
from scipy import stats
from math import log
from tqdm import tqdm
from itertools import chain
from collections import Counter, defaultdict
nltk.download('punkt')



[nltk_data] Downloading package punkt to /Users/james/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [275]:
dist_cache = {}

def get_cache_dist_sum(distribution):
#     if id(distribution) not in dist_cache:
    dist_cache[id(distribution)] = sum(distribution.values())
    return dist_cache[id(distribution)]

def counter_prob(key,distribution:dict):
    return distribution[key]/get_cache_dist_sum(distribution)

def counter_prob_alt(key,distribution:dict,alt:dict):
    return distribution[key]/sum(alt.values())

def cond(a,b,joint,marginal):
    return joint[(a, b)] / marginal[a]

def condb(a,b,joint,marginal):
    return joint[(a, b)] / marginal[b]

def cpt(joint,marginal):
    probs = {}
    for a,b in joint.keys():
        probs[(a,b)] = cond(a,b,joint,marginal)
    return Counter(probs)

def cptb(joint,marginal):
    probs = {}
    for a,b in joint.keys():
        probs[(a,b)] = condb(a,b,joint,marginal)
    return Counter(probs)

def pmi(joint,marginal1,marginal2):
    pmis = {}
    print("Compute PMI for distribution of size {}".format(len(joint)))
    for a,b in tqdm(joint.keys()):
        m2 = counter_prob(b, marginal2)
        pmi = log((joint[(a, b)] / marginal1[a]) / m2, 2)
        pmis[(a, b)] = pmi
    return Counter(pmis)



def lmi(joint,marginal1,marginal2):
    lmis = {}
    print("Compute LMI for distribution of size {}".format(len(joint)))
    for a,b in tqdm(joint.keys()):
        j = counter_prob((a,b),joint)
        m2 = counter_prob(b,marginal2)

        lmi = j*log((joint[(a,b)]/marginal1[a])/m2,2)
        lmis[(a,b)] = lmi
    return Counter(lmis)

def strip_punct(text):
    return text.replace("."," ").replace(","," ").replace("?"," ").strip().lower()


def load_texts(file, is_question=False):
    with open(file) as f:
        for idx, line in enumerate(tqdm(f)):
            instance = json.loads(line)
            if is_question:
                texts = instance['source']['source_question']
                if isinstance(texts,str):
                    texts = [texts]
            else:
                texts = [instance['claim']]
                
#             
        
            labels = set(sec[1]['section_label'] for sec in instance['labels'].items() if sec[1]['section_label'] != "neutral")
            #labels = set(sec[1]['section_label'] for sec in instance['labels'].items())
            for text in texts:            
                yield text,labels
           
            


In [276]:
def compute_pmi_lmi(is_question):
    marginal_bigrams = Counter()
    marginal_labels = Counter()
    joint_bigrams_labels = Counter()
    bigram_counts = Counter()

    for text, labels in chain(
            load_texts("ncmace95/train.jsonl", is_question),
            load_texts("ncmace95/dev.jsonl", is_question),
            load_texts("ncmace95/test.jsonl", is_question)
        ):
        try:
            tokens = nltk.word_tokenize(strip_punct(text))
        except Exception as e:
            print(tokens)
            raise e
        bigrams = [" ".join(b) for b in nltk.bigrams(tokens)]

        bigram_counts.update(bigrams)

        for label in labels:
            marginal_bigrams.update(bigrams)
            marginal_labels.update([label])
            joint_bigrams_labels.update(zip(bigrams,[label]*len(bigrams)))

    bigrams_labels_pmi = pmi(joint_bigrams_labels, marginal_bigrams, marginal_labels)
    bigrams_labels_lmi = lmi(joint_bigrams_labels, marginal_bigrams, marginal_labels)
    
    label_mutations = cpt(joint_bigrams_labels,marginal_bigrams)
    label_mutations2 = cptb(joint_bigrams_labels,marginal_labels)

    prob_bg = []
    lmi_toks = [a[0][0] for a in bigrams_labels_lmi.most_common(5)]


    return bigrams_labels_pmi, bigrams_labels_lmi, label_mutations, label_mutations2


In [277]:
def normalize_distribution(dist):
    b = sum(dist)
    return [a/b for a in dist]

def entropy(dist_dict):
    ndist = normalize_distribution(dist_dict.values())
    return -sum(val*log(val,2) for val in ndist)

def entropy_bin(dist_dict):
    return entropy({k:v for k,v in dist_dict.items() if k!="neutral"})


def report_stats(ranking,maxrank=20):

    print("t1")
    t1 = []
    for (bigram,label), lmi in ranking.most_common(1000):
        if bigram_counts[bigram] < 50:
            continue
        if label == "false":
            t1.append(bigram)
            print(bigram.replace(" ","_") + "\t" + label + "\t" + str(lmi*1e6) + "\t" +  str(bigram_counts[bigram]))
            if len(t1) >= maxrank:
                break
    print()
    print()
    t2 = []
    print("t2")
    for (bigram,label), lmi in ranking.most_common(1000):
        if bigram_counts[bigram] < 50:
            continue
        if label == "true":
            t2.append(bigram)
            print(bigram.replace(" ","_") + "\t" + label + "\t" + str(lmi*1e6) + "\t" +  str(bigram_counts[bigram]))
            if len(t2) >= maxrank:
                break
    print()
    print()

    t3 = []
    print("t3")
    for (bigram,label), lmi in ranking.most_common(1000):
        if bigram_counts[bigram] < 50:
            continue
        if label == "neutral":
            t3.append(bigram)
            print(bigram.replace(" ","_") + "\t" + label + "\t" + str(lmi*1e6) + "\t" +  str(bigram_counts[bigram]))
            if len(t3) >= maxrank:
                break
    print()
    print()
    
    return t1,t2,t3
    
def report_dists(ranking,tt):
    
    table = defaultdict(dict)
    for bigram,prob in ranking.items():
        if bigram[0] in tt:
            prob_bg.append((bigram,prob))
            table[bigram[0]][bigram[1]] = prob

    for k in tt:
        v = table[k]
        print(k + "\t"+ str(entropy(v)) +"\t" + str(entropy_bin(v)) + "\t")
        print(str(v.get("true","")) +"\t" + str(v.get("false","")) + "\t" + str(v.get("neutral","")))
        print()


In [278]:
c_bigrams_labels_pmi, c_bigrams_labels_lmi, c_label_mutations, c_label_mutations2 = compute_pmi_lmi(False)

8723it [00:00, 8831.43it/s]
1088it [00:00, 8272.14it/s]
1086it [00:00, 8543.38it/s]
100%|██████████| 35187/35187 [00:00<00:00, 890267.44it/s]
  1%|          | 385/35187 [00:00<00:09, 3840.44it/s]

Compute PMI for distribution of size 35187
Compute LMI for distribution of size 35187


100%|██████████| 35187/35187 [00:09<00:00, 3897.90it/s]


In [287]:
c_t1,c_t2,c_t3 = report_stats(c_bigrams_labels_lmi)
print()
print()
report_dists(c_label_mutations,c_t1)
print()
report_dists(c_label_mutations,c_t2)
print()
report_dists(c_label_mutations,c_t3)

t1
the_same	false	3225.376375700174	1470
is_the	false	1823.3615081719472	930
same_as	false	1748.9058233550645	780
are_the	false	862.6139352246765	412
is_not	false	785.4844908847941	255
can_not	false	762.5608839768097	120
has_never	false	530.6452369946546	57
same_thing	false	445.4301576257112	167
not_the	false	403.988857605567	167
have_to	false	380.78345020125874	194
is_still	false	369.5805449208163	116
a_true	false	369.3116925702163	314
true_story	false	354.5491536851406	319
need_a	false	328.9350418897443	82
is_based	false	313.25637327311756	288
does_not	false	271.7239438644007	70
you_have	false	245.29568293110552	95
on_a	false	240.89065662049902	426
are_not	false	240.68588084399087	79
you_need	false	227.0012014710995	52


t2
there_is	true	1603.1574056717088	942
is_a	true	1206.8050620522959	1180
can_be	true	616.0656338854783	227
is_such	true	589.7651863632408	98
you_can	true	512.1406118236346	672
a_thing	true	458.227158865427	76
such_a	true	435.12574227494946	75
thing_as	true	350.39133

In [280]:
q_bigrams_labels_pmi, q_bigrams_labels_lmi, q_label_mutations, q_label_mutations2 = compute_pmi_lmi(True)

8723it [00:00, 8874.70it/s]
1088it [00:00, 8356.44it/s]
1086it [00:00, 8771.08it/s]
100%|██████████| 34946/34946 [00:00<00:00, 862891.18it/s]
  1%|          | 389/34946 [00:00<00:08, 3886.89it/s]

Compute PMI for distribution of size 34946
Compute LMI for distribution of size 34946


100%|██████████| 34946/34946 [00:08<00:00, 4035.33it/s]


In [281]:
q_t1,q_t2,q_t3 = report_stats(bigrams_labels_lmi)
#
print()
print()
report_dists(q_label_mutations,q_t1)
print()
report_dists(q_label_mutations,q_t2)
print()
report_dists(q_label_mutations,q_t3)

t1
the_same	false	1895.4476519071152	1470
same_as	false	1108.913713196593	780
have_to	false	410.5835091320015	194
is_a	false	357.0311540036411	1180
need_a	false	252.79684411709275	82
world_cup	false	208.68384302182722	293
same_thing	false	198.89512190601266	167
you_need	false	188.78594460465254	52
you_have	false	173.67915565969886	95
the_world	false	171.43390165879663	271
as_a	false	151.55775797787308	327
the_us	false	150.64148558868987	93
is_the	false	142.51065528538192	930
and_the	false	119.89361617245964	89
as_the	false	116.1585351065963	77
a_gun	false	102.03670690223984	73
to_be	false	99.63002798598757	292
xbox_360	false	81.66954074917683	61
xbox_one	false	68.55634661835924	66
is_an	false	53.09434204919919	88


t2
thing_as	true	601.239965071365	154
a_thing	true	354.82927305111144	76
such_a	true	337.521089728115	75
such_thing	true	249.93525647120123	50
world_cup	true	221.51911215059707	293
have_a	true	212.99739697248975	177
to_the	true	180.11929192414252	196
be_a	true	151.8783103672

In [282]:
list(bigrams_labels_lmi.keys())[:10]
print(bigrams_labels_lmi[('is a','true')],bigrams_labels_lmi[('is a','false')],bigrams_labels_lmi[('is a','neutral')])

-0.00021622788923510822 0.0003570311540036411 -9.815938898911844e-05


In [283]:
all_bigram_counts = Counter()

_, c_lm50,c_ct50,_ = compute_pmi_lmi(False)
all_bigram_counts.update(bigram_counts)
_, q_lm50,q_ct50,_ = compute_pmi_lmi(True)
all_bigram_counts.update(bigram_counts)

8723it [00:00, 8913.93it/s]
1088it [00:00, 8547.89it/s]
1086it [00:00, 8799.36it/s]
100%|██████████| 35187/35187 [00:00<00:00, 825267.01it/s]
  1%|          | 388/35187 [00:00<00:08, 3878.29it/s]

Compute PMI for distribution of size 35187
Compute LMI for distribution of size 35187


100%|██████████| 35187/35187 [00:08<00:00, 4006.26it/s]
8723it [00:00, 8875.88it/s]
1088it [00:00, 8991.40it/s]
1086it [00:00, 8800.87it/s]
100%|██████████| 34946/34946 [00:00<00:00, 870341.12it/s]
  1%|          | 397/34946 [00:00<00:08, 3968.25it/s]

Compute PMI for distribution of size 34946
Compute LMI for distribution of size 34946


100%|██████████| 34946/34946 [00:08<00:00, 4042.91it/s]


In [284]:
print(len(d1))
print(d1[:10])
print(d2[:10])

5236
[1.3921472236645345, 1.5001542009939985, 1.5588718484453605, 0.9709505944546686, -0.0, 1.3566695198333112, 1.5709505944546684, 1.5304930567574826, 1.4591479170272446, 1.5052408149441479]
[999, 1.5001542009939985, 1.5219280948873621, 0.9709505944546686, -0.0, 1.3566695198333112, 1.5709505944546684, 999, 999, 999]


In [285]:
all_keys = set(c_lm50.keys()).union(q_lm50.keys())

d1 = []
d2 = []

for k in all_keys:
    bigram,label = k
    if all_bigram_counts[bigram] < 5:
        continue
    d1.append(c_lm50[k])
    d2.append(q_lm50[k])
    
stats.spearmanr(d1,d2)

SpearmanrResult(correlation=0.8483446923108221, pvalue=0.0)

In [286]:
all_keys = set(c_ct50.keys()).union(q_ct50.keys())

d1 = []
d2 = []

table_c = defaultdict(dict)
for bigram,prob in c_ct50.items():
    table_c[bigram[0]][bigram[1]] = prob

table_q = defaultdict(dict)
for bigram,prob in q_ct50.items():
    table_q[bigram[0]][bigram[1]] = prob

all_keys = set(table_c.keys()).union(table_q.keys())

for k in all_keys:
    if all_bigram_counts[k] <5:
        continue
        
    v = table_c.get(k,{"true":0,"false":0,"neutral":0})
    if sum(v.values()) > 0:
        d1.append(entropy(v))
    else:
        d1.append(999)
    
    v = table_q.get(k,{"true":0,"false":0,"neutral":0})
    if sum(v.values()) > 0:
        d2.append(entropy(v))
    else:
        d2.append(999)

stats.spearmanr(d1,d2)

SpearmanrResult(correlation=0.6615945583719123, pvalue=0.0)