In [18]:
import torch
from torch import nn

In [84]:
class SKConv(nn.Module):
    #                  64       32   2  8  2
    def __init__(self, features, WH, M, G, r, stride=1 ,L=32):
        """ Constructor
        Args:
            features: input channel dimensionality.
            WH: input spatial dimensionality, used for GAP kernel size.
            M: the number of branchs.
            G: num of convolution groups.
            r: the radio for compute d, the length of z.
            stride: stride, default 1.
            L: the minimum dim of the vector z in paper, default 32.
        """
        super(SKConv, self).__init__()
        d = max(int(features/r), L)
        self.M = M
        self.features = features
        self.convs = nn.ModuleList([])
        for i in range(M):
            self.convs.append(nn.Sequential(
                nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, groups=G),
                nn.BatchNorm2d(features),
                nn.ReLU(inplace=False)
            ))
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(features, d)
        self.fcs = nn.ModuleList([])
        for i in range(M):
            self.fcs.append(
                nn.Linear(d, features)
            )
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x):
        for i, conv in enumerate(self.convs):
            fea = conv(x[i]).unsqueeze_(dim=1)
            if i == 0:
                feas = fea
            else:
                feas = torch.cat([feas, fea], dim=1)
            # print(i)
        fea_U = torch.sum(feas, dim=1)
        fea_s = self.gap(fea_U).squeeze_()
        # print(fea_U.shape, fea_s.shape)
        fea_z = self.fc(fea_s)
        for i, fc in enumerate(self.fcs):
            vector = fc(fea_z).unsqueeze_(dim=1)
            if i == 0:
                attention_vectors = vector
            else:
                attention_vectors = torch.cat([attention_vectors, vector], dim=1)
            # print(i)
        attention_vectors = self.softmax(attention_vectors)
        attention_vectors = attention_vectors.unsqueeze(-1).unsqueeze(-1)
        # print( (feas * attention_vectors).shape, feas.shape )
        fea_v = (feas * attention_vectors).sum(dim=1)
        return fea_v

In [85]:
class Decode(nn.Module):
    def __init__(self, in_dim, out_dim, WH = 32):
        super().__init__()
        self.up = nn.Sequential(
            nn.Upsample( scale_factor = 2),
            nn.Conv2d(in_dim, out_dim, kernel_size = 3, stride = 1, padding = 1, groups=8),
            nn.BatchNorm2d(out_dim),
            nn.ReLU(),
        )
        self.conv = nn.Sequential(
                    nn.Conv2d(in_dim, out_dim, kernel_size = 3, stride = 1, padding = 1, groups=8),
                    nn.BatchNorm2d(out_dim),
                    nn.ReLU(),
                )

        self.skconv = SKConv( out_dim, WH, 2, 8, 2)
    def forward(self, x, y):
        y_up = self.up(y)
        x_conv = self.conv(x)
        # print(y_up.shape, x_conv.shape)
        out = self.skconv([x_conv, y_up])
        return out

In [86]:
model = Decode(64, 48)

In [90]:
batch_image = torch.zeros((2,64,32, 32))
batch_image1 = torch.zeros((2,64,32, 32))

In [93]:
torch.cat( [batch_image, batch_image1], dim = 1).shape

torch.Size([2, 128, 32, 32])

In [88]:
model( batch_image, batch_image1).shape

torch.Size([2, 2, 48, 64, 64]) torch.Size([2, 2, 48, 64, 64])


torch.Size([2, 48, 64, 64])

In [21]:
model([ batch_image, batch_image1]).shape

0


RuntimeError: Given groups=8, weight of size [32, 4, 3, 3], expected input[2, 64, 32, 32] to have 32 channels, but got 64 channels instead