In [11]:
import torch 
import torch.nn as nn 
from glove import *
import torch.optim as optim
import torch.nn.functional as F
from sklearn.manifold import TSNE
import numpy as np

In [12]:
dataset = GloveDataset(open("text8").read(), 1000000)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
EMBED_DIM = 300
glove = GloveModel(dataset._vocab_len, EMBED_DIM)
glove = glove.to(device)

# of words: 1000000
Vocabulary length: 52755


In [None]:
optimizer = optim.Adagrad(glove.parameters(), lr=0.001)
N_EPOCHS = 30
BATCH_SIZE = 2048
X_MAX = 100
ALPHA = 0.75
n_batches = int(len(dataset._xij) / BATCH_SIZE)
loss_values = list()
for e in range(1, N_EPOCHS+1):
    batch_i = 0
    loss_values = []
    for x_ij, i_idx, j_idx in dataset.get_batches(BATCH_SIZE):
        x_ij = x_ij.to(device)
        i_idx = i_idx.to(device)
        j_idx = j_idx.to(device)
        
        batch_i += 1
        
        optimizer.zero_grad()
        
        outputs = glove(i_idx, j_idx)
        weights_x = weight_func(x_ij, X_MAX, ALPHA)
        #print(torch.log(x_ij.float()+1))
        loss = wmse_loss(weights_x, outputs, torch.log(x_ij.float()+1))
        
        loss.backward()
        
        optimizer.step()
        
        loss_values.append(loss.item())
        
        if batch_i % 500 == 0:
            print("Epoch: {}/{} \t Batch: {}/{} \t Loss: {}".format(e, N_EPOCHS, batch_i, n_batches, np.mean(loss_values[-20:])))  
    print('Epoch loss: {}'.format(np.mean(loss_values)))
    print("Saving model...")
    torch.save(glove.state_dict(), "text8.pt")

Epoch: 1/30 	 Batch: 500/1724 	 Loss: 742.6070358276368
Epoch: 1/30 	 Batch: 1000/1724 	 Loss: 757.286794090271
Epoch: 1/30 	 Batch: 1500/1724 	 Loss: 336.4401229262352
Epoch loss: 661.716105884636
Saving model...
Epoch: 2/30 	 Batch: 500/1724 	 Loss: 180.10235762000084
Epoch: 2/30 	 Batch: 1000/1724 	 Loss: 490.7532871246338
Epoch: 2/30 	 Batch: 1500/1724 	 Loss: 503.46504499949515
Epoch loss: 654.019794600087
Saving model...
Epoch: 3/30 	 Batch: 500/1724 	 Loss: 873.729664605856
Epoch: 3/30 	 Batch: 1000/1724 	 Loss: 536.1990437984466
Epoch: 3/30 	 Batch: 1500/1724 	 Loss: 436.698803416267
Epoch loss: 585.3253359740819
Saving model...
Epoch: 4/30 	 Batch: 500/1724 	 Loss: 1052.7831892296672
Epoch: 4/30 	 Batch: 1000/1724 	 Loss: 868.1717390924692
Epoch: 4/30 	 Batch: 1500/1724 	 Loss: 752.0615254402161
Epoch loss: 640.8059787456027
Saving model...
Epoch: 5/30 	 Batch: 500/1724 	 Loss: 503.694340762496
Epoch: 5/30 	 Batch: 1000/1724 	 Loss: 749.0832456469536
Epoch: 5/30 	 Batch: 1500/

In [None]:
i_idx.size()