TODO: change the data structure and context window distribution so that we can have different numbers of context windows per word.

In [1]:
import edward as ed
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from scipy.stats import multivariate_normal as mvn

  from ._conv import register_converters as _register_converters


In [2]:
D = 3 # dimensionality of the embeddings
S = 2 # number of different senses per word. TODO: make this random for each word
V = 5 # vocabulary size
N = 10 # number of context windows for each word in the vocabulary
C = 3 # context window size

tf.reset_default_graph() 

In [3]:
def build_toy_dataset(N):
  mu0 = np.zeros(D) # prior mean of the sense means
  sigma0 = np.ones(D) # prior covariance of the sense means
  a0 = np.ones(D) # prior shape for the scale of the sense covariance
  b0 = np.ones(D) # prior scale for the scale of the sense covariance   
  alpha0 = np.ones(S) # priors for the sense probabilities
  beta0 = np.ones(V) # priors over the word frequencies
    
  mus_all = []
  Sigmas_all = []
  pis_all = []

  # draw word frequencies
  pword = np.random.dirichlet(beta0)
  
  # draw the word sense distributions
  for w in range(V):
    mus = []
    Sigmas = []
    pis = np.random.dirichlet(alpha0)
    
    # draw the means for each sense
    for s in range(S):
        mus.append(np.random.multivariate_normal(mu0, np.diag(sigma0)))
        Sigmas.append(np.diag(1.0 / np.random.gamma(a0, b0)))
        
    mus_all.append(mus)
    Sigmas_all.append(Sigmas)
    pis_all.append(pis)
    
  s_all = []
  z_all = []
  c_all = []
    
  # draw the context windows for each word
  for w in range(V):
    
    # draw the sense for each context
    print('Word %i. Sense distribution = %s' % (w, str(pis_all[w])))
    s_ws = np.argmax(np.random.multinomial(1, pis_all[w], N), 1)
    z_ws = []
    
    c_ws = []
    
    for n in range(N):
        
        #print('Sense for pair n = %i' % s_ws[n])
        #print('Mean for chosen sense = %s' % str(mus_all[w][s_ws[n]]))
        #print('Cov for chosen sense = %s' % str(Sigmas_all[w][s_ws[n]]))
        
        # draw the embedding for each context
        z_ws.append(np.random.multivariate_normal(mus_all[w][s_ws[n]], Sigmas_all[w][s_ws[n]]))
    
        # construct the categorical distribution over all words
        joint = []
        for w2 in range(V):
            
            pw2 = 0
            for s in range(S):
                pw2 += pis_all[w2][s] * mvn.pdf(z_ws[-1], mus_all[w2][s], Sigmas_all[w2][s])
            pw2 *= pword[w2]
            joint.append(pw2)
        
        pc_giv_z = joint / np.sum(joint)
        
        c = np.argmax(np.random.multinomial(1, pc_giv_z, C), 1)
        c_ws.append(c)
        
    s_all.append(s_ws)
    z_all.append(z_ws)
    c_all.append(c_ws)
        
  c_all = np.array(c_all, dtype=int).swapaxes(0, 2).swapaxes(0, 1) # so we get N x C x V from V x N x C
        
  return c_all, s_all, z_all, mus_all, Sigmas_all, pis_all

c_all, s_all, z_all, mus_all, Sigmas_all, pis_all = build_toy_dataset(N)

cw_train = c_all
print('Data shape: %s' % str(cw_train.shape))
print('Type of cw: %s' % str(cw_train.dtype))

Word 0. Sense distribution = [0.32518312 0.67481688]
Word 1. Sense distribution = [0.65653091 0.34346909]
Word 2. Sense distribution = [0.44007171 0.55992829]
Word 3. Sense distribution = [0.86733891 0.13266109]
Word 4. Sense distribution = [0.89365981 0.10634019]
Data shape: (10, 3, 5)
Type of cw: int64


In [4]:
from edward.models import Dirichlet, InverseGamma, MultivariateNormalDiag, \
    Normal, ParamMixture, Categorical, OneHotCategorical

