In [14]:
import torch
from torch import nn 
import torch.nn.functional as F
import numpy as np

In [3]:
class WeightConv(nn.Conv2d):
    """
    继承nn.Conv2d，计算weight
    
    """
    def __init__(self,in_channels,
                out_channels,
                kernel,
                stride=1,
                pad=0,
                dilation=1,
                group=1,
                ratio=16):
        super().__init__(in_channels,
                         out_channels,
                         kernel,
                         stride=stride,
                         padding=pad,
                         dilation=dilation,
                         groups=group)
        self.conv1=nn.Conv2d(in_channels,in_channels//ratio,1,1)
        self.act2 = nn.Sigmoid()
        self.conv2 = nn.Conv2d(in_channels//ratio,in_channels,1,1)
        self.act1 = nn.ReLU()
    def forward(self,x):
        scale = self.conv1(self.weight)
        scale = self.act1(scale)
        scale= self.conv2(scale)
        scale = self.act2(scale)
        
        weight = self.weight*scale
        
        return F.conv2d(x,
                        weight,
                        stride=self.stride,
                        padding=self.padding,
                        dilation=self.dilation,
                        groups=self.groups)

In [4]:
a =torch.randn(1,32,5,5)
test =WeightConv(32,20,3)

In [5]:
out =test(a)

In [7]:
test.weight.shape

torch.Size([20, 32, 3, 3])

In [33]:
out.shape

torch.Size([1, 20, 3, 3])

# MWE

In [8]:
def magnitude_based_weight_excitation(inputs,
                                     weight,
                                     stride=1,
                                     pad=0,
                                     groups=1,
                                     dilation=1,
                                     bias=False,eps=0.1):
    """
    标准化针对每一个卷积kernel,指定维度是1，2，3
    
    """
    mean = weight.mean(dim=(1,2,3),keepdim=True)
    std = weight.var(dim = (1,2,3),keepdim=True)
    weight= (weight-mean)/(torch.pow(std,0.5)+1e-5)
    maxvalue =torch.max(weight)
    ma = (1+eps)*maxvalue
    weight = ma*0.5*torch.log(1+weight/ma)/(1-weight/ma)
    
    return F.conv2d(inputs,weight,stride=stride,padding=pad,groups=groups,dilation=dilation)

In [9]:
class MWEConv(nn.Conv2d):
    def __init__(self,in_channels,
                out_channels,
            kernel, 
            stride=1,
             pad=0,
             groups=1,
             dilation=1,
             bias=False,eps=0.1):
        super().__init__(
            in_channels,
            out_channels,
            kernel, 
            stride=1,
            padding=0,
            groups=1,
            dilation=1,
            bias=False,)
        self.eps = eps
        
    def forward(self,x):
        return magnitude_based_weight_excitation(x,self.weight,
                                                self.stride,
                                                self.padding,
                                                self.groups,
                                                 self.dilation,
                                                 self.bias,eps=0.2
                                                )