# hyenaの実装

In [106]:
import numpy as np
import torch
from torch import Tensor
import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Rearrange
from torchinfo import summary

# Projection  

入力をN個にProjection

In [47]:
class Projection(nn.Module):
    def __init__(self,
                 embed_dim: int, #  D:model width モデルの並列数（チャンネル？）
                 order: int = 2, #　N:入力を射影する数
                 kernel_size: int = 3,
                 stride: int = 1,
                 padding: int = 2
                 ):
        super().__init__()
        
        hidden_size = (order + 1) * embed_dim #線形層のunit数
        self.linear = nn.Linear(embed_dim, hidden_size) #embed_dim -> hidden_size
        self.short_conv = nn.Conv1d(
            in_channels=hidden_size,
            out_channels=hidden_size,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            groups=hidden_size
        )
        self.hidden_size = hidden_size
        
    def forward(self, u: Tensor):
        # B: batch size, L: seq len, E: embed dim, N: order of hyena
        L = u.shape[1]
        u = self.linear(u) # embed_dim -> hidden_size
                           # (B, L, E) -> (B, (N+1)*E, L)   H=(N+1)*E
        u = rearrange(u, "B L H -> B H L") #　H: hidden_size
        u = self.short_conv(u)[..., :L]
        return u.chunk(self.hidden_size, dim=1) # v,x1,x2,x3 に分割# ( B, (N+1)*E, L) ->  [(B, E, L)] * (N+1) 

In [48]:
enbed_dim = 1 #model_width
order = 3 #N=5 

proj = Projection(embed_dim=enbed_dim, order=order) 
u = torch.randn(1, 224, 1, requires_grad=True) #(B=1, L=224, E=1) 
u.shape

torch.Size([1, 224, 1])

In [49]:
linear = nn.Linear(in_features=1, out_features=(3+1)*1)
linear(u).shape # B,L,H 

torch.Size([1, 224, 4])

In [50]:
u_rea = rearrange(linear(u), "B L H -> B H L") # B, H, L
u_rea.shape

torch.Size([1, 4, 224])

In [51]:
short_conv = nn.Conv1d(
            in_channels=(3+1)*1,
            out_channels=(3+1)*1,
            kernel_size=3,
            stride=1,
            padding=2,
            groups=(3+1)*1)

sconv = short_conv(u_rea)[...,:224]
sconv.shape

torch.Size([1, 4, 224])

In [54]:
sconv.chunk((3+1)*1,dim=1)[0].shape

torch.Size([1, 1, 224])

In [62]:
class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim: int, max_seq_len: int):
        assert embed_dim % 2 == 1, "`embed_dim` must be odd"
        super().__init__()
        # L: seq len, Ep: pos embed dim, K: (Et-1)//2
        t = torch.linspace(0, 1, steps=max_seq_len).unsqueeze(-1) # -> (L, 1)
        t_pos = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(-1) # -> (L, 1)
        K = (embed_dim - 1) // 2
        k = torch.linspace(0, K - 1, steps=K).unsqueeze(0) # -> (1, K)
        z = torch.exp(1j * 2 * np.pi * k * t_pos / max_seq_len) # -> (L, K)
        self.t = nn.Parameter(t.view(1, 1, max_seq_len), requires_grad=False) # -> (1, 1, L)
        self.z = nn.Parameter(
            torch.cat([t, z.real, z.imag], dim=-1), # -> (L, Ep)
        )

    def forward(self, seq_len: int) -> tuple[Tensor, Tensor]:
        return self.t[..., :seq_len], self.z[:seq_len, :]


In [66]:
embed_dim = 1
max_seq_len = 224
pe = PositionalEncoding(embed_dim=embed_dim, max_seq_len=max_seq_len)


2

In [69]:
pe(224)[0].shape

torch.Size([1, 1, 224])

In [78]:
class Sin(nn.Module):
    def __init__(self, embed_dim: int, freq: float = 8.0, learn: bool = True):
        super().__init__()
        self.freq = nn.Parameter(freq * torch.ones(1, embed_dim), requires_grad=learn)

    def forward(self, x: Tensor) -> Tensor:
        # L: seq len, E: embed dim
        return torch.sin(self.freq * x) # -> (L, E)