from edward.models import RandomVariable
from tensorflow.contrib.distributions import Distribution

class distributions_ContextWindow(Distribution):
  def __init__(self, senses, mus, Sigmas, pword, 
               validate_args=False,
               allow_nan_stats=True,
               name="ContextWindow"):
    
    self.senses = tf.identity(senses, name="senses")
    self.mus = tf.identity(mus, name="mus")
    self.Sigmas = tf.identity(Sigmas, name="Sigmas")
    self.pword = tf.identity(pword, name="pword")
        
    super(distributions_ContextWindow, self).__init__(
            dtype=tf.float32,
            reparameterization_type=tf.contrib.distributions.FULLY_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            name=name,
            graph_parents=[self.senses, self.mus, self.Sigmas, self.pword],
            parameters={'senses':senses, 'mus':mus, 'Sigmas':Sigmas, 'pword':pword},
    )
        
  def _log_prob(self, value=None):
    # value has shape C x V
    if value is None:
        value = self   
    
    value = tf.to_int32(value)
    
    print('Value in _log_prob: %s' % str(value))    
                
    C = value.shape[1]
    V = value.shape[2]
    
    logpc_giv_w = []
        
    for n in range(value.shape[0]):
    
        log_p_sample = []
        
        for w in range(V):
            
            wc = value[n, :, w] # the context words

            cw_sense = tf.gather(self.senses[n, :], wc)
            w_sense = self.senses[n, w]

            idxs = tf.concat([tf.expand_dims(cw_sense, 1), tf.expand_dims(wc, 1)], axis=1)
            mus_cw = tf.gather_nd(self.mus, idxs)
            sigma_cw = tf.gather_nd(self.Sigmas, idxs)               
                
            # here we integrate out zw. Second term penalises uncertainty in zw due to vagueness of w's distribution
            Elogpz_giv_c = tf.log(MultivariateNormalDiag(
                mus_cw, 
                sigma_cw
            ).log_prob(self.mus[w_sense, w])) 
            
            varterm = 0.5 * self.Sigmas[w_sense, w] / tf.diag_part(sigma_cw)
            
            #print('before varterm')
            #print(Elogpz_giv_c)
            
            #print(varterm)
            
            Elogpz_giv_c -= varterm
            
            pword_cw = tf.gather(self.pword, wc)
            
            #print('Elogpz_giv_c')
            #print(Elogpz_giv_c)
            #print('pword_cw')
            #print(pword_cw)
            
            log_p_joint = Elogpz_giv_c + tf.log(pword_cw)
                
            # now need to normalise over the sum_c{Elogpz_and_c} to get Elogp_c_giv_z
            # Do this using a sample of the negative words 
            neg_sample_size = 20
            sample = Categorical(probs=pword).sample(neg_sample_size)
                        
            cw_sense = tf.gather(self.senses[n, :], sample)
            
            idxs = tf.concat([tf.expand_dims(cw_sense, 1), tf.expand_dims(sample, 1)], axis=1)
            mus_sample = tf.gather_nd(self.mus, idxs)
            sigma_sample = tf.gather_nd(self.Sigmas, idxs)
            
            # here we integrate out zw. Second term penalises uncertainty in zw due to vagueness of w's distribution
            Epz_giv_c = MultivariateNormalDiag(
                    mus_sample, 
                    sigma_sample
            ).prob(self.mus[w_sense, w])
            
            pword_sample = tf.gather(self.pword, sample)
            
            Epz_and_c = Epz_giv_c * pword_sample
                        
            # now scale up log_neg_joint to represent all the negative words
            p_poswords = tf.reduce_sum(pword_cw)
            w = (1 - p_poswords) / tf.reduce_sum(pword_sample)
            denominator = tf.log(tf.reduce_sum(w * Epz_and_c) + tf.reduce_sum(tf.exp(log_p_joint)))
            #print('log p joint')
            #print(log_p_joint)
            
            #print('denominator')
            #print(denominator)
            log_p_joint -= denominator
            
            #print('conditional')
            #print(log_p_joint)
            
            log_p_sample.append(log_p_joint)
            
        log_p_sample = tf.stack(log_p_sample, axis=1)
        logpc_giv_w.append(log_p_sample)
        
    logpc_giv_w = tf.stack(logpc_giv_w, axis=0)
    
    print('Completed log_p for context window:')
    print(logpc_giv_w)
    
    return logpc_giv_w         
    
  def _sample_n(self, n, seed=None):
   
    c_all = []
    
    context_words = []
    
    pc_giv_z = []
            
    for x in range(n):
        
        pw2 = []
        sample_weights = []
        
        for w in range(V):
        
            sense_nw = self.senses[x, w]

            z_w = MultivariateNormalDiag(
                    self.mus[sense_nw, w, :], 
                    self.Sigmas[sense_nw, w, :]
            )
                                
            neg_sample_size = 20
            sample = Categorical(probs=self.pword).sample(neg_sample_size)
            sense_w2 = tf.gather(self.senses[x, :], sample)
                        
            idxs = tf.concat([tf.expand_dims(sense_w2, 1), tf.expand_dims(sample, 1)], axis=1)           
            mus_w2 = tf.gather_nd(self.mus, idxs)            
            sigmas_w2 = tf.gather_nd(self.Sigmas, idxs)
            
            pword_w2 = tf.gather(self.pword, sample)

            p_z_giv_c = MultivariateNormalDiag(
                    loc=mus_w2, # this assumes that the senses were instantiated at random from pi and were different for each word
                    scale_diag=sigmas_w2
            ).prob(z_w)            
            
            joint = p_z_giv_c * pword_w2                
            pw2.append(joint)
            
            sample_weights.append(1.0 / tf.reduce_sum(pword_w2))

        pw2 = tf.stack(pw2, axis=0)
        
        sample_weights = tf.expand_dims(tf.stack(sample_weights, axis=0), 1)
        
        pc_giv_z = pw2 / (sample_weights * tf.reduce_sum(pw2, keepdims=True, axis=1)) # shape VxV

        c_sample = Categorical(probs=pc_giv_z).sample(C, seed) # V different distributions in one line. Shape = n x C x N x V
        c_all.append(c_sample)
        
    c_all = tf.stack(c_all, axis=0)

    print('c_all: %s' % str(c_all))
    return c_all
   
               
