In [1]:
from typing import List, TypeVar

import pandas as pd
import numpy as np
from convokit import Corpus, Utterance
from pickle import load

In [2]:
from final import *

%load_ext autoreload
%autoreload 2

In [3]:
RANDOM_SEED = 375
BOOTSTRAP_SIZE = 10
STATISTIC = "PPMI"

In [4]:
def get_table_from_corpus(corpus: Corpus) -> defaultdict:
    counts = get_counts(corpus)
    table = get_table(counts)
    return table


def bootstrap_corpus(rng: np.random.Generator, corpus: Corpus) -> Corpus:
    # Miraculously no type casting needed! :)
    utts = [utt for utt in corpus.iter_utterances()]
    chosen_utts = rng.choice(utts, size=len(utts), replace=True)
    new_corpus = Corpus(utterances=chosen_utts)
    return new_corpus

In [5]:
ground_truth_corpus = get_corpus()
rng = np.random.default_rng(RANDOM_SEED)

In [6]:
# Testing bootstrap_corpus
bootstrap_corpus(rng, ground_truth_corpus)

<convokit.model.corpus.Corpus at 0x7f0ba0195d60>

In [7]:
ground_truth_table = get_table_from_corpus(ground_truth_corpus)
ground_truth_tokens = ground_truth_table.keys()

In [8]:
def bootstrap(tokens: np.array, statistic: str, size: int = BOOTSTRAP_SIZE) -> defaultdict:
    results = defaultdict(list, {k:[] for k in tokens})
    for _ in range(BOOTSTRAP_SIZE):
        corpus = bootstrap_corpus(rng, ground_truth_corpus)
        table = get_table_from_corpus(corpus)
        for index in table:
            results[index].append(table[index][statistic])
    return results

# def bootstrap(statistic: str, size: int = BOOTSTRAP_SIZE) -> np.array:
#     results = []
#     for _ in range(BOOTSTRAP_SIZE):
#         corpus = bootstrap_corpus(rng, ground_truth_corpus)
#         table = get_table_from_corpus(corpus)
#         column = np.asarray(table[statistic].array)
#         results.append(column)
#     # Shape: (v * b), where v is size of vocabulary and b is bootstrap iterations
#     return results

In [9]:
results = bootstrap(tokens=ground_truth_tokens, statistic=STATISTIC)

In [10]:
# Quick test
print(results["court"])

for token in results:
    if None in results[token]:
        print(token)
        break

print(results["infringe"])

[0.17548793088627715, 0.22923094936937127, 0.17190794071575805, 0.14912721562696998, 0.21555508104626742, 0.1959674047697289, 0.12556064918270576, 0.17826097326511603, 0.22722256328746163, 0.18309177283104977]
infringe
[0, None, 0, 0, 0, 0, 0, 0, 0, None]


In [11]:
# We quickly filter out any token where a bootstrap iteration leads to a NaN value for PPMI
# Taking advantage of falsiness of None
results = {token: results[token] for token in results if all(results[token])}

In [14]:
def get_p_values(results: defaultdict, ground_truth: defaultdict, statistic: str) -> defaultdict:
    p_values = defaultdict(float)
    for token in results:
        if ground_truth[token][statistic] == None:
            continue

        p = significance_test(results[token], ground_truth[token][statistic])
        p_values[token] = p
    return p_values

Comparable = TypeVar('Comparable')
def significance_test(results: List[Comparable], ground_truth: Comparable) -> float:
    return sum(1 if x >= ground_truth else 0 for x in results) / len(results)

# empirical_pvalues = np.divide(results, ground_truth)
# empirical_pvalues = np.apply_along_axis(
#     lambda a: sum(1 if x > 1 else 0 for x in a) / len(a), 0, empirical_pvalues
# )
# # Shape: (v * 1), where v is the size of the vocabulary
# return empirical_pvalues

In [13]:
p_values = get_p_values(results, ground_truth_table, STATISTIC)

In [15]:
print(p_values)

defaultdict(<class 'float'>, {'and': 0.6, 'may': 0.4, 'court': 0.3, 'when': 0.7, 'that': 0.4, 'so': 0.3, '.': 0.1, "'s": 0.5, 'basis': 0.6, 'enacting': 0.4, 'article': 0.7, 'section': 0.7, 'property': 0.6, 'unique': 0.5, 'constitutional': 0.2, 'consider': 0.2, 'just': 0.2, 'how': 0.3, 'clear': 0.5, 'not': 0.4, 'be': 0.4, 'your': 0.4, 'honor': 0.2, '?': 0.8, 'exercise': 0.3, 'thus': 0.4, 'wanting': 0.5, 'plan': 0.5, 'no': 0.5, 'immunity': 0.4, 'ha': 0.4, 'already': 0.2, 'patent': 0.7, 'talking': 0.5, 'could': 0.6, 'had': 0.4, 'reasoning': 0.5, 'again': 0.9, 'correct': 0.4, 'wa': 0.3, 'fact': 0.4, 'whether': 0.3, 'valid': 0.4, 'held': 0.4, 'erroneous': 0.6, 'bankruptcy': 0.7, 'find': 0.5, 'methodology': 0.8, 'relevant': 0.5, 'majority': 0.5, 'case': 0.2, 'sentence': 0.3, 'he': 0.4, 'noted': 0.7, 'term': 0.6, 'yes': 0.4, 'essentially': 0.5, 'certainly': 0.4, 'where': 0.7, 'precedential': 0.2, 'undermined': 0.3, 'friendly': 0.4, 'because': 0.3, 'been': 0.5, 'decide': 0.2, 'really': 0.6, 'a