# Attention from scratch using tinyai


In [2]:
import math,torch
from torch import nn
from tinyai.activations import *
import matplotlib.pyplot as plt

In [4]:
set_seed(42)
#   dimentions (batch_size,channels,height,width) # BCHW
x = torch.randn(64,32,16,16)

In this step, we reshape the 'image' into a single tensor. Additionally, since it is originally designed for NLP tasks, the sequence (height x width) comes before the channels.\
Therefore, we also perform a transpose operation here to rearrange the dimensions.


In [5]:
t = x.view(*x.shape[:2], -1).transpose(1, 2)
t.shape

torch.Size([64, 256, 32])

In [6]:
ni = 32

In [7]:
sk = nn.Linear(ni, ni)
sq = nn.Linear(ni, ni)
sv = nn.Linear(ni, ni)

In [8]:
k = sk(t)
q = sq(t)
v = sv(t)

In [14]:
q.shape, k.shape, v.shape

(torch.Size([64, 256, 32]),
 torch.Size([64, 256, 32]),
 torch.Size([64, 256, 32]))

In [17]:
k.transpose(1, 2).shape

torch.Size([64, 32, 256])

In [18]:
(q@k.transpose(1,2)).shape

torch.Size([64, 256, 256])

In [25]:
print(x.view(64,32,-1).shape)
t.view?

torch.Size([64, 32, 256])


[0;31mDocstring:[0m
view(*shape) -> Tensor

Returns a new tensor with the same data as the :attr:`self` tensor but of a
different :attr:`shape`.

The returned tensor shares the same data and must have the same number
of elements, but may have a different size. For a tensor to be viewed, the new
view size must be compatible with its original size and stride, i.e., each new
view dimension must either be a subspace of an original dimension, or only span
across original dimensions :math:`d, d+1, \dots, d+k` that satisfy the following
contiguity-like condition that :math:`\forall i = d, \dots, d+k-1`,

.. math::

  \text{stride}[i] = \text{stride}[i+1] \times \text{size}[i+1]

Otherwise, it will not be possible to view :attr:`self` tensor as :attr:`shape`
without copying it (e.g., via :meth:`contiguous`). When it is unclear whether a
:meth:`view` can be performed, it is advisable to use :meth:`reshape`, which
returns a view if the shapes are compatible, and copies (equivalent to calling
:

In [26]:
class SelfAttention(nn.Module):
    def __init__(self, ni):
        super().__init__()
        self.scale = math.sqrt(ni)
        self.norm = nn.GroupNorm(1, ni)
        self.q = nn.Linear(ni, ni)
        self.k = nn.Linear(ni, ni)
        self.v = nn.Linear(ni, ni)
        self.proj = nn.Linear(ni, ni)
    
    def forward(self, x):
        inp = x
        n,c,h,w = x.shape
        x = self.norm(x)
        x = x.view(n, c, -1).transpose(1, 2)
        q = self.q(x)
        k = self.k(x)
        v = self.v(x)
        s = (q@k.transpose(1,2)) / self.scale
        x = s.softmax(dim=-1)@v
        x = self.proj(x)
        x = x.transpose(1,2).reshape(n,c,h,w)
        return x+inp

In [27]:
sa = SelfAttention(32)

In [28]:
ra = sa(x)
ra.shape

torch.Size([64, 32, 16, 16])

In [29]:
ra[0,0,0]

tensor([ 1.9104,  1.4186,  0.8385, -2.1584,  0.6318, -1.2443, -0.0789, -1.6844,
        -0.7939,  1.6117, -0.3852, -1.4307, -0.7494, -0.6010, -0.8335,  0.7477],
       grad_fn=<SelectBackward0>)

In [30]:
def cp_parms(a,b):
    b.weight = a.weight
    b.bias = a.bias

The code in the cell below is copied from an old version of Hugging Face's repository.


In [31]:
class AttentionBlock(nn.Module):
    def __init__(
        self,
        channels: int,
        num_head_channels= None,
        norm_num_groups: int = 32,
        rescale_output_factor: float = 1.0,
        eps: float = 1e-5,
    ):
        super().__init__()
        self.channels = channels

        self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
        self.num_head_size = num_head_channels
        self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)

        # define q,k,v as linear layers
        self.query = nn.Linear(channels, channels)
        self.key = nn.Linear(channels, channels)
        self.value = nn.Linear(channels, channels)

        self.rescale_output_factor = rescale_output_factor
        self.proj_attn = nn.Linear(channels, channels, 1)

        self._use_memory_efficient_attention_xformers = False
        self._attention_op = None

    def reshape_heads_to_batch_dim(self, tensor):
        batch_size, seq_len, dim = tensor.shape
        head_size = self.num_heads
        tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
        tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
        return tensor

    def reshape_batch_dim_to_heads(self, tensor):
        batch_size, seq_len, dim = tensor.shape
        head_size = self.num_heads
        tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
        tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
        return tensor

    def forward(self, hidden_states):
        residual = hidden_states
        batch, channel, height, width = hidden_states.shape

        # norm
        hidden_states = self.group_norm(hidden_states)

        hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)

        # proj to q, k, v
        query_proj = self.query(hidden_states)
        key_proj = self.key(hidden_states)
        value_proj = self.value(hidden_states)

        scale = 1 / math.sqrt(self.channels / self.num_heads)

        query_proj = self.reshape_heads_to_batch_dim(query_proj)
        key_proj = self.reshape_heads_to_batch_dim(key_proj)
        value_proj = self.reshape_heads_to_batch_dim(value_proj)


        attention_scores = torch.baddbmm(
            torch.empty(
                query_proj.shape[0],
                query_proj.shape[1],
                key_proj.shape[1],
                dtype=query_proj.dtype,
                device=query_proj.device,
            ),
            query_proj,
            key_proj.transpose(-1, -2),
            beta=0,
            alpha=scale,
        )
        attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
        hidden_states = torch.bmm(attention_probs, value_proj)

        # reshape hidden_states
        hidden_states = self.reshape_batch_dim_to_heads(hidden_states)

        # compute next hidden_states
        hidden_states = self.proj_attn(hidden_states)

        hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)

        # res connect and rescale
        hidden_states = (hidden_states + residual) / self.rescale_output_factor
        return hidden_states

