In [47]:
!pip install d2l
import torch
from torch import nn
from d2l import torch as d2l

batch_size, max_window_size, num_noise_words = 512, 5, 5
data_iter, vocab = d2l.load_data_ptb(batch_size, max_window_size,
                                     num_noise_words)



In [48]:
embed = nn.Embedding(num_embeddings=20, embedding_dim=4)
print(f'Parameter embedding_weight ({embed.weight.shape}, '
      'dtype={embed.weight.dtype})')

Parameter embedding_weight (torch.Size([20, 4]), dtype={embed.weight.dtype})


In [49]:
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
embed(x)

tensor([[[ 0.2080, -0.6522,  0.6920,  0.3447],
         [ 0.6020, -0.2797, -0.1708,  0.2582],
         [-0.6875,  0.5881,  0.1166, -1.5285]],

        [[-1.1217, -0.0549, -1.3185, -0.4098],
         [-1.3660, -0.6081, -1.8877, -0.0732],
         [-0.4024,  1.7526, -0.2805, -1.9552]]], grad_fn=<EmbeddingBackward>)

In [50]:
x.shape

torch.Size([2, 3])

**Implemented skip-gram to CBOW**

In [51]:
def skip_gram(center, contexts_and_negatives, embed_v, embed_u):
    v = embed_v(center)
    u = embed_u(contexts_and_negatives)
    pred = torch.bmm(v, u.permute(0, 2, 1))
    return pred

In [52]:
skip_gram(torch.ones((2, 1), dtype=torch.long),
          torch.ones((2, 4), dtype=torch.long), embed, embed).shape

torch.Size([2, 1, 4])

**CBOW**

In [53]:
def CBOW(contexts, center_negative, masks, embed_v, embed_u):
    v = embed_v(contexts).sum(dim=1) /  masks.sum(dim=-1).view(-1, 1)
    u = embed_u(center_negative)
    pred = torch.bmm(v[:, None, :], u.permute(0, 2, 1))
    return pred

In [54]:
CBOW(torch.ones((2, 1), dtype=torch.long),torch.ones((2, 6), dtype=torch.long), torch.ones((2, 6), dtype=torch.long), embed, embed).shape

torch.Size([2, 1, 6])