In [1]:
import pickle
from tqdm import tqdm
import numpy as np
import numpy.random as rand


import snorkel
from snorkel.augmentation import transformation_function
from snorkel.augmentation import RandomPolicy, MeanFieldPolicy
from snorkel.augmentation import TFApplier

In [None]:
with open("../../../w3c-emails/emails.pkl", "rb") as handle:
    senders = [e.senders for e in pickle.load(handle)]

In [None]:
from transformers import DistilBertTokenizer

In [2]:
with open("emails_token_ids.pkl", "rb") as handle:
    email_ids = pickle.load(handle)

In [None]:
inds = list(range(len(email_ids)))

In [None]:
by_author = {}

for x, e in tqdm(zip(senders, inds)):
    if not x in by_author:
        by_author[x] = []
    by_author[x].append(e)

# Snorkel

In [None]:
@transformation_function()
def positive_example(x):
    e_id, auth_label = x
    auth_mails = by_author[auth_label]
    return (e_id, rand.choice(by_author[auth_label])), 1


@transformation_function()
def negative_example(x):
    e_id, auth_label = x
    other_label = rand.choice(list(by_author.keys() - {auth_label}))
    return (e_id, rand.choice(by_author[other_label])), 0

tfs = [positive_example, negative_example]

In [None]:
random_policy = RandomPolicy(
        len(tfs), sequence_length=1, n_per_original=2, keep_original=False
)

mf_policy = MeanFieldPolicy(
        len(tfs), sequence_length=1, p=[0.1, 0.9], n_per_original=2, keep_original=False
)

tf_applier = TFApplier(tfs, mf_policy)
samples = tf_applier.apply(zip(inds, senders))

# Split Data: Train, Eval, Test

In [None]:
ratio = int(len(samples) * 0.6)
permuted = rand.permutation(samples)
train, test = permuted[:ratio], permuted[ratio:]

# Experimental Snorkel Stuff

In [None]:
from snorkel.labeling import LabelModel

label_model = LabelModel(cardinality=2, verbose=True)
label_model.fit(L_train=L_train, n_epochs=500, log_freq=100, seed=123)