In [38]:
NUM_CLASSES = 13

In [42]:
import torch
import torch.nn as nn

class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, 
                              stride=stride, padding=1)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3,
                              padding=1)
        self.bn2 = nn.BatchNorm1d(out_channels)
        
        # Shortcut connection
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm1d(out_channels)
            )

    def forward(self, x):
        residual = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        out += self.shortcut(residual)
        out = self.relu(out)
        
        return out

class CNNLSTM(nn.Module):
    def __init__(self, input_channels,hidden_size=256, num_classes=NUM_CLASSES, num_nodes=33):
        super(CNNLSTM, self).__init__()
        
        self.batch_norm = nn.BatchNorm1d(input_channels * num_nodes)
        # CNN layers
        self.conv1 = nn.Sequential(
            nn.Conv1d(input_channels, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(32),
            nn.ReLU(inplace=True)
        )
        
        self.blk1 = ResBlock(32, 64, stride=2)
        self.blk2 = ResBlock(64, 128, stride=2)
        self.blk3 = ResBlock(128, 256, stride=2)
        self.blk4 = ResBlock(256, 512, stride=2)

        self.V_downsampled = 3
        self.lstm = nn.LSTM(512 * self.V_downsampled, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        N, C, T, V = x.size()
        
        # BATCH NORM
        x = x.permute(0, 3, 1, 2).contiguous()
        x = x.view(N, V * C, T)
        x = self.batch_norm(x)
        x = x.view(N, V, C, T).permute(0, 2, 3, 1).contiguous()
        
        # CNN LAYER APPLICATION
        x = x.view(N*T, C, V)
        x = self.conv1(x)
        x = self.blk1(x)
        x = self.blk2(x)
        x = self.blk3(x)
        x = self.blk4(x)
        
        # Current shape: (N*T, 512, V_downsampled)
        _, C_out, V_down = x.size()
        x = x.view(N, T, C_out, V_down)
        
        # no interleaving (x0,y0,z0, x1,y1,z1,..) instead (x0,x1,..x32, y1, y2,...)
        x = x.view(N, T, -1)
        x, _ = self.lstm(x)
        x = x[:, -1, :]
        
        x = self.fc(x)
        return x
    def count_parameters(self):
        total_params = 0
        for name, parameter in self.named_parameters():
            if parameter.requires_grad:
                params = parameter.numel()
                print(f"{name}: {params}")
                total_params += params
        print(f"Total Trainable Params: {total_params}")

   

# Example usage:
def main():
    # Example parameters
    batch_size = 32
    input_channels = 3
    temporal_length = 20
    num_vertices = 33  # Example for skeleton data
    num_classes = 13   # Example for action recognition
    
    # Create model
    model = CNNLSTM(input_channels=input_channels, 
                    num_classes=num_classes)
    
    # Example input
    x = torch.randn(batch_size, input_channels, temporal_length, num_vertices)
    
    # Forward pass
    output = model(x)
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    model.count_parameters()
    

if __name__ == "__main__":
    main()

Input shape: torch.Size([32, 3, 20, 33])
Output shape: torch.Size([32, 13])
batch_norm.weight: 99
batch_norm.bias: 99
conv1.0.weight: 288
conv1.0.bias: 32
conv1.1.weight: 32
conv1.1.bias: 32
blk1.conv1.weight: 6144
blk1.conv1.bias: 64
blk1.bn1.weight: 64
blk1.bn1.bias: 64
blk1.conv2.weight: 12288
blk1.conv2.bias: 64
blk1.bn2.weight: 64
blk1.bn2.bias: 64
blk1.shortcut.0.weight: 2048
blk1.shortcut.0.bias: 64
blk1.shortcut.1.weight: 64
blk1.shortcut.1.bias: 64
blk2.conv1.weight: 24576
blk2.conv1.bias: 128
blk2.bn1.weight: 128
blk2.bn1.bias: 128
blk2.conv2.weight: 49152
blk2.conv2.bias: 128
blk2.bn2.weight: 128
blk2.bn2.bias: 128
blk2.shortcut.0.weight: 8192
blk2.shortcut.0.bias: 128
blk2.shortcut.1.weight: 128
blk2.shortcut.1.bias: 128
blk3.conv1.weight: 98304
blk3.conv1.bias: 256
blk3.bn1.weight: 256
blk3.bn1.bias: 256
blk3.conv2.weight: 196608
blk3.conv2.bias: 256
blk3.bn2.weight: 256
blk3.bn2.bias: 256
blk3.shortcut.0.weight: 32768
blk3.shortcut.0.bias: 256
blk3.shortcut.1.weight: 256
