# Precalculate generalized 'word' vocab and context vocab

Each triple represents two 'words' and a dependency edge between them.<br/>
As per https://aclanthology.org/P14-2050.pdf, take each word and make <br/>
a context 'word' from the concatenation of the edge with the other word <br/>
of the triple, accounting for direction with a '-1'<br/>

For example:<br/>

&emsp;*"Alice threw the ball."*<br/><br/>
would yield the triples:<br/>

&emsp;*(throw)-[nsubj]->(Alice)*<br/>
&emsp;*(throw)-[dobj]->(ball)*<br/><br/>
From which we want to get the (word, context) pairs:

&emsp;throw, Alice/nsubj<br/>
&emsp;throw, ball/dobj<br/>
&emsp;Alice, throw/nsubj-1<br/>
&emsp;ball, throw/dobj-1<br/>

Then we can construct a word vocabulary and a context vocabulary constrainted<br/>
to the words and contexts that appear at least K times. Because the vocabularies<br/>
are disjoint, their sizes will be different.

### Imports

In [None]:
from pathlib import Path
import os

import pandas as pd
import numpy as np

import pathvecs

In [None]:
# TODO config this
data_path = Path(pathvecs.__file__).parents[1].joinpath('data')

### Config

In [None]:
# Name of the input triples folder in data/triples
dataset = 'wikipedia_20220101'

# Required instances for a word or context to be included in a vocabulary
K = 50

### Load

In [None]:
# duplicate and interleave entries, as each triple corresponds to two pairs
# TODO: may need to do this with integer representations for larger datasets
wc_pairs = []
triples_files = list(data_path.joinpath('triples', dataset).glob('*.df'))

for fp in triples_files:
    triples = pd.read_parquet(fp, engine='fastparquet')
    for src, edge, dst in zip(triples['src'], triples['path'], triples['dst']):
    
        wc_pairs.append((src, dst + '/' + edge))
        wc_pairs.append((dst, src + '/' + edge + '-1'))

wc_pairs = pd.DataFrame(wc_pairs, columns=['word', 'context'])
wc_pairs

### Prune
The arrangement of word,context pairs makes this a big bipartite graph, and </br>
removing a word 'node' or a context 'node' from that graph could result in </br>
its' neighbors falling below the frequency threshold. So prune iteratively until</br>
this doesn't happen.</br>

In [None]:
pruning = True
while pruning:

    len_before = len(wc_pairs)
    
    wc_pairs['w_count'] = wc_pairs['word'].map(wc_pairs['word'].value_counts())
    wc_pairs['c_count'] = wc_pairs['context'].map(wc_pairs['context'].value_counts())
    wc_pairs = wc_pairs.loc[(wc_pairs['w_count'] >= K) & (wc_pairs['c_count'] >= K)]
    
    num_removed = len_before - len(wc_pairs)
    print("Removed {:,} pairs below frequency threshold.".format(num_removed))
    pruning = (num_removed != 0)

### Save Vocab Files

In [None]:
output_folder = data_path.joinpath('vocab', dataset)
if not os.path.isdir(output_folder):
    os.mkdir(output_folder)

In [None]:
# Will overwrite any existing vocab for the dataset
wvocab_file = data_path.joinpath('vocab', dataset, 'wvocab.npy')
wc_pairs.sort_values('w_count', ascending=False, inplace=True)
wvocab = wc_pairs['word'].unique()
np.save(wvocab_file, wvocab)

In [None]:
# Will overwrite any existing vocab for the dataset
cvocab_file = data_path.joinpath('vocab', dataset, 'cvocab.npy')
wc_pairs.sort_values('c_count', ascending=False, inplace=True)
cvocab = wc_pairs['context'].unique()
np.save(cvocab_file, cvocab)