In [1]:
import torch
from torch import nn

In [2]:
%matplotlib inline
%config Completer.use_jedi = False

In [25]:
x = torch.LongTensor([1,2])

- So above, you have a two-item, rank-1 tensor.

In [26]:
embedding = nn.Embedding(1,4)

- Here, you have an embedding that will contain 1 distinct entry and have a size-4 vector. 
    - You might think of this as describing how a single word maps to four other words.

In [38]:
embedding(torch.LongTensor([0,0]))

tensor([[ 0.1799, -0.1115, -1.1305,  1.4224],
        [ 0.1799, -0.1115, -1.1305,  1.4224]], grad_fn=<EmbeddingBackward>)

So, the entirety of the embedding is the 1x4 vector `[ 0.1799, -0.1115, -1.1305,  1.4224]`. When you pass in a two-item, rank-1 tensor with 0s, it returns the vector at index 0 of the embedding twice.

In [39]:
embedding = nn.Embedding(3, 5)

In [40]:
embedding(torch.LongTensor([0,1,2,0]))

tensor([[-6.9557e-01, -6.8639e-01,  1.2521e-01,  1.0239e-04, -1.2211e+00],
        [-7.5383e-01, -4.4092e-01, -1.2775e+00,  9.2140e-01,  9.1968e-01],
        [-1.0855e-01,  7.3577e-01, -2.1150e-01,  4.4531e-01, -2.2939e+00],
        [-6.9557e-01, -6.8639e-01,  1.2521e-01,  1.0239e-04, -1.2211e+00]],
       grad_fn=<EmbeddingBackward>)

If you pass as a lookup-list a tensor like so, you'll get another tensor containing the items index 0, 1, and 2, and then 0 again, from the randomly generated embedding.

In [41]:
embedding = nn.Embedding(200, 1)
embedding(torch.LongTensor([199]))

tensor([[0.4102]], grad_fn=<EmbeddingBackward>)

You can look up the length-1, rank-1 item with index 199 in the 200-item embedding table like above. But if you try to look up the item at index 200, you'll have a bad time, because there is no such item -- you would have had to pass `201` as the first argument to `nn.Embedding` to make it work.

In [43]:
embedding(torch.LongTensor([200]))

IndexError: index out of range in self

In [3]:
x = torch.LongTensor([[1,2,3,4,5], [6,7,8,9,10]])

In [6]:
embedding = nn.Embedding(11,5)
embedding(x)

tensor([[[-1.1566, -1.0602,  1.6966, -0.4835,  0.1418],
         [-0.9794, -1.3156, -0.3406,  1.7615, -0.7826],
         [ 1.0896, -0.8429, -0.1067,  0.2787, -0.8359],
         [ 1.8928, -0.4486, -0.7178, -0.4341, -0.0655],
         [ 1.9307,  0.7189, -0.6727, -1.2131, -1.0864]],

        [[-0.6851, -0.5309, -0.1807,  1.1145, -0.1963],
         [-0.0842, -0.1603,  0.4753, -0.5600, -0.3373],
         [ 1.2985, -0.0272,  0.4433,  0.5025,  0.1340],
         [ 0.5890,  0.0357,  1.9312,  0.3142,  0.2670],
         [ 0.6460,  0.0627, -0.1356,  0.8055,  1.2612]]],
       grad_fn=<EmbeddingBackward>)

In [16]:
x = torch.LongTensor([[1,2,3],[4,5,6],[7,8,9]])
print(x.shape)
embedding = nn.Embedding(10,5)
embedding(x)

torch.Size([3, 3])


