In [1]:
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
corpus = [
    'he is a king',
    'she is a queen',
    'he is a man',
    'she is a woman',
    'warsaw is poland capital',
    'berlin is germany capital',
    'paris is france capital',
]

In [3]:
def tokenize_corpus(corpus):
    tokens = [x.split() for x in corpus]
    return tokens

In [4]:
sentences = tokenize_corpus(corpus=corpus)

In [5]:
print(sentences)

[['he', 'is', 'a', 'king'], ['she', 'is', 'a', 'queen'], ['he', 'is', 'a', 'man'], ['she', 'is', 'a', 'woman'], ['warsaw', 'is', 'poland', 'capital'], ['berlin', 'is', 'germany', 'capital'], ['paris', 'is', 'france', 'capital']]


In [6]:
voc = []

In [7]:
for sentence in sentences:
     for word in sentence:
        if word not in voc:
            voc.append(word)

In [8]:
print(voc)

['he', 'is', 'a', 'king', 'she', 'queen', 'man', 'woman', 'warsaw', 'poland', 'capital', 'berlin', 'germany', 'paris', 'france']


In [9]:
word2idx = {w:idx for (idx, w) in enumerate(voc)}
inx2word = {idx:w for (idx, w) in enumerate(voc)}

In [10]:
window_size = 3
idx_pair = []

In [11]:
for sentence in sentences:
    indices = [word2idx[word] for word in sentence]
    for idx in range(len(indices)):
        if idx + window_size - 1 < len(indices):
            window = indices[idx : idx + window_size]
            for x_word in window:
                for y_word in window:
                    if x_word != y_word:
                        idx_pair.append([x_word, y_word])

In [12]:
print(idx_pair)

[[0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1], [1, 2], [1, 3], [2, 1], [2, 3], [3, 1], [3, 2], [4, 1], [4, 2], [1, 4], [1, 2], [2, 4], [2, 1], [1, 2], [1, 5], [2, 1], [2, 5], [5, 1], [5, 2], [0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1], [1, 2], [1, 6], [2, 1], [2, 6], [6, 1], [6, 2], [4, 1], [4, 2], [1, 4], [1, 2], [2, 4], [2, 1], [1, 2], [1, 7], [2, 1], [2, 7], [7, 1], [7, 2], [8, 1], [8, 9], [1, 8], [1, 9], [9, 8], [9, 1], [1, 9], [1, 10], [9, 1], [9, 10], [10, 1], [10, 9], [11, 1], [11, 12], [1, 11], [1, 12], [12, 11], [12, 1], [1, 12], [1, 10], [12, 1], [12, 10], [10, 1], [10, 12], [13, 1], [13, 14], [1, 13], [1, 14], [14, 13], [14, 1], [1, 14], [1, 10], [14, 1], [14, 10], [10, 1], [10, 14]]


In [13]:
one_hot_vector = torch.zeros(size=(len(voc), len(voc)),requires_grad= False)

In [14]:
for idx, vector in enumerate(one_hot_vector):
    vector[idx] = 1.

In [15]:
class Word2Vec(torch.nn.Module):
    def __init__(self, voc_size, embedding_dim) -> None:
        super(Word2Vec, self).__init__()
        # Encoder
        self.w1 = torch.nn.Parameter(torch.randn(size=(embedding_dim, voc_size), requires_grad = True))
        self.b1 = torch.nn.Parameter(torch.randn(size = (embedding_dim,), requires_grad= True))
        
        # Decoder
        self.w2 = torch.nn.Parameter(torch.randn(size=(voc_size, embedding_dim), requires_grad = True))
        self.b2 = torch.nn.Parameter(torch.randn(size=(voc_size,), requires_grad = True))
    def forward(self, x):
        x = self.w1 @ x + self.b1
        x = self.w2 @ x + self.b2
        return x

In [16]:
word2vec = Word2Vec(voc_size=len(voc), embedding_dim=5)

In [17]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(word2vec.parameters(), lr=1e-3)

In [18]:
for epch in range(1000):
    correct = 0.
    for x,y in idx_pair:
        optimizer.zero_grad()
        y = one_hot_vector[y]
        x = one_hot_vector[x]
        pred = word2vec(x)
        loss = loss_fn(pred, y)
        loss.backward()
        correct += (pred.argmax(0) == y.argmax(0)).float()
        optimizer.step()
    print(f'Acc = {correct/len(idx_pair)}')
    print(f'Current loss {loss.item()}')
        

Acc = 0.0357142873108387
Current loss 4.223329067230225
Acc = 0.0595238097012043
Current loss 3.671424627304077
Acc = 0.0595238097012043
Current loss 3.244203567504883
Acc = 0.0833333358168602
Current loss 2.923527717590332
Acc = 0.0833333358168602
Current loss 2.6887552738189697
Acc = 0.130952388048172
Current loss 2.5193915367126465
Acc = 0.095238097012043
Current loss 2.398077964782715
Acc = 0.0714285746216774
Current loss 2.311722755432129
Acc = 0.0714285746216774
Current loss 2.2509994506835938
Acc = 0.1071428582072258
Current loss 2.2092771530151367
Acc = 0.1071428582072258
Current loss 2.181671142578125
Acc = 0.1071428582072258
Current loss 2.164421319961548
Acc = 0.1190476194024086
Current loss 2.1545562744140625
Acc = 0.25
Current loss 2.149714469909668
Acc = 0.3095238208770752
Current loss 2.148059368133545
Acc = 0.3452380895614624
Current loss 2.1482081413269043
Acc = 0.3571428656578064
Current loss 2.149169445037842
Acc = 0.3571428656578064
Current loss 2.150270462036133
Ac