In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass

In [52]:
@dataclass
class MistralConfig:
    vocab_size = 1000
    d_model = 128
    d_ff = 1024
    layers = 6
    n_head = 4
    kv_head = 2
    max_pos_embed = 512
    sliding_window = 256
    hidden = 'silu'
    eps = 1e-6
    device = "cuda" if torch.cuda.is_available() else "cpu"
    batch_size = 16
    seq_len = 64

    @property
    def head_dim(self):
        return self.d_model // self.n_head

In [53]:
class RMSNorm(nn.Module):
    def __init__(self,d_model,eps=1e-6):
        super().__init__()
        self.d_model = d_model
        self.weights = nn.Parameter(torch.ones(d_model))
        self.eps = eps 

    def forward(self,x):
        mean = torch.mean(x**2,dim=-1,keepdim=True) / self.d_model
        rms = torch.sqrt(mean+self.eps)
        x = (x / rms) * self.weights
        return x
    

In [49]:
def precompute_freqs_cis(head_dim,max_pos_embed,theta=10000.0):
    freqs = 1.0 / theta ** (torch.arange(0,head_dim,2).float() / head_dim)
    pos = torch.arange(max_pos_embed)
    angles = torch.outer(pos,freqs)
    return torch.polar(torch.ones_like(angles),angles)

def apply_rotary_embed(xq,xk,freqs_cis):
    xq_complex = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1],-1,2))
    xk_complex = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1],-1,2))
    freq_cis = freqs_cis.unsqueeze(0).unsqueeze(0)
    print(xq_complex.shape)
    print(freq_cis.shape)
    xq_out = torch.view_as_real(xq_complex * freq_cis).flatten(2)
    xk_out = torch.view_as_real(xk_complex * freq_cis).flatten(2)
    xq_out = xq_out.reshape(*xq.shape)
    xk_out = xk_out.reshape(*xk.shape)
    return xq_out.type_as(xq),xk_out.type_as(xk)

def repeat_kv(x,n_rep):
    batch_size,seq_len,kv_head,head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:,:,:,None,:].expand(batch_size,seq_len,kv_head,n_rep,head_dim).reshape(batch_size,seq_len,kv_head * n_rep, head_dim)
    )

In [31]:
class InputEmbedding(nn.Module):
    def __init__(self,vocab_size,d_model):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size,d_model)
        
    def forward(self,x):
        return self.embedding(x)

In [None]:
'''class MistralAttention(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.d_model = config.d_model
        self.n_head = config.n_head
        self.kv_head = config.kv_head
        self.head_dim = config.head_dim
        self.sliding_window = config.sliding_window
        self.n_rep = self.n_head // self.kv_head

        self.wq = nn.Linear(d_model,n_head*head_dim,bias=False)
        self.wk = nn.Linear(d_model,kv_head * head_dim,bias=False)
        self.wv = nn.Linear(d_model,kv_head * head_dim,bias=False)
        self.wo = nn.Linear(d_model,kv_head * head_dim,bias=False)

    @staticmethod
    def attention(q,k,v,mask=None):
        attn = q @ k.transpose(-1,-2) / config.head_dim ** 0.5
        if mask is not None:
            attn = attn.masked_fill(mask==0,-1e9)
        attn = torch.softmax(attn,dim=-1)
        return (attn @ v) , attn
        
    def forward(self,x,freqs_cis,cache=None,mask=None):
        assert mask is None or cache is None
        seq_len_sum , _ = x.shape
        xq,xk,xv = self.wq(x),self.wk(x),self.wv(x)
        xq = xq.view(seq_len_sum,self.n_head,self.head_dim)
        xk = xk.view(seq_len_sum,self.kv_head,self.head_dim)
        xv = xv.view(seq_len_sum,self.kv_head,self.head_dim)
        xq,xk = apply_rotary_embed(xq,xk,freqs_cis=freqs_cis)

        if cache is None:
            key,val = xk,xv
        elif cache.prefill:
            key,val = cache.interleave_kv(xk,xv)
            cache.update(xk,xv)
        else:
            cache.update(xk,xv)
            key,val = cache.key,cache.value
            key = key.view(seq_len_sum * cache.max_seq_len,self.kv_head,self.head_dim)
            val = val.view(seq_len_sum * cache.max_seq_len,self.kv_head,self.head_dim)

        key,val = repeat_kv(key,val,self.n_rep,dim=1)
        xq,key,val = xq[None,...],key[None,...],val[None,...]
        output = MistralAttention.attention(xq,key,val,mask if cache is None else cache.mask)
        output = output.view(seq_len_sum , self.n_head * self.head_dim)
        return self.wo(output)'''

