In [None]:
#!pip install d2l
import torch
from torch import nn
from d2l import torch as d2l

batch_size, max_window_size, num_noise_words = 512, 5, 5
data_iter, vocab = d2l.load_data_ptb(batch_size, max_window_size,
                                     num_noise_words)

In [12]:
embed = nn.Embedding(num_embeddings=20, embedding_dim=4)
print(f'Parameter embedding_weight ({embed.weight.shape}, '
      'dtype={embed.weight.dtype})')

Parameter embedding_weight (torch.Size([20, 4]), dtype={embed.weight.dtype})


In [13]:
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
embed(x)

tensor([[[ 1.4743, -1.8477, -0.0442,  0.6067],
         [-2.0036, -0.1232,  0.5672, -1.0323],
         [ 0.5882, -0.6870,  1.3224, -1.3498]],

        [[ 0.3389,  0.8951, -1.6349, -1.6618],
         [ 0.5121, -0.9549,  0.4177,  1.2872],
         [-0.6950, -0.4782,  0.6976,  1.2563]]], grad_fn=<EmbeddingBackward>)

**Implemented skip-gram to CBOW**

In [14]:
def skip_gram(center, contexts_and_negatives, embed_v, embed_u):
    v = embed_v(center)
    u = embed_u(contexts_and_negatives)
    pred = torch.bmm(v, u.permute(0, 2, 1))
    return pred

In [15]:
skip_gram(torch.ones((2, 1), dtype=torch.long),
          torch.ones((2, 4), dtype=torch.long), embed, embed).shape

torch.Size([2, 1, 4])

**CBOW**

In [16]:
def CBOW(contexts, center_negative, masks, embed_v, embed_u):
    v = embed_v(contexts).sum(dim=1) /  masks.sum(dim=-1).view(-1, 1)
    u = embed_u(center_negative)
    pred = torch.bmm(v[:, None, :], u.permute(0, 2, 1))
    return pred

In [17]:
CBOW(torch.ones((2, 1), dtype=torch.long),
          torch.ones((2, 6), dtype=torch.long), torch.ones((2, 6), dtype=torch.long), embed, embed).shape

torch.Size([2, 1, 6])