## Llama 2 modules

- **Rotary positional embeddings** -> applied to the computed Q and K vectors in the self_attention part
- **Grouped Query Attention** -> Tradeoff between Mutli-Query attention and MHA, balances memory bandwidth requirements and speedup
- **KV Caching** -> For faster computation and better memory management
- **SwiGLU activation function**
- **RMS Norm**

In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

### Rotary positional embeddings

In [2]:
def precompute_theta_pos_freq(head_dim, seq_len, theta = 10000):

    assert head_dim%2 == 0, "Dimension of head must by divisible by 2"

    # theta_i = 10000^(-2(i-1)/dim) for i = [1,2,3.....dim/2]

    theta_numerator = torch.arange(0, head_dim, 2).float()

    theta = 1.0/ (theta **(theta_numerator/head_dim))

    m = torch.arange(seq_len)

    freqs = torch.outer(m,theta).float()

    #                                   magnituda       angle
    freqs_complex = torch.polar(torch.ones_like(freqs), freqs) 

    return freqs_complex


In [3]:
def apply_rotary_embeds(x, freqs_complex):

    # Separate the last dimension pairs of two values, representing the real and imaginary parts of the complex number
    # Two consecutive values will become a single complex number


    # H -> no.of heads; can be num_heads for Query and num_kv_heads for Key

    # (bsz, seq_len, H , head_dim) -> (bsz, seq_len, H, head_dim/2)

    # (bsz, seq_len, H , head_dim) -> (bsz, seq_len, H, head_dim/2)
    x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))


    # Reshape the freqs_complex tensor to match the shape of the x_complex tensor. 
    # (seq_len, head_dim/2) --> (1, seq_len, 1, head_dim/2)
    freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)


    # (bsz, seq_len, H, head_dim/2) * (1, seq_len, 1, head_dim/2) = (bsz, seq_len, H, head_dim/2)
    x_rotated = x_complex * freqs_complex


    # Convert the complex number back to the real number
    # (bsz, seq_len, H, head_dim/2) -> (bsz, seq_len, H, head_dim/2, 2)
    x_out = torch.view_as_real(x_rotated)


    # (bsz, seq_len, H, head_dim/2, 2) -> (bsz, seq_len, H, head_dim)
    x_out = x_out.reshape(*x.shape)
    

    return x_out.type_as(x)

In [4]:
class SelfAttention(nn.Module):

    def __init__(self, n_heads ,embed_dim, ):
        super().__init__()

        # Indicates the number of heads for the Queries
        self.num_heads = n_heads

        # Indicates the dimension of each head, that is, the part of the embedding that each head will be responsible for
        self.head_dim = embed_dim // n_heads

        self.wq = nn.Linear(embed_dim, n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(embed_dim, n_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(embed_dim, n_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(n_heads * self.head_dim, embed_dim, bias=False)


    def precompute_theta_pos_freq(self, head_dim, seq_len, theta = 10000):

        assert head_dim%2 == 0, "Dimension of head must by divisible by 2"

        # theta_i = 10000^(-2(i-1)/dim) for i = [1,2,3.....dim/2]

        theta_numerator = torch.arange(0, head_dim, 2).float()

        theta = 1.0/ (theta **(theta_numerator/head_dim))

        m = torch.arange(seq_len)

        freqs = torch.outer(m,theta).float()

        #                                   magnituda       angle
        freqs_complex = torch.polar(torch.ones_like(freqs), freqs) 

        return freqs_complex
    

    def apply_rotary_embeds(self, x, freqs_complex):

        # Separate the last dimension pairs of two values, representing the real and imaginary parts of the complex number
        # Two consecutive values will become a single complex number


        # H -> no.of heads; can be num_heads for Query and num_kv_heads for Key

        # (bsz, seq_len, H , head_dim) -> (bsz, seq_len, H, head_dim/2)
        x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))


        # (seq_len, head_dim/2) --> (1, seq_len, 1, head_dim/2)
        freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)



        # (bsz, seq_len, H, head_dim/2) * (1, seq_len, 1, head_dim/2) = (bsz, seq_len, H, head_dim/2)
        x_rotated = x_complex * freqs_complex


        # Convert the complex number back to the real number
        # (bsz, seq_len, H, head_dim/2) -> (bsz, seq_len, H, head_dim/2, 2)
        x_out = torch.view_as_real(x_rotated)

        # (bsz, seq_len, H, head_dim/2, 2) -> (bsz, seq_len, H, head_dim)
        x_out = x_out.reshape(*x.shape)
        
        return x_out.type_as(x)



    def forward(
        self,
        x,
        freqs_complex,
        attn_mask
    ):                                # While inferencing:- 
        bsz, seq_len, _ = x.shape   # (bsz, 1, embed_dim)

        Q = self.wq(x)
        K = self.wk(x)
        V = self.wv(x)


        if attn_mask is not None:
        # ensure attn_mask's dim is 3
            if attn_mask.dim() == 2:
                correct_2d_size = (seq_len, seq_len)
                if attn_mask.shape != correct_2d_size:
                    raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
                attn_mask = attn_mask.unsqueeze(0)
            elif attn_mask.dim() == 3:
                correct_3d_size = (bsz * self.num_heads, seq_len, seq_len)
                if attn_mask.shape != correct_3d_size:
                    raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
            else:
                raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")




        Q = Q.view(bsz, seq_len, self.n_heads, self.head_dim)
        K = K.view(bsz, seq_len, self.n_heads, self.head_dim)
        V = V.view(bsz, seq_len, self.n_heads, self.head_dim)

        Q = self.apply_rotary_embeddings(Q, freqs_complex)
        K = self.apply_rotary_embeddings(K, freqs_complex)

        # (bsz, 1, H_Q, Head_Dim) -> (bsz, H_Q, 1, Head_Dim)
        Q = Q.transpose(1, 2)
        # (bsz, Seq_Len_KV, H_Q, Head_Dim) -> (bsz, H_Q, Seq_Len_KV, Head_Dim)
        K = K.transpose(1, 2)
        # (bsz, Seq_Len_KV, H_Q, Head_Dim) -> (bsz, H_Q, Seq_Len_KV, Head_Dim)
        V = V.transpose(1, 2)

        # (bsz, H_Q, 1, Head_Dim) @ (bsz, H_Q, Head_Dim, Seq_Len_KV) -> (bsz, H_Q, 1, Seq_Len_KV)
        scores = torch.matmul(Q, K.transpose(2, 3)) / math.sqrt(self.head_dim)

        if attn_mask is not None:
            scores = scores + attn_mask 

        # (bsz, H_Q, 1, Seq_Len_KV) -> (bsz, H_Q, 1, Seq_Len_KV)
        scores = F.softmax(scores.float(), dim=-1).type_as(Q)

        # (bsz, H_Q, 1, Seq_Len) @ (bsz, H_Q, Seq_Len_KV, Head_Dim) -> (bsz, H_Q, 1, Head_Dim)
        output = torch.matmul(scores, V)
        # (bsz, H_Q, 1, Head_Dim) -> (bsz, 1, H_Q, Head_Dim) -> (bsz, 1, Dim)
        output = (output.transpose(1, 2).contiguous().view(bsz, seq_len, -1))
        return self.wo(output) # (bsz, 1, Dim) -> (bsz, 1, Dim)

        

