TODO: change the data structure and context window distribution so that we can have different numbers of context windows per word.

In [1]:
import edward as ed
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from scipy.stats import multivariate_normal as mvn

  from ._conv import register_converters as _register_converters


In [2]:
D = 3 # dimensionality of the embeddings
S = 2 # number of different senses per word. TODO: make this random for each word
V = 5 # vocabulary size
N = 10 # number of context windows for each word in the vocabulary
C = 3 # context window size

tf.reset_default_graph() 

In [3]:
def build_toy_dataset(N):
  mu0 = np.zeros(D) # prior mean of the sense means
  sigma0 = np.ones(D) # prior covariance of the sense means
  a0 = np.ones(D) # prior shape for the scale of the sense covariance
  b0 = np.ones(D) # prior scale for the scale of the sense covariance   
  alpha0 = np.ones(S) # priors for the sense probabilities
  beta0 = np.ones(V) # priors over the word frequencies
    
  mus_all = []
  Sigmas_all = []
  pis_all = []

  # draw word frequencies
  pword = np.random.dirichlet(beta0)
  
  # draw the word sense distributions
  for w in range(V):
    mus = []
    Sigmas = []
    pis = np.random.dirichlet(alpha0)
    
    # draw the means for each sense
    for s in range(S):
        mus.append(np.random.multivariate_normal(mu0, np.diag(sigma0)))
        Sigmas.append(np.diag(1.0 / np.random.gamma(a0, b0)))
        
    mus_all.append(mus)
    Sigmas_all.append(Sigmas)
    pis_all.append(pis)
    
  s_all = []
  z_all = []
  c_all = []
    
  # draw the context windows for each word
  for w in range(V):
    
    # draw the sense for each context
    print('Word %i. Sense distribution = %s' % (w, str(pis_all[w])))
    s_ws = np.argmax(np.random.multinomial(1, pis_all[w], N), 1)
    z_ws = []
    
    c_ws = []
    
    for n in range(N):
        
        #print('Sense for pair n = %i' % s_ws[n])
        #print('Mean for chosen sense = %s' % str(mus_all[w][s_ws[n]]))
        #print('Cov for chosen sense = %s' % str(Sigmas_all[w][s_ws[n]]))
        
        # draw the embedding for each context
        z_ws.append(np.random.multivariate_normal(mus_all[w][s_ws[n]], Sigmas_all[w][s_ws[n]]))
    
        # construct the categorical distribution over all words
        joint = []
        for w2 in range(V):
            
            pw2 = 0
            for s in range(S):
                pw2 += pis_all[w2][s] * mvn.pdf(z_ws[-1], mus_all[w2][s], Sigmas_all[w2][s])
            pw2 *= pword[w2]
            joint.append(pw2)
        
        pc_giv_z = joint / np.sum(joint)
        
        c = np.argmax(np.random.multinomial(1, pc_giv_z, C), 1)
        c_ws.append(c)
        
    s_all.append(s_ws)
    z_all.append(z_ws)
    c_all.append(c_ws)
        
  c_all = np.array(c_all, dtype=int).swapaxes(0, 2).swapaxes(0, 1) # so we get N x C x V from V x N x C
        
  return c_all, s_all, z_all, mus_all, Sigmas_all, pis_all

c_all, s_all, z_all, mus_all, Sigmas_all, pis_all = build_toy_dataset(N)

cw_train = c_all
print('Data shape: %s' % str(cw_train.shape))
print('Type of cw: %s' % str(cw_train.dtype))

Word 0. Sense distribution = [0.26194053 0.73805947]
Word 1. Sense distribution = [0.44208271 0.55791729]
Word 2. Sense distribution = [0.35466243 0.64533757]
Word 3. Sense distribution = [0.22277787 0.77722213]
Word 4. Sense distribution = [0.10985147 0.89014853]
Data shape: (10, 3, 5)
Type of cw: int64


In [4]:
from edward.models import Dirichlet, InverseGamma, MultivariateNormalDiag, \
    Normal, ParamMixture, Categorical, OneHotCategorical

