In [8]:
# Encoder
import torch
import math
import torch.nn.functional as F

def self_attention(q, k, v):
    # if 3 dim: b t d
    #prod = Q.bmm(K.permute(0, 2, 1))
    # or
    # prod = torch.einsum("btd, bsd -> bts", q, k)
    # if 4 dim: b h t dh
    prod = torch.einsum("bhtd, bhsd -> bhts", q, k)
    scaled_prod = prod/torch.sqrt(torch.tensor(q.shape[-1]))
    softmaxed_prod = F.softmax(scaled_prod, dim=-1)
    # print(softmaxed_prod.shape)
    # print(softmaxed_prod)
    return softmaxed_prod @ v


x = torch.rand([2, 3, 4, 5])
self_attention(x, x, x)
self_attention(x, x, x).shape

torch.Size([2, 3, 4, 5])

In [2]:
from torch import nn

class MHSA(nn.Module):
    def __init__(self, d: int = 512, h: int = 8):
        super().__init__()
        assert d % h == 0
        self.d = d
        self.dh = d // h
        self.h = h
        self.wq = nn.Linear(self.d, self.d)
        self.wk = nn.Linear(self.d, self.d)
        self.wv = nn.Linear(self.d, self.d)
        self.wo = nn.Linear(self.d, self.d)
 
    def forward(self, q, k, v):
        # b, t, d
        b, t, d = q.size()
        wq = self.wq(q)
        wk = self.wk(k)
        wv = self.wv(v)
        wq = wq.view(b, t, self.h, self.dh)
        wk = wk.view(b, t, self.h, self.dh)
        wv = wv.view(b, t, self.h, self.dh)
        # b, t, h, dh
        # if changing from 4 dim -> 3 dim: b*h, t, dh
        # wq = wq.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # wk = wk.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # wv = wv.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # another option 4 dim -> 3 dim
        # wq = wq.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # wk = wk.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # wv = wv.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # changing the number of dims is not necessary as @ supports 4 dims
        attn = self_attention(wq, wk, wv)
        # b * h, t, dh
        # attn = attn.view(b, self.h, t, self.dh).permute(0, 2, 1, 3).reshape(b, t, d)
        attn = attn.view(b, self.h, t, self.dh).transpose(1, 2).contiguous().view(b, t, d)
        wo = self.wo(attn)
        return wo
        # # 1 2 3 4
        # x = F.relu(self.conv1(x))
        # return F.relu(self.conv2(x))

mhsa = MHSA()
x = torch.rand(2, 3, 512)
mhsa(x, x, x).shape

torch.Size([2, 3, 512])

In [3]:
import torch.nn.functional as F

class EncoderLayer(nn.Module): 

    def __init__(self, d: int = 512, h: int = 8, dropout: float = 0.1):
        super().__init__()
        self.mhsa = MHSA(d, h)
        self.norm1 = nn.LayerNorm(d)
        self.ff1 = nn.Linear(d, d * 4)
        self.ff2 = nn.Linear(d * 4, d)
        self.norm2 = nn.LayerNorm(d)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        b, t, d = x.size()
        x = x + self.dropout(self.mhsa(x, x, x))
        x = self.norm1(x)
        x = x + self.dropout(self.ff2(F.relu(self.ff1(x))))
        x = self.norm2(x)
        return x


encoder_layer = EncoderLayer()
x = torch.rand(2, 3, 512)
encoder_layer(x).shape

torch.Size([2, 3, 512])

In [4]:
from torch import nn