class ContextWindow(RandomVariable, distributions_ContextWindow):
               
  def __init__(self, *args, **kwargs):
    RandomVariable.__init__(self, *args, **kwargs)
    self.conjugate_log_prob = self._log_prob
    
pword = Dirichlet(tf.ones(V))

pi = Dirichlet(tf.ones(S), sample_shape=V) # sense distributions for each word. Needs replacing with CRP
mu = Normal(tf.zeros(D), tf.ones(D), sample_shape=(S, V))
sigmasq = InverseGamma(tf.ones(D), tf.ones(D), sample_shape=(S, V))

senses = Categorical(probs=pi, sample_shape=(N))

print(senses.shape)

cw = ContextWindow(senses, mu, sigmasq, pword, sample_shape=(N)) # result should be N x C x V

(10, 5)
c_all: Tensor("ContextWindow/sample/stack_20:0", shape=(10, 3, 5), dtype=int32)


In [5]:
# approximate distributions
q_mu = Normal(
    loc=tf.Variable(tf.zeros([S, V, D])),
    scale=tf.nn.softplus(tf.Variable(tf.zeros([S, V, D])))
)

q_sigmasq = InverseGamma(
    concentration=tf.nn.softplus(tf.Variable(tf.zeros([S, V, D])) ),
    rate=tf.nn.softplus(tf.Variable(tf.zeros([S, V, D])) )
)

q_pi = Dirichlet(
    concentration=tf.nn.softplus( tf.Variable(tf.zeros([V, S])) )
)

q_pword = Dirichlet(
    concentration=tf.nn.softplus( tf.Variable(tf.zeros([V])) )
)

q_senses = Categorical(
    probs=tf.nn.softmax(tf.Variable(tf.zeros([N, V, S])))
)

