![title](./transformer-basics.png)

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.functional as F

In [36]:
class BasicTransformer:
    """
    Basic Transformer
    """
    def __init__(self, dim):
        self.dim = dim
        self.W_q = torch.rand(dim, dim, requires_grad=True)
        self.W_k = torch.rand(dim, dim, requires_grad=True)
        self.W_v = torch.rand(dim, dim, requires_grad=True)
    
    def forward(self, x):
        """
        x: 
            dim0: batch dimension
            cols: timesteps (or sequence)
            rows: dimensionality (dim)
        """
        q = torch.matmul(self.W_q, x)
        k = torch.matmul(self.W_k, x)
        v = torch.matmul(self.W_v, x)  # 32 x 8 x 17
        
        batch_size = x.shape[0]
        timesteps = x.shape[2]
        
        y = torch.empty(batch_size, self.dim, timesteps)
        
        for i in range(timesteps):
            q_i = q[:, :, i]
            # get weights
            weights = torch.matmul(q_i[:, np.newaxis, :], k).squeeze()
            # scale weights
            weights = weights / np.sqrt(self.dim)
            # softmax weights
            weights = torch.softmax(weights, dim=1)
            
            y[:, :, i] = torch.sum(weights[:, np.newaxis, :] * v, dim=2)
        
        y = y.mean(dim=2)  # average out the sequence
            
        return y

In [37]:
model = BasicTransformer(8)

In [38]:
x_ = torch.rand(32, 8, 17)

In [39]:
y = model.forward(x_)

RuntimeError: Dimension out of range (expected to be in range of [-2, 1], but got 2)

In [40]:
y.shape

torch.Size([32, 8, 17])

In [43]:
y.mean(dim=2).shape

torch.Size([32, 8])

weights.shape: torch.Size([32, 17])
weights.shape: torch.Size([32, 17])
weights.shape: torch.Size([32, 17])
weights.shape: torch.Size([32, 17])
weights.shape: torch.Size([32, 17])
weights.shape: torch.Size([32, 17])
weights.shape: torch.Size([32, 17])
weights.shape: torch.Size([32, 17])
weights.shape: torch.Size([32, 17])
weights.shape: torch.Size([32, 17])
weights.shape: torch.Size([32, 17])
weights.shape: torch.Size([32, 17])
weights.shape: torch.Size([32, 17])
weights.shape: torch.Size([32, 17])
weights.shape: torch.Size([32, 17])
weights.shape: torch.Size([32, 17])
weights.shape: torch.Size([32, 17])


In [82]:
weights.shape

torch.Size([32, 17])

In [85]:
softmaxed_weights = torch.softmax(weights, dim=1)
softmaxed_weights

tensor([[2.0819e-02, 1.4097e-03, 5.4635e-03, 6.4457e-03, 4.0511e-03, 1.1580e-03,
         7.0004e-01, 1.4089e-03, 6.2278e-04, 9.4176e-04, 5.6552e-02, 1.8018e-01,
         5.4823e-03, 1.5045e-03, 3.4460e-03, 8.7194e-03, 1.7479e-03],
        [1.1504e-03, 2.4111e-03, 9.0321e-02, 9.4621e-02, 6.8310e-03, 3.0312e-02,
         1.2864e-02, 1.1005e-01, 1.4621e-02, 1.5085e-03, 2.3201e-03, 1.0446e-03,
         8.6780e-02, 2.7091e-03, 1.3246e-04, 2.3951e-02, 5.1837e-01],
        [1.1220e-02, 5.5838e-03, 1.7391e-01, 3.0684e-02, 3.6276e-04, 1.4706e-02,
         1.4670e-01, 9.0861e-02, 5.8181e-02, 1.8195e-02, 3.1755e-01, 1.8231e-03,
         4.3236e-03, 3.2926e-02, 5.2640e-02, 1.7787e-02, 2.2547e-02],
        [3.8560e-03, 2.9235e-01, 2.4944e-03, 2.7203e-03, 1.8404e-02, 8.3394e-03,
         5.0762e-02, 1.7592e-03, 3.6783e-02, 1.2011e-02, 1.3927e-02, 1.0578e-02,
         1.5957e-01, 7.6049e-02, 1.4316e-02, 2.9308e-01, 3.0050e-03],
        [1.6832e-04, 1.0130e-03, 5.0239e-02, 1.4779e-04, 1.8104e-01, 7.1

In [89]:
y = torch.empty(2, 3)

In [31]:
x1 = torch.rand(32, 8)
x2 = torch.rand(32, 8, 17)

In [57]:
o = torch.matmul(x1[:, np.newaxis, :], x2)

In [60]:
o.squeeze().shape

torch.Size([32, 17])

In [38]:
(x1[..., np.newaxis] * x2).shape

torch.Size([32, 8, 17])

In [13]:
W_ = torch.rand(3, 8)
X_ = torch.rand(32, 8, 17)

In [15]:
O_ = torch.matmul(W_, X_)

In [16]:
O_.shape

torch.Size([32, 3, 17])