In [None]:
import table
import table.IO as tio
import torch
import os
from tqdm.auto import tqdm
import re
from collections import Counter

from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
import pickle

import mittens
 
stopWords = set(stopwords.words('english'))

import numpy as np
import pandas as pd

In [None]:
PUNCTUATION = {
    'sep'   : u'\u200b' + "/-'´′‘…—−–",
    'keep'  : "&",
    'remove': '?!.,，"#$%\'()*+-/:;<=>@[\\]^_`{|}~“”’™•°'
}

def clean_text(x):
    x = x.lower()

    for p in PUNCTUATION['sep']:
        x = x.replace(p, " ")
    for p in PUNCTUATION['keep']:
        x = x.replace(p, " %s " % p)
    for p in PUNCTUATION['remove']:
        x = x.replace(p, "")

    return x

In [None]:
base_dir = '../data_model/comp-sci-corpus/'

all_words = []

for f in tqdm(os.listdir(base_dir)):
    file_contents = [clean_text(l.strip().lower()) for l in open(os.path.join(base_dir, f), "rt").readlines()]
    
    for line in file_contents:
        for w in line.split():
            if re.match(r'[\w]+', w) and w not in stopWords:
                all_words.append(w)
            
print("len(all_words) = %d" % len(all_words))

vocab = Counter() 
for w in tqdm(all_words):
    vocab[w] += 1
    
print("len(vocab) = %d" % len(vocab))

### Build co-occurrence

In [None]:
thr = 20000
window = 10

top_words, top_freqs = zip(*vocab.most_common(thr))
top_words = set(top_words)

word2idx = {w: i for i, w in enumerate(top_words)}

M = np.zeros((thr, thr), dtype=np.uint16)

for i in tqdm(range(len(all_words))):
    if all_words[i] not in top_words:
        continue
    
    for j in range(max(i - window, 0), min(i + window, len(all_words))):
        if i == j or all_words[j] not in top_words: continue
        
        M[word2idx[all_words[i]], word2idx[all_words[j]]] += 1 

In [None]:
out_vocab_file = '../data_model/comp-sci-corpus-thr%d-window%d.vocab' % (thr, window)
out_mat_file = '../data_model/comp-sci-corpus-thr%d-window%d.mat' % (thr, window)

pickle.dump(word2idx, open(out_vocab_file, "wb"))
pickle.dump(M, open(out_mat_file, "wb"))

## GloVe fine tune

In [None]:
_base_dir = "/home/alex/workspace/git/kaggle.git/quora/input/"
EMB_GLOVE_FILE = "%s/embeddings/glove.840B.300d/glove.840B.300d.txt" % _base_dir

def load_glove():
    def get_coefs(word,*arr): 
        return word, np.asarray(arr, dtype='float32')
    embeddings_index = dict(get_coefs(*o.split(" ")) for o in open(EMB_GLOVE_FILE, encoding='latin'))    
    return embeddings_index

emb_glove = load_glove()

In [None]:
mittens_model = mittens.Mittens(n=300, max_iter=1000)

new_emb_glove = mittens_model.fit(
    M, # co-occurrence
    vocab=list(word2idx),
    initial_embedding_dict=emb_glove
)

In [None]:
def closest_to(w, n=1):
    xs = []
    
    for w_ in tqdm(emb_glove):
        if w == w_: continue
        xs += [(w_, np.dot(emb_glove[w], emb_glove[w_])/(np.linalg.norm(emb_glove[w]) * np.linalg.norm(emb_glove[w_])))]

    return [x for x, _ in sorted(xs, key=lambda x:-x[1])[:n]]

closest_to("function", n=10)