In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
from auxiliary_file import *

### Dataset construction dataset

In [2]:
training_data, index_to_word = training_data_cnstructing([pre_processing('article1.txt')])

In [3]:
class MyDataset(Dataset):

    def __init__(self, data, index_to_word):
        self.data = data
        self.word_to_index = index_to_word

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        context, target_word = self.data[idx]
        context = torch.tensor(context, dtype=torch.long)
        target_word = torch.tensor(target_word, dtype=torch.float)
        return context, target_word

In [4]:
data_set = MyDataset(training_data, index_to_word)

In [5]:
data_set[0]

(tensor([ 710, 1080,    6, 1186,  876]),
 tensor([0., 0., 0.,  ..., 0., 0., 0.]))

### Dataloader construction

In [6]:
batch_size = 32
dataloader = torch.utils.data.DataLoader(data_set, batch_size=batch_size, shuffle=True, num_workers=0)

In [7]:
for context, target in dataloader:
    print(context.shape)
    print(target.shape)
    break

torch.Size([32, 5])
torch.Size([32, 1430])


### Architecure

In [8]:
class Word2vec(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super(Word2vec, self).__init__()
        
        self.embed = nn.Embedding(vocab_size, embed_size)      
        self.linear = nn.Linear(embed_size, vocab_size)
        
    def forward(self, context_words):

        embeds = self.embed(context_words)
        
        embeds = torch.mean(embeds, dim=1)
        
        out = self.linear(embeds)
        
        log_probs = F.log_softmax(out, dim=1)
        
        return log_probs

### Training model

In [9]:
vocab_size = len(data_set[0][1])
embed_size = 100
num_epochs = 200

model = Word2vec(vocab_size, embed_size)

optimizer = optim.SGD(model.parameters(), lr=0.9)
loss_function = nn.NLLLoss()

for epoch in range(num_epochs):
    total_loss = 0
    for context, target_word in dataloader:

        log_probs = model(context)
        
        loss = loss_function(log_probs, torch.argmax(target_word, dim=1))
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    print(f"Epoch {epoch + 1}, Loss: {total_loss}")


Epoch 1, Loss: 1127.374116897583
Epoch 2, Loss: 966.134340763092
Epoch 3, Loss: 868.01416015625
Epoch 4, Loss: 792.3959057331085
Epoch 5, Loss: 727.987423658371
Epoch 6, Loss: 671.593722820282
Epoch 7, Loss: 620.4016768932343
Epoch 8, Loss: 573.8380517959595
Epoch 9, Loss: 530.8325037956238
Epoch 10, Loss: 491.59458470344543
Epoch 11, Loss: 455.8370735645294
Epoch 12, Loss: 422.64382553100586
Epoch 13, Loss: 392.467813372612
Epoch 14, Loss: 365.31044363975525
Epoch 15, Loss: 340.1634204387665
Epoch 16, Loss: 317.73384296894073
Epoch 17, Loss: 297.3810614347458
Epoch 18, Loss: 278.67775189876556
Epoch 19, Loss: 261.91796839237213
Epoch 20, Loss: 246.4845970273018
Epoch 21, Loss: 232.20951998233795
Epoch 22, Loss: 218.97701162099838
Epoch 23, Loss: 207.40553033351898
Epoch 24, Loss: 195.84162068367004
Epoch 25, Loss: 185.61240375041962
Epoch 26, Loss: 175.93339723348618
Epoch 27, Loss: 166.8773576617241
Epoch 28, Loss: 158.21355086565018
Epoch 29, Loss: 150.66912174224854
Epoch 30, Loss:

### Collection of word embeddings

In [10]:
embeddings = model.embed.weight.data

word_vectors = {}

for index, word in data_set.word_to_index.items():

    word_vectors[word] = embeddings[index].numpy()

vector = word_vectors['whole']

### Using embeddings

In [11]:
document_vector(word_vectors, pre_processing('article1.txt'))

array([-0.08504745, -0.06208043, -0.13924137,  0.06137526,  0.03140182,
       -0.04903676, -0.06061382, -0.09932429,  0.05973937,  0.12183657,
       -0.03803857,  0.10438314, -0.0643212 , -0.04555834,  0.0886331 ,
       -0.02583544,  0.0176372 ,  0.06187674,  0.1003326 , -0.18175004,
        0.02577518,  0.01620501,  0.07904318,  0.02467192, -0.05901472,
        0.02251715,  0.05450393, -0.01737169, -0.15636475,  0.05712589,
        0.01381911,  0.04989778, -0.09927873,  0.21000224,  0.03257312,
       -0.13371833, -0.05527603,  0.00366283, -0.09176171, -0.02369864,
        0.05874272,  0.01988603, -0.08102527, -0.07479553, -0.02615019,
       -0.0096089 , -0.03722088,  0.00905608, -0.00557863, -0.06213749,
       -0.03447596,  0.1705859 , -0.05792461, -0.08617669, -0.0273452 ,
       -0.057104  , -0.05722999,  0.09842712, -0.1318504 , -0.05002528,
       -0.1026495 ,  0.07031095,  0.04980109,  0.03997704,  0.02542106,
       -0.00435334, -0.02236769, -0.04254798, -0.0405719 ,  0.00