In [1]:
import torch
import math
import torch.nn.functional as F

def self_attention(q, k, v):
    # if 3 dim: b t d
    #prod = Q.bmm(K.permute(0, 2, 1))
    # or
    # prod = torch.einsum("btd, bsd -> bts", q, k)
    # if 4 dim: b h t dh
    prod = torch.einsum("bhtd, bhsd -> bhts", q, k)
    scaled_prod = prod/torch.sqrt(torch.tensor(q.shape[-1]))
    softmaxed_prod = F.softmax(scaled_prod, dim=-1)
    # print(softmaxed_prod.shape)
    # print(softmaxed_prod)
    return softmaxed_prod @ v


x = torch.rand([2, 3, 4, 5])
self_attention(x, x, x)
self_attention(x, x, x).shape

  cpu = _conversion_method_template(device=torch.device("cpu"))


torch.Size([2, 3, 4, 5])

In [2]:
from torch import nn

class MHSA(nn.Module):
    def __init__(self, d: int = 512, h: int = 8):
        super().__init__()
        assert d % h == 0
        self.d = d
        self.dh = d // h
        self.h = h
        self.wq = nn.Linear(self.d, self.d)
        self.wk = nn.Linear(self.d, self.d)
        self.wv = nn.Linear(self.d, self.d)
        self.wo = nn.Linear(self.d, self.d)
 
    def forward(self, q, k, v):
        # b, t, d
        b, t, d = q.size()
        wq = self.wq(q)
        wk = self.wk(k)
        wv = self.wv(v)
        wq = wq.view(b, t, self.h, self.dh)
        wk = wk.view(b, t, self.h, self.dh)
        wv = wv.view(b, t, self.h, self.dh)
        # b, t, h, dh
        # if changing from 4 dim -> 3 dim: b*h, t, dh
        # wq = wq.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # wk = wk.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # wv = wv.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # another option 4 dim -> 3 dim
        # wq = wq.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # wk = wk.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # wv = wv.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # changing the number of dims is not necessary as @ supports 4 dims
        attn = self_attention(wq, wk, wv)
        # b * h, t, dh
        # attn = attn.view(b, self.h, t, self.dh).permute(0, 2, 1, 3).reshape(b, t, d)
        attn = attn.view(b, self.h, t, self.dh).transpose(1, 2).contiguous().view(b, t, d)
        wo = self.wo(attn)
        return wo
        # # 1 2 3 4
        # x = F.relu(self.conv1(x))
        # return F.relu(self.conv2(x))

mhsa = MHSA()
x = torch.rand(2, 3, 512)
mhsa(x, x, x).shape

torch.Size([2, 3, 512])

In [3]:
import torch.nn.functional as F

class EncoderLayer(nn.Module): 

    def __init__(self, d: int = 512, h: int = 8, dropout: float = 0.1):
        super().__init__()
        self.mhsa = MHSA(d, h)
        self.norm1 = nn.LayerNorm(d)
        self.ff1 = nn.Linear(d, d * 4)
        self.ff2 = nn.Linear(d * 4, d)
        self.norm2 = nn.LayerNorm(d)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        b, t, d = x.size()
        x = x + self.dropout(self.mhsa(x, x, x))
        x = self.norm1(x)
        x = x + self.dropout(self.ff2(F.relu(self.ff1(x))))
        x = self.norm2(x)
        return x


encoder_layer = EncoderLayer()
x = torch.rand(2, 3, 512)
encoder_layer(x).shape

torch.Size([2, 3, 512])

In [4]:
from torch import nn

class PE(nn.Module):

    def __init__(self, d: int = 512, max_len: int = 1024, dropout: float = 0.1):
        super().__init__()
        self.d = d
        self.dropout = nn.Dropout(p=dropout)

        twoi = torch.arange(0, self.d, 2)
        pow_ = torch.pow(10000, twoi / self.d)
        position = torch.arange(0, max_len).unsqueeze(1)
        sin_p = torch.sin(position / pow_)
        cos_p = torch.cos(position / pow_)
        pe = torch.zeros(max_len, self.d, requires_grad=False) # Explicit, register buffer insures requires grad = False
        pe[:, 0::2] = sin_p
        pe[:, 1::2] = cos_p
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe) 

    def forward(self, x):
        b, t, d = x.size()
        x = x + self.pe[:, : t, :]
        return self.dropout(x)
print(PE(d=4)(torch.arange(24).view(-1, 3, 4)).size()) # torch.Size([2, 3, 4])

torch.Size([2, 3, 4])


In [5]:
from torch import nn

class PEEmbed(nn.Module):

    def __init__(self, d: int = 512, max_len: int = 1024, dropout: float = 0.1):
        super().__init__()
        self.d = d
        self.pe = nn.Embedding(max_len, d)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        b, t, d = x.size()
        pos = self.pe(torch.arange(t))
        x = x + pos
        return self.dropout(x)
print(PEEmbed(d=4)(torch.arange(24).view(-1, 3, 4)).size()) # torch.Size([2, 3, 4])

torch.Size([2, 3, 4])


In [6]:
from torch import nn

class Encoder(nn.Module): 

    def __init__(self, vocab_size: int = 2**13, n: int = 6, d: int = 512, h: int = 8):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d)
        self.pe = PE(d=d)
        self.layers = [EncoderLayer(d, h) for _ in range(n)]

    def forward(self, x):
        b, t = x.size()
        x = self.embed(x)
        x = self.pe(x)
        for layer in self.layers:
            x = layer(x)
        return x

encoder = Encoder()
x = torch.randint(0, 2**13, (2, 3))
encoder(x).shape

torch.Size([2, 3, 512])

In [7]:
assert False

AssertionError: 

# 