In [47]:
class MistralAttention(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.d_model = config.d_model
        self.n_head = config.n_head
        self.kv_head = config.kv_head
        self.head_dim = config.head_dim
        self.sliding_window = config.sliding_window
        self.n_rep = self.n_head // self.kv_head
        
        self.q_proj = nn.Linear(self.d_model,self.d_model,bias=False)
        self.k_proj = nn.Linear(self.d_model,self.kv_head * self.head_dim , bias=False)
        self.v_proj = nn.Linear(self.d_model,self.kv_head * self.head_dim ,bias=False)
        self.o_proj = nn.Linear(self.d_model,self.d_model,bias=False)

        self.register_buffer("cache_k",torch.zeros((config.batch_size,self.sliding_window,self.kv_head,self.head_dim),device=config.device))
        self.register_buffer("cache_v",torch.zeros((config.batch_size,self.sliding_window,self.kv_head,self.head_dim),device=config.device))

    @staticmethod
    def attention(q,k,v,mask=None):
        attn = q @ k.transpose(-1,-2) / config.head_dim ** 0.5
        if mask is not None:
            attn = attn.masked_fill(mask==0,-1e9)
        attn = torch.softmax(attn,dim=-1)
        return (attn @ v) , attn
    
    def forward(self, x, freqs_complex, mask=None):
        if mask is None:
            mask = torch.ones(x.shape[0], 1, x.shape[1], x.shape[1], device=x.device)
    
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        
        q = q.view(q.shape[0], q.shape[1], self.n_head, self.head_dim)
        k = k.view(k.shape[0], k.shape[1], self.kv_head, self.head_dim)
        v = v.view(v.shape[0], v.shape[1], self.kv_head, self.head_dim)
    
        q, _ = apply_rotary_embed(q, q, freqs_cis=freqs_complex)
        k, _ = apply_rotary_embed(k, k, freqs_cis=freqs_complex)
    
        self.cache_k = torch.cat((self.cache_k[:, 1:], k[:, :, -1:]), dim=1)
        self.cache_v = torch.cat((self.cache_v[:, 1:], v[:, :, -1:]), dim=1)
    
        keys = repeat_kv(self.cache_k, self.n_rep)
        values = repeat_kv(self.cache_v, self.n_rep)

        seq_len = q.shape[-3]
        sliding_mask = torch.zeros((seq_len, seq_len), device=x.device)
        for i in range(seq_len):
            start = max(0, i - self.sliding_window + 1)
            end = i + 1
            sliding_mask[i, start:end] = 1

        q = q.transpose(1,2)
        keys = keys.transpose(1,2)
        values = values.transpose(1,2)
        
        # Combine the input mask and sliding mask
        mask = mask.unsqueeze(0).unsqueeze(1) & sliding_mask.unsqueeze(0).unsqueeze(1)
        
        x, self.attn = MistralAttention.attention(q, keys, values, mask)
        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.n_head * self.head_dim)
        return self.o_proj(x)


In [33]:
class MistralMLP(nn.Module):
    def __init__(self,config):
        super().__init__()
        d_ff = int(2 * config.d_ff / 3)
        self.gate_proj = nn.Linear(config.d_model,d_ff,bias=False)
        self.layer1 = nn.Linear(config.d_model,d_ff,bias=False)
        self.layer2 = nn.Linear(d_ff,config.d_model,bias=False)
        self.act = F.silu

    def forward(self,x):
        return self.layer2(self.act(self.layer1(x)) * self.gate_proj(x))

