In [5]:
import torch
import math
import torch.nn.functional as F


def self_attention(q, k, v):
    # if 3 dim: b t d
    #prod = Q.bmm(K.permute(0, 2, 1))
    # or
    # prod = torch.einsum("btd, bsd -> bts", q, k)
    # if 4 dim: b h t dh
    prod = torch.einsum("bhtd, bhsd -> bhts", q, k)
    scaled_prod = prod/torch.sqrt(torch.tensor(q.shape[-1]))
    softmaxed_prod = F.softmax(scaled_prod, dim=-1)
    # print(softmaxed_prod.shape)
    # print(softmaxed_prod)
    return softmaxed_prod @ v


x = torch.rand([2, 3, 4, 5])
self_attention(x, x, x)
self_attention(x, x, x).shape

torch.Size([2, 3, 4, 5])

In [6]:
from torch import nn

class MHSA(nn.Module):
    def __init__(self, d: int = 512, h: int = 8):
        super().__init__()
        assert d % h == 0
        self.d = d
        self.dh = d // h
        self.h = h
        self.wq = nn.Linear(self.d, self.d)
        self.wk = nn.Linear(self.d, self.d)
        self.wv = nn.Linear(self.d, self.d)
        self.wo = nn.Linear(self.d, self.d)
 
    def forward(self, q, k, v):
        # b, t, d
        b, t, d = q.size()
        wq = self.wq(q)
        wk = self.wk(k)
        wv = self.wv(v)
        wq = wq.view(b, t, self.h, self.dh)
        wk = wk.view(b, t, self.h, self.dh)
        wv = wv.view(b, t, self.h, self.dh)
        # b, t, h, dh
        # if changing from 4 dim -> 3 dim: b*h, t, dh
        # wq = wq.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # wk = wk.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # wv = wv.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # another option 4 dim -> 3 dim
        # wq = wq.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # wk = wk.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # wv = wv.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # changing the number of dims is not necessary as @ supports 4 dims
        attn = self_attention(wq, wk, wv)
        # b * h, t, dh
        # attn = attn.view(b, self.h, t, self.dh).permute(0, 2, 1, 3).reshape(b, t, d)
        attn = attn.view(b, self.h, t, self.dh).transpose(1, 2).contiguous().view(b, t, d)
        wo = self.wo(attn)
        return wo
        # # 1 2 3 4
        # x = F.relu(self.conv1(x))
        # return F.relu(self.conv2(x))

mhsa = MHSA()
x = torch.rand(2, 3, 512)
mhsa(x, x, x).shape

torch.Size([2, 3, 512])

In [7]:
import torch.nn.functional as F

class EncoderLayer(nn.Module): 

    def __init__(self, d: int = 512, h: int = 8, dropout: float = 0.1):
        super().__init__()
        self.mhsa = MHSA(d, h)
        self.norm1 = nn.LayerNorm(d)
        self.ff1 = nn.Linear(d, d * 4)
        self.ff2 = nn.Linear(d * 4, d)
        self.norm2 = nn.LayerNorm(d)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        b, t, d = x.size()
        x = x + self.dropout(self.mhsa(x, x, x))
        x = self.norm1(x)
        x = x + self.dropout(self.ff2(F.relu(self.ff1(x))))
        x = self.norm2(x)
        return x


encoder_layer = EncoderLayer()
x = torch.rand(2, 3, 512)
encoder_layer(x).shape

torch.Size([2, 3, 512])

In [41]:
class PE1():
    def __init__(self, d: int = 512):
        self.d = d
    # -> d vector
    def __call__(self, pos):
        pow = torch.pow(10000, torch.arange(0, self.d) / self.d)
        return torch.sin(torch.arange(0, self.d) / pow)

PE1()(1).size()

torch.Size([512])

