In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [65]:
model_width = 256 #4096
# 8k context
seq_len = 50
device = 'mps'
batch_size = 1

class RoPE(nn.Module):
    def __init__(self, model_width, seq_len):
        super(RoPE, self).__init__()
        self.model_width = model_width
        mask = torch.cat([torch.ones(model_width, device=device), torch.zeros(model_width, device=device)], dim=0)
        mask = mask.view(2, -1).transpose(1,0).flatten()[: model_width]
        self.register_buffer("mask", mask)
        thetas = torch.arange(0, self.model_width // 2, device=device)
        thetas = -2 * (thetas - 1) / self.model_width
        thetas = torch.exp(np.log(10000) * thetas)
        thetas_repeated = thetas.repeat_interleave(2) * torch.arange(0, seq_len, device=device).view(-1, 1)  # seq_len, model_width
        self.register_buffer("thetas_repeated", thetas_repeated)

    def forward(self, x):
        b, l, w = x.shape
        cos_terms = self.thetas_repeated.cos() * x
        swapped = x.view(-1, 2).flip(1).reshape(b, l, -1)
        sin_terms = self.thetas_repeated.sin() * swapped
        return cos_terms + sin_terms * self.mask


class Sine(nn.Module):
    def forward(self, x):
        return torch.sin(x)

class Attention(nn.Module):
    def __init__(self, inner_dim, num_heads):
        super(Attention, self).__init__()
        self.inner_dim = inner_dim
        self.to_qkv = nn.Linear(model_width, 3 * num_heads * inner_dim)
        self.to_out = nn.Linear(num_heads * inner_dim, model_width)
        self.num_heads = num_heads
        self.register_buffer("mask", torch.tril(torch.ones(seq_len, seq_len)))
        self.rope = RoPE(model_width=inner_dim, seq_len=seq_len)
    def forward(self, x):
        qkv = self.to_qkv(x)
        q,k,v = qkv.chunk(3, dim=-1)
        factor = self.inner_dim ** -0.5
        q = q.view(batch_size, seq_len, self.num_heads, self.inner_dim).transpose(1,2)
        k = k.view(batch_size, seq_len, self.num_heads, self.inner_dim).transpose(1,2)
        v = v.view(batch_size, seq_len, self.num_heads, self.inner_dim).transpose(1,2)

        q = self.rope(q.view(batch_size * self.num_heads, seq_len, -1).contiguous()).view(batch_size, self.num_heads, seq_len, -1)
        k = self.rope(k.view(batch_size * self.num_heads, seq_len, -1).contiguous()).view(batch_size, self.num_heads, seq_len, -1)

        attn = (q @ k.transpose(-2, -1)) * factor
        attn = attn.masked_fill(self.mask[:seq_len, :seq_len] == 0, float('-inf'))
        attn = attn.softmax(dim=-1)
        out = (attn @ v).transpose(1,2).reshape(batch_size, seq_len, -1)
        return self.to_out(out)

class MoeFFN(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * 3),
            nn.GELU(),
            nn.Linear(dim * 3, dim)
        )
    def forward(self, x):
        return self.net(x)
    
class MoE(nn.Module):
    def __init__(self, inner_dim, num_experts=8):
        super().__init__()
        self.gating_function = MoeFFN(inner_dim)
        self.experts = nn.ModuleList([MoeFFN(inner_dim) for _ in range(num_experts)])
    def forward(self, vector_in):
        b, l, d = vector_in.shape
        vector_in = vector_in.reshape(b*l, d)   
        
        x = self.gating_function(vector_in)
        topk_values, topk_indices = x.topk(k=2, dim=-1) # (B, l, 2)
        topk_values = topk_values.softmax(dim=-1) # (B, l, 2)
        
        all_expert_output = torch.zeros_like(vector_in) # (B, l, d)

        for expert_idx in range(len(self.experts)):
            expert_mask = (topk_indices == expert_idx).any(dim=-1) 
            
            # Get the weights where this expert was selected
            expert_locations = (topk_indices == expert_idx) 
            expert_weights = torch.zeros_like(vector_in[:, 0])
            expert_weights[expert_mask] = topk_values[expert_locations]
            
            expert_tokens = vector_in[expert_mask] 
            if expert_tokens.size(0) > 0:
                local_expert_output = self.experts[expert_idx](expert_tokens)
                all_expert_output[expert_mask] += local_expert_output * expert_weights[expert_mask].unsqueeze(-1)

            return all_expert_output.reshape(b, l, -1)


class AttentionBlock(nn.Module):
    def __init__(self, inner_dim, num_heads):
        super(AttentionBlock, self).__init__()
        self.attn = Attention(inner_dim, num_heads)
        self.norm1 = nn.LayerNorm(model_width)
        self.norm2 = nn.LayerNorm(model_width)
        self.num_experts = 8
        self.moe = MoE(model_width, self.num_experts)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.moe(self.norm2(x))
        return x
        

