In [11]:
import torch.nn as nn
import torch
from torchsummary import summary

In [5]:
class depthwise_separable_conv(nn.Module):
    def init(self, nin, nout):
        super(depthwise_separable_conv, self).init()
        self.depthwise = nn.Conv2d(nin, nin, kernel_size=3, padding=0, groups=nin)
        self.pointwise = nn.Conv2d(nin, nout, kernel_size=1)
    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out

In [7]:
model = depthwise_separable_conv()

TypeError: __init__() got an unexpected keyword argument 'nin'

In [61]:
# 1569
class Conv_block_7(nn.Module):
    def __init__(self, ch_out=1,kernel_size=9, stride=3) -> None:
        super().__init__()

        self.kernel_size = (3,kernel_size)
        self.stride = (3,stride)
        # self.ln_in = int((300-kernel_size)/stride+1)
        self.ln_in = 84

        self.ac_func = nn.Softplus()

        self.conv1 = nn.LazyConv1d(out_channels=2*3,kernel_size=9,stride=3,groups=3)
        self.conv2 = nn.LazyConv1d(out_channels=4*3,kernel_size=6,stride=3,groups=2*3)
        self.conv3 = nn.LazyConv1d(out_channels=8*3,kernel_size=4,stride=2,groups=4*3)
        # 相当于对len=14的维度，用一个kernel去提取
        self.conv4 = nn.LazyConv1d(out_channels=3,kernel_size=1,stride=1)

        self.BN_aff1 = nn.BatchNorm1d(num_features=2*3,affine=True)
        self.BN_aff2 = nn.BatchNorm1d(num_features=4*3,affine=True)
        self.BN_aff3 = nn.BatchNorm1d(num_features=8*3,affine=True)
        self.BN_aff4 = nn.BatchNorm1d(num_features=3,affine=True)

    def forward(self, x):
        # Conv=>BN=>AC
        x = self.conv1(x)
        x = self.ac_func(self.BN_aff1(x))

        x = self.conv2(x)
        x = self.ac_func(self.BN_aff2(x))

        x = self.conv3(x)
        x = self.ac_func(self.BN_aff3(x))

        x = self.conv4(x)
        x = self.ac_func(self.BN_aff4(x))
        x = torch.flatten(x,start_dim=1)

        return x

class Conv_1_7(nn.Module):
    # code->generate->override methods
    def __init__(self, n_gaussians, ch_out=1, kernel_size=9, stride=3) -> None:
        super().__init__()

        self.kernel_size = (3,kernel_size)
        self.stride = (3,stride)
        # self.ln_in = int((300-self.kernel_size[1])/self.stride[1]+1)
        self.ln_in = 42

        self.BN1 = nn.BatchNorm1d(num_features=3,affine=True)
        # out_channels == K * in_channels
        # 有in_channels个group,每一个group有K个卷积核，能产生K张特征图，所以最后的输出是Cin*K.

        self.layer_pi = Conv_block_7(ch_out=1,kernel_size=kernel_size,stride=stride)
        self.layer_scale = Conv_block_7(ch_out=1,kernel_size=kernel_size,stride=stride)
        self.layer_shape = Conv_block_7(ch_out=1,kernel_size=kernel_size,stride=stride)

        self.ac_func = nn.Softplus()

        self.z_pi = nn.Sequential(
            nn.Linear(self.ln_in, n_gaussians),
            nn.Softmax(dim=1)           # dim=0是B, dim=1才是feature
        )
        self.z_scale = nn.Linear(self.ln_in, n_gaussians)
        self.z_shape = nn.Linear(self.ln_in, n_gaussians)

    def forward(self, x):

        x = self.BN1(x)
        # x = torch.unsqueeze(x,dim=1)                     # torch.Size([B, 1, 3, 300])


        x_pi = self.layer_pi(x)
        x_scale = self.layer_scale(x)
        x_shape = self.layer_shape(x)

        pi = self.z_pi(x_pi)
        scale = torch.exp(self.z_scale(x_scale))
        scale = torch.clamp(scale,1e-4)
        shape = torch.exp(self.z_shape(x_shape))
        shape = torch.clamp(shape,1e-4)

        return pi,scale,shape
        # return pi

In [62]:
mlp = Conv_1_7(2)
mlp.to(device="cuda")
summary(mlp, (3,300))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
       BatchNorm1d-1               [-1, 3, 300]               6
            Conv1d-2                [-1, 6, 98]              60
       BatchNorm1d-3                [-1, 6, 98]              12
          Softplus-4                [-1, 6, 98]               0
            Conv1d-5               [-1, 12, 31]              84
       BatchNorm1d-6               [-1, 12, 31]              24
          Softplus-7               [-1, 12, 31]               0
            Conv1d-8               [-1, 24, 14]             120
       BatchNorm1d-9               [-1, 24, 14]              48
         Softplus-10               [-1, 24, 14]               0
           Conv1d-11                [-1, 3, 14]              75
      BatchNorm1d-12                [-1, 3, 14]               6
         Softplus-13                [-1, 3, 14]               0
     Conv_block_7-14                   