# Precalculate generalized 'word' vocab and context vocab

Each triple represents two 'words' and a dependency edge between them.<br/>
As per https://aclanthology.org/P14-2050.pdf, take each word and make <br/>
a context 'word' from the concatenation of the edge with the other word <br/>
of the triple, accounting for direction with a '-1'<br/>

For example:<br/>

&emsp;*"Alice threw the ball."*<br/><br/>
would yield the triples:<br/>

&emsp;*(throw)-[nsubj]->(Alice)*<br/>
&emsp;*(throw)-[dobj]->(ball)*<br/><br/>
From which we want to get the (word, context) pairs:

&emsp;throw, Alice/nsubj<br/>
&emsp;throw, ball/dobj<br/>
&emsp;Alice, throw/nsubj-1<br/>
&emsp;ball, throw/dobj-1<br/>

Then we can construct a word vocabulary and a context vocabulary constrainted<br/>
to the words and contexts that appear at least K times. Because the vocabularies<br/>
are disjoint, their sizes will be different.

### Imports

In [None]:
from pathlib import Path
from collections import Counter
import sys
import os

from tqdm.notebook import tqdm
import pandas as pd
import numpy as np
import torch

### Config

In [None]:
# Place where the pipeline artifacts are going. default: {project}/data
data_path = Path('../').resolve().joinpath('data')

# Name of the input triples folder in data/triples
dataset = 'wikipedia_20220101'

# Required instances for a word or context to be included in a vocabulary
K = 100

### Load
In order to circumvent memory limits on a single machine, the word-context</br>
pair dataset has to be constructed as pairs of mapped integers. This makes</br>
it necessary to perform a first pass over the triples files to compute the</br>
(un-pruned) vocabularies</br>

In [None]:
# TODO map-reduce this
wcounts = Counter()
ccounts = Counter()
triples_files = list(data_path.joinpath('triples', dataset).glob('*.df'))

for fp in tqdm(triples_files):
    triples = pd.read_parquet(fp, engine='fastparquet')
    for src, edge, dst in zip(triples['src'], triples['path'], triples['dst']):
        
        wcounts[src] += 1
        wcounts[dst] += 1
        ccounts[dst + '/' + edge] += 1
        ccounts[src + '/' + edge + '-1'] += 1


In [None]:
wvocab = {}
for wi, (word, count) in enumerate(wcounts.most_common()):
    if count < K:
        break
    wvocab[word] = wi

In [None]:
cvocab = {}
for ci, (context, count) in enumerate(ccounts.most_common()):
    if count < K:
        break
    cvocab[context] = ci

In [None]:
max_table_size = max(
    sum(v for v in wcounts.values() if v >= K),
    sum(v for v in ccounts.values() if v >= K)
)

In [None]:
max_table_size

In [None]:
wc_pairs = np.zeros((max_table_size, 2), dtype=np.int32)
wc_pairs

In [None]:
# Only load pairs where both word and context are possibly in vocab
triples_files = list(data_path.joinpath('triples', dataset).glob('*.df'))

row = 0
for fp in tqdm(triples_files):
    triples = pd.read_parquet(fp, engine='fastparquet')
    for src, edge, dst in zip(triples['src'], triples['path'], triples['dst']):

        src_context = dst + '/' + edge
        if src in wvocab and src_context in cvocab:
            wi = wvocab[src]
            ci = cvocab[src_context]
            wc_pairs[row] = [wi, ci]
            row += 1

        dst_context = src + '/' + edge + '-1'
        if dst in wvocab and dst_context in cvocab:
            wi = wvocab[dst]
            ci = cvocab[dst_context]
            wc_pairs[row] = [wi, ci]
            row += 1

In [None]:
wc_pairs = wc_pairs[:row]

In [None]:
wc_pairs = pd.DataFrame(wc_pairs, columns=['wi', 'ci'])
wc_pairs

### Prune
The arrangement of word,context pairs makes this a big bipartite graph, and </br>
removing a word 'node' or a context 'node' from that graph could result in </br>
its' neighbors falling below the frequency threshold. So prune iteratively until</br>
this doesn't happen.</br>
This treatment of the threshold isn't *explicit*

In [None]:
# pruning = True
# while pruning:

#     len_before = len(wc_pairs)
    
#     wc_pairs['w_count'] = wc_pairs['word'].map(wc_pairs['word'].value_counts())
#     wc_pairs['c_count'] = wc_pairs['context'].map(wc_pairs['context'].value_counts())
#     wc_pairs = wc_pairs.loc[(wc_pairs['w_count'] >= K) & (wc_pairs['c_count'] >= K)]
    
#     num_removed = len_before - len(wc_pairs)
#     print("Removed {:,} pairs below frequency threshold.".format(num_removed))
#     pruning = (num_removed != 0)

### Save Vocab Files And Mapped Pairs Data

In [None]:
# Will overwrite existing pairs dataset
output_folder = data_path.joinpath('pairs', dataset)
if not os.path.isdir(output_folder):
    os.mkdir(output_folder)

In [None]:
pairs_file = data_path.joinpath('pairs', dataset, 'pairs.pt')
pairs = torch.tensor(wc_pairs[['wi', 'ci']].values)
torch.save(pairs, pairs_file)

In [None]:
output_folder = data_path.joinpath('vocab', dataset)
if not os.path.isdir(output_folder):
    os.mkdir(output_folder)

In [None]:
# Will overwrite any existing vocab for the dataset
wvocab_file = data_path.joinpath('vocab', dataset, 'wvocab.txt')
with open(wvocab_file, 'w+') as outfile:
    for word in wvocab:
        outfile.write(word)
        outfile.write('\n')

In [None]:
# Will overwrite any existing vocab for the dataset
cvocab_file = data_path.joinpath('vocab', dataset, 'cvocab.txt')
with open(cvocab_file, 'w+') as outfile:
    for context in cvocab:
        outfile.write(context)
        outfile.write('\n')