In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class MultiStream3DCNN(nn.Module):
    def __init__(self):
        super(MultiStream3DCNN, self).__init__()

        # Stream for Segmented Hands and Face
        self.hand_face_stream = nn.Sequential(
            nn.Conv3d(3, 16, kernel_size=(3, 3, 3), padding=1),
            nn.ReLU(),
            nn.MaxPool3d((1, 2, 2)),
            nn.Conv3d(16, 16, kernel_size=(3, 3, 3), padding=1),
            nn.ReLU(),
            nn.MaxPool3d((2, 2, 2)),
            nn.Conv3d(16, 32, kernel_size=(3, 3, 3), padding=1),
            nn.ReLU(),
            nn.MaxPool3d((2, 2, 2)),
            nn.Conv3d(32, 32, kernel_size=(3, 3, 3), padding=1),
            nn.ReLU(),
            nn.MaxPool3d((2, 2, 2)),
            nn.Conv3d(32, 64, kernel_size=(3, 3, 3), padding=1),
            nn.ReLU(),
            nn.MaxPool3d((2, 2, 2)),
            nn.Conv3d(64, 64, kernel_size=(3, 3, 3), padding=1),
            nn.ReLU(),
            nn.MaxPool3d((2, 2, 2)),
        )

        # Stream for Distances and Speed Maps
        self.dist_speed_stream = nn.Sequential(
            nn.Conv3d(3, 16, kernel_size=(3, 3, 3), padding=1),
            nn.ReLU(),
            nn.MaxPool3d((1, 2, 2)),
            nn.Conv3d(16, 32, kernel_size=(3, 3, 3), padding=1),
            nn.ReLU(),
            nn.MaxPool3d((2, 2, 2)),
        )

        # Stream for RGB and Depth Maps
        self.rgb_depth_stream = nn.Sequential(
            nn.Conv3d(4, 16, kernel_size=(3, 3, 3), padding=1),
            nn.ReLU(),
            nn.MaxPool3d((1, 2, 2)),
            nn.Conv3d(16, 32, kernel_size=(3, 3, 3), padding=1),
            nn.ReLU(),
            nn.MaxPool3d((2, 2, 2)),
            nn.Conv3d(32, 64, kernel_size=(3, 3, 3), padding=1),
            nn.ReLU(),
            nn.MaxPool3d((2, 2, 2)),
        )

        # Fully Connected Layers after Concatenation
        self.fc_layers = nn.Sequential(
            nn.Linear(64*2 + 32, 128),  # Adjust the size according to output of flattening
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 20),  # Assuming 20 classes for the output
            nn.Softmax(dim=1)
        )

    def forward(self, hand_face_input, dist_speed_input, rgb_depth_input):
        # Pass inputs through respective streams
        hand_face_features = self.hand_face_stream(hand_face_input)
        dist_speed_features = self.dist_speed_stream(dist_speed_input)
        rgb_depth_features = self.rgb_depth_stream(rgb_depth_input)

        # Flatten the outputs for concatenation
        hand_face_features = torch.flatten(hand_face_features, start_dim=1)
        dist_speed_features = torch.flatten(dist_speed_features, start_dim=1)
        rgb_depth_features = torch.flatten(rgb_depth_features, start_dim=1)

        # Concatenate all features
        concatenated_features = torch.cat((hand_face_features, dist_speed_features, rgb_depth_features), dim=1)

        # Pass concatenated features through fully connected layers
        output = self.fc_layers(concatenated_features)

        return output

In [3]:
model = MultiStream3DCNN()

In [4]:
total_params = 0
for name, parameter in model.named_parameters():
    if not parameter.requires_grad:
        continue
    param = parameter.numel()
    total_params += param
    print(f"{name} has {param} parameters")
print(f"Total number of parameters: {total_params}")

hand_face_stream.0.weight has 1296 parameters
hand_face_stream.0.bias has 16 parameters
hand_face_stream.3.weight has 6912 parameters
hand_face_stream.3.bias has 16 parameters
hand_face_stream.6.weight has 13824 parameters
hand_face_stream.6.bias has 32 parameters
hand_face_stream.9.weight has 27648 parameters
hand_face_stream.9.bias has 32 parameters
hand_face_stream.12.weight has 55296 parameters
hand_face_stream.12.bias has 64 parameters
hand_face_stream.15.weight has 110592 parameters
hand_face_stream.15.bias has 64 parameters
dist_speed_stream.0.weight has 1296 parameters
dist_speed_stream.0.bias has 16 parameters
dist_speed_stream.3.weight has 13824 parameters
dist_speed_stream.3.bias has 32 parameters
rgb_depth_stream.0.weight has 1728 parameters
rgb_depth_stream.0.bias has 16 parameters
rgb_depth_stream.3.weight has 13824 parameters
rgb_depth_stream.3.bias has 32 parameters
rgb_depth_stream.6.weight has 55296 parameters
rgb_depth_stream.6.bias has 64 parameters
fc_layers.0.weig