In [85]:
import pyro 
import torch 
import pandas as pd 
import numpy as np 
import matplotlib.pyplot as plt

In [86]:
import pyro.distributions as dist
from torch.nn import Embedding
import ast
from tqdm import tqdm
from collections import Counter 

In [87]:
dat = pd.read_csv('./data/cleaned_documents.csv')

In [88]:
E = 100 # size of embedding
sent = []
for each in dat.text:
    sent = sent + ast.literal_eval(each)
W = len(np.unique(sent))

In [89]:
print(W)

15673


In [90]:
words= np.unique(sent)
word2idx = {word:idx for idx, word in enumerate(words)}
idx2word = {idx:word for idx, word in enumerate(words)}

In [91]:
uni_freq = Counter(words)

In [92]:
for each in uni_freq:
    uni_freq[each] = uni_freq[each]**0.75
deno = sum(uni_freq.values())
for each in uni_freq:
    uni_freq[each] = uni_freq[each]/deno


### For each document in the corpus, we get x_n^+ and x_n^- by looking at context window and negative sampling. 
Ref : https://arxiv.org/pdf/1711.03946.pdf

In [142]:
# Hard coding context window size to 3.
c = 2
neg_samples = 2*c
def get_pos_neg_sample(dat):
    x_plus = []
    x_neg = []
    uni_keys = list(uni_freq.keys())
    uni_values = list(uni_freq.values())
    for each in dat.text:
        x_plus_n = []
        x_neg_n = []
        sentence = ast.literal_eval(each)
        random_sample = np.random.choice(uni_keys, len(sentence)*neg_samples, p = uni_values)
        r_i = 0
        for idx, token in enumerate(sentence):
            context = sentence[max(idx-c,0):min(idx+c+1, len(sentence))]
            context.remove(token)
            for every in context:
                x_plus_n.append((word2idx[token], word2idx[every]))
                
            for sample in random_sample[r_i:r_i+len(context)]:
                x_neg_n.append((word2idx[token], word2idx[sample]))
            r_i+=len(context)
            
        x_plus.append(x_plus_n)
        x_neg.append(x_neg_n)
    return x_plus, x_neg

In [143]:
x_plus, x_neg = get_pos_neg_sample(dat)

# Stage 1 : Update U,V using sgd


In [None]:
lamb = 0.01
phi = 0.01
annealed_rate = 0.0001
alpha = 0.999
n_iter = 100
loss = []
U = torch.distributions.MultivariateNormal(torch.zeros(W,E), lamb**2 *torch.eye(E)).rsample()
V = torch.distributions.MultivariateNormal(torch.zeros(W,E), lamb**2 *torch.eye(E)).rsample()
U.requires_grad=True
V.requires_grad=True
for n in range(len(x_plus)):
    annealed_rate = annealed_rate * alpha
    d_i = torch.distributions.MultivariateNormal( torch.zeros(E), phi*torch.eye(E)).rsample()
    d_i.requires_grad=True

    X_plus = torch.tensor(x_plus[n])
    X_neg = torch.tensor(x_neg[n])
    #SGD for d_i
    for iterat in range(n_iter):
        # Loss function
        logProb = -torch.sum(torch.log(1+torch.exp(-torch.bmm(U[X_plus[:,0]].view(U[X_plus[:,0]].shape[0], 
                                                                                  1,
                                                                                  E) ,
                                                              (V[X_plus[:,1]]+d_i).view((V[X_plus[:,1]]+d_i).shape[0],
                                                                                        E,
                                                                                        1)))))
        logProb -= torch.sum(torch.log(1+torch.exp(-torch.bmm(-U[X_neg[:,0]].view(U[X_neg[:,0]].shape[0], 
                                                                                  1,
                                                                                  E) ,
                                                              (V[X_neg[:,1]]+d_i).view((V[X_neg[:,1]]+d_i).shape[0],
                                                                                        E,
                                                                                        1)))))


        d_i.retain_grad()
        logProb.backward()
        d_i = d_i-0.0001*d_i.grad
        
    # SGD for U_i, V_j
    logProb = -torch.sum(torch.log(1+torch.exp(-torch.bmm(U[X_plus[:,0]].view(U[X_plus[:,0]].shape[0], 
                                                                              1,
                                                                              E) ,
                                                          (V[X_plus[:,1]]+d_i).view((V[X_plus[:,1]]+d_i).shape[0],
                                                                                    E,
                                                                                    1)))))
    logProb -= torch.sum(torch.log(1+torch.exp(-torch.bmm(-U[X_neg[:,0]].view(U[X_neg[:,0]].shape[0], 
                                                                              1,
                                                                              E) ,
                                                          (V[X_neg[:,1]]+d_i).view((V[X_neg[:,1]]+d_i).shape[0],
                                                                                    E,
                                                                                    1)))))

    loss.append(logProb)
    U.retain_grad()
    V.retain_grad()
    logProb.backward()
    if n%10==0:
        print("Iteration : ", n+1)
        print(logProb, torch.mean(U), torch.mean(annealed_rate * U.grad))
    U = U - annealed_rate * U.grad
    V = V - annealed_rate * V.grad