from edward.models import RandomVariable
from tensorflow.contrib.distributions import Distribution

class distributions_ContextWindow(Distribution):
  def __init__(self, senses, mus, Sigmas, 
               validate_args=False,
               allow_nan_stats=True,
               name="ContextWindow"):
    
    self.senses = tf.identity(senses, name="senses")
    self.mus = tf.identity(mus, name="mus")
    self.Sigmas = tf.sqrt(Sigmas, name="Sigmas")
    self.pword = tf.ones(self.mus.shape[1]) / self.mus.shape.as_list()[1]
        
    super(distributions_ContextWindow, self).__init__(
            dtype=tf.float32,
            reparameterization_type=tf.contrib.distributions.FULLY_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            name=name,
            graph_parents=[self.senses, self.mus, self.Sigmas],
            parameters={'senses':senses, 'mus':mus, 'Sigmas':Sigmas},
    )
        
  def _log_prob(self, value=None):
    # value has shape C x V
    if value is None:
        value = self   
    
    value = tf.to_int32(value)
    
    print('Value in _log_prob: %s' % str(value))    
                
    C = value.shape[1]
    V = value.shape[2]
    
    logpc_giv_w = []
        
    for n in range(value.shape[0]):
    
        print('computed log_p for sample %i' % n)
    
        log_p_sample = []
        
        for w in range(V):
            
            wc = value[n, :, w] # the context words

            cw_sense = tf.gather(self.senses[n, :], wc)
            w_sense = self.senses[n, w]

            idxs = tf.concat([tf.expand_dims(cw_sense, 1), tf.expand_dims(wc, 1)], axis=1)
            mus_cw = tf.gather_nd(self.mus, idxs)
            sigma_cw = tf.gather_nd(self.Sigmas, idxs)               
                
            # here we integrate out zw. Second term penalises uncertainty in zw due to vagueness of w's distribution
            logEpz_giv_c = tf.log(MultivariateNormalDiag(
                mus_cw, 
                sigma_cw
            ).log_prob(self.mus[w_sense, w])) 
            
            varterm = 0.5 * tf.reduce_sum(self.Sigmas[w_sense, w] / tf.diag_part(sigma_cw))
            
            # pword terms drop out because they factor out as constants independent of the other variables.
            
            log_p_joint = logEpz_giv_c - varterm
                         
            # now need to normalise over the sum_c{Elogpz_and_c} to get Elogp_c_giv_z
            # Do this using a sample of the negative words 
            neg_sample_size = 20
            sample = Categorical(probs=self.pword).sample(neg_sample_size)
                        
            cw_sense = tf.gather(self.senses[n, :], sample)
            
            idxs = tf.concat([tf.expand_dims(cw_sense, 1), tf.expand_dims(sample, 1)], axis=1)
            mus_sample = tf.gather_nd(self.mus, idxs)
            sigma_sample = tf.gather_nd(self.Sigmas, idxs)
            
            # here we integrate out zw. Second term penalises uncertainty in zw due to vagueness of w's distribution
            logEpz_neg = MultivariateNormalDiag(
                    mus_sample, 
                    sigma_sample
            ).log_prob(self.mus[w_sense, w])
            
            varterm = 0.5 * tf.reduce_sum(sigma_sample / tf.diag_part(sigma_cw), axis=1)
            
            Elogpz_neg = logEpz_neg - varterm
            
            pword_sample = tf.gather(self.pword, sample)
            
            log_p_joint_neg = logEpz_neg + tf.log(pword_sample)
                        
            # now scale up log_neg_joint to represent all the negative words
            w = 1.0 / ( tf.reduce_sum(pword_sample) )
            denominator = tf.log(w * tf.reduce_sum(tf.exp(log_p_joint_neg)))

            log_p_joint -= denominator
            
            log_p_sample.append(log_p_joint)
            
        log_p_sample = tf.stack(log_p_sample, axis=1)
        logpc_giv_w.append(log_p_sample)
        
    logpc_giv_w = tf.stack(logpc_giv_w, axis=0)
    
    print('Completed log_p for context window:')
    print(logpc_giv_w)
    #print(tf.is_nan(logpc_giv_w).eval())
    
    return logpc_giv_w         
    
  def _sample_n(self, n, seed=None):
   
    c_all = []
    
    context_words = []
    
    pc_giv_z = []
            
    for x in range(n):
        
        pw2 = []
        sample_weights = []
        
        for w in range(V):
        
            sense_nw = self.senses[x, w]

            z_w = MultivariateNormalDiag(
                    self.mus[sense_nw, w, :], 
                    self.Sigmas[sense_nw, w, :]
            )
                                
            neg_sample_size = 20
            sample = Categorical(probs=self.pword).sample(neg_sample_size)
            sense_w2 = tf.gather(self.senses[x, :], sample)
                        
            idxs = tf.concat([tf.expand_dims(sense_w2, 1), tf.expand_dims(sample, 1)], axis=1)           
            mus_w2 = tf.gather_nd(self.mus, idxs)            
            sigmas_w2 = tf.gather_nd(self.Sigmas, idxs)
            
            pword_w2 = tf.gather(self.pword, sample)

            p_z_giv_c = MultivariateNormalDiag(
                    loc=mus_w2, # this assumes that the senses were instantiated at random from pi and were different for each word
                    scale_diag=sigmas_w2
            ).prob(z_w)            
            
            joint = p_z_giv_c * pword_w2                
            pw2.append(joint)
            
            sample_weights.append(1.0 / tf.reduce_sum(pword_w2))

        pw2 = tf.stack(pw2, axis=0)
        
        sample_weights = tf.expand_dims(tf.stack(sample_weights, axis=0), 1)
        
        pc_giv_z = pw2 / (sample_weights * tf.reduce_sum(pw2, keepdims=True, axis=1)) # shape VxV

        c_sample = Categorical(probs=pc_giv_z).sample(C, seed) # V different distributions in one line. Shape = n x C x N x V
        c_all.append(c_sample)
        
    c_all = tf.stack(c_all, axis=0)

    print('c_all: %s' % str(c_all))
    return c_all
   
               
