In [None]:
import numpy as np
from scipy.sparse import csr_matrix
import os
from heapq import nlargest
from random import sample


def load_cooccurrences(path):
  """ Usage: load_cooccurrences("cooccurrence.bin") """
  dt = np.dtype([('i', '<i4'), ('j', '<i4'), ('x', '<f8')])
  arr = np.fromfile(path, dtype=dt)
  return csr_matrix((arr['x'], (arr['i']-1, arr['j']-1)))


def load_vocab(path):
  """
  Usage: load_vocab("vocab.txt")

  Returns a list of tuples of (word: str, freq: int)
  """
  with open(path, "r") as f:
    res = []
    for line in f:
      word, freq = line.split(' ')
      res.append((word, int(freq)))
  return res


def load_vectors(path, vector_size, vocab_size):
  """
  Usage: load_vectors("vectors.bin")

  Returns (word_vectors, context_vectors, word_biases, context_biases).

  word_vectors and context_vectors are (vocab_size, vector_size) matrices
  word_biases and context_biases are (vocab_size) arrays
  """
  dt = np.dtype('<f8')
  arr = np.fromfile(path, dtype=dt)
  vecs = arr.reshape((2*vocab_size, vector_size+1))
  word_mat, ctx_mat = np.split(vecs, 2)
  word, ctx = word_mat[:, :vector_size], ctx_mat[:, :vector_size]
  bias_word, bias_ctx = word_mat[:, vector_size], ctx_mat[:, vector_size]
  return word, ctx, bias_word, bias_ctx


home = os.path.expanduser('~')

# TODO: This sample data is from GloVe's `demo.sh`, need to train for Wikipedia
cooccur_path = "sample-data/cooccurrence.bin"
vocab_path = "sample-data/vocab.txt"
vector_path = "sample-data/vectors.bin"
vector_size = 50

print("Loading...")
C = load_cooccurrences(cooccur_path)
vocab = load_vocab(vocab_path)
dictionary = [v[0] for v in vocab]
D = len(dictionary)
freq = [v[1] for v in vocab]
vecs = load_vectors(vector_path, vector_size, len(dictionary))
word, ctx, B, B_ctx = vecs
print("Loaded.")


In [None]:
print("Generating 100 samples of the top 100k pairs...")

top_100k_sample = sample(
    nlargest(100000, range(D), key=lambda i: freq[i]), 200)
top_pairs = list(zip(top_100k_sample[::2], top_100k_sample[1::2]))

print("Done.")


In [None]:
# TODO: Implement made-up words
# TODO: Implement bias for made-up words; see pp. 17-18

# Set to None to disable
# fake = "foobar123"
fake = None

if fake is not None:
  if fake in dictionary:
    raise RuntimeError("Fake word isn't fake!")

  fake_idx = D
  dictionary.append(fake)
  freq.append(0)
  B = np.append(B, 0)
  D += 1
  C.resize((D, D))
else:
  fake_idx = -1


In [None]:
import torch
import numpy as np
import cupy as cp
from numba import cuda, njit
from math import exp, log, sqrt, inf
from bisect import bisect_left
from copy import deepcopy
from typing import List


@cuda.jit(device=True)
def cuda_model_f(u, v, c, epsilon, B):
  return max(log(c)-B[u]-B[v], epsilon)


@cuda.jit
def sim2_kernel(s, delta, Cps, Mdots_t, Ms_norm, M_t, M_t_norm, B, T_wgt, res):
  # assumes i is not s or a target (otherwise, the result is incorrect but not undefined)
  i = cuda.grid(1)
  if i < Cps.shape[0]:
    sim2 = 0.
    old_M_si = cuda_model_f(s, i, Cps[i], 0, B)
    new_M_si = cuda_model_f(s, i, Cps[i]+delta, 0, B)
    for t in range(M_t.shape[0]):
      dot_id = Mdots_t[t] + (new_M_si-old_M_si)*M_t[t, i]
      Ms_normid = Ms_norm + new_M_si*new_M_si - old_M_si*old_M_si
      Mt_normid = M_t_norm[t]
      # sim2 += T_wgt[t]*dot_id/sqrt(Ms_normid*Mt_normid)
      old_sim2 = Mdots_t[t]/sqrt(Ms_norm*M_t_norm[t])
      new_sim2 = dot_id/sqrt(Ms_normid*Mt_normid)
      sim2 += T_wgt[t]*(new_sim2-old_sim2)
    res[i] = sim2


@njit
def _model_f(u: int, v: int, c: float, epsilon: float, B: np.ndarray) -> float:
  logc = log(c) if c > 0 else -inf
  return max(logc-B[u]-B[v], epsilon)


