# **HybridNet Scratch Implementation**

_____
### **Imports**

In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F

_____
### **Conv layers**

In [21]:
class Conv_i(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1):
        super(Conv_i, self).__init__()
        
        # Convolution layer
        self.conv = nn.Conv2d(in_channels=in_channels, 
                              out_channels=out_channels, 
                              kernel_size=kernel_size, 
                              stride=stride, 
                              padding=padding, 
                              bias=False)
        
        # Batch normalization
        self.bn = nn.BatchNorm2d(out_channels)
        
        # Activation (ReLU)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)    # Apply convolution
        x = self.bn(x)      # Apply batch normalization
        x = self.relu(x)    # Apply ReLU activation
        return x
    



class Conv_ab(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0):
        super(Conv_ab, self).__init__()
        
        # Convolution layer
        self.conv = nn.Conv2d(in_channels=in_channels, 
                              out_channels=out_channels, 
                              kernel_size=kernel_size, 
                              stride=stride, 
                              padding=padding, 
                              bias=False)
        
        # Batch normalization
        self.bn = nn.BatchNorm2d(out_channels)
        
        # Activation (ReLU)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)    # Apply convolution
        x = self.bn(x)      # Apply batch normalization
        x = self.relu(x)    # Apply ReLU activation
        return x
    


    
class Conv_c(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1):
        super(Conv_c, self).__init__()
        
        # Convolution layer
        self.conv = nn.Conv2d(in_channels=in_channels, 
                              out_channels=out_channels, 
                              kernel_size=kernel_size, 
                              stride=stride, 
                              padding=padding, 
                              bias=False)
        
        # Batch normalization
        self.bn = nn.BatchNorm2d(out_channels)
        
        # Activation (ReLU)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)    # Apply convolution
        x = self.bn(x)      # Apply batch normalization
        x = self.relu(x)    # Apply ReLU activation
        return x

_____
### **Dummy input to conv layers**

In [22]:
# Example input output
conv_i = Conv_i(in_channels=3, out_channels=128)

# Dummy input of shape [batch_size, channels, height, width], e.g., (1, 3, 5000, 5000)
dummy_input = torch.randn(1, 3, 5000, 5000)
output = conv_i(dummy_input)
print(f"Output shape: {output.shape}")

Output shape: torch.Size([1, 128, 2500, 2500])


In [23]:
# Example input output
conv_a1 = Conv_ab(in_channels=256, out_channels=256)
# Dummy input of shape [batch_size, channels, height, width], e.g., (1, 3, 5000, 5000)
dummy_input = torch.randn(1, 256, 1250, 1250)
output = conv_a1(dummy_input)
print(f"Conv_a1 Output shape: {output.shape}")

Conv_a1 Output shape: torch.Size([1, 256, 1250, 1250])


In [24]:
conv_b1 = Conv_ab(in_channels=512, out_channels=256)
# Dummy input of shape [batch_size, channels, height, width], e.g., (1, 3, 5000, 5000)
dummy_input_ = torch.randn(1, 512, 625, 625)
output_ = conv_b1(dummy_input_)
print(f"Conv_b1 Output shape: {output_.shape}")

Conv_b1 Output shape: torch.Size([1, 256, 625, 625])


In [25]:
#Example input output
conv_c1 = Conv_c(in_channels=256, out_channels=256)
# Dummy input of shape [batch_size, channels, height, width], e.g., (1, 3, 5000, 5000)
dummy_input = torch.randn(1, 256, 625, 625)
output = conv_c1(dummy_input)
print(f"Conv_c1 Output shape: {output.shape}")


Conv_c1 Output shape: torch.Size([1, 256, 313, 313])


_____
### **EMB_a block**
- Dummy MBConv block
- Block C_out scaling factor = "single" or "double"
- Merge cardinality type: K=1, S=1, P=0

