In [None]:
import os
import pickle
from tqdm import tqdm
import numpy as np
import numpy.random as rand

from collections import Counter
import matplotlib.pyplot as plt

import pandas as pd

import snorkel
from snorkel.augmentation import transformation_function
from snorkel.augmentation import RandomPolicy, MeanFieldPolicy
from snorkel.augmentation import TFApplier, PandasTFApplier

from snorkel.labeling import labeling_function
from snorkel.labeling import PandasLFApplier


# Load Data

In [None]:
# "../../../w3c-emails/emails.pkl"
with open("../../../../w3c-emails/emails.pkl", "rb") as handle:
    emails = [e for e in pickle.load(handle)]

# rng = list(range(len(emails)))
for i, e in enumerate(emails):
    e.index = i

# senders = [e.sender for e in emails]


In [None]:
# "emails_token_ids.pkl"
with open("1000_emails_token_ids.pkl", "rb") as handle:
    email_ids = pickle.load(handle)

# Put into DF $\rightarrow$ AS COMBINATIONS

In [None]:
names = "index", "id", "sender", "body_raw"

# names2 = tuple(n + "_2" for n in names)

def by_author(emails):
    d = {}
    
    for e in emails:
        if not e.sender in d:
            d[e.sender] = []
        d[e.sender].append(e)
    return d

by_auths = by_author(emails)

        
def objs_to_frame(emails, min_size=1):
    rows = []
    for s, e_ls in by_auths.items():
        if len(e_ls) < min_size:
            continue
        e_ls2 = rand.permutation(e_ls)
        for e1, e2 in zip(e_ls, e_ls2):
            row = {n: str(getattr(e1, n)) for n in names}
            row.update({n + "_2": str(getattr(e2, n)) for n in names})
            rows.append(row)
    return pd.DataFrame(rows)

mail_frame = objs_to_frame(emails, min_size=2)#skip_singletons=True)

# Define Transformation Functions & Apply

In [None]:
@transformation_function()
def do_nothing(row):
    return row

all_auths = list(by_auths.keys())

def sample_other_author(auth_label, num_tries=0):
    sample_other_author.tries_collected.append(num_tries)
    
    drawn = all_auths[rand.choice(len(all_auths))]
    
    if num_tries < 5 and drawn == auth_label:
        return sample_other_author(auth_label, num_tries=num_tries+1)
    else:
        return drawn
sample_other_author.tries_collected = []


@transformation_function()
def do_noise(row):
    e_id, auth_label = row["index"], row["sender"]
    other_label = sample_other_author(auth_label)
    other_email = rand.choice(by_auths[other_label])
    
#     new_row = row.copy()
    for n in names:
        row[n+"_2"] = getattr(other_email, n)

    return row

tfs = [do_nothing, do_noise]

In [None]:
mf_policy = MeanFieldPolicy(
        len(tfs), sequence_length=1, p=[0.3, 0.7], n_per_original=2, keep_original=False
)

tf_applier = PandasTFApplier(tfs, mf_policy)
samples = tf_applier.apply(mail_frame)

In [None]:
samples[samples.sender == samples.sender_2].shape[0]/samples.shape[0]

# Label Rows

In [None]:
@labeling_function()
def label(row):
    return row.sender == row.sender_2

In [None]:
lf_applier = PandasLFApplier([label])
labelled = lf_applier.apply(samples)

samples["is_same"] = labelled

# Inspect Results

In [None]:
samples[samples["id"] == samples["id_2"]].shape[0]/samples.shape[0]

In [None]:
1 - samples.drop_duplicates().shape[0]/samples.shape[0]

In [None]:
samples[["index", "index_2", "is_same"]].shape

# Save everything

In [None]:
samples.to_csv("samples.csv.gz")

In [None]:
sample_matrix = samples[["index", "index_2", "is_same"]].to_numpy().astype("int")

In [None]:
with open("train_and_test_inds.pkl", "wb") as handle:
    pickle.dump(sample_matrix, handle)

In [None]:
xs = [(i1, i2) for i1, i2, _ in sample_matrix] 
ys = [y for _, _, y in sample_matrix]