In [6]:
print(pword.shape)
print(q_pword.shape)

print(mu.shape)
print(q_mu.shape)

print(sigmasq.shape)
print(q_sigmasq.shape)

print(pi.shape)
print(q_pi.shape)

print(cw.shape)
print(cw_train.shape)

print(senses.shape)
print(q_senses.shape)

latent_vars = {
    senses: q_senses,
    mu: q_mu,
    sigmasq: q_sigmasq,
    pi: q_pi, 
    pword: q_pword
}

data = {
    cw: cw_train,
}

(5,)
(5,)
(2, 5, 3)
(2, 5, 3)
(2, 5, 3)
(2, 5, 3)
(5, 2)
(5, 2)
(10, 3, 5)
(10, 3, 5)
(10, 5)
(10, 5)


In [8]:
inference = ed.KLpq(latent_vars, data)

n_iter = 100

#inference.run(n_iter=n_iter)

inference.initialize()
tf.global_variables_initializer().run()

for i in range(n_iter):
    print('VB iteration %i' % i)
    info_dict = inference.update()
    inference.print_progress(info_dict)
    print(q_pi.mean().eval())
    
inference.finalize()


  not np.issubdtype(value.dtype, np.float) and \
  not np.issubdtype(value.dtype, np.int) and \


c_all: Tensor("inference_1/sample_1/ContextWindow/sample_1/stack_20:0", shape=(10, 3, 5), dtype=int32)
Value in _log_prob: Tensor("inference_1/sample_1/ContextWindow/log_prob/ToInt32:0", shape=(10, 3, 5), dtype=int32)
before varterm
Tensor("inference_1/sample_1/ContextWindow/log_prob/Log:0", shape=(3,), dtype=float32)
Tensor("inference_1/sample_1/ContextWindow/log_prob/truediv:0", shape=(3,), dtype=float32)
Elogpz_giv_c
Tensor("inference_1/sample_1/ContextWindow/log_prob/sub:0", shape=(3,), dtype=float32)
pword_cw
Tensor("inference_1/sample_1/ContextWindow/log_prob/Gather_1:0", shape=(3,), dtype=float32)
log p joint
Tensor("inference_1/sample_1/ContextWindow/log_prob/add_2:0", shape=(3,), dtype=float32)
denominator
Tensor("inference_1/sample_1/ContextWindow/log_prob/Log_2:0", shape=(), dtype=float32)
conditional
Tensor("inference_1/sample_1/ContextWindow/log_prob/sub_2:0", shape=(3,), dtype=float32)
before varterm
Tensor("inference_1/sample_1/ContextWindow/log_prob/Log_3:0", shape=(3,)

log p joint
Tensor("inference_1/sample_1/ContextWindow/log_prob/add_57:0", shape=(3,), dtype=float32)
denominator
Tensor("inference_1/sample_1/ContextWindow/log_prob/Log_35:0", shape=(), dtype=float32)
conditional
Tensor("inference_1/sample_1/ContextWindow/log_prob/sub_35:0", shape=(3,), dtype=float32)
before varterm
Tensor("inference_1/sample_1/ContextWindow/log_prob/Log_36:0", shape=(3,), dtype=float32)
Tensor("inference_1/sample_1/ContextWindow/log_prob/truediv_24:0", shape=(3,), dtype=float32)
Elogpz_giv_c
Tensor("inference_1/sample_1/ContextWindow/log_prob/sub_36:0", shape=(3,), dtype=float32)
pword_cw
Tensor("inference_1/sample_1/ContextWindow/log_prob/Gather_49:0", shape=(3,), dtype=float32)
log p joint
Tensor("inference_1/sample_1/ContextWindow/log_prob/add_62:0", shape=(3,), dtype=float32)
denominator
Tensor("inference_1/sample_1/ContextWindow/log_prob/Log_38:0", shape=(), dtype=float32)
conditional
Tensor("inference_1/sample_1/ContextWindow/log_prob/sub_38:0", shape=(3,), dty

log p joint
Tensor("inference_1/sample_1/ContextWindow/log_prob/add_117:0", shape=(3,), dtype=float32)
denominator
Tensor("inference_1/sample_1/ContextWindow/log_prob/Log_71:0", shape=(), dtype=float32)
conditional
Tensor("inference_1/sample_1/ContextWindow/log_prob/sub_71:0", shape=(3,), dtype=float32)
before varterm
Tensor("inference_1/sample_1/ContextWindow/log_prob/Log_72:0", shape=(3,), dtype=float32)
Tensor("inference_1/sample_1/ContextWindow/log_prob/truediv_48:0", shape=(3,), dtype=float32)
Elogpz_giv_c
Tensor("inference_1/sample_1/ContextWindow/log_prob/sub_72:0", shape=(3,), dtype=float32)
pword_cw
Tensor("inference_1/sample_1/ContextWindow/log_prob/Gather_97:0", shape=(3,), dtype=float32)
log p joint
Tensor("inference_1/sample_1/ContextWindow/log_prob/add_122:0", shape=(3,), dtype=float32)
denominator
Tensor("inference_1/sample_1/ContextWindow/log_prob/Log_74:0", shape=(), dtype=float32)
conditional
Tensor("inference_1/sample_1/ContextWindow/log_prob/sub_74:0", shape=(3,), d

log p joint
Tensor("inference_1/sample_1/ContextWindow/log_prob/add_177:0", shape=(3,), dtype=float32)
denominator
Tensor("inference_1/sample_1/ContextWindow/log_prob/Log_107:0", shape=(), dtype=float32)
conditional
Tensor("inference_1/sample_1/ContextWindow/log_prob/sub_107:0", shape=(3,), dtype=float32)
before varterm
Tensor("inference_1/sample_1/ContextWindow/log_prob/Log_108:0", shape=(3,), dtype=float32)
Tensor("inference_1/sample_1/ContextWindow/log_prob/truediv_72:0", shape=(3,), dtype=float32)
Elogpz_giv_c
Tensor("inference_1/sample_1/ContextWindow/log_prob/sub_108:0", shape=(3,), dtype=float32)
pword_cw
Tensor("inference_1/sample_1/ContextWindow/log_prob/Gather_145:0", shape=(3,), dtype=float32)
log p joint
Tensor("inference_1/sample_1/ContextWindow/log_prob/add_182:0", shape=(3,), dtype=float32)
denominator
Tensor("inference_1/sample_1/ContextWindow/log_prob/Log_110:0", shape=(), dtype=float32)
conditional
Tensor("inference_1/sample_1/ContextWindow/log_prob/sub_110:0", shape=

log p joint
Tensor("inference_1/sample_1/ContextWindow/log_prob/add_247:0", shape=(3,), dtype=float32)
denominator
Tensor("inference_1/sample_1/ContextWindow/log_prob/Log_149:0", shape=(), dtype=float32)
conditional
Tensor("inference_1/sample_1/ContextWindow/log_prob/sub_149:0", shape=(3,), dtype=float32)
Completed log_p for context window:
Tensor("inference_1/sample_1/ContextWindow/log_prob/stack_10:0", shape=(10, 3, 5), dtype=float32)
VB iteration 0
   1/1000 [  0%]                                ETA: 89681s | Loss: nan[[nan nan]
 [nan nan]
 [nan nan]
 [nan nan]
 [nan nan]]
VB iteration 1


InvalidArgumentError: Received a label value of 2 which is outside the valid range of [0, 2).  Label values: 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
	 [[Node: inference_1/sample_1/Categorical_1/log_prob/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits = SparseSoftmaxCrossEntropyWithLogits[T=DT_FLOAT, Tlabels=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](inference_1/sample_1/Categorical_1/log_prob/SparseSoftmaxCrossEntropyWithLogits/Reshape, inference_1/sample_1/Categorical_1/log_prob/SparseSoftmaxCrossEntropyWithLogits/Reshape_1)]]

Caused by op 'inference_1/sample_1/Categorical_1/log_prob/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits', defined at:
  File "/Users/edwin/anaconda3/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/Users/edwin/anaconda3/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/ipykernel/kernelapp.py", line 478, in start
    self.io_loop.start()
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/zmq/eventloop/ioloop.py", line 177, in start
    super(ZMQIOLoop, self).start()
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/tornado/ioloop.py", line 888, in start
    handler_func(fd_obj, events)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/tornado/stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 440, in _handle_events
    self._handle_recv()
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 472, in _handle_recv
    self._run_callback(callback, msg)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 414, in _run_callback
    callback(*args, **kwargs)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/tornado/stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 283, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 233, in dispatch_shell
    handler(stream, idents, msg)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 399, in execute_request
    user_expressions, allow_stdin)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/ipykernel/ipkernel.py", line 208, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/ipykernel/zmqshell.py", line 537, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2728, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2850, in run_ast_nodes
    if self.run_code(code, result):
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2910, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-8-ad99bb909555>", line 7, in <module>
    inference.initialize()
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/edward/inferences/klpq.py", line 95, in initialize
    return super(KLpq, self).initialize(*args, **kwargs)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/edward/inferences/variational_inference.py", line 68, in initialize
    self.loss, grads_and_vars = self.build_loss_and_gradients(var_list)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/edward/inferences/klpq.py", line 144, in build_loss_and_gradients
    qz_copy.log_prob(tf.stop_gradient(dict_swap[z])))
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/distributions/distribution.py", line 716, in log_prob
    return self._call_log_prob(value, name)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/distributions/distribution.py", line 698, in _call_log_prob
    return self._log_prob(value, **kwargs)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/distributions/categorical.py", line 307, in _log_prob
    logits=logits)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/nn_ops.py", line 2055, in sparse_softmax_cross_entropy_with_logits
    precise_logits, labels, name=name)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/gen_nn_ops.py", line 4753, in _sparse_softmax_cross_entropy_with_logits
    labels=labels, name=name)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3271, in create_op
    op_def=op_def)
  File "/Users/edwin/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1650, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): Received a label value of 2 which is outside the valid range of [0, 2).  Label values: 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
	 [[Node: inference_1/sample_1/Categorical_1/log_prob/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits = SparseSoftmaxCrossEntropyWithLogits[T=DT_FLOAT, Tlabels=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](inference_1/sample_1/Categorical_1/log_prob/SparseSoftmaxCrossEntropyWithLogits/Reshape, inference_1/sample_1/Categorical_1/log_prob/SparseSoftmaxCrossEntropyWithLogits/Reshape_1)]]


