In [2]:
import torch 
import torch.nn as nn 
from glove import *
import torch.optim as optim
import torch.nn.functional as F
from sklearn.manifold import TSNE
import numpy as np
from collections import Counter, defaultdict
import matplotlib.pyplot as plt


In [3]:
class GloveModel(nn.Module):
    def __init__(self, num_embedding, embedding_dim):
        super(GloveModel, self).__init__()
        self.wi = nn.Embedding(num_embedding, embedding_dim)
        self.wj = nn.Embedding(num_embedding, embedding_dim)
        self.bi = nn.Embedding(num_embedding, 1)
        self.bj = nn.Embedding(num_embedding, 1)

    def forward(self, i_indices, j_indices):
        w_i = self.wi(i_indices)
        w_j = self.wj(j_indices)
        b_i = self.bi(i_indices).squeeze()
        b_j = self.bj(j_indices).squeeze()
        x = torch.sum(w_i * w_j) + b_i + b_j
        return x

def weight_func(x, x_max, alpha):
    wx = (x/x_max)**alpha
    wx = torch.min(wx, torch.ones_like(wx))
    return wx

#     return wx.cuda()

def wmse_loss(weights, inputs, targets):
    loss = weights * F.mse_loss(inputs, targets, reduction='none')
    return torch.mean(loss)
#     return torch.mean(loss).cuda()

In [6]:

class GloveDataset:
    def __init__(self, text, n_words=200000, window_size=5):
        self._window_size = window_size
        self._tokens = text.split(' ')[:n_words]
        word_counter = Counter()
        word_counter.update(self._tokens)
        self._word2id = {w:i for i, (w,_) in enumerate(word_counter.most_common())}
        self._id2word = {i:w for w, i in self._word2id.items()}
        self._vocab_len = len(self._word2id)
        self._id_tokens = [self._word2id[w] for w in self._tokens]
        self._create_coocurrence_matrix()

        print('# of words: {}'.format(len(self._tokens)))
        print('Vocabulary length: {}'.format(self._vocab_len))

    def _create_coocurrence_matrix(self):
        cooc_mat = defaultdict(Counter)

        for i,w in enumerate(self._id_tokens):
            start_i = max(i - self._window_size, 0)
            end_i = min(i + self._window_size +1, len(self._id_tokens))

            for j in range(start_i, end_i):
                if i != j:
                    c = self._id_tokens[j]
                    cooc_mat[w][c] += 1 / abs(j-i)

        self._i_idx = list()
        self._j_idx = list()
        self._xij = list()

        for w, cnt in cooc_mat.items():
            for c, v in cnt.items():
                self._i_idx.append(w)
                self._j_idx.append(c)
                self._xij.append(v)
        self._i_idx = torch.LongTensor(self._i_idx)
        self._j_idx = torch.LongTensor(self._j_idx)
        self._xij = torch.LongTensor(self._xij)

#         self._i_idx = torch.LongTensor(self._i_idx).cuda()
#         self._j_idx = torch.LongTensor(self._j_idx).cuda()
#         self._xij = torch.LongTensor(self._xij).cuda()


    def get_batches(self, batch_size):
        rand_ids= torch.LongTensor(np.random.choice(len(self._xij), len(self._xij), replace=False))
        for p in range(0, len(rand_ids), batch_size):
            batch_ids = rand_ids[p:p+batch_size]
            yield self._xij[batch_ids], self._i_idx[batch_ids], self._j_idx[batch_ids]





In [7]:
dataset = GloveDataset(open("text.txt").read(), 1000000)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
EMBED_DIM = 300
glove = GloveModel(dataset._vocab_len, EMBED_DIM)
glove = glove.to(device)

# of words: 1000000
Vocabulary length: 52755


In [12]:
optimizer = optim.Adagrad(glove.parameters(), lr=0.001)
N_EPOCHS = 1
BATCH_SIZE = 2048
X_MAX = 100
ALPHA = 0.75
n_batches = int(len(dataset._xij) / BATCH_SIZE)
loss_values = list()
for e in range(1, N_EPOCHS+1):
    batch_i = 0
    loss_values = []
    for x_ij, i_idx, j_idx in dataset.get_batches(BATCH_SIZE):
#         print(batch_i)
        x_ij = x_ij.to(device)
        i_idx = i_idx.to(device)
        j_idx = j_idx.to(device)
        
        batch_i += 1
        
        optimizer.zero_grad()
        
        outputs = glove(i_idx, j_idx)
        weights_x = weight_func(x_ij, X_MAX, ALPHA)
        #print(torch.log(x_ij.float()+1))
        loss = wmse_loss(weights_x, outputs, torch.log(x_ij.float()+1))
        
        loss.backward()
        
        optimizer.step()
        
        loss_values.append(loss.item())
        
        if batch_i % 500 == 0:
            print("Epoch: {}/{} \t Batch: {}/{} \t Loss: {}".format(e, N_EPOCHS, batch_i, n_batches, np.mean(loss_values[-20:])))  
    print('Epoch loss: {}'.format(np.mean(loss_values)))
    print("Saving model...")
    torch.save(glove.state_dict(), "text.pt")

Epoch: 1/1 	 Batch: 500/1724 	 Loss: 736.8017051875592
Epoch: 1/1 	 Batch: 1000/1724 	 Loss: 205.43506680727006
Epoch: 1/1 	 Batch: 1500/1724 	 Loss: 373.33878472447395
Epoch loss: 390.70504263611014
Saving model...


In [None]:
i_idx.size()