In [13]:
import numpy as np
from scipy.sparse import csr_matrix
import os
from heapq import nlargest
from random import sample


def load_cooccurrences(path):
  """ Usage: load_cooccurrences("cooccurrence.bin") """
  dt = np.dtype([('i', '<i4'), ('j', '<i4'), ('x', '<f8')])
  arr = np.fromfile(path, dtype=dt)
  return csr_matrix((arr['x'], (arr['i']-1, arr['j']-1)))


def load_vocab(path):
  """
  Usage: load_vocab("vocab.txt")

  Returns a list of tuples of (word: str, freq: int)
  """
  with open(path, "r") as f:
    res = []
    for line in f:
      word, freq = line.split(' ')
      res.append((word, int(freq)))
  return res


def load_vectors(path, vector_size, vocab_size):
  """
  Usage: load_vectors("vectors.bin")

  Returns (word_vectors, context_vectors, word_biases, context_biases).

  word_vectors and context_vectors are (vocab_size, vector_size) matrices
  word_biases and context_biases are (vocab_size) arrays
  """
  dt = np.dtype('<f8')
  arr = np.fromfile(path, dtype=dt)
  vecs = arr.reshape((2*vocab_size, vector_size+1))
  word_mat, ctx_mat = np.split(vecs, 2)
  word, ctx = word_mat[:, :vector_size], ctx_mat[:, :vector_size]
  bias_word, bias_ctx = word_mat[:, vector_size], ctx_mat[:, vector_size]
  return word, ctx, bias_word, bias_ctx


home = os.path.expanduser('~')

# TODO: This sample data is from GloVe's `demo.sh`, need to train for Wikipedia
cooccur_path = "cooccurrence4.bin"
vocab_path = "vocab.txt"
vector_path = "vectors4.bin"
vector_size = 50

print("Loading...")
C = load_cooccurrences(cooccur_path)
vocab = load_vocab(vocab_path)
dictionary = [v[0] for v in vocab]
D = len(dictionary)
freq = [v[1] for v in vocab]
vecs = load_vectors(vector_path, vector_size, len(dictionary))
word, ctx, B, B_ctx = vecs
print("Loaded.")


Loading...
Loaded.


In [14]:
s = 22748 # officio
t = 48628 # leverett
print("s", s, dictionary[s])
print("t", t, dictionary[t])

def SIM1(u, v):
  return word[u].dot(ctx[v]) + ctx[u].dot(word[v])

def SIM2(u, v):
  return word[u].dot(word[v]) + ctx[u].dot(ctx[v])


print("SIM1", SIM1(s, t))
print("SIM2", SIM2(s, t))
print("SIM+2", (SIM1(s, t)+SIM2(s, t))/2)

s 22748 officio
t 48628 leverett
sim1 5.572523183980332
sim2 7.695977689508876
sim1+2 6.634250436744604
