In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data

In [12]:
class Word2VecModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, skip_gram=True):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.skip_gram = skip_gram

        self.embedding_layer = nn.Embedding(self.vocab_size, self.embedding_dim)
        self.linear_layer = nn.Linear(self.embedding_dim, self.vocab_size)
    
    def forward(self, inputs):
        if self.skip_gram:
            #inputs size = [batch] (value is index)
            embedding_vectors = self.embedding_layer(inputs)       # embedding_vectors size = [batch, embedding_dim]
        else:
            #inputs size = [batch, bag_size]
            embedding_vectors = []
            for BOW in inputs:
                BOW_vectors = self.embedding_layer(BOW)            # BOW_vectors size = [bag_size, embedding_dim]
                BOW_vectors = torch.sum(BOW_vectors, 0) / len(BOW) # BOW_vectors size = [1, embedding_dim]
                embedding_vectors.append(BOW_vectors)
            embedding_vectors = torch.cat(embedding_vectors, 0)    #embedding_vectors size = [batch, embedding_dim]
        vocab_vectors = self.linear_layer(embedding_vectors)       #vocab_vectors size = [batch, vocab_size]
        preds = F.softmax(vocab_vectors)
        return preds
    
    def inference(self, inputs):
        #input size = [batch, seq_len]
        embedding_vectors = []
        for sequence in inputs:
            batch_vectors = self.embedding_layer(sequence)
            batch_vectors = batch_vectors.unsqueeze(0)
            embedding_vectors.append(batch_vectors)
        preds = torch.cat(embedding_vectors, 0)                    # preds size = [batch, seq_len, embedding_dim]
        return preds

In [19]:
model = Word2VecModel(4, 3)
test = torch.randint(0, 3, (2, ))
preds = model(test)
print(preds)

tensor([[0.1961, 0.0472, 0.6094, 0.1472],
        [0.3274, 0.1281, 0.3497, 0.1948]], grad_fn=<SoftmaxBackward>)




In [3]:
a = torch.randn(2,3,3)
print(a)

tensor([[[-0.7863, -0.4312, -1.2767],
         [-0.6403,  0.7165,  1.9938],
         [-0.2419,  0.4264, -0.1182]],

        [[-0.5699, -0.1546,  0.5419],
         [ 0.5353,  0.1165,  0.5651],
         [ 0.6662,  0.3728, -0.5938]]])
