In [None]:
from dataclasses import dataclass
from typing import Union
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import repeat, rearrange

@dataclass
class ModelArgs:
    d_model: int
    n_layer: int
    vocab_size: int
    d_state: int = 16
    expand: int = 2
    dt_rank: Union[int, str] = 'auto'
    d_conv: int = 4 
    pad_vocab_size_multiple: int = 8
    conv_bias: bool = True
    bias: bool = False
    
    def __post_init__(self):
        self.d_inner = int(self.expand * self.d_model)
        
        if self.dt_rank == 'auto':
            self.dt_rank = math.ceil(self.d_model / 16)
            
        if self.vocab_size % self.pad_vocab_size_multiple != 0:
            self.vocab_size += (self.pad_vocab_size_multiple
                                - self.vocab_size % self.pad_vocab_size_multiple)

class MambaBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args
        self.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)
        self.conv1d = nn.Conv1d(
            in_channels=args.d_inner,
            out_channels=args.d_inner,
            bias=args.conv_bias,
            kernel_size=args.d_conv,
            groups=args.d_inner,
            padding=args.d_conv - 1,
        )
        # ssm 내부에서 사용
        # 입력 x를 확장해 Δ, B, C를 위한 벡터를 생성하는 층
        self.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False)
        # dt_rank차원을 d_inner차원으로 확장해 Δ 생성하는 층
        self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True)
        A = repeat(torch.arange(1, args.d_state + 1), 'd_state -> d_model d_state',
        d=args.d_inner)
        self.A_log = nn.Parameter(torch.log(A))
        self.D = nn.Parameter(torch.ones(args.d_inner))
        self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)
    def forward(self, x):
        (b, l, d_model) = x.shape
        x_and_res = self.in_proj(x) # shape (b, l, 2 * d_inner)
        (x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner],
        dim=-1)
        x = rearrange(x, 'b l d_inner -> b d_inner l')
        x = self.conv1d(x)[:, :, :l]
        x = rearrange(x, 'b d_inner l -> b l d_inner')
        x = F.silu(x)
        y = self.ssm(x)
        y = y * F.silu(res)
        output = self.out_proj(y)
        return output