@njit
def _M(u: int, v: int, C_uv: float, Dhat: float):
  if u == s:
    C_offset = Dhat[v]
  elif v == s:
    C_offset = Dhat[u]
  else:
    C_offset = 0
  return _model_f(u, v, C_uv+C_offset, 0, B)


class CorpusPoison:
  class CompDiffState:
    def __init__(self, Jhat: float, Csum, M_norms, t_dots, Delta_size: int):
      self.Jhat = Jhat
      self.Csum = Csum
      self.M_norms = M_norms
      self.t_dots = t_dots
      self.Delta_size = Delta_size

  class TensorPrime:
    def __init__(self, tensor: torch.Tensor, overrides=None):
      self.tensor = tensor
      self.overrides = overrides if overrides is not None else {}

    def __getitem__(self, key):
      return self.overrides[key] if key in self.overrides else self.tensor[key]

    def __setitem__(self, key, value):
      self.overrides[key] = value

    def apply(self, other):
      for key in self.overrides:
        other[key] = self.overrides[key]

    def dbg(self):
      return "{" + ", ".join(f"{k}: {self.tensor[k]}->{v}" for k, v in self.overrides.items()) + "}"

  def __init__(self, dictionary, cooccur, bias):
    self.dictionary = dictionary
    self.C = cooccur
    self.B = bias
    self.e60 = exp(-60)

  def model_f(self, u: int, v: int, c: float, epsilon: float, B: np.ndarray) -> float:
    return _model_f(u, v, c, epsilon, B)

  def M(self, u, v, C):
    return _M(u, v, C[u, v], self.Dhat)

  def sim1(self, t: float, C: float, Csum: List[float]):
    s = self.s
    num = self.M(s, t, C)
    den1 = self.model_f(s, t, Csum[s], self.e60, B)
    den2 = self.model_f(s, t, Csum[t], self.e60, B)
    return num/sqrt(den1*den2)

  def sim2(self, t, t_dots, M_norms):
    s = self.s
    return t_dots[t]/sqrt(M_norms[s]*M_norms[t])

  def comp_diff_naive(self, i: int, delta: float, state: CompDiffState, dbg=False) -> CompDiffState:
    # note: we assume that C is symmetric and s is not in POS or NEG.

    # calculate the new state after hypothetically executing C[s, i] += delta.
    # TensorPrime is essentially a tensor view that records our modifications
    # without applying them to the original tensor. this makes things much cleaner.
    C = self.TensorPrime(self.C)
    C[s, i] += delta
    if i != s:
      C[i, s] += delta

    # side effect: Csum[s] += delta and Csum[i] += delta
    Csum = self.TensorPrime(state.Csum)
    Csum[s] += delta
    if i != s:
      Csum[i] += delta

    # side effect: t_dots[t] += (change in M[s, i])*M[t, i]
    t_dots = self.TensorPrime(state.t_dots)
    def old_M(u, v): return self.M(u, v, self.C)
    def new_M(u, v): return self.M(u, v, C)
    for t in self.T:
      # account for the change in M_si
      t_dots[t] += (new_M(s, i)-old_M(s, i)) * old_M(t, i)
      if t == i:
        # account for the change in M_ts
        t_dots[t] += (new_M(t, s)-old_M(t, s)) * old_M(s, s)

    # side effect: M_norms[s] += change in (M_si)^2
    M_norms = self.TensorPrime(state.M_norms)
    M_norms[s] += new_M(s, i)**2 - old_M(s, i)**2
    if i in self.T:
      M_norms[i] += new_M(i, s)**2 - old_M(i, s)**2

    # ok, the new state is almost ready!
    # calculate the changes in sim and the objective
    change_Jhat_total = 0.
    for t in self.T:
      old_sim1 = self.sim1(t, self.C, state.Csum)
      new_sim1 = self.sim1(t, C, Csum)
      old_sim2 = self.sim2(t, state.t_dots, state.M_norms)
      new_sim2 = self.sim2(t, t_dots, M_norms)
      change_sim12 = ((new_sim1-old_sim1)+(new_sim2-old_sim2))/2
      sign = 1 if t in self.POS else -1
      change_Jhat_total += sign*change_sim12

    dJhat = change_Jhat_total/len(self.T)
    # adjust weight for second-order sequences
    omega_i = 1 if i in self.POS else 137/30

    return self.CompDiffState(dJhat, Csum, M_norms, t_dots, delta/omega_i)

  def solve_greedy(self, s: int, POS, NEG, t_rank: float, alpha: float, max_Delta: float):
    self.s = s
    self.POS = POS
    self.NEG = NEG
    self.t_rank = t_rank
    self.alpha = alpha
    self.max_Delta = max_Delta
    self.T = T = POS + NEG
    A = T + [s]
    C = self.C
    B = self.B
    print("Generating keep list...")
    keep = np.fromiter(
        (j for j in range(len(self.dictionary))
         if j in A or any(self.model_f(u, j, C[u, j], 0, B) > 0 for u in A)),
        dtype=np.int_
    )
    A_keep_idx = {x: i for i, x in enumerate(keep) if x in A}
    pct = 100*keep.shape[0]/len(self.dictionary)
    print(f"Keep list length: {keep.shape[0]} ({pct:.2f}%)")
    print("Preparing state...")
    self.Dhat = Dhat = np.zeros(D, dtype=np.float32)
    K = len(keep)
    Csum = {j: C[j].sum() for j in keep}

    def M_row(j):
      return np.fromiter((self.M(j, d, C) for d in range(D)), dtype=np.float32)

    def M_norm2(j):
      return np.linalg.norm(M_row(j)) ** 2

    M_norms = {u: M_norm2(u) for u in A}
    t_dots = {t: M_row(s).dot(M_row(t)).item() for t in A}
    T_wgt = np.array([1]*len(POS) + [-1]*len(NEG), dtype=np.float32)

    total_sim12 = 0.
    for t, w in zip(T, T_wgt):
      sim1 = self.sim1(t, C, Csum)
      sim2 = self.sim2(t, t_dots, M_norms)
      sim12 = (sim1+sim2)/2
      total_sim12 += w*sim12
    init_Jhat = total_sim12/len(T)
    print("Initial Jhat:", init_Jhat)

    state = self.CompDiffState(init_Jhat, Csum, M_norms, t_dots, 0)
    orig_state = deepcopy(state)  # used in sanity checks later

    Ms_norm = M_norm2(s)
    Kstride = (4*K+31) & (~31)  # TODO: should this be larger?
    pad = Kstride//4 - K

    Mdots_t = np.array([t_dots[t] for t in T], dtype=np.float32)
    M_t = np.array([[self.M(t, j, C) for j in keep]+[0]
                   * pad for t in T], dtype=np.float32)

    print("Preparing CUDA...")

    sim2_Cps = cuda.to_device(C[s, keep].todense().A1)
    sim2_Mdots_t = cuda.to_device(Mdots_t)
    sim2_M_t = cuda.to_device(M_t)
    sim2_M_t_norm = cuda.to_device(
        np.array([M_norms[t] for t in T], dtype=np.float32))
    sim2_B = cuda.to_device(B[keep])
    sim2_T_wgt = cuda.to_device(T_wgt)
    sim2_res = cuda.device_array(K, dtype=np.float32)

    threadsperblock = 32
    blockspergrid = (K + (threadsperblock - 1)) // threadsperblock

    print("Starting iteration...")
    iters = 0
    while state.Jhat < t_rank + alpha and state.Delta_size < max_Delta:
      dmap = {}
      iters += 1

      for l in range(1, 31):
        delta = l/5
        new_Csum = self.TensorPrime(state.Csum)
        new_Csum[s] += delta
        old_sim1 = sum(w*self.sim1(t, self.C, state.Csum)
                       for t, w in zip(T, T_wgt))
        new_sim1 = sum(w*self.sim1(t, self.C, new_Csum)
                       for t, w in zip(T, T_wgt))
        dsim1 = new_sim1-old_sim1

        # call the kernel
        sim2_kernel[blockspergrid, threadsperblock](
            A_keep_idx[s], delta, sim2_Cps, sim2_Mdots_t, Ms_norm, sim2_M_t, sim2_M_t_norm, sim2_B, sim2_T_wgt, sim2_res)

        # kernel complete! create cupy array view and use it to find sim12 vector
        sim2_res_cp = cp.asarray(sim2_res)
        sim12 = (dsim1+sim2_res_cp)/(2*len(T))

        # argmax the values that CUDA calculates correctly (ie. everything not in A)
        for j in A:
          sim12[A_keep_idx[j]] = -100
        i = keep[sim12.argmax().item()]

        # calculate naively for i and everything in A
        for j in A + [i]:
          dmap[(j, delta)] = self.comp_diff_naive(j, delta, state, dbg=False)

      i_star, delta_star = -1, -1
      best = -1
      for i, delta in dmap:
        if i == 0 or i == s:
          continue
        cd_state = dmap[(i, delta)]
        score = cd_state.Jhat / cd_state.Delta_size
        if score > best:
          best = score
          i_star, delta_star = i, delta
      cd_star = dmap[(i_star, delta_star)]

      # Update naive state
      state.Jhat += cd_star.Jhat
      cd_star.Csum.apply(state.Csum)
      cd_star.M_norms.apply(state.M_norms)
      cd_star.t_dots.apply(state.t_dots)
      state.Delta_size += cd_star.Delta_size
      Dhat[i_star] += delta_star

      # Update CUDA state
      i_star_keep = bisect_left(keep, i_star)
      if i_star_keep >= K or keep[i_star_keep] != i_star:
        print("UNABLE TO FIND i_star in keep!")
        return None
      sim2_Cps[i_star_keep] += delta_star
      for t, x in cd_star.t_dots.overrides.items():
        if t in T:
          sim2_Mdots_t[T.index(t)] = x
      Ms_norm = cd_star.M_norms[s]
      for t, x in cd_star.M_norms.overrides.items():
        if t in T:
          sim2_M_t_norm[T.index(t)] = x

      if iters % 20 == 0:
        print("Iterated!", iters, i_star, delta, state.Jhat,
              state.Delta_size, dictionary[i_star])

      if iters % 500 == 0:
        print("Calculating sanity check...")
        # Sanity check: freshly recalculate the state, and compare the recalculated objective to our stored Jhat.
        # TODO: Should we actually update the state with our recalculated values to help avoid compounding rounding errors?

        Csum_dbg = {t: orig_state.Csum[t]+self.Dhat[t] for t in T}
        Csum_dbg[s] = orig_state.Csum[s] + self.Dhat.sum()
        M_norms_dbg = {u: np.linalg.norm(M_row(u)).item() ** 2 for u in A}
        t_dots_dbg = {t: M_row(s).dot(M_row(t)).item() for t in A}

        total_sim12 = 0.
        for t, w in zip(T, T_wgt):
          sim1 = self.sim1(t, C, Csum_dbg)
          sim2 = self.sim2(t, t_dots_dbg, M_norms_dbg)
          sim12 = (sim1+sim2)/2
          total_sim12 += w*sim12
        real_Jhat = total_sim12/len(T)
        print("state vs. real Jhat:", state.Jhat, real_Jhat)
        error = abs(state.Jhat - real_Jhat)
        print("Measured error:", error)

    print("Done!")
    return Dhat


