In [1]:
import sys
sys.path.append("..")

import torch
import torch.nn as nn
from ptflops import get_model_complexity_info

from models.common import PatchAttn
from torchvision.ops import DeformConv2d

In [2]:
class Sq_DFS_Attention(nn.Module):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size: int,
            groups: int = 1,
            bias: bool = True
    ):

        super(Sq_DFS_Attention, self).__init__()

        self.kernel_size = kernel_size
        self.groups = groups

        self.ref_conv = nn.Sequential(
            nn.Conv2d(3 * 4, in_channels, 3, stride=1, padding=1, bias=bias),
            nn.ReLU(inplace=True)
        )

        self.offset_conv = nn.Sequential(
            nn.Conv2d(in_channels * 2, in_channels, 3, stride=1, padding=1, bias=bias),
            PatchAttn(in_channels, 8),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, in_channels, 3, stride=1, padding=1, bias=bias),
            PatchAttn(in_channels, 8),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, groups * 2, 3, stride=1, padding=1, bias=bias),
        )
        self.deform_conv = DeformConv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=kernel_size // 2,
            padding=kernel_size % 2,
            groups=groups,
            bias=bias)
        self.relu = nn.ReLU(inplace=True)

        self.rgb_conv = nn.Conv2d(in_channels, 3, 3, stride=1, padding=1, bias=bias)

        self.offset_conv.apply(self.init_0)

    def init_0(self, m):
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, 0.02)
            nn.init.constant_(m.bias, 0)

    def forward(self, x):
        ref_x = x
        n, c, h, w = x.size()

        ref_x = self.ref_conv(torch.cat([ref_x,
                                         torch.flip(ref_x, [2]),
                                         torch.flip(ref_x, [3]),
                                         torch.flip(ref_x, [2, 3])], dim=1))
        offset = self.offset_conv(torch.cat([x, ref_x], dim=1))

        with torch.no_grad():
            point_x = torch.linspace(0, h - 1, h).reshape(1, 1, h, 1).repeat(n, 1, 1, w)
            point_y = torch.linspace(0, w - 1, w).reshape(1, 1, 1, w).repeat(n, 1, h, 1)
            point = torch.cat((point_x, point_y), dim=1).type_as(x)
            point = point + offset
            tmp_x = point[:, 0:1, :, :]
            tmp_y = point[:, 1:2, :, :]
            mask = (tmp_x >= 0) * (tmp_x < h) * (tmp_y >= 0) * (tmp_y < w)
            mask = mask.float()

        conv_offset = offset.repeat(1, self.kernel_size * self.kernel_size, 1, 1)
        x = self.relu(self.deform_conv(x, conv_offset))
        rgb = self.rgb_conv(x)
        return x, rgb, mask, offset


In [3]:
dcn = Sq_DFS_Attention(64, 64, 3, True)

In [5]:
macs, params = get_model_complexity_info(dcn, (64, 64, 64), as_strings=True,
                                         print_per_layer_stat=True, verbose=True)



RuntimeError: Given groups=1, weight of size [64, 12, 3, 3], expected input[1, 256, 64, 64] to have 12 channels, but got 256 channels instead