In [16]:
import torch
from torch import nn as nn
from torch.nn import Linear,Conv2d,BatchNorm2d,MaxPool2d,Flatten,G
import torch.nn.functional as F
import warnings
from typing import Callable, Any, Optional, Tuple, List

In [12]:
#基本卷积层
class BasicConv2d(nn.Module):
    def __init__(self, in_channels: int, out_channels: int,**kwargs:Any) -> None:
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels,**kwargs,bias=True)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

    def forward(self, x):
        #x为tensor 输出也为tensor
        x1 = self.conv(x)
        x2 = self.bn(x1)
        return F.relu(x2, inplace=True)

#残差单元
class ResidualUnit(nn.Module):
    def __init__(self,in_channels,out_channels,**kwargs:Any) :
        super().__init__()
        self.conv2d_block1=BasicConv2d(in_channels,out_channels,kernel_size=3,**kwargs)
        self.conv2d_block2=BasicConv2d(out_channels,out_channels,kernel_size=3,strides=1)
        self.skip_block1=BasicConv2d(in_channels,out_channels,kernel_size=1,**kwargs)
    #定义前向传播
    def forward(self,x):
        residual1=self.conv2d_block1(x)
        residual=self.conv2d_block2(residual1)
        fx=self.skip_block1(x)     
        return fx+residual
        

#搭建Resnet34-CNN
class ResidualNet(nn.Module):
    def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None,residual_block:Optional[Callable[..., nn.Module]]=None):
        super().__init__()
        if conv_block is None:
            conv_block = BasicConv2d
        if residual_block is None:
            residual_unit=ResidualUnit
        self.conv1=conv_block(in_channels,64,strides=2)
        self.residual_block=[residual_unit(64,64),residual_unit(64,64),residual_unit(64,64),residual_unit(64,128),residual_unit(128,128),residual_unit(128,128),residual_unit(128,128),residual_unit(128,256),residual_unit(256,256),
                            residual_unit(256,256),residual_unit(256,256),residual_unit(256,256),residual_unit(256,256),residual_unit(256,512),residual_unit(512,512),residual_unit(512,512)]
        self.flatten=Flatten()
        self.softmax=Linear(512,10)

    def forward(self,x):
        x1=self.conv1(x)
        x2=F.max_pool2d(x1)
        for layers in self.residual_block:
            x2=layers(x2)
        x3 = F.adaptive_avg_pool2d(x2,(1,1))
        x4=self.flatten(x3)
        oupput=self.softmax(x4)
        return oupput