model = CorpusPoison(dictionary, C, B)
for s, t in top_pairs[:1]:
  print()
  print("Trying pair:", s, t, dictionary[s], dictionary[t])
  Dhat = model.solve_greedy(s, [t], [], inf, 0, 1250)
  break

In [None]:
import random
from math import ceil

Gamma = 5
lambda_ = 5


def gamma(d):
  return 1/d if d <= Gamma else 0


def place_additions(Dhat, s, POS):
  """ Algorithm 2, Placement into corpus: finding the change set ∆ """
  Delta = []

  for t in POS:
    Delta += [f"{dictionary[s]} {dictionary[t]}"] * ceil(Dhat[t]/gamma(1))

  change_map = {u: Du for u, Du in enumerate(Dhat) if Du != 0 and u not in POS}
  for u in POS:
    if u in change_map:
      del change_map[u]

  sum_coocur = sum(gamma(d) for d in range(1, lambda_+1))
  min_sequences_required = ceil(sum(change_map.values())/sum_coocur)
  default_seq = [-1]*lambda_ + [s] + [-1]*lambda_
  live = [default_seq[:] for _ in range(min_sequences_required)]
  indices = list(range(lambda_)) + list(range(lambda_+1, 2*lambda_+1))

  cm_keys = list(change_map.keys())
  for u in cm_keys:
    while change_map[u] > 0:
      best_seq_i = best_i = -1
      best = float("inf")
      for seq_i, seq in enumerate(live):
        cm_u = change_map[u]
        for i in indices:
          if seq[i] < 0:
            score = abs(gamma(abs(lambda_-i))-cm_u)
            if score < best:
              best_seq_i, best_i = seq_i, i
              best = score
      seq, i = live[best_seq_i], best_i

      seq[i] = u
      change_map[u] -= gamma(abs(lambda_-i))
      if all(seq[i] > 0 for i in indices):
        Delta.append(seq)
        # remove seq from live
        if best_seq_i < len(live)-1:
          live[best_seq_i] = live.pop()
        else:
          live.pop()
      if not live:
        live = [default_seq[:]]

  for seq in live:
    for i in indices:
      if seq[i] < 0:
        seq[i] = random.choice(cm_keys)
    Delta.append(seq)

  return Delta


s, t = top_pairs[0]
print("s/t:", dictionary[s], dictionary[t])
Delta = place_additions(Dhat, s, [t])

# print(Delta)

# print a few example sequences
for seq in Delta[::len(Delta)//20]:
  print([dictionary[s] for s in seq])
