In [86]:
import torch
from torch import nn
import timm
import math
from torch.nn import functional as F
class Conv2d_cd(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
                 padding=1, dilation=1, groups=1, bias=False, adaptive_type='learnable', theta_init=0.5):

        super(Conv2d_cd, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)

        self.adaptive_type = adaptive_type
        if self.adaptive_type == 'learnable':
            self.theta = torch.nn.parameter.Parameter(torch.tensor(0.0,requires_grad=True))  # init: 0 - after sigmoid - 0.5

        elif self.adaptive_type == 'layer_attn':
            #self.down_scale = torch.nn.Linear(out_channels, out_channels//2)

            self.channel_pooling = torch.nn.Conv2d(out_channels, 1, kernel_size=(1,1), stride=1, padding=0)
            self.attn_layer1 = torch.nn.Linear(2, 4)
            self.attn_layer2 = torch.nn.Linear(4, 2)

        elif self.adaptive_type == 'channel_attn':
            pass

        else:
            raise NotImplementedError

    def calculate_layer_attn_for_theta(self, conv_feat, cdc_feat):

        # conv_feat (b, c_out, h, w) ->  (b, c_out)

        conv_feat_d = torch.nn.functional.adaptive_avg_pool2d(conv_feat, (1,1)) # (b, 1)
        conv_feat_d = self.channel_pooling(conv_feat_d)
        cdc_feat_d = torch.nn.functional.adaptive_avg_pool2d(cdc_feat, (1,1)) # (b, 1)
        cdc_feat_d = self.channel_pooling(cdc_feat_d)
        cat_feat = torch.cat([conv_feat_d, cdc_feat_d], 1).squeeze(dim=3).squeeze(dim=2)
        print(cat_feat.size())

        atten_logit = self.attn_layer2(F.relu(self.attn_layer1(cat_feat)))
        theta_attention = torch.softmax(atten_logit,1)[:,1].view(-1,1,1,1)
        print(theta_attention.size())
        return theta_attention

    def calculate_channel_attn_for_theta(self, conv_feat, cdc_feat):

        # conv_feat (b, c_out, h, w) ->  (b, c_out)

        conv_feat_d = torch.nn.functional.adaptive_avg_pool3d(conv_feat, (1,1,1)).view(conv_feat.size(0),1) # (b, 1)
        cdc_feat_d = torch.nn.functional.adaptive_avg_pool3d(cdc_feat, (1,1,1)).view(conv_feat.size(0),1) # (b, 1)

        cat_feat = torch.cat( [conv_feat_d, cdc_feat_d], 1) # (b,2)

        atten_logit = self.attn_layer(cat_feat)
        theta_attention = torch.sigmoid(atten_logit,1)[:,1].view(-1,1,1,1)
        print(theta_attention.size())
        return theta_attention


    def forward(self, x):
        out_normal = self.conv(x)

        if self.adaptive_type == 'learnable':
            # theta = (1+torch.tanh(self.theta)) # constrain to 0~1
            theta = torch.sigmoid(self.theta)  # constrain to 0~1

        if self.adaptive_type == 'layer_attn':
            theta = 0.5

        if math.fabs(theta - 0.0) < 1e-8:
            return out_normal
        else:
            [C_out,C_in, kernel_size,kernel_size] = self.conv.weight.shape
            kernel_diff = self.conv.weight.sum(2).sum(2)
            kernel_diff = kernel_diff[:, :, None, None]
            out_diff = F.conv2d(input=x, weight=kernel_diff, bias=self.conv.bias, stride=self.conv.stride, padding=0, groups=self.conv.groups)



            if self.adaptive_type == 'layer_attn':
                layer_attn_theta = self.calculate_layer_attn_for_theta(out_normal, out_diff)
                print(layer_attn_theta.size())
                print(out_diff.size())
                return out_normal - layer_attn_theta * out_diff

            if self.adaptive_type == 'channel_attn':
                layer_attn_theta = self.calculate_layer_attn_for_theta(out_normal, out_diff)
                return out_normal - layer_attn_theta * out_diff


In [87]:
model = Conv2d_cd(4,8, adaptive_type='layer_attn')
x = torch.rand(2,4,16,16)
model(x)

#model.state_dict()

torch.Size([2, 2])
torch.Size([2, 1, 1, 1])
torch.Size([2, 1, 1, 1])
torch.Size([2, 8, 16, 16])


tensor([[[[-4.6084e-02, -2.2194e-01,  1.6116e-01,  ..., -1.0496e-01,
           -1.4926e-02,  4.3239e-02],
          [ 1.7068e-02,  5.5577e-02, -1.5669e-01,  ..., -2.8768e-01,
            1.1116e-01, -3.6284e-03],
          [-1.5329e-01,  1.6896e-01, -2.8063e-01,  ...,  2.0391e-02,
            2.8652e-02, -1.5282e-01],
          ...,
          [-2.1883e-01,  1.7996e-01, -1.6431e-01,  ...,  1.7076e-01,
            9.3656e-02,  1.8757e-01],
          [-2.0693e-02,  3.1210e-02,  1.5915e-01,  ...,  1.8637e-01,
           -4.5829e-01,  3.8566e-03],
          [ 1.6681e-01, -1.6647e-01,  4.8291e-01,  ...,  6.9550e-02,
            2.1436e-01,  1.9593e-02]],

         [[-1.1945e-01,  5.6504e-01, -6.8770e-02,  ..., -2.3456e-01,
            2.2841e-01, -1.6089e-02],
          [ 6.9756e-02,  5.1129e-01,  6.5183e-02,  ...,  3.0130e-01,
            2.6899e-01,  4.4122e-01],
          [ 1.8591e-01,  2.7597e-01, -1.2126e-01,  ...,  4.6521e-01,
            2.1639e-01, -1.1583e-01],
          ...,
     

In [None]:
class SpatialAttention(nn.Module):
    def __init__(self):
        super(SpatialAttention,self).__init__()

        self.conv2d = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=5, stride=1, padding=2)
        self.sigmoid = nn.Sigmoid()

    def forward(self,x):
        avgout = torch.mean(x,dim=1, keepdim=True)
        maxout, _ = torch.max(x, dim=1, keepdim=True)

        out = torch.cat([avgout, maxout], dim=1)
        out = self.sigmoid(self.conv2d(out))
        return out
