In [None]:
import torch
from torchinfo import summary
from model_V0 import TremorNetGRU  # Update path if needed

# 1. Instantiate model
model = TremorNetGRU(num_classes=3)

# 2. Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# 3. Dummy inputs
# x: [batch, time_steps, channels] = [1, 1024, 6]
# wrist: [batch] = [1]
x = torch.zeros(1, 1024, 6, device=device)
wrist = torch.zeros(1, dtype=torch.long, device=device)

# 4. Print detailed model summary
summary(
    model=model,
    input_data=(x, wrist),  # Multi-input tuple
    col_names=[
        "input_size",
        "output_size",
        "num_params",
        "trainable",
    ],
    col_width=20,
    row_settings=["var_names"],  # show variable names
    depth=4                      # controls recursion depth
)