tensor([[[ 0.7285,  0.0778, -0.6906, -0.0529,  1.1316],
         [ 1.3995, -0.8523,  0.2807, -2.6537, -1.0933],
         [-0.0457, -0.0781,  3.8702,  0.3604,  0.0418]],

        [[ 1.6352,  1.7193, -0.3228, -0.5870, -0.5407],
         [-0.6589,  0.9963, -0.2757,  0.4144, -0.6856],
         [-0.2298, -0.1375,  0.6844, -0.3869,  0.1136]],

        [[ 0.7640, -1.6889, -1.1883, -0.4289, -0.6095],
         [ 1.2039,  1.9125,  0.1949,  1.2838,  1.1462],
         [-0.7507, -0.8444,  0.8797, -0.2231, -0.6771]]],
       grad_fn=<EmbeddingBackward>)

In [15]:
x = torch.LongTensor([[[1,2,3],[4,5,6],[7,8,9]],
                      [[10,11,12],[13,14,15],[16,17,18]]])
print(x.shape)
embedding = nn.Embedding(19,5)
embedding(x)

torch.Size([2, 3, 3])


tensor([[[[ 1.5321, -0.4328,  1.5068, -0.4174, -0.9384],
          [ 0.5735,  1.0236, -0.7790,  0.3301,  0.3866],
          [-0.4201, -1.3785,  0.9226,  1.1392,  0.9528]],

         [[-0.3818, -0.4232,  1.1798,  0.9941,  0.0115],
          [ 2.1712, -0.5416, -1.1216, -0.8111,  0.8903],
          [-0.8523, -0.6758, -0.3092, -0.8634, -1.0632]],

         [[-0.2722, -1.5317,  0.8587,  0.3555,  0.4001],
          [-0.5981,  0.3267, -1.3805,  1.7650,  1.6163],
          [-0.4526,  0.3024, -1.1152, -1.2691,  0.7107]]],


        [[[-0.2419,  0.1115,  0.7022, -0.8589,  0.9142],
          [-0.4763,  1.2867, -0.5896, -1.1737, -1.9199],
          [ 2.1772,  0.3886,  1.3080, -0.0141, -1.2014]],

         [[-0.9749,  0.5032, -1.4533,  0.6109,  0.0718],
          [ 0.2362,  0.9619, -1.3447, -0.5027,  0.0739],
          [-0.2146, -2.2915,  0.5675,  1.0459, -1.5595]],

         [[ 0.4355, -1.8622,  1.0779,  1.0574, -0.3239],
          [-1.4404, -1.2105, -0.4782,  1.7985,  0.5934],
          [ 1.8132,

In [23]:
x = torch.randint(10, (3,3,4,5))
print(x.shape)
embedding = nn.Embedding(10,5)
embedding(x)

torch.Size([3, 3, 4, 5])


tensor([[[[[-1.6401, -0.6372,  2.3019, -0.8344, -0.0973],
           [-1.6401, -0.6372,  2.3019, -0.8344, -0.0973],
           [ 1.8128,  0.7059,  1.3513, -0.7154,  0.5635],
           [ 1.4438,  1.4082, -0.8478, -1.6787, -0.8331],
           [ 1.8128,  0.7059,  1.3513, -0.7154,  0.5635]],

          [[-1.2930,  1.0422,  1.3527,  0.3613, -3.5503],
           [ 0.1310,  0.8138, -1.6016, -2.4117, -0.4255],
           [ 1.8128,  0.7059,  1.3513, -0.7154,  0.5635],
           [-1.2213,  2.5747,  0.1518,  0.7915, -0.9062],
           [ 0.1310,  0.8138, -1.6016, -2.4117, -0.4255]],

          [[-1.0844,  0.4452,  1.8177,  1.6652,  0.0757],
           [ 1.8128,  0.7059,  1.3513, -0.7154,  0.5635],
           [ 0.5062,  0.9103, -1.9796,  0.4233,  0.3958],
           [ 1.8128,  0.7059,  1.3513, -0.7154,  0.5635],
           [-1.6401, -0.6372,  2.3019, -0.8344, -0.0973]],

          [[-1.2930,  1.0422,  1.3527,  0.3613, -3.5503],
           [-1.2930,  1.0422,  1.3527,  0.3613, -3.5503],
        