class ContextWindow(RandomVariable, distributions_ContextWindow):
               
  def __init__(self, *args, **kwargs):
    RandomVariable.__init__(self, *args, **kwargs)
    self.conjugate_log_prob = self._log_prob
    
pi = Dirichlet(tf.ones(S), sample_shape=V, validate_args=True) # sense distributions for each word. Needs replacing with CRP
mu = Normal(tf.zeros(D), tf.ones(D), sample_shape=(S, V), validate_args=True)
sigmasq = InverseGamma(tf.ones(D), tf.ones(D), sample_shape=(S, V), validate_args=True)

senses = Categorical(probs=pi, sample_shape=(N), validate_args=True)

print(senses.shape)

cw = ContextWindow(senses, mu, sigmasq, sample_shape=(N), validate_args=True) # result should be N x C x V

(10, 5)
c_all: Tensor("ContextWindow/sample/stack_20:0", shape=(10, 3, 5), dtype=int32)


In [5]:
sigmasq.shape

TensorShape([Dimension(2), Dimension(5), Dimension(3)])

In [6]:
# approximate distributions
q_mu = Normal(
    loc=tf.Variable(tf.zeros([S, V, D])),
    scale=tf.nn.softplus(tf.Variable(tf.ones([S, V, D]))), validate_args=True
)

q_sigmasq = InverseGamma(
    concentration=tf.nn.softplus(tf.Variable(tf.ones([S, V, D])) ),
    rate=tf.nn.softplus(tf.Variable(tf.ones([S, V, D])) ), validate_args=True
)

q_pi = Dirichlet(
    concentration=tf.nn.softplus(tf.Variable(tf.ones([V, S])) ), validate_args=True
)

q_senses = Categorical(
    probs=tf.nn.softmax(tf.Variable(tf.ones([N, V, S])) ), validate_args=True
)

In [7]:
print(mu.shape)
print(q_mu.shape)

print(sigmasq.shape)
print(q_sigmasq.shape)

print(pi.shape)
print(q_pi.shape)