In [None]:
# print out the sense labels inferred for all the central word occurrences

Esenses = q_senses.eval()

print(Esenses)
print(s_all)

print(Esenses[0, 0])

print(s_all[0][0])

for w in range(V):
    for n in range(N):    
        print('word %i, sample %i, -- probability of senses = %s, true sense is %i' % (n, w, str(Esenses[n, w]), s_all[w][n]) )
            

### What makes our method novel?

A fully Bayesian approach to learning word embeddings with multiple, potentially infinite numbers of distinct senses per token.

The Bayesian treatment is intended to help with:
* Rare words in the training corpus, whose embeddings cannot be confidently estimated -- variance means we don't put too much weight onto these uncertain cases during learning
* Inferring the number of senses -- priors effectively regularise the model toward fewer senses
* Domain adaptation/Transfer learning -- we can inflate variances to indicate uncertainty in new domains
* (As in Barkan, Brazinskas et al) context-specific embeddings for each word instance
* (As in Barkan, Brazinskas et al) composition of sentence or document embeddings -- word occurrences with more confident or precise embeddings will have stronger influence in the combined sentence embedding. I think this will push the sentence embeddings away from generic words and toward the extremes.

### How do we test our model?

* Look at the tasks tried by Brazinskas et al, Barkan, and the ACL 2018 paper and ty to reuse their code where possible
* Compute the context-specific embeddings for each word (posterior means)
* Test what happens if we concatenate the variances to the embedding vector as a vagueness or uncertainty feature