In [18]:
import torch
import torch.nn as nn
from torchinfo import summary

class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=3):
        super(SimpleRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        out, _ = self.rnn(x, h0)
        out = self.fc(out)
        # out = self.fc(out[:, -1, :])  # 取最後一個時間步的輸出
        return out

# 設定模型參數
input_size = 50
hidden_size = 20
output_size = 50
num_layers = 3

# 初始化模型
model = SimpleRNN(input_size, hidden_size, output_size, num_layers)

# 顯示模型資訊
dummy_input = torch.randn(8, 15, input_size)  # (batch_size=8, sequence_length=15, input_size=10)
summary(model, input_data=dummy_input, depth=15, col_names=("input_size", "output_size", "num_params"))

Layer (type:depth-idx)                   Input Shape               Output Shape              Param #
SimpleRNN                                [8, 15, 50]               [8, 15, 50]               --
├─RNN: 1-1                               [8, 15, 50]               [8, 15, 20]               3,120
├─Linear: 1-2                            [8, 15, 20]               [8, 15, 50]               1,050
Total params: 4,170
Trainable params: 4,170
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0.38
Input size (MB): 0.02
Forward/backward pass size (MB): 0.07
Params size (MB): 0.02
Estimated Total Size (MB): 0.11

In [3]:
import selective_scan_cuda

In [2]:
from mamba_ssm.modules.mamba_simple import Mamba

import torch
# from mamba_ssm import Mamba
from torchinfo import summary


batch, length, dim = 7, 137, 37
print("B,L,D:",batch, length, dim)
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=16,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")

# summary_str = summary(model, input_size=[(7, 201, 286), (7, 201, 286)], depth=5, col_names=("input_size", "output_size", "num_params"), verbose=0)
# summary_str = summary(model, input_size=[(batch, length, dim)], depth=15, col_names=("input_size", "output_size", "num_params"), verbose=0)
# print(summary_str)

y = model(x)
print(x.shape)
print(y.shape)
assert y.shape == x.shape

B,L,D: 7 137 37
Mamba
batch, seqlen, dim: 7 137 37
xz.shape: torch.Size([7, 148, 137])
1111111111111111111111111111111111111111
MambaInnerFn
xz.shape: torch.Size([7, 148, 137])
x.shape, z.shape: torch.Size([7, 74, 137]) torch.Size([7, 74, 137])
conv1d_out.shape: torch.Size([7, 74, 137])
x_dbl.shape: torch.Size([959, 35])
torch.Size([7, 137, 37])
torch.Size([7, 137, 37])


In [5]:
import torch

class SquareFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input_tensor):
        # 儲存 input 給 backward 使用
        ctx.save_for_backward(input_tensor)
        # 自訂屬性範例（非 tensor）
        ctx.note = "我是 forward 儲存的小抄"
        print(ctx.note)
        return input_tensor ** 2

    @staticmethod
    def backward(ctx, grad_output):
        # 取出在 forward 存下來的 tensor
        (input_tensor,) = ctx.saved_tensors
        print("小抄內容：", ctx.note)
        # 根據 y = x^2，dy/dx = 2x
        grad_input = 2 * input_tensor * grad_output
        return grad_input

# 用我們自訂的 Function 包裝一下
square = SquareFunction.apply

# 建立一個需要 gradient 的 tensor
x = torch.tensor([3.0], requires_grad=True)

# 呼叫自訂的 forward
y = square(x)

# 執行 backward
y.backward()

# 看 gradient
print("x 的 gradient：", x.grad)

我是 forward 儲存的小抄
小抄內容： 我是 forward 儲存的小抄
x 的 gradient： tensor([6.])
