In [None]:
!pip install einops

In [117]:
import einops
import torch
import math
import torch.nn as nn
import numpy as np

In [108]:
x = torch.randn(16, 100, dim)
to_qvk = torch.nn.Linear(dim, dim * 3, bias=False)
qvk = to_qvk(x)
q, v, k = tuple(einops.rearrange(qvk, "b t (n w) -> n b t w", n=3))
mask = torch.triu(torch.ones((q.shape[1], q.shape[1])), diagonal=1).bool()


In [110]:
dot_prod = torch.einsum("b i e, b j e -> b i j", q, k)
scaled_dot_prod = dot_prod * scale_factor
if mask is not None:
    assert mask.shape == scaled_dot_prod.shape[1:]
    scaled_dot_prod = scaled_dot_prod.masked_fill(mask, -np.inf)
scaled_dot_prod_norm = torch.softmax(scaled_dot_prod, dim=-1)
attention = torch.einsum("b i e, b e j -> b i j", scaled_dot_prod_norm, v)

In [144]:
class SelfAttention(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.to_qkv = torch.nn.Linear(dim, dim * 3, bias=False)
        self.scale_factor = 1 / math.sqrt(dim)
        
    def forward(self, x, mask=None):
        assert x.dim() == 3
        
        qvk = to_qvk(x)
        q, v, k = tuple(einops.rearrange(qvk, "b t (n e) -> n b t e", n=3))
        dot_prod = torch.einsum("b i e, b j e -> b i j", q, k)
        scaled_dot_prod = dot_prod * scale_factor
        if mask:
            mask = torch.triu(torch.ones((q.shape[1], q.shape[1])), diagonal=1).bool()
            scaled_dot_prod = scaled_dot_prod.masked_fill(mask, -np.inf)
        scaled_dot_prod_norm = torch.softmax(scaled_dot_prod, dim=-1)
        attention = torch.einsum("b i e, b e j -> b i j", scaled_dot_prod_norm, v)
        
        return attention
    
    
class MultiHeadAttention(nn.Module):
    def __init__(self, dim, heads):
        super().__init__()
        assert dim % heads == 0 
        self.to_qkv = torch.nn.Linear(dim, dim * 3, bias=False)
        self.scale_factor = 1 / math.sqrt(dim / heads)
        self.heads = heads
        self.W_0 = torch.nn.Linear(dim, dim, bias=False)
        
    def forward(self, x, mask=None):
        assert x.dim() == 3
        
        qvk = to_qvk(x)
        q, v, k = tuple(einops.rearrange(qvk, "b t (n h e) -> n b h t e", n=3, h=self.heads))
        dot_prod = torch.einsum("b h i e, b h j e -> b h i j", q, k)
        scaled_dot_prod = dot_prod * scale_factor
        if mask:
            mask = torch.triu(torch.ones((q.shape[2], q.shape[2])), diagonal=1).bool()
            scaled_dot_prod = scaled_dot_prod.masked_fill(mask, -np.inf)
        scaled_dot_prod_norm = torch.softmax(scaled_dot_prod, dim=-1)
        attention = torch.einsum("b h i e, b h e j -> b h i j", scaled_dot_prod_norm, v)
        out = einops.rearrange(attetion, "b h t d -> b t (h d)")
        
        return self.W_O(out)
        

In [None]:
class TranformerBlock(nn.Module):
    def __init__(self, dim, heads, dim_linear_block=1024, dropout=0.1):
        super().__init__()
        
        self.mha = MultiHeadAttention(dim, heads)
        self.dropout = nn.DropOut(dropout)
        self.ln1 = nn.LayerNorm(dim)
        self.ln2 = nn.LayerNorm(dim)
        
        self.linear = nn.Sequential(
            nn.Linear(dim, dim_linear_block),
            nn.GELU(),
            nn.DropOut(dropout),
            nn.Linear(dim_linear_block, dim),
            nn.DropOut(dropout)
        )
        
    def forward(self, x, mask=None):
        x = self.ln1(self.drop(self.mha(x, mask)) + x)
        return self.ln2(self.linear(x) + x)
        
        
class TransformerEncoder(nn.Module):
    def __init__(self, dim, blocks=6, heads=8):
        super().__init__()
        self.block_list = [TransformerBlock(dim, heads) for _ in range(blocks)]
        self.layers = nn.ModuleList(self.block_list)

    def forward(self, x, mask=None):
        for layer in self.layers:
            x = layer(x, mask)
        return x