class ConvProjection(nn.Module):
    def __init__(self, inner_dim, order=2, kernel_size=10):
        super(ConvProjection, self).__init__()
        self.inner_dim = inner_dim
        self.order = order
        self.linear_proj = nn.Linear(model_width, inner_dim * (self.order + 1))
        # this is a depthwise conv
        # meaning each feature for each timestep is convolved with a different filter
        self.conv_filter = nn.Conv1d(
            in_channels=inner_dim * (order + 1),
            out_channels=inner_dim * (order + 1),
            kernel_size=kernel_size,
            padding='same',
            groups=1)

    def forward(self, u):
        # u is an L x model_width tensor
        x = self.linear_proj(u) # B x L x (inner_dim * (order + 1))
        x = x.transpose(1, 2) # B x (inner_dim * (order + 1)) x L
        x = self.conv_filter(x) # no change in shape
        projections = x.chunk(self.order + 1, dim=-2) # B x inner_dim x L
        return projections


class HyenaFilter(nn.Module):
    def __init__(self, inner_dim, order):
        super(HyenaFilter, self).__init__()
        self.inner_dim = inner_dim
        self.t_embed = nn.Embedding(seq_len, model_width)
        self.order = order
        self.ffn = nn.Sequential(
            nn.Linear(model_width, inner_dim * self.order),
            Sine(),
            nn.Linear(inner_dim * self.order, inner_dim * self.order),
        )
        # N parameters, one for each filter
        # the alphas should be positive so we don't explode -> exp decay
        self.alphas = nn.Parameter(torch.rand(self.order))
        self.biases = nn.Parameter(torch.zeros(self.order))

    def forward(self, batch_size, L):
        t = torch.arange(L, device=device).unsqueeze(0).expand(batch_size, -1) # (B,L)
        t_embed = self.t_embed(t) # (B, L, model_width)
        x = self.ffn(t_embed) # (B, L, inner_dim * N)
        h_hat = x.reshape(batch_size, self.order, self.inner_dim, L) # (B, N, inner_dim, L)
        # expanded times
        expanded_t = t.unsqueeze(1).expand(batch_size, self.order, -1) # (B, N, L)
        expanded_alphas = self.alphas.unsqueeze(0).unsqueeze(2).expand(batch_size, -1, L) # (B, N, L)
        expanded_biases = self.biases.unsqueeze(0).unsqueeze(2).expand(batch_size, -1, L) # (B, N, L)
        window = (expanded_alphas * expanded_t).neg().exp() + expanded_biases # (B, N, L)
        expanded_window = window.unsqueeze(2).expand(-1, -1, self.inner_dim, -1) # (B, N, inner_dim, L)
        h = h_hat * expanded_window
        return [t.squeeze(0) for t in torch.chunk(h, self.order, dim=1)]


class HyenaBlock(nn.Module):
    def __init__(self, order):
        super(HyenaBlock, self).__init__()
        self.conv_proj = ConvProjection(inner_dim=model_width, order=order)
        self.filters = HyenaFilter(inner_dim=model_width, order=order)
        self.order = order

    def forward(self, u):
        projs = self.conv_proj(u)
        filters = self.filters(batch_size, seq_len)
        assert len(projs) == len(filters) + 1
        v = projs[0] # (B, model_width, L)
        for n in range(self.order):
            proj = projs[n+1] # (B, model_width, L)
            filter = filters[n] # (B, model_width, L)
            value_fourier = torch.fft.fft(v)
            filter_fourier = torch.fft.fft(filter)
            convolved_fourier = value_fourier * filter_fourier
            convolved = torch.fft.ifft(convolved_fourier).real
            v = proj * convolved
        return v.transpose(1,2) # (B, L, model_width)

class StripedHyenaBlock(nn.Module):
    def __init__(self, inner_dim):
        super(StripedHyenaBlock, self).__init__()
        self.conv_proj = ConvProjection(inner_dim=inner_dim, order=2)
        self.filters = HyenaFilter(inner_dim=inner_dim, order=1)

    def forward(self, u):
        q,k,v = self.conv_proj(u)
        h, = self.filters(batch_size, seq_len)
        v = k * v
        value_fourier = torch.fft.fft(v)
        filter_fourier = torch.fft.fft(h)
        convolved_fourier = value_fourier * filter_fourier
        convolved = torch.fft.ifft(convolved_fourier).real
        v = q * convolved
        return v.transpose(1,2) # (B, L, model_width)

class StripedHyena(nn.Module):
    def __init__(self, num_layers=32, spacing=10):
        super(StripedHyena, self).__init__()
        self.base_embed = nn.Embedding(4, model_width)
        self.blocks = nn.ModuleList([
            nn.Sequential(StripedHyenaBlock(inner_dim=model_width), nn.Linear(model_width, model_width * 2), nn.GLU()) if (i+1) % spacing != 0
            else AttentionBlock(inner_dim=128, num_heads=4)
            for i in range(num_layers)
        ])
        self.norm = nn.LayerNorm(model_width)
        self.to_logits = nn.Linear(model_width, 4)

    def forward(self, x, target=None):
        x = self.base_embed(x)
        for block in self.blocks:
            x = block(x)
             
        x = self.norm(x)
        logits = self.to_logits(x)
        if target is not None:
            loss = F.cross_entropy(logits.view(-1, 4), target.view(-1))
            return loss
        # return F.softmax(logits[:, -1, :], dim=-1)

u = torch.randint(0, 4, (batch_size, seq_len)).to(device)
model = StripedHyena(num_layers=32, spacing=10).to(device)
model(u)