In [32]:
at = AttentionBlock(32, norm_num_groups=1)
src = sa.q,sa.k,sa.v,sa.proj,sa.norm
dst = at.query,at.key,at.value,at.proj_attn,at.group_norm
for s,d in zip(src,dst): cp_parms(s,d)

# Ours

In [34]:
ra[0,0,0]

tensor([ 1.9104,  1.4186,  0.8385, -2.1584,  0.6318, -1.2443, -0.0789, -1.6844,
        -0.7939,  1.6117, -0.3852, -1.4307, -0.7494, -0.6010, -0.8335,  0.7477],
       grad_fn=<SelectBackward0>)

# Diffusers

In [33]:
rb = at(x)
rb[0,0,0]

tensor([ 1.9104,  1.4186,  0.8385, -2.1584,  0.6318, -1.2443, -0.0789, -1.6844,
        -0.7939,  1.6117, -0.3852, -1.4307, -0.7494, -0.6010, -0.8335,  0.7477],
       grad_fn=<SelectBackward0>)

In [35]:
sqkv = nn.Linear(ni, ni*3)
st = sqkv(t)
st.shape

torch.Size([64, 256, 96])

In [36]:
q,k,v = torch.chunk(st, 3, dim=-1)
q.shape

torch.Size([64, 256, 32])

In [37]:
(k@q.transpose(1,2)).shape

torch.Size([64, 256, 256])

In [38]:
class SelfAttention(nn.Module):
    def __init__(self, ni):
        super().__init__()
        self.scale = math.sqrt(ni)
        self.norm = nn.BatchNorm2d(ni)
        self.qkv = nn.Linear(ni, ni*3)
        self.proj = nn.Linear(ni, ni)
    
    def forward(self, inp):
        n,c,h,w = inp.shape
        x = self.norm(inp).view(n, c, -1).transpose(1, 2)
        q,k,v = torch.chunk(self.qkv(x), 3, dim=-1)
        s = (q@k.transpose(1,2))/self.scale
        x = s.softmax(dim=-1)@v
        x = self.proj(x).transpose(1,2).reshape(n,c,h,w)
        return x+inp

In [40]:
sa = SelfAttention(32)
sa(x).shape

torch.Size([64, 32, 16, 16])

In [41]:
sa(x).std()

tensor(1.0085, grad_fn=<StdBackward0>)

In [52]:
def heads_to_batch(x, heads):
    n,sl,d = x.shape
    x = x.reshape(n, sl, heads, -1)
    return x.transpose(2, 1).reshape(n*heads,sl,-1)

def batch_to_heads(x, heads):
    n,sl,d = x.shape
    x = x.reshape(-1, heads, sl, d)
    return x.transpose(2, 1).reshape(-1,sl,d*heads)

In [53]:
from einops import rearrange

In [54]:
t2 = rearrange(t , 'n s (h d) -> (n h) s d', h=8)
t.shape, t2.shape

(torch.Size([64, 256, 32]), torch.Size([512, 256, 4]))

In [55]:
t3 = rearrange(t2, '(n h) s d -> n s (h d)', h=8)

In [56]:
t2.shape,t3.shape

(torch.Size([512, 256, 4]), torch.Size([64, 256, 32]))

In [57]:
(t==t3).all()

tensor(True)

In [59]:
class SelfAttentionMultiHead(nn.Module):
    def __init__(self, ni, nheads):
        super().__init__()
        self.nheads = nheads
        self.scale = math.sqrt(ni/nheads)
        self.norm = nn.BatchNorm2d(ni)
        self.qkv = nn.Linear(ni, ni*3)
        self.proj = nn.Linear(ni, ni)
    
    def forward(self, inp):
        n,c,h,w = inp.shape
        x = self.norm(inp).view(n, c, -1).transpose(1, 2)
        x = self.qkv(x)
        
        # nhead * the bath size but, channels  /  nheads 
        x = rearrange(x, 'n s (h d) -> (n h) s d', h=self.nheads)
        
        q,k,v = torch.chunk(x, 3, dim=-1)
        s = (q@k.transpose(1,2))/self.scale
        x = s.softmax(dim=-1)@v
        # chunk since we  are using ni*3 for qkv
        x = rearrange(x, '(n h) s d -> n s (h d)', h=self.nheads)
        x = self.proj(x).transpose(1,2).reshape(n,c,h,w)
        return x+inp

In [60]:
sa = SelfAttentionMultiHead(32, 4)
sx = sa(x)
sx.shape

torch.Size([64, 32, 16, 16])

In [49]:
sx.mean(),sx.std()

(tensor(-0.0093, grad_fn=<MeanBackward0>),
 tensor(1.0066, grad_fn=<StdBackward0>))

In [61]:
nm = nn.MultiheadAttention(32, num_heads=8, batch_first=True)
nmx,nmw = nm(t,t,t)
nmx = nmx+t

In [62]:
nmx.mean(),nmx.std()

(tensor(-0.0015, grad_fn=<MeanBackward0>),
 tensor(1.0050, grad_fn=<StdBackward0>))