In [1]:
import torch
from torch import nn

In [29]:
class SEBlock(nn.Module):
    """
    Implementation of the Squeeze-and-Excitation (SE) block proposed in [1].
    Parameters
    ----------
    in_channels : int
        Number of channels in the input tensor.
    reduction : int, optional, default=16
        Reduction ratio to control the intermediate channel dimension.
    References
    ----------
    1. "`Squeeze-and-Excitation Networks. <https://arxiv.org/abs/1709.01507>`_" Jie Hu, et al. CVPR 2018.
    """

    def __init__(
        self,
        in_channels: int,
        reduction: int = 16
    ) -> None:
        super(SEBlock, self).__init__()

        out_channels = in_channels // reduction
        self.squeeze = nn.AdaptiveAvgPool2d(1)
        self.excitation = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, 1),
            nn.ReLU(),
            nn.Conv2d(out_channels, in_channels, 1, 1),
            nn.Sigmoid()
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Parameters
        ----------
        x : torch.Tensor (batch_size, in_channels, height, width)
            Input tensor.
        Returns
        -------
        out : torch.Tensor (batch_size, in_channels, height, width)
            Output of the SK convolution layer.
        """
        # x: [b, c, h, w]

        z = self.squeeze(x)
        s = self.excitation(z)
        out =  x * s
        return out

In [30]:
features = torch.rand(1, 32, 25, 25)
out = SEBlock(32)
out(features).shape

torch.Size([1, 32, 25, 25])

In [103]:
import torch
from torch import nn
from typing import List, Optional
from einops import rearrange

class SKConv(nn.Module):
    """
    Implementation of the Selective Kernel (SK) Convolution proposed in [1].
    Parameters
    ----------
    in_channels : int
        Number of channels in the input tensor.
    out_channels : int
        Number of channels produced by the convolution.
    kernels : List[int], optional, default=[3, 5]
        List of kernel sizes for each branch.
    reduction : int, optional, default=16
        Reduction ratio to control the dimension of "compact feature" ``z`` (see eq.4).
    L : int, optional, default=32
        Minimal value of the dimension of "compact feature" ``z`` (see eq.4).
    groups : int, optional, default=32
        Hyperparameter for ``torch.nn.Conv2d``.
    References
    ----------
    1. "`Selective Kernel Networks. <https://arxiv.org/abs/1903.06586>`_" Xiang Li, et al. CVPR 2019.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: Optional[int] = None,
        kernels: List[int] = [3, 5],
        reduction: int = 16,
        L: int = 32,
        groups: int = 32
    ) -> None:
        super(SKConv, self).__init__()

        if out_channels is None:
            out_channels = in_channels
        self.out_channels = out_channels

        d = max([in_channels // reduction, L])

        self.M = len(kernels)

        self.convs = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(
                    in_channels,
                    out_channels,
                    kernel_size=3,
                    dilation=k//2,
                    padding=k//2
                ),
                nn.BatchNorm2d(out_channels),
                nn.ReLU()
            )
            for k in kernels
        ])

        self.pool = nn.AdaptiveAvgPool2d(1)

        self.fc_z = nn.Sequential(
            nn.Conv2d(out_channels, d, 1, 1),
            nn.BatchNorm2d(d),
            nn.ReLU()
            
        )
        self.fc_attn = nn.Conv2d(d, out_channels * self.M, 1)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Parameters
        ----------
        x : torch.Tensor (batch_size, in_channels, height, width)
            Input tensor.
        Returns
        -------
        out : torch.Tensor (batch_size, out_channels, height, width)
            Output of the SK convolution layer.
        """
        #Conv2d , AvgPoll, softmax, ReLU, BatchNorm, Linear

        # ----- split -----
        # x: [b, c, h, w]
        feats = torch.cat(tuple(conv(x).unsqueeze(1) for conv in self.convs), dim=1) # [b, M, c, h, w]

        # ----- fuse -----
        # eq.1
        U = torch.sum(feats, dim = 1)
        # channel-wise statistics, eq.2
        s = self.pool(U)  #s: [b, c]
        # compact feature, eq.3
        z = self.fc_z(s) # z [b, d]

        # ----- select -----
        batch_size, out_channels = s.shape[:2]

        # attention map, eq.5
        score = self.fc_attn(z)  # (batch_size, M * out_channels)
        score = rearrange(score, 'b (M C) 1 1-> b M C 1 1', M = self.M)  # (batch_size, M, out_channels, 1, 1)
        att = self.softmax(score)


        # fuse multiple branches, eq.6
        out = torch.sum(feats * att, dim = 1)  # (batch_size, out_channels, height, width)
        return out

In [104]:
features = torch.rand(2, 34*16, 25, 25)
out = SKConv(34*16).eval()
out(features).shape

torch.Size([2, 544, 25, 25])