In [26]:
class EMB_a(nn.Module):
    def __init__(self, in_channels, retainment_factor, out_channels, block_type="single"):
        super(EMB_a, self).__init__()
        
        self.retainment_factor = retainment_factor
        
        # Retainment calculation
        retain_channels = int(in_channels * (retainment_factor / 100))
        non_retain_channels = in_channels - retain_channels
        
        # Two consecutive conv blocks
        self.conv_block1 = nn.Sequential(
            nn.Conv2d(non_retain_channels, non_retain_channels, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(non_retain_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(non_retain_channels, non_retain_channels, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(non_retain_channels),
            nn.ReLU(inplace=True)
        )
        
        # Upsampling paths (4x and 16x)
        self.upsample_4x = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False)
        self.upsample_16x = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=False)
        
        # MBConv Block placeholder (define separately)
        self.mbconv_block = self._define_mbconv_block(non_retain_channels, non_retain_channels)
        
        # Shuffle cardinality using group conv
        total_channels = retain_channels + 3 * non_retain_channels  # Retain, 4x, 16x, and MBConv outputs combined
        self.group_conv1 = nn.Conv2d(total_channels, total_channels, kernel_size=1, stride=1, padding=0, bias=False)
        self.group_conv2 = nn.Conv2d(total_channels, total_channels, kernel_size=1, stride=1, padding=0, bias=False)
        
        # Merge cardinality conv
        factor = 1 if block_type == "single" else 2
        self.merge_conv = nn.Conv2d(total_channels, factor * in_channels, kernel_size=1, stride=1, padding=0, bias=False)
        
    def _define_mbconv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        retain_channels = int(x.size(1) * (self.retainment_factor / 100))
        non_retain_channels = x.size(1) - retain_channels
        
        print(f"Input shape: {x.shape}")
        print(f"Retain Channels: {retain_channels}, Non-Retain Channels: {non_retain_channels}")
        
        retain_part, non_retain_part = torch.split(x, [retain_channels, non_retain_channels], dim=1)
        
        print(f"Retain part shape: {retain_part.shape}")
        print(f"Non-retain part shape: {non_retain_part.shape}")
        
        conv1_output = self.conv_block1(non_retain_part)
        conv2_output = self.conv_block1(conv1_output)
        
        upsample_4x_output = self.upsample_4x(conv1_output)
        upsample_16x_output = self.upsample_16x(conv2_output)
        
        # Crop the upsampled outputs to match the input size
        upsample_4x_output = self._crop_to_input_size(upsample_4x_output, x.size(2), x.size(3))
        upsample_16x_output = self._crop_to_input_size(upsample_16x_output, x.size(2), x.size(3))
        
        mbconv_output = self.mbconv_block(non_retain_part)
        
        combined = torch.cat([retain_part, mbconv_output, upsample_4x_output, upsample_16x_output], dim=1)
        
        combined = self.group_conv1(combined)
        combined = self.group_conv2(combined)
        
        output = self.merge_conv(combined)
        
        return output

    def _crop_to_input_size(self, x, target_height, target_width):
        """Crop tensor `x` to have the specified target height and width."""
        _, _, h, w = x.size()
        crop_h = (h - target_height) // 2
        crop_w = (w - target_width) // 2
        return x[:, :, crop_h:crop_h + target_height, crop_w:crop_w + target_width]


____
### **EMB_b block**
- Dummy MBConv block
- Block C_out scaling factor = "single" or "double"
- Merge cardinality type: K=3, S=2, P=1

In [27]:
class EMB_b(nn.Module):
    def __init__(self, in_channels, retainment_factor, out_channels, block_type="single"):
        super(EMB_b, self).__init__()
        
        # Store retainment factor as a percentage
        self.retainment_factor = retainment_factor
        
        # Retainment calculation
        retain_channels = int(in_channels * (retainment_factor / 100))
        non_retain_channels = in_channels - retain_channels
        
        # Two consecutive conv blocks
        self.conv_block1 = nn.Sequential(
            nn.Conv2d(non_retain_channels, non_retain_channels, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(non_retain_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(non_retain_channels, non_retain_channels, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(non_retain_channels),
            nn.ReLU(inplace=True)
        )
        
        # Upsampling paths (4x and 16x)
        self.upsample_4x = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False)
        self.upsample_16x = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=False)
        
        # MBConv Block placeholder (define separately)
        self.mbconv_block = self._define_mbconv_block(non_retain_channels, non_retain_channels)
        
        # Shuffle cardinality using group conv
        total_channels = retain_channels + 3 * non_retain_channels  # Retain, 4x, 16x, and MBConv outputs combined
        self.group_conv1 = nn.Conv2d(total_channels, total_channels, kernel_size=1, stride=1, padding=0, bias=False)
        self.group_conv2 = nn.Conv2d(total_channels, total_channels, kernel_size=1, stride=1, padding=0, bias=False)
        
        # Merge cardinality conv
        factor = 1 if block_type == "single" else 2
        self.merge_conv = nn.Conv2d(total_channels, factor * in_channels, kernel_size=3, stride=2, padding=1, bias=False)
        
    def _define_mbconv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        retain_channels = int(x.size(1) * (self.retainment_factor / 100))
        non_retain_channels = x.size(1) - retain_channels
        
        print(f"Input shape: {x.shape}")
        print(f"Retain Channels: {retain_channels}, Non-Retain Channels: {non_retain_channels}")
        
        retain_part, non_retain_part = torch.split(x, [retain_channels, non_retain_channels], dim=1)
        
        print(f"Retain part shape: {retain_part.shape}")
        print(f"Non-retain part shape: {non_retain_part.shape}")
        
        conv1_output = self.conv_block1(non_retain_part)
        conv2_output = self.conv_block1(conv1_output)
        
        upsample_4x_output = self.upsample_4x(conv1_output)
        upsample_16x_output = self.upsample_16x(conv2_output)
        
        # Crop the upsampled outputs to match the input size
        upsample_4x_output = self._crop_to_input_size(upsample_4x_output, x.size(2), x.size(3))
        upsample_16x_output = self._crop_to_input_size(upsample_16x_output, x.size(2), x.size(3))
        
        mbconv_output = self.mbconv_block(non_retain_part)
        
        combined = torch.cat([retain_part, mbconv_output, upsample_4x_output, upsample_16x_output], dim=1)
        
        combined = self.group_conv1(combined)
        combined = self.group_conv2(combined)
        
        output = self.merge_conv(combined)
        
        return output

    def _crop_to_input_size(self, x, target_height, target_width):
        """Crop tensor `x` to have the specified target height and width."""
        _, _, h, w = x.size()
        crop_h = (h - target_height) // 2
        crop_w = (w - target_width) // 2
        return x[:, :, crop_h:crop_h + target_height, crop_w:crop_w + target_width]


_____
### **Dummy input testing**

In [28]:
channels = 20

elan_mbconv_block = EMB_a(in_channels=channels, retainment_factor=30, out_channels=channels, block_type="single")

# Dummy image tensor with shape [batch_size, channels, height, width]
dummy_image = torch.randn(1, channels, 2500, 2500) 

# Passing the dummy image through the ELANMBConv block
output = elan_mbconv_block(dummy_image)

# output shape
print(f"Output shape: {output.shape}")


Input shape: torch.Size([1, 20, 2500, 2500])
Retain Channels: 6, Non-Retain Channels: 14
Retain part shape: torch.Size([1, 6, 2500, 2500])
Non-retain part shape: torch.Size([1, 14, 2500, 2500])
Output shape: torch.Size([1, 20, 2500, 2500])


In [29]:
channels = 20

elan_mbconv_block = EMB_b(in_channels=channels, retainment_factor=30, out_channels=channels, block_type="double")

# Dummy image tensor with shape [batch_size, channels, height, width]
dummy_image = torch.randn(1, channels, 2500, 2500) 

# Passing the dummy image through the ELANMBConv block
output = elan_mbconv_block(dummy_image)

# output shape
print(f"Output shape: {output.shape}")


Input shape: torch.Size([1, 20, 2500, 2500])
Retain Channels: 6, Non-Retain Channels: 14
Retain part shape: torch.Size([1, 6, 2500, 2500])
Non-retain part shape: torch.Size([1, 14, 2500, 2500])
Output shape: torch.Size([1, 40, 1250, 1250])


_____
### **MBConv Block**

In [30]:
class ConvBlockMB(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, exp=1, act=True):
        super(ConvBlockMB, self).__init__()
        
        # Expansion phase (1x1 convolution)
        exp_channels = in_channels * exp
        self.expansion = nn.Conv2d(in_channels, exp_channels, kernel_size=1, stride=1, bias=False) if exp > 1 else nn.Identity()
        self.bn1 = nn.BatchNorm2d(exp_channels)
        self.act = nn.ReLU() if act else nn.Identity()
        
        # Depthwise convolution
        self.depthwise = nn.Conv2d(exp_channels, exp_channels, kernel_size, stride, padding=(kernel_size // 2), groups=exp_channels, bias=False)
        self.bn2 = nn.BatchNorm2d(exp_channels)
        
    def forward(self, x):
        x = self.expansion(x)
        x = self.bn1(x)
        x = self.act(x)
        
        # Apply depthwise convolution
        x = self.depthwise(x)
        x = self.bn2(x)
        return x

class SeBlock(nn.Module):
    def __init__(self, in_channels, reduction=4):
        super(SeBlock, self).__init__()
        reduced_channels = in_channels // reduction
        
        self.squeeze = nn.AdaptiveAvgPool2d(1)  # Global average pooling
        self.fc1 = nn.Conv2d(in_channels, reduced_channels, kernel_size=1)
        self.fc2 = nn.Conv2d(reduced_channels, in_channels, kernel_size=1)
        self.silu = nn.SiLU()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        scale = self.squeeze(x)
        scale = self.silu(self.fc1(scale))
        scale = self.sigmoid(self.fc2(scale))
        return x * scale

class StochasticDepth(nn.Module):
    def __init__(self, p: float = 0.5):
        super(StochasticDepth, self).__init__()
        self.p = p

    def forward(self, x):
        if not self.training or self.p == 0.0:
            return x
        keep_prob = 1 - self.p
        mask = torch.rand(x.shape[0], 1, 1, 1, device=x.device) < keep_prob
        return x / keep_prob * mask

class MBConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, exp=1, reduction=4, sd_prob=0.0):
        super(MBConv, self).__init__()
        exp_channels = in_channels * exp
        self.add_skip = (in_channels == out_channels) and (stride == 1)

        # Expansion phase (1x1 conv if exp > 1)
        self.conv1 = ConvBlockMB(in_channels, exp_channels, 1, 1, exp=exp)
        
        # Depthwise convolution
        self.depthwise_conv = nn.Conv2d(exp_channels, exp_channels, kernel_size, stride=stride, padding=(kernel_size // 2), groups=exp_channels, bias=False)
        self.bn = nn.BatchNorm2d(exp_channels)

        # Squeeze-and-Excitation block
        self.se = SeBlock(exp_channels, reduction=reduction)
        
        # Projection phase (1x1 convolution to project back to output channels)
        self.conv2 = nn.Conv2d(exp_channels, out_channels, kernel_size=1, stride=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # Stochastic depth
        self.sd = StochasticDepth(sd_prob)

        # Skip connection if in_channels == out_channels and stride == 1
        if self.add_skip:
            self.skip_connection = nn.Identity()
        else:
            self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)

    def forward(self, x):
        shortcut = x

        # Expansion + Depthwise conv + SE block
        x = self.conv1(x)
        x = self.depthwise_conv(x)
        x = self.bn(x)
        x = self.se(x)
        
        # Project to out_channels
        x = self.conv2(x)
        x = self.bn2(x)

        # Stochastic depth and skip connection
        if self.add_skip:
            x = x + shortcut
        x = self.sd(x)

        return x

_____
### **ELAN-MBConv block with Original MBConv**

In [31]:
class EMB_A(nn.Module):
    def __init__(self, in_channels, retainment_factor, out_channels, block_type="single", exp=4, sd_prob=0.1):
        super(EMB_A, self).__init__()
        
        # Store retainment factor as a percentage
        self.retainment_factor = retainment_factor
        
        # Retainment calculation
        retain_channels = int(in_channels * (retainment_factor / 100))
        non_retain_channels = in_channels - retain_channels
        
        # Two consecutive conv blocks
        self.conv_block1 = nn.Sequential(
            nn.Conv2d(non_retain_channels, non_retain_channels, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(non_retain_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(non_retain_channels, non_retain_channels, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(non_retain_channels),
            nn.ReLU(inplace=True)
        )
        
        # Upsampling paths (4x and 16x)
        self.upsample_4x = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False)
        self.upsample_16x = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=False)
        
        # MBConv Block with updated, simplified code
        self.mbconv_block = MBConv(
            in_channels=non_retain_channels, 
            out_channels=non_retain_channels,  # Typically the output is same as input for residual blocks
            kernel_size=3, 
            stride=1, 
            exp=exp, 
            sd_prob=sd_prob
        )
        
        # Shuffle cardinality using group conv
        total_channels = retain_channels + 3 * non_retain_channels  # Retain, 4x, 16x, and MBConv outputs combined
        self.group_conv1 = nn.Conv2d(total_channels, total_channels, kernel_size=1, stride=1, padding=0, bias=False)
        self.group_conv2 = nn.Conv2d(total_channels, total_channels, kernel_size=1, stride=1, padding=0, bias=False)
        
        # Merge cardinality conv
        factor = 1 if block_type == "single" else 2
        self.merge_conv = nn.Conv2d(total_channels, factor * in_channels, kernel_size=1, stride=1, padding=0, bias=False)
        

    def forward(self, x):
        # Split the channels based on the retainment factor
        retain_channels = int(x.size(1) * (self.retainment_factor / 100))
        non_retain_channels = x.size(1) - retain_channels
        
        retain_part, non_retain_part = torch.split(x, [retain_channels, non_retain_channels], dim=1)
        
        # Process non-retained part: conv -> copy -> conv again
        conv1_output = self.conv_block1(non_retain_part)
        conv2_output = self.conv_block1(conv1_output)
        
        # Upsample outputs
        upsample_4x_output = self.upsample_4x(conv1_output)
        upsample_16x_output = self.upsample_16x(conv2_output)
        
        # Resize upsampled outputs to match retain_part size (height and width)
        upsample_4x_output = F.interpolate(upsample_4x_output, size=(retain_part.size(2), retain_part.size(3)), mode='bilinear', align_corners=False)
        upsample_16x_output = F.interpolate(upsample_16x_output, size=(retain_part.size(2), retain_part.size(3)), mode='bilinear', align_corners=False)
        
        # MBConv processing
        mbconv_output = self.mbconv_block(non_retain_part)
        
        # Combine outputs: retain + mbconv + 4x upsample + 16x upsample
        combined = torch.cat([retain_part, mbconv_output, upsample_4x_output, upsample_16x_output], dim=1)
        
        # Shuffle cardinality (group conv)
        combined = self.group_conv1(combined)
        combined = self.group_conv2(combined)
        
        # Merge cardinality conv
        output = self.merge_conv(combined)
        
        return output


In [33]:
class EMB_B(nn.Module):
    def __init__(self, in_channels, retainment_factor, out_channels, block_type="single", exp=4, sd_prob=0.1):
        super(EMB_B, self).__init__()
        
        # Store retainment factor as a percentage
        self.retainment_factor = retainment_factor
        
        # Retainment calculation
        retain_channels = int(in_channels * (retainment_factor / 100))
        non_retain_channels = in_channels - retain_channels
        
        # Two consecutive conv blocks
        self.conv_block1 = nn.Sequential(
            nn.Conv2d(non_retain_channels, non_retain_channels, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(non_retain_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(non_retain_channels, non_retain_channels, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(non_retain_channels),
            nn.ReLU(inplace=True)
        )
        
        # Upsampling paths (4x and 16x)
        self.upsample_4x = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False)
        self.upsample_16x = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=False)
        
        # MBConv Block with updated, simplified code
        self.mbconv_block = MBConv(
            in_channels=non_retain_channels, 
            out_channels=non_retain_channels,  # Typically the output is same as input for residual blocks
            kernel_size=3, 
            stride=1, 
            exp=exp, 
            sd_prob=sd_prob
        )
        
        # Shuffle cardinality using group conv
        total_channels = retain_channels + 3 * non_retain_channels  # Retain, 4x, 16x, and MBConv outputs combined
        self.group_conv1 = nn.Conv2d(total_channels, total_channels, kernel_size=1, stride=1, padding=0, bias=False)
        self.group_conv2 = nn.Conv2d(total_channels, total_channels, kernel_size=1, stride=1, padding=0, bias=False)
        
        # Merge cardinality conv
        factor = 1 if block_type == "single" else 2
        self.merge_conv = nn.Conv2d(total_channels, factor * in_channels, kernel_size=3, stride=2, padding=1, bias=False)
        

    def forward(self, x):
        # Split the channels based on the retainment factor
        retain_channels = int(x.size(1) * (self.retainment_factor / 100))
        non_retain_channels = x.size(1) - retain_channels
        
        retain_part, non_retain_part = torch.split(x, [retain_channels, non_retain_channels], dim=1)
        
        # Process non-retained part: conv -> copy -> conv again
        conv1_output = self.conv_block1(non_retain_part)
        conv2_output = self.conv_block1(conv1_output)
        
        # Upsample outputs
        upsample_4x_output = self.upsample_4x(conv1_output)
        upsample_16x_output = self.upsample_16x(conv2_output)
        
        # Resize upsampled outputs to match retain_part size (height and width)
        upsample_4x_output = F.interpolate(upsample_4x_output, size=(retain_part.size(2), retain_part.size(3)), mode='bilinear', align_corners=False)
        upsample_16x_output = F.interpolate(upsample_16x_output, size=(retain_part.size(2), retain_part.size(3)), mode='bilinear', align_corners=False)
        
        # MBConv processing
        mbconv_output = self.mbconv_block(non_retain_part)
        
        # Combine outputs: retain + mbconv + 4x upsample + 16x upsample
        combined = torch.cat([retain_part, mbconv_output, upsample_4x_output, upsample_16x_output], dim=1)
        
        # Shuffle cardinality (group conv)
        combined = self.group_conv1(combined)
        combined = self.group_conv2(combined)
        
        # Merge cardinality conv
        output = self.merge_conv(combined)
        
        return output


_____
### **Dummy input testing**

In [32]:
channels = 100

elan_mbconv_block = EMB_A(in_channels=channels, retainment_factor=50, out_channels=channels, block_type="single")

# Dummy image tensor with shape [batch_size, channels, height, width]
dummy_image = torch.randn(1, channels, 1250, 1250) 

# Passing the dummy image through the ELANMBConv block
output = elan_mbconv_block(dummy_image)

# output shape
print(f"Output shape: {output.shape}")

Output shape: torch.Size([1, 100, 1250, 1250])


In [34]:
channels = 50

elan_mbconv_block = EMB_B(in_channels=channels, retainment_factor=50, out_channels=channels, block_type="double")

# Dummy image tensor with shape [batch_size, channels, height, width]
dummy_image = torch.randn(1, channels, 1250, 1250) 

# Passing the dummy image through the ELANMBConv block
output = elan_mbconv_block(dummy_image)

# output shape
print(f"Output shape: {output.shape}")

Output shape: torch.Size([1, 100, 625, 625])


## **Backbone**

In [None]:
class Backbone(nn.Module):
    def __init__():