## Importing necessary libraries

In [1]:
# PyTorch Core Libraries
import torch               
import torch.nn as nn      
import torch.nn.functional as F  
import numpy as np
import time
import torch.optim as optim

device = "cuda" if torch.cuda.is_available() else "cpu"

# Model code

## Different Attention Mechanisms

### ============================= SE Layer =============================


In [2]:
# https://github.com/moskomule/senet.pytorch/blob/master/senet/se_module.py
class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

### ============================= ECA Mechansim =============================

In [3]:
class eca_layer(nn.Module):
    """Constructs a ECA module.
    Args:
        channel: Number of channels of the input feature map
        k_size: Adaptive selection of kernel size
        source: https://github.com/BangguWu/ECANet
    """
    def __init__(self, channel, k_size=3):
        super(eca_layer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) 
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # x: input features with shape [b, c, h, w]
        b, c, h, w = x.size()

        # feature descriptor on the global spatial information
        y = self.avg_pool(x)

        # Two different branches of ECA module
        y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)

        # Multi-scale information fusion
        y = self.sigmoid(y)

        return x * y.expand_as(x)

### ============================= CBAM Module =============================

In [4]:
class BasicConv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None
        self.relu = nn.ReLU() if relu else None
        
    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)

class ChannelGate(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
        super(ChannelGate, self).__init__()
        self.gate_channels = gate_channels
        self.mlp = nn.Sequential(
            Flatten(),
            nn.Linear(gate_channels, gate_channels // reduction_ratio),
            nn.ReLU(),
            nn.Linear(gate_channels // reduction_ratio, gate_channels)
        )
        self.pool_types = pool_types
        
    def forward(self, x):
        channel_att_sum = None
        for pool_type in self.pool_types:
            if pool_type == 'avg':
                avg_pool = F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp(avg_pool)
            elif pool_type == 'max':
                max_pool = F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp(max_pool)
            elif pool_type == 'lp':
                lp_pool = F.lp_pool2d(x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp(lp_pool)
            elif pool_type == 'lse':
                # LSE pool
                lse_pool = logsumexp_2d(x)
                channel_att_raw = self.mlp(lse_pool)
                
            if channel_att_sum is None:
                channel_att_sum = channel_att_raw
            else:
                channel_att_sum = channel_att_sum + channel_att_raw
                
        scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
        return x * scale

def logsumexp_2d(tensor):
    tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
    s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
    outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
    return outputs

class ChannelPool(nn.Module):
    def forward(self, x):
        return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)

class SpatialGate(nn.Module):
    def __init__(self):
        super(SpatialGate, self).__init__()
        kernel_size = 7
        self.compress = ChannelPool()
        self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
        
    def forward(self, x):
        x_compress = self.compress(x)
        x_out = self.spatial(x_compress)
        scale = torch.sigmoid(x_out)  # broadcasting
        return x * scale

class CBAM(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
        super(CBAM, self).__init__()
        self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
        self.no_spatial = no_spatial
        if not no_spatial:
            self.SpatialGate = SpatialGate()
            
    def forward(self, x):
        x_out = self.ChannelGate(x)
        if not self.no_spatial:
            x_out = self.SpatialGate(x_out)
        return x_out

## ============================= ConvLSTM Cell =============================

In [5]:
class ConvLSTMCell_layer(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size=(3, 3), bias=False):
        """
        Initialize ConvLSTM cell.
        Parameters
        ----------
        input_dim: int
            Number of channels of input tensor.
        hidden_dim: int
            Number of channels of hidden state.
        kernel_size: (int, int)
            Size of the convolutional kernel.
        bias: bool
            Whether or not to add the bias.
        """
        super(ConvLSTMCell_layer, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.padding = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias = bias

        self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)

    def forward(self, input_tensor, cur_state):
        # cur_state is a tuple
        h_cur, c_cur = cur_state

        combined = torch.cat([input_tensor, h_cur], dim=1)  # concatenate along channel axis

        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
        i = torch.sigmoid(cc_i)  # Input gate
        f = torch.sigmoid(cc_f)  # Forget gate
        o = torch.sigmoid(cc_o)  # Output gate
        g = torch.tanh(cc_g)     # Cell gate

        c_next = f * c_cur + i * g  # Update cell state
        h_next = o * torch.tanh(c_next)  # Update hidden state

        return h_next, c_next

## ============================= ResNet Helper Functions =============================

In [6]:
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)

def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

##  ============================= RLA Bottleneck Block =============================


In [7]:
class RLA_Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, 
                 rla_channel=32, attention_type=None, attention_param=None, 
                 groups=1, base_width=64, dilation=1, norm_layer=None):
        super(RLA_Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        
        # `planes * base_width / 64 * cardinality`
        width = int(planes * (base_width / 64.)) * groups
        
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes + rla_channel, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        
        self.averagePooling = None
        if downsample is not None and stride != 1:
            self.averagePooling = nn.AvgPool2d((2, 2), stride=(2, 2))
        
        # Initialize attention module based on specified type
        self.attention_module = None
        channels = planes * self.expansion
        
        if attention_type == 'SE':
            reduction = 16 if attention_param is None else attention_param
            self.attention_module = SELayer(channels, reduction)
        elif attention_type == 'ECA':
            k_size = 3 if attention_param is None else attention_param
            self.attention_module = eca_layer(channels, k_size)
        elif attention_type == 'CBAM':
            reduction = 16 if attention_param is None else attention_param
            self.attention_module = CBAM(channels, reduction)

    def forward(self, x, h, c):
        identity = x
        
        x = torch.cat((x, h), dim=1)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)
        
        # Apply attention module if specified
        if self.attention_module is not None:
            out = self.attention_module(out)
        
        y = out
        
        if self.downsample is not None:
            identity = self.downsample(identity)
        if self.averagePooling is not None:
            h = self.averagePooling(h)
            c = self.averagePooling(c)
        
        out += identity
        out = self.relu(out)

        return out, y, h, c

### ============================= RLAlstm_ResNet Network =============================


In [8]:
class RLAlstm_ResNet(nn.Module):
    
    def __init__(self, block, layers,artist_classes=1000,genre_classes=1000,style_classes=1000,
                 rla_channel=32, attention_type=None, attention_params=None,
                 zero_init_last_bn=True,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        """
        Initialize RLAlstm_ResNet model with flexible attention mechanisms.
        
        Parameters:
        ----------
        block : nn.Module
            The block module to use (e.g., RLA_Bottleneck)
        layers : list
            Number of blocks in each layer
        num_classes : int
            Number of output classes
        rla_channel : int
            Number of filters in RLA
        attention_type : str or None
            Type of attention to use: 'SE', 'ECA', 'CBAM', or None
        attention_params : list or None
            Parameters for attention modules for each stage
            For SE: reduction ratios
            For ECA: kernel sizes
            For CBAM: reduction ratios
        """
        super(RLAlstm_ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        
        # Default attention parameters if None provided
        if attention_params is None and attention_type is not None:
            if attention_type == 'SE':
                attention_params = [16, 16, 16, 16]  # Default reduction ratios
            elif attention_type == 'ECA':
                attention_params = [3, 5, 7, 9]  # Default kernel sizes
            elif attention_type == 'CBAM':
                attention_params = [16, 16, 16, 16]  # Default reduction ratios
        
        self.rla_channel = rla_channel
        self.flops = False
        self.attention_type = attention_type
        
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        conv_outs = [None] * 4
        recurrent_convs = [None] * 4
        stages = [None] * 4
        stage_bns = [None] * 4

        # Create the four stages of the network with appropriate attention mechanisms
        # For each stage, we use the corresponding attention parameter if provided
        stages[0], stage_bns[0], conv_outs[0], recurrent_convs[0] = self._make_layer(
            block, 64, layers[0], 
            rla_channel=rla_channel, 
            attention_type=attention_type, 
            attention_param=None if attention_params is None else attention_params[0]
        )
        
        stages[1], stage_bns[1], conv_outs[1], recurrent_convs[1] = self._make_layer(
            block, 128, layers[1], 
            rla_channel=rla_channel, 
            attention_type=attention_type, 
            attention_param=None if attention_params is None else attention_params[1],
            stride=2, 
            dilate=replace_stride_with_dilation[0]
        )
        
        stages[2], stage_bns[2], conv_outs[2], recurrent_convs[2] = self._make_layer(
            block, 256, layers[2], 
            rla_channel=rla_channel, 
            attention_type=attention_type, 
            attention_param=None if attention_params is None else attention_params[2],
            stride=2, 
            dilate=replace_stride_with_dilation[1]
        )
        
        stages[3], stage_bns[3], conv_outs[3], recurrent_convs[3] = self._make_layer(
            block, 512, layers[3], 
            rla_channel=rla_channel, 
            attention_type=attention_type, 
            attention_param=None if attention_params is None else attention_params[3],
            stride=2, 
            dilate=replace_stride_with_dilation[2]
        )
        
        self.conv_outs = nn.ModuleList(conv_outs)
        self.recurrent_convs = nn.ModuleList(recurrent_convs)
        self.stages = nn.ModuleList(stages)
        self.stage_bns = nn.ModuleList(stage_bns)
        
        self.tanh = nn.Tanh()
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1 = nn.Linear(512 * block.expansion + rla_channel, artist_classes)
        self.fc2 = nn.Linear(512 * block.expansion + rla_channel, genre_classes)
        self.fc3 = nn.Linear(512 * block.expansion + rla_channel, style_classes)

        # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch
        if zero_init_last_bn:
            for m in self.modules():
                if isinstance(m, RLA_Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                    
    def _make_layer(self, block, planes, blocks, 
                    rla_channel, attention_type, attention_param, stride=1, dilate=False):
        
        conv_out = conv1x1(planes * block.expansion, rla_channel)
        recurrent_convlstm = ConvLSTMCell_layer(rla_channel, rla_channel, (3, 3))
        
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, 
                            rla_channel=rla_channel, 
                            attention_type=attention_type, 
                            attention_param=attention_param,
                            groups=self.groups,
                            base_width=self.base_width, 
                            dilation=previous_dilation, 
                            norm_layer=norm_layer))
        
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, 
                                rla_channel=rla_channel, 
                                attention_type=attention_type, 
                                attention_param=attention_param,
                                groups=self.groups,
                                base_width=self.base_width, 
                                dilation=self.dilation,
                                norm_layer=norm_layer))

        bns = [norm_layer(rla_channel) for _ in range(blocks)]

        return nn.ModuleList(layers), nn.ModuleList(bns), conv_out, recurrent_convlstm
    
    def _forward_impl(self, x):
        # Initial feature extraction
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        # Initialize hidden states
        batch, _, height, width = x.size()
        if self.flops:  # For computing FLOPs and params
            h = torch.zeros(batch, self.rla_channel, height, width)
            c = torch.zeros(batch, self.rla_channel, height, width)
        else:
            h = torch.zeros(batch, self.rla_channel, height, width, device='cuda')
            c = torch.zeros(batch, self.rla_channel, height, width, device='cuda')

        # Process through stages with RLA
        for layers, bns, conv_out, recurrent_convlstm in zip(self.stages, self.stage_bns, self.conv_outs, self.recurrent_convs):    
            for layer, bn in zip(layers, bns):
                # Forward through RLA bottleneck block
                x, y, h, c = layer(x, h, c)
                
                # RLA module updates
                y_out = conv_out(y)
                y_out = bn(y_out)
                y_out = self.tanh(y_out)
                
                # Update hidden states using ConvLSTM
                h, c = recurrent_convlstm(y_out, (h, c))
        
        # Concatenate final feature maps with hidden state
        x = torch.cat((x, h), dim=1)
        
        # Global average pooling and classification
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        artist = self.fc1(x)
        genre = self.fc2(x)
        style = self.fc3(x)

        return artist,genre,style

    def forward(self, x):
        return self._forward_impl(x)    

### ============================= Model Loading  =============================


In [9]:
def rlalstm_resnet50(artist_classes=1000,genre_classes=1000,style_classes=1000, rla_channel=32, attention_type=None, attention_params=None):
    """
    Constructs a RLAlstm_ResNet-50 model with flexible attention mechanisms.
    
    Parameters:
    -----------
    num_classes : int
        Number of output classes
    rla_channel : int
        Number of channels in RLA module
    attention_type : str or None
        Type of attention to use: 'SE', 'ECA', 'CBAM', or None
    attention_params : list or None
        Parameters for attention modules for each stage
        - For SE: reduction ratios (default: [16, 16, 16, 16])
        - For ECA: kernel sizes (default: [3, 5, 7, 9])
        - For CBAM: reduction ratios (default: [16, 16, 16, 16])
    
    Returns:
    --------
    model : RLAlstm_ResNet
        The initialized model
    """
    print(f"Constructing rlalstm_resnet50 with {attention_type} attention...")
    model = RLAlstm_ResNet(
        RLA_Bottleneck, 
        [3, 4, 6, 3], 
        artist_classes=artist_classes,
        genre_classes=genre_classes,
        style_classes=style_classes,
        rla_channel=rla_channel,
        attention_type=attention_type,
        attention_params=attention_params
    )
    return model

## Model testing script

In [10]:
# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Print model summary
def model_summary(model):
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Model has {num_params:,} trainable parameters")

# Create dummy data for testing
def create_dummy_data(batch_size=8, img_size=224, num_classes=(23, 10, 16)):
    # Create random images
    inputs = torch.randn(batch_size, 3, img_size, img_size)
    
    # Create random labels for artist, style, and genre
    artist_labels = torch.randint(0, num_classes[0], (batch_size,))
    genre_labels = torch.randint(0, num_classes[1], (batch_size,))
    style_labels = torch.randint(0, num_classes[2], (batch_size,))
    
    return inputs, (artist_labels, genre_labels, style_labels)

# Test forward pass
def test_forward_pass(model, batch_size=8, img_size=224):
    print("\n=== Testing Forward Pass ===")
    model.eval()  # Set to evaluation mode
    
    # Generate dummy data
    inputs, _ = create_dummy_data(batch_size, img_size)
    inputs = inputs.to(device)
    
    # Measure inference time
    start_time = time.time()
    with torch.no_grad():
        artist_outputs, genre_outputs, style_outputs = model(inputs)
    inference_time = time.time() - start_time
    
    # Print shape information
    print(f"Input shape: {inputs.shape}")
    print(f"Artist output shape: {artist_outputs.shape}")
    print(f"Genre output shape: {genre_outputs.shape}")
    print(f"Style output shape: {style_outputs.shape}")
    print(f"Inference time for batch: {inference_time:.4f} seconds")
    print(f"Inference time per image: {inference_time/batch_size:.4f} seconds")
    
    return artist_outputs, genre_outputs, style_outputs

# Test backward pass (training)
def test_backward_pass(model, num_iterations=5, batch_size=8, img_size=224, lr=0.001):
    print("\n=== Testing Backward Pass (Training) ===")
    model.train()  # Set to training mode
    
    # Define optimizer and loss functions
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    # Training loop
    total_train_time = 0
    losses = []
    
    for i in range(num_iterations):
        # Generate dummy data
        inputs, (artist_labels, genre_labels, style_labels) = create_dummy_data(batch_size, img_size)
        inputs = inputs.to(device)
        artist_labels = artist_labels.to(device)
        genre_labels = genre_labels.to(device)
        style_labels = style_labels.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        start_time = time.time()
        artist_outputs, genre_outputs, style_outputs = model(inputs)
        
        # Calculate loss
        artist_loss = criterion(artist_outputs, artist_labels)
        genre_loss = criterion(genre_outputs, genre_labels)
        style_loss = criterion(style_outputs, style_labels)
        total_loss = artist_loss + genre_loss + style_loss
        
        # Backward pass and optimize
        total_loss.backward()
        optimizer.step()
        
        iteration_time = time.time() - start_time
        total_train_time += iteration_time
        
        # Print iteration details
        losses.append(total_loss.item())
        print(f"Iteration {i+1}/{num_iterations}, Loss: {total_loss.item():.4f}, Time: {iteration_time:.4f}s")
    
    avg_train_time = total_train_time / num_iterations
    print(f"\nAverage training time per iteration: {avg_train_time:.4f} seconds")
    print(f"Final loss: {losses[-1]:.4f}")
    
    return losses

# Memory usage tracking
def print_memory_usage():
    if torch.cuda.is_available():
        print("\n=== GPU Memory Usage ===")
        print(f"Allocated: {torch.cuda.memory_allocated(device) / 1024**3:.2f} GB")
        print(f"Cached: {torch.cuda.memory_reserved(device) / 1024**3:.2f} GB")
        torch.cuda.empty_cache()

# Test model on different input resolutions
def test_resolutions(model, resolutions=[224, 384, 512]):
    print("\n=== Testing Different Input Resolutions ===")
    for res in resolutions:
        print(f"\nTesting with resolution {res}x{res}")
        try:
            # Reduce batch size for higher resolutions to avoid OOM
            batch_size = 8 if res <= 224 else (4 if res <= 384 else 2)
            test_forward_pass(model, batch_size=batch_size, img_size=res)
            print_memory_usage()
        except RuntimeError as e:
            print(f"Error with resolution {res}x{res}: {e}")

# Full test suite
def run_tests(model):
    print("\n==== Running Full Test Suite ====")
    
    model_summary(model)
    
    # Basic forward pass test
    outputs = test_forward_pass(model)
    print_memory_usage()
    
    # Test backward pass (training)
    losses = test_backward_pass(model)
    print_memory_usage()
    
    # Test different resolutions
    test_resolutions(model)
    
    print("\n==== Tests Completed ====")

### ============================= Base Model Testing  =============================

In [11]:
model = rlalstm_resnet50(artist_classes=23,genre_classes=10,style_classes=16, rla_channel=32)
model.to(device)

print("Base Model testing started")

# Run the tests
try:
    run_tests(model)
except Exception as e:
    print(f"Error during testing: {e}")
finally:
    # Clean up
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        print("GPU cache cleared")

print("Base Model testing completed")

Constructing rlalstm_resnet50 with None attention...
Base Model testing started

==== Running Full Test Suite ====
Model has 24,149,649 trainable parameters

=== Testing Forward Pass ===
Input shape: torch.Size([8, 3, 224, 224])
Artist output shape: torch.Size([8, 23])
Genre output shape: torch.Size([8, 10])
Style output shape: torch.Size([8, 16])
Inference time for batch: 0.7825 seconds
Inference time per image: 0.0978 seconds

=== GPU Memory Usage ===
Allocated: 0.10 GB
Cached: 0.28 GB

=== Testing Backward Pass (Training) ===
Iteration 1/5, Loss: 8.4728, Time: 0.5036s
Iteration 2/5, Loss: 9.3464, Time: 0.0437s
Iteration 3/5, Loss: 9.1736, Time: 0.0407s
Iteration 4/5, Loss: 10.6813, Time: 0.0422s
Iteration 5/5, Loss: 10.7163, Time: 0.0353s

Average training time per iteration: 0.1331 seconds
Final loss: 10.7163

=== GPU Memory Usage ===
Allocated: 0.20 GB
Cached: 1.42 GB

=== Testing Different Input Resolutions ===

Testing with resolution 224x224

=== Testing Forward Pass ===
Input 

### ============================= Advance Model Testing  =============================

In [12]:
# With SE attention
model= rlalstm_resnet50(
    artist_classes=23,
    genre_classes=10,
    style_classes=16,
    rla_channel=32,
    attention_type='SE',
    attention_params=[16, 16, 16, 16]  # Reduction ratios for each stage
)
model.to(device)

print("SE attention Model testing started")

# Run the tests
try:
    run_tests(model)
except Exception as e:
    print(f"Error during testing: {e}")
finally:
    # Clean up
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        print("GPU cache cleared")

print("SE attention Model testing completed")

Constructing rlalstm_resnet50 with SE attention...
SE attention Model testing started

==== Running Full Test Suite ====
Model has 26,664,593 trainable parameters

=== Testing Forward Pass ===
Input shape: torch.Size([8, 3, 224, 224])
Artist output shape: torch.Size([8, 23])
Genre output shape: torch.Size([8, 10])
Style output shape: torch.Size([8, 16])
Inference time for batch: 0.0258 seconds
Inference time per image: 0.0032 seconds

=== GPU Memory Usage ===
Allocated: 0.12 GB
Cached: 0.49 GB

=== Testing Backward Pass (Training) ===
Iteration 1/5, Loss: 8.3382, Time: 0.1570s
Iteration 2/5, Loss: 8.8769, Time: 0.1051s
Iteration 3/5, Loss: 9.1787, Time: 0.0726s
Iteration 4/5, Loss: 11.4937, Time: 0.0731s
Iteration 5/5, Loss: 11.5075, Time: 0.0535s

Average training time per iteration: 0.0923 seconds
Final loss: 11.5075

=== GPU Memory Usage ===
Allocated: 0.22 GB
Cached: 1.63 GB

=== Testing Different Input Resolutions ===

Testing with resolution 224x224

=== Testing Forward Pass ===


In [13]:
#  With ECA attention
model = rlalstm_resnet50(
    artist_classes=23,
    genre_classes=10,
    style_classes=16,
    rla_channel=32,
    attention_type='ECA',
    attention_params=[3, 5, 7, 9]  # Kernel sizes for each stage
)
model.to(device)

print("ECA attention Model testing started")

# Run the tests
try:
    run_tests(model)
except Exception as e:
    print(f"Error during testing: {e}")
finally:
    # Clean up
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        print("GPU cache cleared")

print("ECA attention Model testing completed")

Constructing rlalstm_resnet50 with ECA attention...
ECA attention Model testing started

==== Running Full Test Suite ====
Model has 24,149,747 trainable parameters

=== Testing Forward Pass ===
Input shape: torch.Size([8, 3, 224, 224])
Artist output shape: torch.Size([8, 23])
Genre output shape: torch.Size([8, 10])
Style output shape: torch.Size([8, 16])
Inference time for batch: 0.0470 seconds
Inference time per image: 0.0059 seconds

=== GPU Memory Usage ===
Allocated: 0.11 GB
Cached: 0.38 GB

=== Testing Backward Pass (Training) ===
Iteration 1/5, Loss: 8.4708, Time: 0.1470s
Iteration 2/5, Loss: 8.6740, Time: 0.0535s
Iteration 3/5, Loss: 9.2979, Time: 0.0440s
Iteration 4/5, Loss: 10.7206, Time: 0.0475s
Iteration 5/5, Loss: 13.0363, Time: 0.0433s

Average training time per iteration: 0.0671 seconds
Final loss: 13.0363

=== GPU Memory Usage ===
Allocated: 0.20 GB
Cached: 1.60 GB

=== Testing Different Input Resolutions ===

Testing with resolution 224x224

=== Testing Forward Pass ==

In [14]:
#  With CBAM attention
model = rlalstm_resnet50(
    artist_classes=23,
    genre_classes=10,
    style_classes=16,
    rla_channel=32,
    attention_type='CBAM',
    attention_params=[16, 16, 16, 16]  # Reduction ratios for each stage
)
model.to(device)

print("CBAM attention Model testing started")

# Run the tests
try:
    run_tests(model)
except Exception as e:
    print(f"Error during testing: {e}")
finally:
    # Clean up
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        print("GPU cache cleared")

print("CBAM attention Model testing completed")

Constructing rlalstm_resnet50 with CBAM attention...
CBAM attention Model testing started

==== Running Full Test Suite ====
Model has 26,682,241 trainable parameters

=== Testing Forward Pass ===
Input shape: torch.Size([8, 3, 224, 224])
Artist output shape: torch.Size([8, 23])
Genre output shape: torch.Size([8, 10])
Style output shape: torch.Size([8, 16])
Inference time for batch: 0.0508 seconds
Inference time per image: 0.0064 seconds

=== GPU Memory Usage ===
Allocated: 0.12 GB
Cached: 0.44 GB

=== Testing Backward Pass (Training) ===
Iteration 1/5, Loss: 8.5306, Time: 0.1107s
Iteration 2/5, Loss: 8.5557, Time: 0.0749s
Iteration 3/5, Loss: 9.5770, Time: 0.0747s
Iteration 4/5, Loss: 11.7860, Time: 0.0743s
Iteration 5/5, Loss: 10.5449, Time: 0.0747s

Average training time per iteration: 0.0819 seconds
Final loss: 10.5449

=== GPU Memory Usage ===
Allocated: 0.23 GB
Cached: 1.76 GB

=== Testing Different Input Resolutions ===

Testing with resolution 224x224

=== Testing Forward Pass 