In [69]:
class PEScalar():
    def __init__(self, d: int = 512):
        self.d = d
    # -> d vector
    def __call__(self, pos):
        ar = torch.arange(0, self.d, 2)
        pow_ = torch.pow(10000, ar / self.d)
        sin_p = torch.sin(pos / pow_)
        cos_p = torch.cos(pos / pow_)
        # a = torch.arange(0, 12, 2)
        # b = torch.arange(1, 12, 2)
        # torch.stack((a, b), dim=1).view(-1)
        return torch.stack((sin_p, cos_p), dim=-1).view(-1)

PEScalar()(1).size()

torch.Size([1, 512])

In [68]:
class PE():
    def __init__(self, d: int = 512):
        self.d = d
    # t -> t d
    def __call__(self, pos):
        ar = torch.arange(0, self.d, 2)
        pow_ = torch.pow(10000, ar / self.d)
        sin_p = torch.sin(pos / pow_)
        cos_p = torch.cos(pos / pow_)
        # a = torch.arange(0, 12, 2).view(-1, 2)
        # b = torch.arange(1, 12, 2).view(-1, 2)
        # torch.stack((a, b), dim=-1).view(-1, 4)
        return torch.stack((sin_p, cos_p), dim=-1).view(-1, self.d)

PE()(torch.arange(3).view(-1, 1)).size()

torch.Size([3, 512])

In [49]:
a = torch.arange(0, 12, 2)
b = torch.arange(1, 12, 2)
torch.stack((a, b), dim=1).view(-1)

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])

In [67]:
a = torch.arange(0, 12, 2).view(-1, 2)
b = torch.arange(1, 12, 2).view(-1, 2)
torch.stack((a, b), dim=-1).view(-1, 4)

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])

In [54]:
torch.arange(1, 12, 2).view(-1, 1)

tensor([[ 1],
        [ 3],
        [ 5],
        [ 7],
        [ 9],
        [11]])

In [8]:
class Encoder(nn.Module): 

    def __init__(self, n: int = 6, d: int = 512, h: int = 8):
        super().__init__()
        self.embed = nn.Embedding(10000, d)
        self.layers = [EncoderLayer(d, h) for _ in range(n)]

    def forward(self, x):
        b, t, d = x.size()
        x = self.embed(x)
        for layer in self.layers:
            x = layer(x)
        return x

encoder = Encoder()
x = torch.rand(2, 3, 512)
encoder(x).shape

torch.Size([2, 3, 512])

In [26]:
torch.randint(0 , 10, [2, 3, 4])

tensor([[[5, 0, 7, 5],
         [1, 1, 5, 5],
         [9, 5, 5, 3]],

        [[0, 3, 6, 2],
         [4, 7, 6, 3],
         [8, 4, 6, 0]]])

In [27]:
e = nn.Embedding(10, 4)
e(torch.randint(0 , 10, [2, 3]))

tensor([[[-0.3452, -0.2566,  0.8572, -0.7027],
         [-0.5786,  1.9277, -1.6954, -0.9788],
         [-0.5786,  1.9277, -1.6954, -0.9788]],

        [[-0.5786,  1.9277, -1.6954, -0.9788],
         [ 0.8099,  0.6575, -1.7057, -0.9658],
         [-0.2256, -0.3173,  0.4518, -0.4014]]], grad_fn=<EmbeddingBackward0>)

In [2]:
a = torch.arange(12)
a = a.view(2,3,2)
a

tensor([[[ 0,  1],
         [ 2,  3],
         [ 4,  5]],

        [[ 6,  7],
         [ 8,  9],
         [10, 11]]])

In [9]:
a.permute(2, 0, 1)

tensor([[[ 0,  2,  4],
         [ 6,  8, 10]],

        [[ 1,  3,  5],
         [ 7,  9, 11]]])

In [3]:
import torch
a = torch.rand(2, 3, 4, 5)
b = torch.rand(2, 3, 5, 6)
(a @ b).size()

torch.Size([2, 3, 4, 6])