Iteration :  1
tensor(-1111.8203, grad_fn=<SubBackward0>) tensor(-1.7039e-06, grad_fn=<MeanBackward0>) tensor(-2.2608e-08)
Iteration :  11
tensor(-474.1290, grad_fn=<SubBackward0>) tensor(-1.6813e-06, grad_fn=<MeanBackward0>) tensor(-1.7723e-10)
Iteration :  21
tensor(-1716.2916, grad_fn=<SubBackward0>) tensor(-1.6793e-06, grad_fn=<MeanBackward0>) tensor(1.9188e-10)
Iteration :  31
tensor(-867.8243, grad_fn=<SubBackward0>) tensor(-1.6800e-06, grad_fn=<MeanBackward0>) tensor(-1.1165e-10)
Iteration :  41
tensor(-463.0487, grad_fn=<SubBackward0>) tensor(-1.6802e-06, grad_fn=<MeanBackward0>) tensor(-1.8893e-10)
Iteration :  51
tensor(-629.3931, grad_fn=<SubBackward0>) tensor(-1.6795e-06, grad_fn=<MeanBackward0>) tensor(7.1556e-11)
Iteration :  61
tensor(-707.0316, grad_fn=<SubBackward0>) tensor(-1.6789e-06, grad_fn=<MeanBackward0>) tensor(5.9516e-11)
Iteration :  71
tensor(-529.5614, grad_fn=<SubBackward0>) tensor(-1.6779e-06, grad_fn=<MeanBackward0>) tensor(-9.3129e-11)
Iteration :  81
te

In [373]:
torch.sum(torch.log(1+torch.exp(-torch.bmm(-U[X_neg[:,0]].view(U[X_neg[:,0]].shape[0], 
                                                                              1,
                                                                              E) ,
                                                          (V[X_neg[:,1]]+d_i).view((V[X_neg[:,1]]+d_i).shape[0],
                                                                                    E,
                                                                                    1)))))


tensor(345.2450, grad_fn=<SumBackward0>)

In [374]:
torch.trace(torch.log(1+torch.exp(-torch.mm(-U[X_neg[:,0]], (V[X_neg[:,1]]+d_i).t()))))

tensor(345.2450, grad_fn=<TraceBackward>)

In [None]:
def model(X_plus,X_neg):
    with pyro.plate('component', W):
        lamb = 5
        U = pyro.sample('U', dist.MultivariateNormal(loc = torch.zeros(E), scale = lamb*torch.eye(E)), requires_gradient=True)
        V = pyro.sample('V', dist.MultivariateNormal(loc = torch.zeros(E), scale = lamb*torch.eye(E)), requires_gradient=True)
    phi = 10
    d_i = pyro.sample('d_i', dist.MultivariateNormal(loc = torch.zeros(E), scale = phi*torch.eye(E)),requires_gradient=True)
                                          
    dfx, = grad(logProb,d_i,create_graph=True)
    for i in range(100):
        # Loss function
        logProb = torch.log(torch.sigmoid(torch.dot(U[X_plus[0][0]],V[X_plus[0][1]]+d_i))
        for iteration in range(1,len(X_plus)):
            i,j = X_plus[iteration]
            logProb += torch.log(torch.sigmoid(torch.dot(U[i],V[j]+d_i))
        for i,j in X_neg:
            logProb += torch.log(torch.sigmoid(torch.dot(-U[i],V[j]+d_i))
        logProb.backward()
        with torch.no_grad():
            d_i = d_i-0.001*d_i.grad 
    pred 
    
    with pyro.plate('data', len(data)):
        pass


In [None]:
# logProb = torch.Value(0,requires_gradient=True)


# Stage 2 : Fix U,V and update d (doc vector) using SVI 