print(cw.shape)
print(cw_train.shape)

print(senses.shape)
print(q_senses.shape)

latent_vars = {
    senses: q_senses,
    mu: q_mu,
    sigmasq: q_sigmasq,
    pi: q_pi, 
}

data = {
    cw: cw_train,
    #senses: np.array(s_all).T
}

(2, 5, 3)
(2, 5, 3)
(2, 5, 3)
(2, 5, 3)
(5, 2)
(5, 2)
(10, 3, 5)
(10, 3, 5)
(10, 5)
(10, 5)


In [8]:
#from tensorflow.python import debug as tf_debug

#sess = tf_debug.LocalCLIDebugWrapperSession(ed.util.get_session(), ui_type="curses")
#sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)

#with sess.as_default():

inference = ed.KLqp(latent_vars, data)

n_iter = 100

#inference.run(n_iter=n_iter)

inference.initialize()
tf.global_variables_initializer().run()

for i in range(n_iter):
    print('\n')
    print(q_pi.mean().eval())
    print('\n')
    print(q_mu.mean().eval())
    print('\n')
    print(q_sigmasq.mean().eval())
    print('\n')
    print(q_senses.mode().eval())

    print('VB iteration %i' % i)
    info_dict = inference.update()
    inference.print_progress(info_dict)
    
inference.finalize()


  not np.issubdtype(value.dtype, np.float) and \
  not np.issubdtype(value.dtype, np.int) and \


c_all: Tensor("inference/sample/ContextWindow/sample_1/stack_20:0", shape=(10, 3, 5), dtype=int32)
Value in _log_prob: Tensor("inference/sample/ContextWindow/log_prob/ToInt32:0", shape=(10, 3, 5), dtype=int32)
computed log_p for sample 0
computed log_p for sample 1
computed log_p for sample 2
computed log_p for sample 3
computed log_p for sample 4
computed log_p for sample 5
computed log_p for sample 6
computed log_p for sample 7
computed log_p for sample 8
computed log_p for sample 9
Completed log_p for context window:
Tensor("inference/sample/ContextWindow/log_prob/stack_10:0", shape=(10, 3, 5), dtype=float32)


[[0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]]