class PE(nn.Module):

    def __init__(self, d: int = 512, max_len: int = 1024, dropout: float = 0.1):
        super().__init__()
        self.d = d
        self.dropout = nn.Dropout(p=dropout)

        twoi = torch.arange(0, self.d, 2)
        pow_ = torch.pow(10000, twoi / self.d)
        position = torch.arange(0, max_len).unsqueeze(1)
        sin_p = torch.sin(position / pow_)
        cos_p = torch.cos(position / pow_)
        pe = torch.zeros(max_len, self.d, requires_grad=False) # Explicit, register buffer insures requires grad = False
        pe[:, 0::2] = sin_p
        pe[:, 1::2] = cos_p
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe) 

    def forward(self, x):
        b, t, d = x.size()
        x = x + self.pe[:, : t, :]
        return self.dropout(x)
print(PE(d=4)(torch.arange(24).view(-1, 3, 4)).size()) # torch.Size([2, 3, 4])

torch.Size([2, 3, 4])


In [5]:
from torch import nn

class PEEmbed(nn.Module):

    def __init__(self, d: int = 512, max_len: int = 1024, dropout: float = 0.1):
        super().__init__()
        self.d = d
        self.pe = nn.Embedding(max_len, d)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        b, t, d = x.size()
        pos = self.pe(torch.arange(t))
        x = x + pos
        return self.dropout(x)
print(PEEmbed(d=4)(torch.arange(24).view(-1, 3, 4)).size()) # torch.Size([2, 3, 4])

torch.Size([2, 3, 4])


In [6]:
from torch import nn

class Encoder(nn.Module): 

    def __init__(self, vocab_size: int = 2**13, n: int = 6, d: int = 512, h: int = 8):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d)
        self.pe = PE(d=d)
        self.layers = [EncoderLayer(d, h) for _ in range(n)]

    def forward(self, x):
        b, t = x.size()
        x = self.embed(x)
        x = self.pe(x)
        for layer in self.layers:
            x = layer(x)
        return x

encoder = Encoder()
x = torch.randint(0, 2**13, (2, 3))
encoder(x).shape

torch.Size([2, 3, 512])

In [62]:
# Decoder
import torch
import math
import torch.nn.functional as F

def self_attention_masked(q, k, v, mask=None):
    # if 3 dim: b t d
    #prod = Q.bmm(K.permute(0, 2, 1))
    # or
    # prod = torch.einsum("btd, bsd -> bts", q, k)
    # if 4 dim: b h t dh
    prod = torch.einsum("bhtd, bhsd -> bhts", q, k)
    scaled_prod = prod/torch.sqrt(torch.tensor(q.shape[-1]))
    print(f"scaled_prod.shape: \n {scaled_prod.shape}")
    if mask is not None:
        scaled_prod = scaled_prod.masked_fill(mask == 0, float("-inf"))
    print(f"scaled_prod: \n {scaled_prod}")
    softmaxed_prod = F.softmax(scaled_prod, dim=-1)
    # print(softmaxed_prod.shape)
    print(f"softmaxed_prod: \n {softmaxed_prod}")
    return softmaxed_prod @ v


x = torch.rand([2, 2, 3, 4])
print(x)
mask = torch.ones([2, 3])
mask[0, 2] = 0
mask[1, 2] = 0
mask[1, 1] = 0
print(f"mask: \n {mask}")
# print(f"mask == 0: \n {mask == 0}")
# add head dim to make mask broatcastable to q x k.T prod: mask = b t -> b h(broadcasted) t d (broadcasted) 
# mask = mask.unsqueeze(1).unsqueeze(2)
mask = mask.unsqueeze(1)
print(f"mask unsqueeze(1): \n {mask}")
print(mask.shape)
mask = mask.unsqueeze(1)
print(f"mask unsqueeze(2): \n {mask}")
print(mask.shape)
# mask = mask.broadcast_to(2, 2, 3)

# mask = mask.permute(0, 2, 1)
# is the mask that I need? keys are ignored?
print(f"mask: \n {mask}")
print(f"mask broadcast: \n {mask.broadcast_to([2, 2, 3, 3])}")
a = self_attention_masked(x, x, x, mask=mask)
print(f"a: \n {a}" )
print(f"a.shape: \n {a.shape}")

