План семинара **"Реализация собственных операторов на языке Python в фреймворке Pytorch"**
1. [Squeeze-and-Excitation (SE) Block](https://arxiv.org/abs/1709.01507)
2. [Selective Kernel (SK) Convolution](https://arxiv.org/abs/1903.06586)

# Squeeze-and-Excitation (SE) Block

“Squeeze-and-Excitation” (SE) block can adaptively recalibrates
channel-wise feature responses by explicitly modelling interdependencies between channels.

In [50]:
import torch
from torch import nn

class SEBlock(nn.Module):
    """
    Implementation of the Squeeze-and-Excitation (SE) block proposed in [1].
    Parameters
    ----------
    in_channels : int
        Number of channels in the input tensor.
    reduction : int, optional, default=16
        Reduction ratio to control the intermediate channel dimension.
    References
    ----------
    1. "`Squeeze-and-Excitation Networks. <https://arxiv.org/abs/1709.01507>`_" Jie Hu, et al. CVPR 2018.
    """

    def __init__(
        self,
        in_channels: int,
        reduction: int = 16
    ) -> None:
        super(SEBlock, self).__init__()

        middle_channels = max(4, in_channels // reduction)
        self.squeeze = nn.AdaptiveAvgPool2d((1, 1))
        # self.excitation = nn.Sequential(
        #         nn.Linear(in_channels, out_channels),
        #         nn.ReLU(),
        #         nn.Linear(out_channels, in_channels),
        #         nn.Sigmoid()
        #     )
        self.excitation = nn.Sequential(
            nn.Conv2d(in_channels, middle_channels, 1, 1),
            nn.ReLU(),
            nn.Conv2d(middle_channels, in_channels, 1, 1),
            nn.Sigmoid()
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Parameters
        ----------
        x : torch.Tensor (batch_size, in_channels, height, width)
            Input tensor.
        Returns
        -------
        out : torch.Tensor (batch_size, in_channels, height, width)
            Output of the SK convolution layer.
        """
        # x: [b, c, h, w]
        # z = 1/(x.size(dim=2) * x.size(dim=3)) * (torch.sum(x, dim=(2, 3)))
        # z = torch.reshape(self.squeeze(x), (x.size(dim=0), x.size(dim=1)))
        z = self.squeeze(x)
        # print(z.size())
        # print(z)
        s = self.excitation(z)
        # print(s.size(), x.size())
        s = torch.reshape(s, (x.size(dim=0), x.size(dim=1), 1, 1))
        out = s * x
        return out

In [51]:
features = torch.rand(1, 32, 30, 28)
out = SEBlock(32)
out(features).shape

torch.Size([1, 32, 30, 28])

# Selective Kernel (SK) Convolution

To enable the neurons to adaptively adjust their RF sizes,
we propose an automatic selection operation, “Selective
Kernel” (SK) convolution, among multiple kernels with different kernel sizes

In [125]:
import torch
from torch import nn
from typing import List, Optional
from einops.layers.torch import Rearrange, Reduce

class SKConv(nn.Module):
    """
    Implementation of the Selective Kernel (SK) Convolution proposed in [1].
    Parameters
    ----------
    in_channels : int
        Number of channels in the input tensor.
    out_channels : int
        Number of channels produced by the convolution.
    kernels : List[int], optional, default=[3, 5]
        List of kernel sizes for each branch.
    reduction : int, optional, default=16
        Reduction ratio to control the dimension of "compact feature" ``z`` (see eq.4).
    L : int, optional, default=32
        Minimal value of the dimension of "compact feature" ``z`` (see eq.4).
    groups : int, optional, default=32
        Hyperparameter for ``torch.nn.Conv2d``.
    References
    ----------
    1. "`Selective Kernel Networks. <https://arxiv.org/abs/1903.06586>`_" Xiang Li, et al. CVPR 2019.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: Optional[int] = None,
        kernels: List[int] = [3, 5],
        reduction: int = 16,
        L: int = 32,
        groups: int = 32
    ) -> None:
        super(SKConv, self).__init__()

        if out_channels is None:
            out_channels = in_channels
        self.out_channels = out_channels

        d = max(out_channels // reduction, L) # eq.4

        self.M = len(kernels)

        self.convs = nn.ModuleList([
                nn.Sequential(
                  nn.Conv2d(in_channels, out_channels, kernel_size=3, dilation=(k // 2), stride=1, padding_mode="zeros", padding=(k // 2)),
                  nn.BatchNorm2d(num_features=in_channels),
                  nn.ReLU(),
            )
            for k in kernels
        ])

        self.pool = nn.AdaptiveAvgPool2d((1, 1))

        self.fc_z = nn.Sequential(
            # nn.Conv2d(in_channels=in_channels, out_channels=d, kernel_size=1),
            nn.Linear(in_features=out_channels, out_features=d),
            nn.BatchNorm1d(num_features=d),
            nn.ReLU(),
        )
        self.fc_attn = nn.Linear(in_features=d, out_features=out_channels*len(kernels))
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Parameters
        ----------
        x : torch.Tensor (batch_size, in_channels, height, width)
            Input tensor.
        Returns
        -------
        out : torch.Tensor (batch_size, out_channels, height, width)
            Output of the SK convolution layer.
        """
        #Conv2d , AvgPoll, softmax, ReLU, BatchNorm, Linear

        # ----- split -----
        # x: [b, c, h, w]
        b, c, h, w = x.size(dim=0), x.size(dim=1), x.size(dim=2), x.size(dim=3)
        M = len(self.convs)
        feats = torch.cat([conv(x).unsqueeze(1) for conv in self.convs], dim=1)  # [b, M, c, h, w]
    

        # ----- fuse -----
        # eq.1
        U = torch.sum(feats, dim=1)
        # channel-wise statistics, eq.2
        s = self.pool(U) #s: [b, c]
        s = s.reshape([b, c])
        # compact feature, eq.3
        z = self.fc_z(s) # z [b, d]

        # ----- select -----
        batch_size, out_channels = s.shape

        # attention map, eq.5
        score = self.fc_attn(z)  # (batch_size, M * out_channels)
        score = score.view(b, M, out_channels, 1, 1)  # (batch_size, M, out_channels, 1, 1)
        att = self.softmax(score)
        


        # fuse multiple branches, eq.6
        out = torch.sum(feats * att, dim=1).reshape([b, out_channels, h, w])  # (batch_size, out_channels, height, width)
        return out

In [126]:
features = torch.rand([1, 34*16, 25, 25])
print(features.size())
out = SKConv(34*16).eval()
out(features).shape

torch.Size([1, 544, 25, 25])


torch.Size([1, 544, 25, 25])

In [None]:
n = nn.Conv2d(3, 3, kernel_size=3)
n.weight.shape

torch.Size([3, 3, 3, 3])

In [None]:
n = nn.Conv2d(3, 3, kernel_size=3, groups=3)
n.weight.shape

torch.Size([3, 1, 3, 3])

In [None]:
features = torch.rand(1, 3, 25, 25)
n(features).shape

torch.Size([1, 3, 23, 23])