[[[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]

 [[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]]


[[[4.1922197 4.1922197 4.1922197]
  [4.1922197 4.1922197 4.1922197]
  [4.1922197 4.1922197 4.1922197]
  [4.1922197 4.1922197 4.1922197]
  [4.1922197 4.1922197 4.1922197]]

 [[4.1922197

InvalidArgumentError: assertion failed: [] [Condition x > 0 did not hold element-wise:] [x (Softplus:0) = ] [[[nan nan nan]]...]
	 [[Node: Normal_1/assert_positive/assert_less/Assert/AssertGuard/Assert = Assert[T=[DT_STRING, DT_STRING, DT_STRING, DT_FLOAT], summarize=3, _device="/job:localhost/replica:0/task:0/device:CPU:0"](Normal_1/assert_positive/assert_less/Assert/AssertGuard/Assert/Switch, Normal_1/assert_positive/assert_less/Assert/AssertGuard/Assert/data_0, Normal_1/assert_positive/assert_less/Assert/AssertGuard/Assert/data_1, Normal_1/assert_positive/assert_less/Assert/AssertGuard/Assert/data_2, Normal_1/assert_positive/assert_less/Assert/AssertGuard/Assert/Switch_1)]]

Caused by op 'Normal_1/assert_positive/assert_less/Assert/AssertGuard/Assert', defined at:
  File "/Users/edwin/anaconda3/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/Users/edwin/anaconda3/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/ipykernel/kernelapp.py", line 478, in start
    self.io_loop.start()
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/zmq/eventloop/ioloop.py", line 177, in start
    super(ZMQIOLoop, self).start()
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/tornado/ioloop.py", line 888, in start
    handler_func(fd_obj, events)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/tornado/stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 440, in _handle_events
    self._handle_recv()
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 472, in _handle_recv
    self._run_callback(callback, msg)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 414, in _run_callback
    callback(*args, **kwargs)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/tornado/stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 283, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 233, in dispatch_shell
    handler(stream, idents, msg)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 399, in execute_request
    user_expressions, allow_stdin)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/ipykernel/ipkernel.py", line 208, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/ipykernel/zmqshell.py", line 537, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2728, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2850, in run_ast_nodes
    if self.run_code(code, result):
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2910, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-6-31b12f73c286>", line 4, in <module>
    scale=tf.nn.softplus(tf.Variable(tf.ones([S, V, D]))), validate_args=True
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/edward/models/random_variables.py", line 21, in __init__
    _RandomVariable.__init__(self, *args, **kwargs)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/edward/models/random_variable.py", line 112, in __init__
    super(RandomVariable, self).__init__(*args, **kwargs)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/distributions/normal.py", line 137, in __init__
    validate_args else []):
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/check_ops.py", line 221, in assert_positive
    return assert_less(zero, x, data=data, summarize=summarize)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/check_ops.py", line 579, in assert_less
    return control_flow_ops.Assert(condition, data, summarize=summarize)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/tensorflow/python/util/tf_should_use.py", line 118, in wrapped
    return _add_should_use_warning(fn(*args, **kwargs))
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 177, in Assert
    guarded_assert = cond(condition, no_op, true_assert, name="AssertGuard")
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 432, in new_func
    return func(*args, **kwargs)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2027, in cond
    orig_res_f, res_f = context_f.BuildCondBranch(false_fn)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 1868, in BuildCondBranch
    original_result = fn()
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 175, in true_assert
    condition, data, summarize, name="Assert")
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/gen_logging_ops.py", line 48, in _assert
    name=name)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3271, in create_op
    op_def=op_def)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1650, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): assertion failed: [] [Condition x > 0 did not hold element-wise:] [x (Softplus:0) = ] [[[nan nan nan]]...]
	 [[Node: Normal_1/assert_positive/assert_less/Assert/AssertGuard/Assert = Assert[T=[DT_STRING, DT_STRING, DT_STRING, DT_FLOAT], summarize=3, _device="/job:localhost/replica:0/task:0/device:CPU:0"](Normal_1/assert_positive/assert_less/Assert/AssertGuard/Assert/Switch, Normal_1/assert_positive/assert_less/Assert/AssertGuard/Assert/data_0, Normal_1/assert_positive/assert_less/Assert/AssertGuard/Assert/data_1, Normal_1/assert_positive/assert_less/Assert/AssertGuard/Assert/data_2, Normal_1/assert_positive/assert_less/Assert/AssertGuard/Assert/Switch_1)]]


In [None]:
# print out the sense labels inferred for all the central word occurrences

Esenses = q_senses.eval()

print(Esenses)
print(s_all)

print(Esenses[0, 0])

print(s_all[0][0])

for w in range(V):
    for n in range(N):    
        print('word %i, sample %i, -- probability of senses = %s, true sense is %i' % (n, w, str(Esenses[n, w]), s_all[w][n]) )
            

### What makes our method novel?

A fully Bayesian approach to learning word embeddings with multiple, potentially infinite numbers of distinct senses per token.

The Bayesian treatment is intended to help with:
* Rare words in the training corpus, whose embeddings cannot be confidently estimated -- variance means we don't put too much weight onto these uncertain cases during learning
* Inferring the number of senses -- priors effectively regularise the model toward fewer senses
* Domain adaptation/Transfer learning -- we can inflate variances to indicate uncertainty in new domains
* (As in Barkan, Brazinskas et al) context-specific embeddings for each word instance
* (As in Barkan, Brazinskas et al) composition of sentence or document embeddings -- word occurrences with more confident or precise embeddings will have stronger influence in the combined sentence embedding. I think this will push the sentence embeddings away from generic words and toward the extremes.

### How do we test our model?

* Look at the tasks tried by Brazinskas et al, Barkan, and the ACL 2018 paper and ty to reuse their code where possible
* Compute the context-specific embeddings for each word (posterior means)
* Test what happens if we concatenate the variances to the embedding vector as a vagueness or uncertainty feature

In [None]:
q_senses.mode().eval()