tensor([[[[0.5812, 0.3549, 0.6686, 0.1257],
          [0.8588, 0.8205, 0.8803, 0.0413],
          [0.7683, 0.0840, 0.9481, 0.7191]],

         [[0.5729, 0.7094, 0.9740, 0.7036],
          [0.4616, 0.8769, 0.7034, 0.9517],
          [0.7903, 0.1736, 0.5303, 0.5040]]],


        [[[0.8184, 0.9197, 0.0484, 0.7075],
          [0.8625, 0.6552, 0.3668, 0.9257],
          [0.5766, 0.9511, 0.7558, 0.4701]],

         [[0.7811, 0.9651, 0.1196, 0.6709],
          [0.0086, 0.7451, 0.2597, 0.4614],
          [0.0048, 0.6826, 0.3323, 0.7069]]]])
mask: 
 tensor([[1., 1., 0.],
        [1., 0., 0.]])
mask unsqueeze(1): 
 tensor([[[1., 1., 0.]],

        [[1., 0., 0.]]])
torch.Size([2, 1, 3])
mask unsqueeze(2): 
 tensor([[[[1., 1., 0.]]],


        [[[1., 0., 0.]]]])
torch.Size([2, 1, 1, 3])
mask: 
 tensor([[[[1., 1., 0.]]],


        [[[1., 0., 0.]]]])