NameError: name 'ModelArgs' is not defined

In [None]:


class FeedForward(nn.Module):
    def __init__(
        self,
        embed_dim, 

    ):
        super().__init__()

        hidden_dim = 4 * embed_dim
        hidden_dim = int(2 * hidden_dim / 3)
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        # Round the hidden_dim to the nearest multiple of the multiple_of parameter
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w1 = nn.Linear(embed_dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, embed_dim, bias=False)
        self.w3 = nn.Linear(embed_dim, hidden_dim, bias=False)

    def forward(self, x: torch.Tensor):
        # (B, Seq_Len, Dim) --> (B, Seq_Len, Hidden_Dim)
        swish = F.silu(self.w1(x))
        # (B, Seq_Len, Dim) --> (B, Seq_Len, Hidden_Dim)
        x_V = self.w3(x)
        # (B, Seq_Len, Hidden_Dim) * (B, Seq_Len, Hidden_Dim) --> (B, Seq_Len, Hidden_Dim)
        x = swish * x_V
        # (B, Seq_Len, Hidden_Dim) --> (B, Seq_Len, Dim)
        x = self.w2(x)
        return x


In [None]:
class TransformerBlock(nn.Module):

    def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
                 activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
                 layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
                 bias: bool = True, device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout,
                                            bias=bias, batch_first=batch_first,
                                            **factory_kwargs)
        # Implementation of Feedforward model
        self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs)
        self.dropout = Dropout(dropout)
        self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)

        self.norm_first = norm_first
        self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
        self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
        self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)




    def __init__(self, n_heads, embed_dim, norm_eps = 1e-6, ):
        super().__init__()

        self.n_heads = n_heads
        self.embed_dim = embed_dim
        self.head_embed_dim = embed_dim // n_heads

        self.attention = SelfAttention()
        self.feed_forward = FeedForward()

        # Normalization BEFORE the attention block
        self.attention_norm = RMSNorm(embed_dim, eps=norm_eps)
        # Normalization BEFORE the feed forward block
        self.ffn_norm = RMSNorm(embed_dim, eps=norm_eps)
    
    def forward(self, x: torch.Tensor, start_pos: int, freqs_complex: torch.Tensor):
        # (B, Seq_Len, embed_dim) + (B, Seq_Len, embed_dim) --> (B, Seq_Len, embed_dim)
        h = x + self.attention.forward(
            self.attention_norm(x), start_pos, freqs_complex
        )
        # (B, Seq_Len, embed_dim) + (B, Seq_Len, embed_dim) --> (B, Seq_Len, embed_dim)
        out = h + self.feed_forward.forward(self.ffn_norm(h))
        return out