In [34]:
class MistralBlock(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.d_model = config.d_model
        self.attn = MistralAttention(config)
        self.mlp = MistralMLP(config)
        self.norm1 = RMSNorm(config.d_model,eps=config.eps)
        self.norm2 = RMSNorm(config.d_model,eps = config.eps)

    def forward(self,x,freqs_complex):
        print(f"x shape before attention: {x.shape}")
        h = x + self.attn(self.norm1(x),freqs_complex)
        out = h + self.mlp(self.norm2(h))
        return out


In [35]:
class Transformer(nn.Module):
        def __init__(self,config):
            super().__init__()
            assert config.vocab_size != -1
            self.device = config.device
            self.norm = RMSNorm(config.d_model,eps = config.eps)
            self.output = nn.Linear(config.d_model,config.vocab_size,bias=False)
            self.freq_complex = precompute_freqs_cis(config.d_model,config.max_pos_embed)
            self.embed = InputEmbedding(config.vocab_size,config.d_model)
            self.n_layers = nn.ModuleList()
            for layer in range(config.layers):
                self.n_layers.append(MistralBlock(config))

        def forward(self,x,start_pos=0):
            batch_size,seq_len = x.shape
            x = self.embed(x)
            freq_complex = self.freq_complex[start_pos:start_pos+seq_len]
            for layer in self.n_layers:
                x = layer(x,freq_complex)
            x = self.norm(x)
            logits = self.output(x)
            return logits 

In [50]:
if __name__ == "__main__":
    config = MistralConfig()
    model = Transformer(config).to(config.device)
    dummy_input = torch.randint(config.vocab_size, (config.batch_size, config.seq_len), device=config.device)
    logits = model(dummy_input)
    print(f"Logits shape: {logits.shape}")

x shape before attention: torch.Size([16, 64, 128])
torch.Size([16, 64, 4, 16])
torch.Size([1, 1, 64, 16])


RuntimeError: The size of tensor a (4) must match the size of tensor b (64) at non-singleton dimension 2

In [51]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass

@dataclass
class MistralConfig:
    vocab_size: int = 1000
    d_model: int = 128
    d_ff: int = 1024
    layers: int = 6
    n_head: int = 4
    kv_head: int = 2
    max_pos_embed: int = 512
    sliding_window: int = 256
    hidden: str = "silu"
    eps: float = 1e-6
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    batch_size: int = 16
    seq_len: int = 64
    
    @property
    def head_dim(self):
        return self.d_model // self.n_head

class RMSNorm(nn.Module):
    def __init__(self, d_model, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(d_model))
        self.eps = eps

    def forward(self, x):
        rms = torch.sqrt(torch.mean(x.pow(2), dim=-1, keepdim=True) + self.eps)
        return self.weight * (x / rms)

def precompute_freqs_cis(head_dim, max_position_embeddings, theta=10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
    t = torch.arange(max_position_embeddings)
    freqs = torch.outer(t, freqs)
    return torch.polar(torch.ones_like(freqs), freqs)

def apply_rotary_embed(x, freqs_cis):
    x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
    freqs_cis = freqs_cis.reshape(1, x.shape[1], 1, -1)
    x_rotated = torch.view_as_real(x_complex * freqs_cis).flatten(3)
    return x_rotated.type_as(x)

class MistralAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.k_proj = nn.Linear(config.d_model, config.kv_head * config.head_dim, bias=False)
        self.v_proj = nn.Linear(config.d_model, config.kv_head * config.head_dim, bias=False)
        self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        
        self.register_buffer("cache_k", torch.zeros(
            (config.batch_size, config.sliding_window, config.kv_head, config.head_dim),
            device=config.device))
        self.register_buffer("cache_v", torch.zeros(
            (config.batch_size, config.sliding_window, config.kv_head, config.head_dim),
            device=config.device))

    def forward(self, x, freqs_cis):
        batch_size, seq_len, _ = x.shape
        
        q = self.q_proj(x).view(batch_size, seq_len, self.config.n_head, self.config.head_dim)
        k = self.k_proj(x).view(batch_size, seq_len, self.config.kv_head, self.config.head_dim)
        v = self.v_proj(x).view(batch_size, seq_len, self.config.kv_head, self.config.head_dim)

        q = apply_rotary_embed(q, freqs_cis)
        k = apply_rotary_embed(k, freqs_cis)

        self.cache_k = torch.roll(self.cache_k, shifts=-seq_len, dims=1)
        self.cache_v = torch.roll(self.cache_v, shifts=-seq_len, dims=1)
        
        self.cache_k[:, -seq_len:] = k
        self.cache_v[:, -seq_len:] = v
        
        keys = self.cache_k.repeat_interleave(self.config.n_head // self.config.kv_head, dim=2)
        values = self.cache_v.repeat_interleave(self.config.n_head // self.config.kv_head, dim=2)

        attn_weights = (q @ keys.transpose(-2, -1)) / math.sqrt(self.config.head_dim)
        attn_weights = F.softmax(attn_weights, dim=-1)
        
        output = (attn_weights @ values).transpose(1, 2).reshape(batch_size, seq_len, -1)
        return self.o_proj(output)

class MistralMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        hidden_dim = int(2 * config.d_ff / 3)
        self.gate_proj = nn.Linear(config.d_model, hidden_dim, bias=False)
        self.up_proj = nn.Linear(config.d_model, hidden_dim, bias=False)
        self.down_proj = nn.Linear(hidden_dim, config.d_model, bias=False)

    def forward(self, x):
        return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))

class MistralBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attn = MistralAttention(config)
        self.mlp = MistralMLP(config)
        self.norm1 = RMSNorm(config.d_model, config.eps)
        self.norm2 = RMSNorm(config.d_model, config.eps)

    def forward(self, x, freqs_cis):
        x = x + self.attn(self.norm1(x), freqs_cis)
        x = x + self.mlp(self.norm2(x))
        return x

class Transformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embed = nn.Embedding(config.vocab_size, config.d_model)
        self.layers = nn.ModuleList([MistralBlock(config) for _ in range(config.layers)])
        self.norm = RMSNorm(config.d_model, config.eps)
        self.output = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.freqs_cis = precompute_freqs_cis(
            config.head_dim, 
            config.max_pos_embed
        ).to(config.device)

    def forward(self, x):
        batch_size, seq_len = x.shape
        x = self.embed(x)
        freqs_cis = self.freqs_cis[:seq_len]
        
        for layer in self.layers:
            x = layer(x, freqs_cis)
            
        x = self.norm(x)
        return self.output(x)

if __name__ == "__main__":
    config = MistralConfig()
    model = Transformer(config).to(config.device)
    dummy_input = torch.randint(config.vocab_size, (config.batch_size, config.seq_len), device=config.device)
    logits = model(dummy_input)
    print(f"Logits shape: {logits.shape}")  # Should output: Logits shape: torch.Size([16, 64, 1000])

RuntimeError: The size of tensor a (64) must match the size of tensor b (256) at non-singleton dimension 1