mask broadcast: 
 tensor([[[[1., 1., 0.],
          [1., 1., 0.],
          [1., 1., 0.]],

         [[1., 1., 0.],
          [1., 1., 0.],
          [

In [23]:
# mask is equal to making keys on masked places 0:
# the result in terms of masked symbols is the same
k = x.clone()
k[0, 0, 2, :] = float("-inf")
k[0, 1, 2, :] = float("-inf")
k[1, 0, 2, :] = float("-inf")
k[1, 0, 1, :] = float("-inf")
k[1, 1, 2, :] = float("-inf")
k[1, 1, 1, :] = float("-inf")
print(f"k: \n {k}")
a = self_attention_masked(x, k, x)
print(f"a: \n {a}" )
print(f"a.shape: \n {a.shape}")


k: 
 tensor([[[[0.2800, 0.5668, 0.9794, 0.8760],
          [0.3606, 0.5417, 0.6145, 0.4127],
          [  -inf,   -inf,   -inf,   -inf]],

         [[0.2931, 0.6796, 0.0210, 0.2225],
          [0.1959, 0.5985, 0.7252, 0.9275],
          [  -inf,   -inf,   -inf,   -inf]]],


        [[[0.6324, 0.6074, 0.7130, 0.1096],
          [  -inf,   -inf,   -inf,   -inf],
          [  -inf,   -inf,   -inf,   -inf]],

         [[0.0270, 0.5604, 0.2235, 0.5859],
          [  -inf,   -inf,   -inf,   -inf],
          [  -inf,   -inf,   -inf,   -inf]]]])
scaled_prod.shape: 
 torch.Size([2, 2, 3, 3])
scaled_prod: 
 tensor([[[[1.0632, 0.6857,   -inf],
          [0.6857, 0.4856,   -inf],
          [0.5562, 0.3466,   -inf]],

         [[0.2989, 0.3429,   -inf],
          [0.3429, 0.8914,   -inf],
          [0.3723, 0.9237,   -inf]]],


        [[[0.6447,   -inf,   -inf],
          [0.5792,   -inf,   -inf],
          [0.5897,   -inf,   -inf]],

         [[0.3540,   -inf,   -inf],
          [0.1137,   -inf, 

In [32]:
a = torch.tensor([[[1], [2]]])
b = torch.tensor([[[1], [2]]])
a.bmm(b.permute(0, 2, 1))

tensor([[[1, 2],
         [2, 4]]])

In [32]:
test = torch.rand([2, 3, 4])
test[0, 2, :] = 0
test[1, 1, :] = 0
test[1, 2, :] = 0
print(test)
test_v = test.view(2, 3, 2, 2)
print(test_v)
test_perm = test_v.permute(0, 2, 1, 3)
test_perm

tensor([[[0.5413, 0.1433, 0.3837, 0.5224],
         [0.4675, 0.5409, 0.0295, 0.1846],
         [0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.6739, 0.4627, 0.0566, 0.4800],
         [0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000]]])
tensor([[[[0.5413, 0.1433],
          [0.3837, 0.5224]],

         [[0.4675, 0.5409],
          [0.0295, 0.1846]],

         [[0.0000, 0.0000],
          [0.0000, 0.0000]]],


        [[[0.6739, 0.4627],
          [0.0566, 0.4800]],

         [[0.0000, 0.0000],
          [0.0000, 0.0000]],

         [[0.0000, 0.0000],
          [0.0000, 0.0000]]]])


tensor([[[[0.5413, 0.1433],
          [0.4675, 0.5409],
          [0.0000, 0.0000]],

         [[0.3837, 0.5224],
          [0.0295, 0.1846],
          [0.0000, 0.0000]]],


        [[[0.6739, 0.4627],
          [0.0000, 0.0000],
          [0.0000, 0.0000]],

         [[0.0566, 0.4800],
          [0.0000, 0.0000],
          [0.0000, 0.0000]]]])

In [38]:
test_q = torch.rand([2, 3, 4])
test_k = test_q.clone()
# test_k[0, 2, :] = 0
# test_k[1, 1, :] = 0
# test_k[1, 2, :] = 0
test_k[0, 2, :] = float("-inf")
test_k[1, 1, :] = float("-inf")
test_k[1, 2, :] = float("-inf")
print(test_k)

test_q_view = test_q.view(2, 3, 2, 2)
test_k_view = test_k.view(2, 3, 2, 2)
print(test_k_view)
test_q_perm = test_q_view.permute(0, 2, 1, 3)
test_k_perm = test_k_view.permute(0, 2, 1, 3)
print(test_k_perm)
print(torch.einsum("bhtd, bhsd -> bhts", test_q_perm, test_k_perm))

tensor([[[0.4194, 0.0101, 0.7188, 0.4208],
         [0.9474, 0.8979, 0.8474, 0.5824],
         [  -inf,   -inf,   -inf,   -inf]],

        [[0.9691, 0.3886, 0.9165, 0.3982],
         [  -inf,   -inf,   -inf,   -inf],
         [  -inf,   -inf,   -inf,   -inf]]])
tensor([[[[0.4194, 0.0101],
          [0.7188, 0.4208]],

         [[0.9474, 0.8979],
          [0.8474, 0.5824]],

         [[  -inf,   -inf],
          [  -inf,   -inf]]],


        [[[0.9691, 0.3886],
          [0.9165, 0.3982]],

         [[  -inf,   -inf],
          [  -inf,   -inf]],

         [[  -inf,   -inf],
          [  -inf,   -inf]]]])
tensor([[[[0.4194, 0.0101],
          [0.9474, 0.8979],
          [  -inf,   -inf]],

         [[0.7188, 0.4208],
          [0.8474, 0.5824],
          [  -inf,   -inf]]],


        [[[0.9691, 0.3886],
          [  -inf,   -inf],
          [  -inf,   -inf]],

         [[0.9165, 0.3982],
          [  -inf,   -inf],
          [  -inf,   -inf]]]])
tensor([[[[0.1760, 0.4065,   -inf],
    

In [59]:
torch.tensor([1,2,3]).dim()

1

In [7]:
assert False

AssertionError: 

In [9]:
x = torch.rand([1, 2, 3])
mask = torch.ones([1, 2])
mask[0, 1] = 0
mask = mask.unsqueeze(1)
print(mask == 0)
x.masked_fill(mask == 0, float("-inf"))

tensor([[[False,  True]]])


RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 2