In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchinfo import summary

# 設定設備
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 超參數
input_size = 7  # 單個輸入特徵
hidden_size = 10  # RNN 隱藏層大小
output_size = 1  # 輸出特徵大小
num_layers = 2  # 雙層 RNN

# 簡單的 RNN 模型
class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers):
        super(SimpleRNN, self).__init__()
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out, _ = self.rnn(x)  # RNN 輸出 (batch, seq, hidden)
        out = self.fc(out[:, -1, :])  # 只取最後時間步的輸出
        return out

# 初始化模型、損失函數與優化器
model = SimpleRNN(input_size, hidden_size, output_size, num_layers).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# 假設的輸入 (batch_size=5, seq_length=13, input_size=7)
x = torch.rand(5, 13, 7).to(device)  # 隨機輸入
y_true = torch.rand(5, 1).to(device)  # 目標輸出

# 打印模型資訊
print(summary(model, input_size=(5, 13, 7), col_names=["input_size", "output_size", "num_params", "trainable"], device=device))

# 訓練一步
optimizer.zero_grad()
y_pred = model(x)
loss = criterion(y_pred, y_true)
loss.backward()
optimizer.step()

print("Predicted output:", y_pred)


Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Trainable
SimpleRNN                                [5, 13, 7]                [5, 1]                    --                        True
├─RNN: 1-1                               [5, 13, 7]                [5, 13, 10]               410                       True
├─Linear: 1-2                            [5, 10]                   [5, 1]                    11                        True
Total params: 421
Trainable params: 421
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0.03
Input size (MB): 0.00
Forward/backward pass size (MB): 0.01
Params size (MB): 0.00
Estimated Total Size (MB): 0.01
Predicted output: tensor([[-0.1884],
        [-0.1023],
        [-0.1200],
        [ 0.0015],
        [-0.0869]], device='cuda:0', grad_fn=<AddmmBackward0>)


In [1]:
import torch
from mambapy.mamba import Mamba, MambaConfig
from torchinfo import summary
from c66 import pp

# B, L, D = 7, 64, 16
B, L, D = 7, 201, 286
pp(B,L,D)

config = MambaConfig(d_model=D, n_layers=1, use_cuda=False)
model = Mamba(config)
# .to("cuda")

x = torch.randn(B, L, D)
# .to("cuda")
y = model(x)

assert y.shape == x.shape
pp(x.shape)

B: 7
L: 201
D: 286
------------------------------
In ResidualBlock
x.shape: torch.Size([7, 201, 286])
self.norm(x).shape: torch.Size([7, 201, 286])
------------------------------
------------------------------
In MambaBlock
torch.Size([7, 572, 201])
torch.Size([7, 572, 201])
B, L, ED, N, dt_rank: 7 201 572 16 18
x.shape, delta.shape, A.shape, B.shape, C.shape, z.shape:
torch.Size([7, 201, 572]) torch.Size([7, 201, 572]) torch.Size([572, 16]) torch.Size([7, 201, 16]) torch.Size([7, 201, 16]) torch.Size([7, 201, 572])
self.selective_scan
------------------------------
x.shape: torch.Size([7, 201, 286])


In [4]:
print(model)

Mamba(
  (layers): ModuleList(
    (0): ResidualBlock(
      (mixer): MambaBlock(
        (in_proj): Linear(in_features=286, out_features=1144, bias=False)
        (conv1d): Conv1d(572, 572, kernel_size=(4,), stride=(1,), padding=(3,), groups=572)
        (x_proj): Linear(in_features=572, out_features=50, bias=False)
        (dt_proj): Linear(in_features=18, out_features=572, bias=True)
        (out_proj): Linear(in_features=572, out_features=286, bias=False)
      )
      (norm): RMSNorm()
    )
  )
)


In [None]:
print(model.layers[0].mixer)

MambaBlock(
  (in_proj): Linear(in_features=286, out_features=1144, bias=False)
  (conv1d): Conv1d(572, 572, kernel_size=(4,), stride=(1,), padding=(3,), groups=572)
  (x_proj): Linear(in_features=572, out_features=50, bias=False)
  (dt_proj): Linear(in_features=18, out_features=572, bias=True)
  (out_proj): Linear(in_features=572, out_features=286, bias=False)
)


In [4]:
summary_str = summary(model, input_size=[(B,L,D)], depth=5, col_names=("input_size", "output_size", "num_params"), verbose=0)
print(summary_str)

torch.Size([7, 201, 286])
Layer (type:depth-idx)                   Input Shape               Output Shape              Param #
Mamba                                    [7, 201, 286]             [7, 201, 286]             --
├─ModuleList: 1-1                        --                        --                        --
│    └─ResidualBlock: 2-1                [7, 201, 286]             [7, 201, 286]             --
│    │    └─RMSNorm: 3-1                 [7, 201, 286]             [7, 201, 286]             286
│    │    └─MambaBlock: 3-2              [7, 201, 286]             [7, 201, 286]             20,592
│    │    │    └─Linear: 4-1             [7, 201, 286]             [7, 201, 1144]            327,184
│    │    │    └─Conv1d: 4-2             [7, 572, 201]             [7, 572, 204]             2,860
│    │    │    └─Linear: 4-3             [7, 201, 572]             [7, 201, 50]              28,600
│    │    │    └─Linear: 4-4             [7, 201, 572]             [7, 201, 286]        