In [91]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
import torch.nn.utils.rnn as rnn_utils
import math

class AttentionLayer(nn.Module):
    """Basic attention mechanism for sequence processing"""
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.W = nn.Linear(input_size, hidden_size)
        self.V = nn.Linear(hidden_size, 1)
        
        nn.init.xavier_uniform_(self.W.weight)
        nn.init.xavier_uniform_(self.V.weight)
        nn.init.zeros_(self.W.bias)
        nn.init.zeros_(self.V.bias)

    def forward(self, lstm_output, mask=None):
        att_scores = self.V(torch.tanh(self.W(lstm_output))).squeeze(-1)
        if mask is not None:
            att_scores = att_scores.masked_fill(mask == 0, -1e9)
        att_weights = F.softmax(att_scores, dim=1)
        context = (lstm_output * att_weights.unsqueeze(-1)).sum(1)
        return context, att_weights

class CNNLSTM(nn.Module):
    """Modular video action classifier with various configuration options"""
    def __init__(self, num_classes,
                 lstm_hidden_size=512,
                 lstm_layers=1,
                 dropout=0.5,
                 freeze_cnn=True,
                 use_attention=False,
                 cnn_model='resnet18'):

        super().__init__()
        self.use_attention = use_attention
        self.cnn_model = cnn_model

        # CNN Feature Extractor
        self.cnn, self.cnn_feature_size = self._build_cnn(cnn_model)
        self._set_cnn_freeze(freeze_cnn)

        lstm_input_size = self.cnn_feature_size

        self.lstm = nn.LSTM(
            input_size=lstm_input_size,
            hidden_size=lstm_hidden_size,
            num_layers=lstm_layers,
            batch_first=True,
            dropout=dropout if lstm_layers > 1 else 0
        )
        
        self._init_lstm_weights()

        if self.use_attention:
            self.attention = AttentionLayer(
                input_size=lstm_hidden_size,
                hidden_size=lstm_hidden_size
            )

        self.classifier = self._build_classifier(
            lstm_hidden_size,
            num_classes,
            dropout
        )

        self._initialize_weights()

    def _build_cnn(self, model_name):
        """Initialize CNN feature extractor with proper feature sizes"""
        if model_name == 'resnet18':
            cnn = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
            return nn.Sequential(*list(cnn.children())[:-2]), 512
        elif model_name == 'mobilenet_v3':
            cnn = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT)
            return nn.Sequential(*list(cnn.children())[:-2]), 576
        elif model_name == 'vgg16':
            cnn = models.vgg16(weights=models.VGG16_Weights.DEFAULT)
            return nn.Sequential(*list(cnn.features)), 512
        elif model_name == 'vgg19':
            cnn = models.vgg19(weights=models.VGG19_Weights.DEFAULT)
            return nn.Sequential(*list(cnn.features)), 512
        elif model_name == 'densenet121':
            cnn = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT)
            return nn.Sequential(*list(cnn.children())[:-1]), 1024
        elif model_name == 'densenet201':
            cnn = models.densenet201(weights=models.DenseNet201_Weights.DEFAULT)
            return nn.Sequential(*list(cnn.children())[:-1]), 1920
        else:
            raise ValueError(f"Unsupported CNN model: {model_name}")

    def _build_classifier(self, input_size, num_classes, dropout):
        return nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(input_size, num_classes)
        )

    def _init_lstm_weights(self):
        for name, param in self.lstm.named_parameters():
            if 'weight_ih' in name:
                nn.init.xavier_uniform_(param)
            elif 'weight_hh' in name:
                nn.init.orthogonal_(param)
            elif 'bias' in name:
                nn.init.zeros_(param)

    def _initialize_weights(self):
        for m in self.classifier.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def _set_cnn_freeze(self, freeze):
        for param in self.cnn.parameters():
            param.requires_grad = not freeze

    def forward(self, x, lengths):
        batch_size, seq_len = x.size(0), x.size(1)
        x = x.view(batch_size*seq_len, *x.size()[2:])
        x = self.cnn(x)
        
        # Handle different CNN output shapes
        if self.cnn_model in ['resnet18', 'mobilenet_v3', 'vgg16', 'vgg19', 'densenet121', 'densenet201']:
            x = F.adaptive_avg_pool2d(x, (1, 1))
            
        x = x.view(batch_size, seq_len, -1)

        packed_x = rnn_utils.pack_padded_sequence(
            x, lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        packed_out, _ = self.lstm(packed_x)
        lstm_out, _ = rnn_utils.pad_packed_sequence(packed_out, batch_first=True)

        if self.use_attention:
            mask = self._create_attention_mask(lstm_out.size(1), lengths)
            context, _ = self.attention(lstm_out, mask)
        else:
            # print(lengths)
            indices = torch.clamp(lengths - 1, min=0)
            # print(indices)
            context = lstm_out[torch.arange(batch_size), indices, :]

        return self.classifier(context)

    def _create_attention_mask(self, max_len, lengths):
        device = lengths.device
        return torch.arange(max_len, device=device).expand(len(lengths), max_len) < lengths.unsqueeze(1)

    def unfreeze_cnn_layers(self, num_layers=3, start_from_end=True):
        all_cnn_layers = []
        for name, module in self.cnn.named_modules():
            if isinstance(module, (nn.Conv2d, nn.BatchNorm2d)):
                all_cnn_layers.append(module)
        
        conv_layers = [l for l in all_cnn_layers if isinstance(l, nn.Conv2d)]
        bn_layers = [l for l in all_cnn_layers if isinstance(l, nn.BatchNorm2d)]
        
        if start_from_end:
            layers_to_unfreeze = conv_layers[-num_layers:]
            bn_to_unfreeze = bn_layers[-num_layers:] if bn_layers else []
        else:
            layers_to_unfreeze = conv_layers[:num_layers]
            bn_to_unfreeze = bn_layers[:num_layers] if bn_layers else []
        
        for layer in layers_to_unfreeze + bn_to_unfreeze:
            for param in layer.parameters():
                param.requires_grad = True

    def count_parameters(self):
        total_params = 0
        for name, parameter in self.named_parameters():
            if parameter.requires_grad:
                params = parameter.numel()
                print(f"{name}: {params}")
                total_params += params
        print(f"Total Trainable Params: {total_params}")

In [89]:
# With VGG19
model_vgg19 = CNNLSTM(
    num_classes=13,
    cnn_model='vgg19',
    lstm_hidden_size=512,
)
model_vgg19.count_parameters()
print()
print()

# Create dummy input data (batch_size=2, sequence_length=5, channels=3, height=224, width=224)
batch_size = 10
seq_length = 5
dummy_input = torch.randn(batch_size, seq_length, 3, 224, 224)

# Create dummy lengths (batch_size=2)
dummy_lengths = torch.randint(1, seq_length + 1, (batch_size,))

# Create dummy target labels (batch_size=2)
dummy_target = torch.randint(0, 13, (batch_size,))

# Forward pass
output = model_vgg19(dummy_input, lengths=dummy_lengths)

# Define loss function
criterion = nn.CrossEntropyLoss()

# Compute loss
loss = criterion(output, dummy_target)

print(f"Input shape: {dummy_input.shape}")
print(f"Output shape: {output.shape}")
print(f"Loss value: {loss.item():.4f}")

lstm.weight_ih_l0: 1048576
lstm.weight_hh_l0: 1048576
lstm.bias_ih_l0: 2048
lstm.bias_hh_l0: 2048
classifier.1.weight: 6656
classifier.1.bias: 13
Total Trainable Params: 2107917


Input shape: torch.Size([10, 5, 3, 224, 224])
Output shape: torch.Size([10, 13])
Loss value: 2.5886


In [90]:
model_densenet121 = CNNLSTM(
    num_classes=13,
    cnn_model='densenet121',
    lstm_hidden_size=512,
    use_attention=False
)
model_densenet121.count_parameters()
print()
print()
# Create dummy input data (batch_size=2, sequence_length=5, channels=3, height=224, width=224)
batch_size = 10
seq_length = 5
dummy_input = torch.randn(batch_size, seq_length, 3, 224, 224)

# Create dummy lengths (batch_size=2)
dummy_lengths = torch.randint(1, seq_length + 1, (batch_size,))
print(dummy_lengths)

# Create dummy target labels (batch_size=2)
dummy_target = torch.randint(0, 13, (batch_size,))

# Forward pass
output = model_densenet121(dummy_input, lengths=dummy_lengths)

# Define loss function
criterion = nn.CrossEntropyLoss()

# Compute loss
loss = criterion(output, dummy_target)

print(f"Input shape: {dummy_input.shape}")
print(f"Output shape: {output.shape}")
print(f"Loss value: {loss.item():.4f}")

lstm.weight_ih_l0: 2097152
lstm.weight_hh_l0: 1048576
lstm.bias_ih_l0: 2048
lstm.bias_hh_l0: 2048
classifier.1.weight: 6656
classifier.1.bias: 13
Total Trainable Params: 3156493


tensor([2, 5, 2, 3, 2, 2, 3, 2, 2, 1])
Input shape: torch.Size([10, 5, 3, 224, 224])
Output shape: torch.Size([10, 13])
Loss value: 2.5215


In [79]:
res_model = CNNLSTM(
    num_classes=13,
    cnn_model='resnet18',
    lstm_hidden_size=512,
    use_attention=False
)
res_model.count_parameters()
print()
print()
# Create dummy input data (batch_size=2, sequence_length=5, channels=3, height=224, width=224)
batch_size = 10
seq_length = 5
dummy_input = torch.randn(batch_size, seq_length, 3, 224, 224)

# Create dummy lengths (batch_size=2)
dummy_lengths = torch.randint(1, seq_length + 1, (batch_size,))

# Create dummy target labels (batch_size=2)
dummy_target = torch.randint(0, 13, (batch_size,))

# Forward pass
output = model_densenet121(dummy_input, lengths=dummy_lengths)

# Define loss function
criterion = nn.CrossEntropyLoss()

# Compute loss
loss = criterion(output, dummy_target)

print(f"Input shape: {dummy_input.shape}")
print(f"Output shape: {output.shape}")
print(f"Loss value: {loss.item():.4f}")

lstm.weight_ih_l0: 1048576
lstm.weight_hh_l0: 1048576
lstm.bias_ih_l0: 2048
lstm.bias_hh_l0: 2048
classifier.1.weight: 6656
classifier.1.bias: 13
Total Trainable Params: 2107917


Input shape: torch.Size([10, 5, 3, 224, 224])
Output shape: torch.Size([10, 13])
Loss value: 2.5736


In [80]:
vgg_model = CNNLSTM(
    num_classes=13,
    cnn_model='vgg16',
    lstm_hidden_size=512,
    use_attention=False
)
vgg_model.count_parameters()
print()
print()
# Create dummy input data (batch_size=2, sequence_length=5, channels=3, height=224, width=224)
batch_size = 10
seq_length = 5
dummy_input = torch.randn(batch_size, seq_length, 3, 224, 224)

# Create dummy lengths (batch_size=2)
dummy_lengths = torch.randint(1, seq_length + 1, (batch_size,))

# Create dummy target labels (batch_size=2)
dummy_target = torch.randint(0, 13, (batch_size,))

# Forward pass
output = vgg_model(dummy_input, lengths=dummy_lengths)

# Define loss function
criterion = nn.CrossEntropyLoss()

# Compute loss
loss = criterion(output, dummy_target)

print(f"Input shape: {dummy_input.shape}")
print(f"Output shape: {output.shape}")
print(f"Loss value: {loss.item():.4f}")

lstm.weight_ih_l0: 1048576
lstm.weight_hh_l0: 1048576
lstm.bias_ih_l0: 2048
lstm.bias_hh_l0: 2048
classifier.1.weight: 6656
classifier.1.bias: 13
Total Trainable Params: 2107917


Input shape: torch.Size([10, 5, 3, 224, 224])
Output shape: torch.Size([10, 13])
Loss value: 2.6226


In [94]:
mobilenet_model = CNNLSTM(
    num_classes=13,
    cnn_model='mobilenet_v3',
    lstm_hidden_size=512,
    use_attention=False
)
mobilenet_model.count_parameters()
print()
print()
# Create dummy input data (batch_size=2, sequence_length=5, channels=3, height=224, width=224)
batch_size = 10
seq_length = 5
dummy_input = torch.randn(batch_size, seq_length, 3, 224, 224)

# Create dummy lengths (batch_size=2)
dummy_lengths = torch.randint(1, seq_length + 1, (batch_size,))
print(dummy_lengths)
# Create dummy target labels (batch_size=2)
dummy_target = torch.randint(0, 13, (batch_size,))

# Forward pass
output = mobilenet_model(dummy_input, lengths=dummy_lengths)

# Define loss function
criterion = nn.CrossEntropyLoss()

# Compute loss
loss = criterion(output, dummy_target)

print(f"Input shape: {dummy_input.shape}")
print(f"Output shape: {output.shape}")
print(f"Loss value: {loss.item():.4f}")

lstm.weight_ih_l0: 1179648
lstm.weight_hh_l0: 1048576
lstm.bias_ih_l0: 2048
lstm.bias_hh_l0: 2048
classifier.1.weight: 6656
classifier.1.bias: 13
Total Trainable Params: 2238989


tensor([4, 5, 4, 1, 3, 3, 1, 1, 3, 4])
tensor([3, 4, 3, 0, 2, 2, 0, 0, 2, 3])
Input shape: torch.Size([10, 5, 3, 224, 224])
Output shape: torch.Size([10, 13])
Loss value: 2.5443