In [79]:
class ExponentialDecayWindow(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        fast_decay_t: float = 0.3,
        slow_decay_t: float = 1.5,
        target: float = 1e-2,
        shift: float = 0.0
    ):
        super().__init__()
        max_decay = np.log(target) / fast_decay_t
        min_decay = np.log(target) / slow_decay_t
        self.alphas = nn.Parameter(
            torch.linspace(min_decay, max_decay, steps=embed_dim).view(1, embed_dim, 1)
        )
        self.shift = shift

    def forward(self, x: Tensor, t: Tensor) -> Tensor:
        # L: seq len, E: embed dim, N: order of hyena
        L = x.shape[-1]
        decay = torch.exp(self.alphas * t)[..., :L] # -> (1, E, L)
        x *= decay + self.shift
        return x


In [80]:
class HyenaFilter(nn.Module):
    def __init__(
        self,
        pos_embed_dim: int,
        max_seq_len: int,
        seq_embed_dim: int,
        order: int = 2,
        fnn_depth: int = 4,
        fnn_hidden_size: int = 64,
        freq: float = 10.0,
        learn: bool = True,
        fast_decay_t: float = 0.3,
        slow_decay_t: float = 1.5,
        target: float = 1e-2,
        shift: float = 0.0
    ):
        super().__init__()
        
        assert fnn_depth > 2, "fnn_depth must be grater than 2"
        self.pos = PositionalEncoding(pos_embed_dim, max_seq_len)
        self.fnn = nn.Sequential(
            nn.Linear(pos_embed_dim, fnn_hidden_size),
            Sin(fnn_hidden_size, freq, learn)
        )
        
        for _ in range(fnn_depth - 2):
            self.fnn.append(nn.Linear(fnn_hidden_size, fnn_hidden_size))
            self.fnn.append(Sin(fnn_hidden_size, freq, learn))
        self.fnn.append(nn.Linear(fnn_hidden_size, order * seq_embed_dim, bias=False))
        
        self.embed_dim = seq_embed_dim
        self.order = order
        self.window = ExponentialDecayWindow(
            seq_embed_dim,
            fast_decay_t=fast_decay_t,
            slow_decay_t=slow_decay_t,
            target=target,
            shift=shift
        )
        
    def forward(self, seq_len: int) -> list[Tensor]:
        # L: seq len, Ep: pos embed dim, N: order of hyena, E: seq embed dim
        t, z = self.pos(seq_len) # -> (1, 1, L), (L, Ep)
        h = (
            self.fnn(z) # (L, Ep) -> (L, N*E)
            .transpose(0, 1) # (L, N*E) -> (N*E, L)
            .reshape(self.order, self.embed_dim, seq_len) # (N*E, L) -> (N, E, L)
        )
        h = self.window(h, t) # (N, E, L) -> (N, E, L)
        return h.chunk(self.order, dim=0) # (N, E, L) -> [(1, E, L)] * N
        

