In [1]:
from typing import List, TypeVar

import pandas as pd
import numpy as np
from convokit import Corpus, Utterance
from pickle import load

In [2]:
from final import *

%load_ext autoreload
%autoreload 2

In [3]:
RANDOM_SEED = 375
BOOTSTRAP_SIZE = 10
STATISTIC = "PPMI"

In [4]:
def get_table_from_corpus(corpus: Corpus) -> defaultdict:
    counts = get_counts(corpus)
    table = get_table(counts)
    return table


def bootstrap_corpus(rng: np.random.Generator, corpus: Corpus) -> Corpus:
    # Miraculously no type casting needed! :)
    utts = [utt for utt in corpus.iter_utterances()]
    chosen_utts = rng.choice(utts, size=len(utts), replace=True)
    new_corpus = Corpus(utterances=chosen_utts)
    return new_corpus

In [5]:
ground_truth_corpus = get_corpus()
rng = np.random.default_rng(RANDOM_SEED)

In [19]:
ground_truth_table = get_table_from_corpus(ground_truth_corpus)
ground_truth_tokens = ground_truth_table.keys()

In [20]:
def bootstrap(tokens: np.array, statistic: str, size: int = BOOTSTRAP_SIZE) -> defaultdict:
    results = defaultdict(list, {k:[] for k in tokens})
    for _ in range(BOOTSTRAP_SIZE):
        corpus = bootstrap_corpus(rng, ground_truth_corpus)
        table = get_table_from_corpus(corpus)
        for index in table:
            results[index].append(table[index][statistic])
    return results

# def bootstrap(statistic: str, size: int = BOOTSTRAP_SIZE) -> np.array:
#     results = []
#     for _ in range(BOOTSTRAP_SIZE):
#         corpus = bootstrap_corpus(rng, ground_truth_corpus)
#         table = get_table_from_corpus(corpus)
#         column = np.asarray(table[statistic].array)
#         results.append(column)
#     # Shape: (v * b), where v is size of vocabulary and b is bootstrap iterations
#     return results

In [21]:
results = bootstrap(tokens=ground_truth_tokens, statistic=STATISTIC)

In [22]:
# Quick test
print(results["court"])

for token in results:
    if None in results[token]:
        print(token)
        break

print(results["infringe"])

[0.2037752275409371, 0.17548793088627715, 0.22923094936937127, 0.17190794071575805, 0.14912721562696998, 0.21555508104626742, 0.1959674047697289, 0.12556064918270576, 0.17826097326511603, 0.22722256328746163]
infringe
[None, 0, None, 0, 0, 0, 0, 0, 0, 0]


In [23]:
# We quickly filter out any token where a bootstrap iteration leads to a NaN value for PPMI
results = {token: results[token] for token in results if not (None in results[token])}

In [14]:
def get_p_values(results: defaultdict, ground_truth: defaultdict, statistic: str) -> defaultdict:
    p_values = defaultdict(float)
    for token in results:
        if ground_truth[token][statistic] == None:
            continue

        p = significance_test(results[token], ground_truth[token][statistic])
        p_values[token] = p
    return p_values

Comparable = TypeVar('Comparable')
def significance_test(results: List[Comparable], ground_truth: Comparable) -> float:
    return sum(1 if x >= ground_truth else 0 for x in results) / len(results)

# empirical_pvalues = np.divide(results, ground_truth)
# empirical_pvalues = np.apply_along_axis(
#     lambda a: sum(1 if x > 1 else 0 for x in a) / len(a), 0, empirical_pvalues
# )
# # Shape: (v * 1), where v is the size of the vocabulary
# return empirical_pvalues

In [24]:
p_values = get_p_values(results, ground_truth_table, STATISTIC)

In [25]:
print(p_values)
# Keep in mind this number will be larger as the #iterations grows
print(len(p_values))

defaultdict(<class 'float'>, {'mr.': 0.2, 'chief': 1.0, 'justice': 1.0, ',': 0.2, 'and': 0.7, 'may': 0.4, 'it': 1.0, 'please': 0.7, 'the': 1.0, 'court': 0.4, ':': 0.8, 'when': 0.8, 'state': 1.0, 'exclusive': 0.7, 'federal': 1.0, 'right': 0.5, 'that': 0.3, 'congress': 1.0, 'is': 1.0, 'charged': 0.3, 'with': 1.0, 'can': 0.6, 'make': 1.0, 'pay': 1.0, 'for': 0.6, 'doing': 1.0, 'so': 0.2, '.': 0.1, "'s": 0.5, 'our': 1.0, 'today': 1.0, 'one': 1.0, 'follows': 1.0, 'from': 1.0, 'constitution': 1.0, 'text': 1.0, 'basis': 0.5, 'this': 0.4, 'to': 1.0, 'work': 1.0, 'did': 0.4, 'in': 1.0, 'enacting': 0.4, 'article': 0.7, 'i': 1.0, 'section': 0.7, '<': 1.0, 'number': 1.0, '>': 1.0, 'clause': 1.0, 'what': 1.0, 'we': 0.6, "'re": 1.0, 'calling': 1.0, 'property': 0.7, 'unique': 0.4, 'within': 1.0, 'down': 1.0, 'an': 0.4, 'express': 0.5, 'constitutional': 0.2, 'mandate': 1.0, 'protect': 0.3, 'private': 1.0, 'against': 0.6, 'any': 0.6, 'all': 1.0, 'intrusion': 0.5, 'consider': 0.2, 'just': 0.2, 'how': 0.2