In [None]:
'''
Main components of Vision-Llama

- A generic MLP class
- An MLP Transformer
- Multi-head attention
- Transformer Layer
- Transformer
- Transformer Mapper

Note: Class structure and code logic adopted from the official repo for Vision-Llama
'''
import torch
import torch.nn as nn
import torch.nn.functional as nnf
from typing import Tuple

In [None]:
'''
Generic MLP Module with Tanh activation
'''
class MLP(nn.Module):
    def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.sizes = sizes
        self.layers = []
        for i in range(len(sizes)-1):
            self.layers.append(nn.Linear(sizes[i], sizes[i+1], bias=bias))
            if i < len(sizes)-2:
                self.layers.append(act())
        self.model = nn.Sequential(*self.layers)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

In [None]:
'''
Generic MLP based transformer
'''

class MLPTransformer(nn.Module):
    def __init__(self, dim_in, dim_h, dim_out, dropout_ratio, act=nnf.relu, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.fc1 = nn.Linear(dim_in, dim_h)
        self.fc2 = nn.Linear(dim_h, dim_out)
        self.act = act
        self.dropout = nn.Dropout(dropout_ratio)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

In [None]:
'''
Generic Multi-head Attention Architecture
'''
class MultiHeadAttention(nn.Module):
    def __init__(self, dim_self, dim_ref, num_heads, bias, dropout, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.dim_self = dim_self
        self.dim_ref = dim_ref
        self.num_heads = num_heads
        self.head_dim = self.dim_self//self.num_heads
        self.bias = bias
        self.dropout = dropout
        self.scale = self.head_dim ** -0.5
        self.to_q = nn.Linear(self.dim_self, self.dim_self, bias=True)
        self.to_k = nn.Linear(self.dim_ref, self.dim_self, bias=True) # dim_ref -> dim_self mapping
        self.to_v = nn.Linear(self.dim_ref, self.dim_self, bias=True) # dim_ref -> dim_self mapping
        self.project = nn.Linear(self.dim_self, self.dim_self, bias=True) # dim_ref -> dim_self mapping
        
    def forward(self, X, y=None, mask=None):
        y = y if y is not None else X
        b, n, c = X.shape
        _, m, c = y.shape
        q = self.to_q(X).view(b, n, self.heads, self.head_dim) # Add an additional `heads` dimension
        k = self.to_k(y).view(b, m, self.heads, self.head_dim) # Add an additional `heads` dimension
        v = self.to_v(y).view(b, m, self.heads, self.head_dim) # Add an additional `heads` dimension
        attention = torch.einsum("bnhd,bmhd->bnmh", q, k) * self.scale
        
        # Add causal mask to self.attention
        if mask is not None:
            attention = attention.masked_fill(mask.unsqueeze(1).unsqueeze(2), float("-inf"))

        # Attention module
        attention = nnf.softmax(attention)

        # Taking the softmax of the mask filled elements yields a probability of 0
        out = torch.einsum("bnmh,bmhd->bnhd", attention, v) # resurgence (dim-wise -> back to square one)
        out = out.view(b, n, c)
        
        # One last linear layer
        out = self.project(out)
        
        return out, attention
        

$$
\text{LayerNorm}(x) = \gamma \left( \frac{x - \mu}{\sigma + \epsilon} \right) + \beta
$$


In [None]:
'''
gamma: scale parameter (scaling according to the importance of features)
beta: shift parameter (allows it to find the optimal mean values)
'''
class LayerNorm(torch.nn.Module):
    def __init__(self, features, epsilon=1e-5):
        super(LayerNorm, self).__init__()
        self.gamma = torch.nn.Parameter(torch.ones(features))
        self.beta = torch.nn.Parameter(torch.zeros(features))
        self.epsilon = epsilon

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.epsilon) + self.beta

In [None]:
'''
Transformer Layer
'''
class Transformer(nn.Module):
    def forward_with_attention(self, x, y=None, mask=None):
        attentions = []
        for layer in self.layers:
            x, att = layer.forward_with_attention(x, y, mask)
            attentions.append(att)
        return x, attentions

    def forward(self, x, y=None, mask=None):
        for i, layer in enumerate(self.layers):
            if i % 2 == 0 and self.enc_dec: # cross
                x = layer(x, y)
            elif self.enc_dec:  # self
                x = layer(x, x, mask)
            else:  # self or cross
                x = layer(x, y, mask)
        return x

    def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None,
                 mlp_ratio: float = 2., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm, enc_dec: bool = False):
        super(Transformer, self).__init__()
        dim_ref = dim_ref if dim_ref is not None else dim_self
        self.enc_dec = enc_dec
        if enc_dec:
            num_layers = num_layers * 2
        layers = []
        for i in range(num_layers):
            if i % 2 == 0 and enc_dec:  # cross
                layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
            elif enc_dec:  # self
                layers.append(TransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
            else:  # self or cross
                layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
        self.layers = nn.ModuleList(layers)    

In [None]:
class TransformerLayer(nn.Module):

    def forward_with_attention(self, x, y=None, mask=None):
        x_, attention = self.attn(self.norm1(x), y, mask)
        x = x + x_
        x = x + self.mlp(self.norm2(x))
        return x, attention

    def forward(self, x, y=None, mask=None):
        x = x + self.attn(self.norm1(x), y, mask)[0]
        x = x + self.mlp(self.norm2(x))
        return x

    def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=nnf.relu,
                 norm_layer: nn.Module = nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim_self)
        self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout)
        self.norm2 = norm_layer(dim_self)
        self.mlp = MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout)