In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [2]:
from typing import Optional
class args:
    dim: int = 4096
    n_layers: int = 32
    n_heads: int = 32
    n_kv_heads: Optional[int] = None
    vocab_size: int = -1 
    multiple_of: int = 256
    ffn_dim_multiplier: Optional[float] = None
    norm_eps: float = 1e-5
    theta: float = 10000.0
    
    # Needed for KV cache
    max_batch_size: int = 32
    max_seq_len: int = 2048

    device: str = "cuda"

In [3]:
import RMSNorm
import RoPE
rope = RoPE.rope(args.dim, args.max_seq_len,args.device, args.theta) 

In [4]:
rms = RMSNorm.RMSNorm(args.dim)

In [5]:
class MHA(nn.Module):
    def __init__(self, args, qkv_bias = False):
        super(MHA, self).__init__()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        self.n_heads_q = args.n_heads
        self.n_rep = self.n_heads_q // self.n_kv_heads
        self.head_dim = args.dim // args.n_heads
        self.dim = args.dim
        
        self.W_query = nn.Linear(self.dim, self.head_dim*self.n_heads_q , bias=qkv_bias).to("cuda")
        self.W_key = nn.Linear(self.dim, self.n_kv_heads*self.head_dim, bias=qkv_bias).to("cuda")
        self.W_value = nn.Linear(self.dim, self.n_kv_heads*self.head_dim, bias=qkv_bias).to("cuda")
        self.out_proj = nn.Linear(args.n_heads*self.head_dim, self.dim).to("cuda")
        
        self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim)).to("cuda")
        self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim)).to("cuda")
        
    @staticmethod
    def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
        batch_size, seq_len, n_kv_heads, head_dim = x.shape
        if n_rep == 1:
            return x
        return (
        x[:, :, :, None, :]
        .expand(batch_size, seq_len, n_kv_heads, n_rep, head_dim)
        .reshape(batch_size, seq_len, n_kv_heads * n_rep, head_dim)
        )
        
    def forward(self, x, start_pos=0):
        batch_size, seq_len, _ = x.shape
        # Linear transformation
        xq , xk, xv = self.W_query(x), self.W_key(x), self.W_value(x)
        
        # Change shape to (batch_size, seq_len, n_heads, head_dim)
        xq = xq.view(batch_size, seq_len, self.n_heads_q, self.head_dim)
        xk = xk.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)
        xv = xv.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)   

        # Apply RoPE with position offset
        xq = rope(xq)
        xk = rope(xk)
        
        # Update KV cache
        self.cache_k[:batch_size, start_pos:start_pos+seq_len] = xk
        self.cache_v[:batch_size, start_pos:start_pos+seq_len] = xv
        
        # Retrieve cached keys/values including current sequence
        keys = self.cache_k[:batch_size, :start_pos+seq_len]
        values = self.cache_v[:batch_size, :start_pos+seq_len]
        
        # Repeat KV heads to match Q heads
        keys = self.repeat_kv(keys, self.n_rep)
        values = self.repeat_kv(values, self.n_rep)
        
        # Transpose for attention computation
        xq = xq.transpose(1, 2)  # (bs, n_heads_q, seq_len, hd)
        keys = keys.transpose(1, 2)  # (bs, n_heads_kv*rep, cache_len, hd)
        values = values.transpose(1, 2)  # (bs, n_heads_kv*rep, cache_len, hd)
        
        # Compute attention scores
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        
        # Weighted sum of values
        output = torch.matmul(scores, values)  # (bs, n_heads_q, seq_len, hd)
        
        # Combine heads and project
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        return self.out_proj(output).to(self.device)


In [6]:
class FeedForward(nn.Module):
    def __init__(self, args):
        super().__init__()
        hidden_dim = 4 * args.dim
        hidden_dim = int(2 * hidden_dim / 3)
        if args.ffn_dim_multiplier is not None:
            hidden_dim = int(args.ffn_dim_multiplier * hidden_dim)
        # Round the hidden_dim to the nearest multiple of the multiple_of parameter
        hidden_dim = args.multiple_of * ((hidden_dim + args.multiple_of - 1) // args.multiple_of)

        self.w1 = nn.Linear(args.dim, hidden_dim, bias=False).to(args.device)
        self.w2 = nn.Linear(hidden_dim, args.dim, bias=False).to(args.device)
        self.w3 = nn.Linear(args.dim, hidden_dim, bias=False).to(args.device)
    def forward(self, x: torch.Tensor):
        # (B, Seq_Len, Dim) --> (B, Seq_Len, Hidden_Dim)
        swish = F.silu(self.w1(x))
        # (B, Seq_Len, Dim) --> (B, Seq_Len, Hidden_Dim)
        x_V = self.w3(x)
        # (B, Seq_Len, Hidden_Dim) * (B, Seq_Len, Hidden_Dim) --> (B, Seq_Len, Hidden_Dim)
        x = swish * x_V
        # (B, Seq_Len, Hidden_Dim) --> (B, Seq_Len, Dim)
        x = self.w2(x)
        return x
        

In [7]:
class Encoder(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.mha = MHA(args)
        self.ffn = FeedForward(args)
        self.rmsnorm = RMSNorm.RMSNorm(args.dim)
        self.swish = nn.SiLU()
        
    def forward(self, x: torch.Tensor):
        x_norm = self.rmsnorm(x)
        x_mha = self.mha(x_norm)
        x = x + x_mha
        
        x_norm_2 = self.rmsnorm(x)
        x_ffn = self.ffn(x_norm_2)
        x = x + x_ffn
        return x
        

In [8]:
class Model(nn.Module):
    def __init__(self, args):
        super().__init__()
        
        self.encoder = nn.Sequential(*[Encoder(args) for _ in range(args.n_layers)])
        
        self.embd = nn.Embedding(args.vocab_size, args.dim).to(args.device)
        self.norm = RMSNorm.RMSNorm(args.dim)
        self.linear = nn.Linear(args.dim, args.vocab_size).to(args.device)
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, x: torch.Tensor):
        x = self.embd(x)
        x = self.encoder(x)
        x = self.norm(x)
        x = self.linear(x)
        return x

In [9]:
x = torch.randn(1, 2048, 4096).to("cuda")

In [10]:
model = Model(args)

OutOfMemoryError: CUDA out of memory. Tried to allocate 1024.00 MiB. GPU 0 has a total capacity of 4.00 GiB of which 0 bytes is free. Of the allocated memory 6.82 GiB is allocated by PyTorch, and 1.95 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)