In [3]:
NUM_CLASSES = 13

In [43]:
import torch
import torch.nn as nn

class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3,
                              stride=stride, padding=1)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3,
                              padding=1)
        self.bn2 = nn.BatchNorm1d(out_channels)

        # Shortcut connection
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm1d(out_channels)
            )

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += self.shortcut(residual)
        out = self.relu(out)

        return out

class CNNLSTM(nn.Module):
    def __init__(self, input_channels,hidden_size=256, num_classes=NUM_CLASSES, num_nodes=33):
        super(CNNLSTM, self).__init__()

        self.batch_norm = nn.BatchNorm1d(input_channels * num_nodes)
        # CNN layers
        self.conv1 = nn.Sequential(
            nn.Conv1d(input_channels, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(32),
            nn.ReLU(inplace=True)
        )

        self.blk1 = ResBlock(32, 64, stride=2)
        self.blk2 = ResBlock(64, 128, stride=2)
        self.blk3 = ResBlock(128, 256, stride=2)
        self.blk4 = ResBlock(256, 512, stride=2)

        self.V_downsampled = 3
        self.lstm = nn.LSTM(512 * self.V_downsampled, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        N, C, T, V = x.size()

        # BATCH NORM
        x = x.permute(0, 3, 1, 2).contiguous()
        x = x.view(N, V * C, T)
        x = self.batch_norm(x)
        x = x.view(N, V, C, T).permute(0, 2, 3, 1).contiguous()

        # CNN LAYER APPLICATION
        x = x.view(N*T, C, V)
        x = self.conv1(x)
        x = self.blk1(x)
        x = self.blk2(x)
        x = self.blk3(x)
        x = self.blk4(x)

        # Current shape: (N*T, 512, V_downsampled)
        _, C_out, V_down = x.size()
        x = x.view(N, T, C_out, V_down)

        # no interleaving (x0,y0,z0, x1,y1,z1,..) instead (x0,x1,..x32, y1, y2,...)
        x = x.view(N, T, -1)
        x, _ = self.lstm(x)
        x = x[:, -1, :]

        x = self.fc(x)
        return x
    def count_parameters(self):
        total_params = 0
        param_str = ""
        for name, parameter in self.named_parameters():
            if parameter.requires_grad:
                params = parameter.numel()
                print(f"{name}: {params}")
                param_str += f"{name}: {params}\n"
                total_params += params
        param_str += f"Total Trainable Params: {total_params}"
        return param_str, total_params

In [44]:
import torch
import torch.nn as nn
import torch.optim as optim

# Assuming the CNN_LSTM class is defined as in the previous cells

# Dummy input data
N = 2  # Batch size
C = 3  # Number of channels
T = 20  # Sequence length
V = 33  # Number of joints

dummy_input = torch.randn(N, C, T, V)
dummy_labels = torch.randint(0, NUM_CLASSES, (N,))

# Instantiate the model
model = CNNLSTM(3)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Forward pass
output = model(dummy_input)

# Calculate the loss
loss = criterion(output, dummy_labels)

# Print the loss
print("Loss:", loss.item())

# # Optional: Backpropagation (for a complete test)
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()

print("Output shape:", output.shape)
print("Labels shape:", dummy_labels.shape)


Loss: 2.7769935131073
Output shape: torch.Size([2, 13])
Labels shape: torch.Size([2])


In [45]:
model.count_parameters()

batch_norm.weight: 99
batch_norm.bias: 99
conv1.0.weight: 288
conv1.0.bias: 32
conv1.1.weight: 32
conv1.1.bias: 32
blk1.conv1.weight: 6144
blk1.conv1.bias: 64
blk1.bn1.weight: 64
blk1.bn1.bias: 64
blk1.conv2.weight: 12288
blk1.conv2.bias: 64
blk1.bn2.weight: 64
blk1.bn2.bias: 64
blk1.shortcut.0.weight: 2048
blk1.shortcut.0.bias: 64
blk1.shortcut.1.weight: 64
blk1.shortcut.1.bias: 64
blk2.conv1.weight: 24576
blk2.conv1.bias: 128
blk2.bn1.weight: 128
blk2.bn1.bias: 128
blk2.conv2.weight: 49152
blk2.conv2.bias: 128
blk2.bn2.weight: 128
blk2.bn2.bias: 128
blk2.shortcut.0.weight: 8192
blk2.shortcut.0.bias: 128
blk2.shortcut.1.weight: 128
blk2.shortcut.1.bias: 128
blk3.conv1.weight: 98304
blk3.conv1.bias: 256
blk3.bn1.weight: 256
blk3.bn1.bias: 256
blk3.conv2.weight: 196608
blk3.conv2.bias: 256
blk3.bn2.weight: 256
blk3.bn2.bias: 256
blk3.shortcut.0.weight: 32768
blk3.shortcut.0.bias: 256
blk3.shortcut.1.weight: 256
blk3.shortcut.1.bias: 256
blk4.conv1.weight: 393216
blk4.conv1.bias: 512
blk

('batch_norm.weight: 99\nbatch_norm.bias: 99\nconv1.0.weight: 288\nconv1.0.bias: 32\nconv1.1.weight: 32\nconv1.1.bias: 32\nblk1.conv1.weight: 6144\nblk1.conv1.bias: 64\nblk1.bn1.weight: 64\nblk1.bn1.bias: 64\nblk1.conv2.weight: 12288\nblk1.conv2.bias: 64\nblk1.bn2.weight: 64\nblk1.bn2.bias: 64\nblk1.shortcut.0.weight: 2048\nblk1.shortcut.0.bias: 64\nblk1.shortcut.1.weight: 64\nblk1.shortcut.1.bias: 64\nblk2.conv1.weight: 24576\nblk2.conv1.bias: 128\nblk2.bn1.weight: 128\nblk2.bn1.bias: 128\nblk2.conv2.weight: 49152\nblk2.conv2.bias: 128\nblk2.bn2.weight: 128\nblk2.bn2.bias: 128\nblk2.shortcut.0.weight: 8192\nblk2.shortcut.0.bias: 128\nblk2.shortcut.1.weight: 128\nblk2.shortcut.1.bias: 128\nblk3.conv1.weight: 98304\nblk3.conv1.bias: 256\nblk3.bn1.weight: 256\nblk3.bn1.bias: 256\nblk3.conv2.weight: 196608\nblk3.conv2.bias: 256\nblk3.bn2.weight: 256\nblk3.bn2.bias: 256\nblk3.shortcut.0.weight: 32768\nblk3.shortcut.0.bias: 256\nblk3.shortcut.1.weight: 256\nblk3.shortcut.1.bias: 256\nblk4.c

In [41]:
model = CNNLSTM(3, 128)
param_count, trainable_count = model.count_parameters()

with open("model_parameters_count.txt", "w") as f:
    f.write(f"Total Params: {param_count}\n")
    f.write(f"Total Trainable Params: {trainable_count}\n")

In [28]:
mean_logits = output.mean().item()
std_logits = output.std().item()

print(f"Mean logits: {mean_logits}")
print(f"Std logits: {std_logits}")


Mean logits: 0.011346329003572464
Std logits: 0.20061618089675903
