In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
import math

In [28]:
class RMSNorm(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.eps = float(config.eps)
        self.weight = nn.Parameter(torch.ones(config.d_model))

    def _norm(self,x):
        return x * torch.rsqrt(torch.mean(x**2,dim=-1,keepdim=True)+self.eps)

    def forward(self,x):
        return self._norm(x.float()).type_as(x) * self.weight

In [3]:
class ParallelGatedMLP(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.d_model = config.d_model
        self.d_ff = config.d_ff
        self.act = F.silu
        self.gate = nn.Linear(config.d_model,config.d_ff,bias=False)
        self.fc1 = nn.Linear(config.d_model,config.d_ff,bias=False)
        self.fc2 = nn.Linear(config.d_ff,config.d_model,bias=False)

    def forward(self,x):
        y = self.fc2(self.act(self.gate(x) * self.fc1(x)))
        return y

In [4]:
class LinearlyScaledRotaryEmbedding(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.d_model = config.d_model
        self.device = config.device
        self._linear_scaling_factor = config.scaling_factor
        self.inv_freq = 1.0 / (config.base ** (torch.arange(0,self.d_model,2,device=self.device).float() / self.d_model))
        self._seq_len_cached = 0
        self._cos_cached = None
        self._sin_cached = None

    def update_cache(self,seq_len,device=None,dtype=None):
        if seq_len > self._seq_len_cached or self._cos_cached is None or self._cos_cached.device != device or self._cos_cached.dtype != dtype:
            self._seq_len_cached = seq_len
            t = torch.arange(seq_len,device=device,dtype=dtype) / self._linear_scaling_factor
            freqs = torch.outer(t,self.inv_freq.to(dtype))
            self._cos_cached = torch.cos(freqs).to(dtype)
            self._sin_cached = torch.sin(freqs).to(dtype)

    def forward(self,x):
        seq_len = x.shape[1]
        device = x.device
        dtype = x.dtype
        self.update_cache(seq_len,device=device,dtype=dtype)
        cos,sin = self._cos_cached,self._sin_cached

        d_model = x.size(-1)
        split_dim = d_model // 2
        x1 = x[...,:split_dim]
        x2 = x[...,split_dim:]

        cos = cos[:seq_len,:].unsqueeze(0)
        sin = sin[:seq_len,:].unsqueeze(0)

        x1_rot = x1 * cos - x2 * sin
        x2_rot = x1 * sin - x2 * cos

        return torch.cat([x1_rot,x2_rot],dim=-1)

In [5]:
class MHA(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.d_model = config.d_model
        self.n_heads = config.n_heads
        self.d_head = config.d_head

        self.wq = nn.Linear(config.d_model,config.d_model,bias=False)
        self.wk = nn.Linear(config.d_model,config.d_model,bias=False)
        self.wv = nn.Linear(config.d_model,config.d_model,bias=False)
        self.wo = nn.Linear(config.d_model,config.d_model,bias=False)
        self.rotary_embedding = LinearlyScaledRotaryEmbedding(config)

    def forward(self,x):
        batch_size,seq_len,_ = x.shape
        q = self.wq(x).view(batch_size,seq_len,self.n_heads,self.d_head)
        k = self.wk(x).view(batch_size,seq_len,self.n_heads,self.d_head)
        v = self.wv(x).view(batch_size,seq_len,self.n_heads,self.d_head)

        q = self.roatry_embedding(q)
        k = self.rotary_embedding(k)

        q = q.transpose(1,2)
        k = k.transpose(1,2)
        v = v.transpose(1,2)

        attn_score = torch.matmul(q,k.transpose(-2,-1)) / math.sqrt(self.d_model)
        attn_weight = F.softmax(attn_score,dim=-1)
        attn_output = torch.matmul(attn_weight,v)
        attn_output = attn_output.transpose(1,2).contiguous().view(batch_size,seq_len,self.d_model)
        attn_output = self.wo(attn_output)
        return attn_output

In [6]:
class InputEmbedding(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.d_model = config.d_model
        self.vocab_size = config.vocab_size
        self.embedding = nn.Embedding(self.vocab_size,self.d_model)

    def forward(self,x):
        return self.embedding(x)

In [17]:
class AttentionBlock(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.mha = MHA(config)
        self.norm1 = RMSNorm(config)
        self.norm2 = RMSNorm(config)
        self.mlp = ParallelGatedMLP(config)

    def forward(self,x,padding_mask = None):
        if isinstance(padding_mask,torch.Tensor):
            x = x * padding_mask.unsqueeze(-1)
        attn_out = self.mha(self.norm1(x))
        x = x + attn_out
        if isinstance(padding_mask,torch.Tensor):
            x = x * padding_mask.unsqueeze(-1)
        mlp_out = self.mlp(self.norm2(x))
        return x + mlp_out

In [33]:
class ParallelHyenaFilter(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.d_model = config.d_model
        self.short_filter_length = config.short_filter_length
        self.short_filter_weight = nn.Parameter(torch.randn(config.d_model,1,config.short_filter_length))
        self.short_filter_bias = (nn.Parameter(torch.randn(config.d_model)) if config.short_filter_bias else None)
        self.D = nn.Parameter(torch.zeros(config.d_model))
        self.state_size = config.state_size

        poles = torch.randn(self.d_model,config.state_size,1,2)
        poles[...,0] = 1e-2 * torch.randn(self.d_model,config.state_size,1)
        poles[...,1] = 1e-3 * torch.randn(self.d_model,config.state_size,1)
        self.poles = nn.Parameter(poles)
        self.residues = nn.Parameter(torch.randn(self.d_model,config.state_size,1,2))
        self.h = None
        self.t = None

    def update_time(self,L,device):
        self.t = torch.arange(L,device=device).unsqueeze(0).unsqueeze(0)

    def compute_filter(self,L,device):
        self.update_tim(L,device)
        residues = torch.view_as_complex(self.residues.float())
        log_poles = torch.view_as_complex(self.poles.float()).low()
        h = (residues * (log_poles * self.t).exp()).real.sum(dim=1)
        h = h.mean(dim=0,keepdim=True)
        h = h.unsqueeze(0).transpose(1,2)
        return h

    def forward(self,x,padding_mask=None):
        seq_len = x.shape[1]
        device = x.device
        x_t = x.transpose(1,2)
        z_pre = F.conv1d(x_t,self.short_filter_weight,bias=self.short_filter_bias,padding=(self.short_filter_length -1)//2,groups=3*self.d_model)
        z_pre = z_pre.transpose(1,2)
        z_pre = z_pre[...,:self.d_model]
        if self.h is None or self.h.shape[1]<seq_len:
            self.h = self.compute_filter(seq_len,device)
        y = z_pre * self.h + x[...,:self.d_model] * self.D
        if padding_mask is not None:
            y = y.masked_fill(padding_mask.unsqueeze(-1),0)
        return y,None

In [34]:
class GatedConvBlock(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.norm1 = RMSNorm(config)
        self.norm2 = RMSNorm(config)
        self.hyena = ParallelHyenaFilter(config)
        self.mlp = ParallelGatedMLP(config)
        self.proj = nn.Linear(config.d_model,3*config.d_model,bias=True)
        self.out_filter_dense = nn.Linear(config.d_model,config.d_model,bias=True)

    def forward(self,x,padding_mask=None):
        x = self.proj(self.norm1(x))
        if isinstance(padding_mask,torch.Tensor):
            x = x.masked_fill(padding_mask.unsqueeze(-1),0)
        x,_ = self.hyena(x,padding_mask=padding_mask)
        x_in = self.out_filter_dense(x) + x
        if isinsatnce(padding_mask,torch.Tensor):
            x_in = x_in.masked_fill(padding_mask.unsqueeze(-1),0)
        y = self.mlp(self.norm2(x_in)) + x_in
        return y

In [35]:
class Block(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.attn = AttentionBlock(config)
        self.conv = GatedConvBlock(config)

    def forward(self,x,padding_mask=None):
        x = self.attn(x,padding_mask=padding_mask)
        x = self.conv(x,padding_mask=padding_mask)
        return x

In [36]:
class Projection(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.vocab_size = config.vocab_size
        self.d_model = config.d_model
        self.unembed = nn.Embedding(self.d_model,self.vocab_size)

    def forward(self,x):
        return self.unembed(x)

In [37]:
class StripedHyena(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.embed = InputEmbedding(config)
        self.norm = RMSNorm(config)
        self.unembed = Projection(config)
        self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layers)])

    def forward(self,x,padding_mask=None):
        x = self.embed(x)
        if padding_mask is not None:
            x = x * padding_mask.unsqueeze(-1)
        for block in self.blocks:
            x = block(x,padding_mask=padding_mask)
        x = self.norm(x)
        x = self.unembed(x)
        return x

In [38]:
@dataclass
class HyenaConfig:
    d_model = 265
    seq_len = 128
    n_heads = 8
    vocab_size = 10000
    eps = 1e-5
    short_filter_length = 21
    state_size = 16
    n_layers = 4
    tie_embed = False
    n_heads_kv = 8
    short_filter_bias = True
    device = "cuda"
    d_ff = 1024
    d_head = d_model // n_heads
    scaling_factor=1.0
    base = 10000.0

In [39]:
config = HyenaConfig()
inputs = torch.randint(0,config.vocab_size,(32,128))
padding_mask=torch.ones(32,128,dtype=torch.bool)
model = StripedHyena(config)
output = model(inputs,padding_mask=padding_mask)
print("output shape",output.shape)

RuntimeError: shape '[32, 128, 8, 33]' is invalid for input of size 1085440