In [81]:
class HyenaBlock(nn.Module):
    def __init__(
        self,
        *,
        embed_dim: int,
        max_seq_len: int,
        order: int,
        pos_dim: int = 65,
        kernel_size: int = 3,
        stride: int = 1,
        padding: int = 2,
        fnn_depth: int = 4,
        fnn_hidden_size: int = 64,
        freq: float = 8.0,
        learn_filter: bool = True,
        fast_decay_t: float = 0.3,
        slow_decay_t: float = 1.5,
        target: float = 1e-2,
        shift: float = 0.0,
        activation: str = "identity"
    ):
        super().__init__()
        self.proj = Projection(
            embed_dim,
            order,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding
        )
        self.hyena_filter = HyenaFilter(
            pos_dim,
            max_seq_len,
            seq_embed_dim=embed_dim,
            order=order,
            fnn_depth=fnn_depth,
            fnn_hidden_size=fnn_hidden_size,
            freq=freq,
            learn=learn_filter,
            fast_decay_t=fast_decay_t,
            slow_decay_t=slow_decay_t,
            target=target,
            shift=shift
        )
        self.bias = nn.Parameter(torch.randn(order, 1, embed_dim, 1))
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.fc = nn.Linear(embed_dim, embed_dim)

        act: nn.Module
        match name := activation.lower():
            case "identity": act = nn.Identity()
            case "relu": act = nn.ReLU()
            case "leaky-relu": act = nn.LeakyReLU()
            case "gelu": act = nn.GELU()
            case "silu": act = nn.SiLU()
            case "tanh": act = nn.Tanh()
            case _: raise NotImplementedError(f"activation `{name}` is invalid")
        self.act = act

    @staticmethod
    def fftconv(x: Tensor, h: Tensor, d: Tensor) -> Tensor:
        # B: batch size, L: seq len, E: embed dim
        L = x.shape[-1]
        h_fft = torch.fft.rfft(h, n=2*L, norm="forward") # (1, E, L) -> (1, E, 2*L)
        x_fft = torch.fft.rfft(x.to(dtype=h.dtype), n=2*L) # (B, E, L) -> (B, E, 2*L)
        y = torch.fft.irfft(x_fft * h_fft, n=2*L, norm="forward")[..., :L] # -> (B, E, L)
        y += x * d
        return y.to(dtype=x.dtype)

    def forward(self, u: Tensor) -> Tensor:
        # B: batch size, L: seq len, E: embed dim, N: order of hyena
        L = u.shape[1]
        x = self.norm1(u) # (B, L, E) -> (B, L, E)
        ### hyena 
        x = self.proj(x) # (B, L, E) -> [(B, E, L)] * (N+1)
        h = self.hyena_filter(L) # -> [(1, E, L)] * N
        v = x[-1] # -> (B, E, L)
        for x_i, h_i, d_i in zip(x[:-1], h, self.bias):
            v = x_i * self.fftconv(v, h_i, d_i)
        ###
        y = u + v.transpose(1, 2) # -> (B, L, E)
        out = self.norm2(y) # (B, L, E) -> (B, L, E)
        out = self.fc(out) # (B, L, E) -> (B, L, E)
        out = self.act(out) # (B, L, E) -> (B, L, E)
        out += y
        return out


In [93]:
hb = HyenaBlock(embed_dim=3, max_seq_len=224*224, order=3) #order = 3 -> v x1, x2, x3

In [94]:
u = torch.randn(1,224*224,3, requires_grad=True)

In [107]:
class Patching(nn.Module):
    def __init__(self, patch_size):
        """ [input]
            - patch_size (int) : パッチの縦の長さ（=横の長さ）
        """
        super().__init__()
        self.net = Rearrange("b c (h ph) (w pw) -> b (h w) (ph pw c)", ph = patch_size, pw = patch_size)
    
    def forward(self, x):
        """ [input]
            - x (torch.Tensor) : 画像データ
                - x.shape = torch.Size([batch_size, channels, image_height, image_width])
        """
        x = self.net(x)
        return x


In [140]:
class LinearProjection(nn.Module):
    def __init__(self, patch_dim, embed_dim):
        """ [input]
            - patch_dim (int) : 一枚あたりのパッチの次元（= channels * (patch_size ** 2)
            - dim (int) : パッチが変換されたベクトルの次元 
        """
        super().__init__()
        self.net = nn.Linear(patch_dim, embed_dim)

    def forward(self, x):
        """ [input]
            - x (torch.Tensor) 
                - x.shape = torch.Size([batch_size, n_patches, patch_dim])
        """
        x = self.net(x)
        return x


In [141]:
from PIL import Image
img = np.array(Image.open("../images/hyena.jpg").resize((224,224)))
img = torch.Tensor(img).unsqueeze(0)
img = rearrange(img, "b h w c -> b c h w")
img.shape

torch.Size([1, 3, 224, 224])

In [142]:
patching = Patching(patch_size=8)
patching(img).shape # b (h w) (ph pw c) <=> b N C*P*P


torch.Size([1, 784, 192])

In [143]:
lp = LinearProjection(patch_dim=192, embed_dim=3)
lp(patching(img)).shape

torch.Size([1, 784, 3])

In [135]:
224*224*3

150528