In [4]:
import torch
from torch import nn
from d2l import torch as d2l

In [12]:
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
    if not torch.is_grad_enabled(): #如果是预测模式，这为False
        X_hat=(X-moving_mean)/torch.sqrt(moving_var+eps)
    else:
        #分来那种情况讨论，如果X是全连接层，len(X)==2 ,如果作用于卷积层在非线性函数之前，这len(X)==4;
        assert len(X) in (2,4)
        if len(X)==2:
            mean=torch.mean(X,dim=0,keepdim=True)
            var=((X-mean)**2).mean(dim=0)
            
        else:
            mean=torch.mean(X,dim=(0,2,3),keepdim=True)
            var=((X-mean)**2).mean(dim=(0,2,3),keepdim=True)
        
        X_hat=(X-mean)/torch.sqrt(var+eps)
             # 更新移动平均的均值和方差
        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
        moving_var = momentum * moving_var + (1.0 - momentum) * var
    Y = gamma * X_hat + beta  # 缩放和移位
    return Y, moving_mean.data, moving_var.data

In [13]:
class BatchNorm(nn.Module):
    def __init__(self,numfeature,numdim):
        super().__init__()
        if(numdim==2): #定义Gammma，eps,Moving Mean,Moving variance
            shape=(1,numfeature)
        else:
            shape=(1,numfeature,1,1)
        self.gamma=nn.parameter(torch.ones(shape))
        self.beta=nn.parameter(torch.zeros(shape))
        self.moving_mean=torch.zeros(shape)
        self.moving_var=torch.ones(shape)
    def forward(self,X):
        if self.moving_mean.device!=X.device:
            self.moving_mean.device=self.moving_mean.to(X.device)
            self.moving_var.device=self.moving_var.to(X.device)
        
        Y, self.moving_mean, self.moving_var = batch_norm(
            X, self.gamma, self.beta, self.moving_mean,
            self.moving_var, eps=1e-5, momentum=0.9)
        return Y


In [52]:
#Residual 
from torch.nn import functional as F
class Residual(nn.Module):
    def __init__(self, input_channels, num_channels,
                 use_1x1conv=False, strides=1) -> None:
        super().__init__()
        self.conv1=nn.Conv2d(input_channels,num_channels,kernel_size=3,stride=strides,padding=1)
        self.conv2=nn.Conv2d(num_channels,num_channels,kernel_size=3,padding=1)

        if use_1x1conv:
            self.conv3=nn.Conv2d(input_channels,num_channels,kernel_size=1,stride=strides)
        else:
            self.conv3=None
        
        self.norm1=nn.BatchNorm2d(num_channels)
        self.norm2=nn.BatchNorm2d(num_channels)
    def forward(self,X):
        F1=F.relu(self.norm1(self.conv1(X)))
        F2=self.norm2(self.conv2(F1))
        if self.conv3:
            X=self.conv3(X)
        F2+=X
        return F.relu(F2)



In [53]:
#定义一个ResNet
b1=nn.Sequential(nn.Conv2d(1,64,kernel_size=7,padding=3,stride=2),nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(kernel_size=3,stride=2,padding=1))

In [54]:
def resnet_block(input_channels,num_channels,num_residuals,isFirstBlock):
    block=[]
    for i in range(num_residuals):
        if i==0 and not isFirstBlock:
            block.append(Residual(input_channels,num_channels,True,strides=2))
        else:
            block.append(Residual(num_channels,num_channels))
    return block

In [55]:
net=nn.Sequential(b1,nn.Sequential(*resnet_block(64,64,2,isFirstBlock=True)),
                  nn.Sequential(*resnet_block(64,128,2,isFirstBlock=False)),
                  nn.Sequential(*resnet_block(128,256,2,isFirstBlock=False)),
                  nn.Sequential(*resnet_block(256,512,2,isFirstBlock=False)),
                  nn.AdaptiveAvgPool2d((1,1)),nn.Flatten(),nn.Linear(512,10))

In [56]:
X=torch.rand(size=(1,1,224,224))
for layer in net:
    X=layer(X)
    print(layer.__class__.__name__,"OutPutShape:",X.shape)

Sequential OutPutShape: torch.Size([1, 64, 56, 56])
Sequential OutPutShape: torch.Size([1, 64, 56, 56])
Sequential OutPutShape: torch.Size([1, 128, 28, 28])
Sequential OutPutShape: torch.Size([1, 256, 14, 14])
Sequential OutPutShape: torch.Size([1, 512, 7, 7])
AdaptiveAvgPool2d OutPutShape: torch.Size([1, 512, 1, 1])
Flatten OutPutShape: torch.Size([1, 512])
Linear OutPutShape: torch.Size([1, 10])


DenseNet 的思想


In [69]:
A=[4,4,4,4]
for i ,num in enumerate(A):
    print({i},{num})

{0} {4}
{1} {4}
{2} {4}
{3} {4}


In [64]:
import torch
from torch import nn
from d2l import torch as d2l


def conv_block(input_channels, num_channels):
    return nn.Sequential(
        nn.BatchNorm2d(input_channels), nn.ReLU(),
        nn.Conv2d(input_channels, num_channels, kernel_size=3, padding=1))
class DenseBlock(nn.Module):
    def __init__(self, num_convs, input_channels, num_channels):
        super(DenseBlock, self).__init__()
        layer = []
        for i in range(num_convs):
            layer.append(conv_block(
                num_channels * i + input_channels, num_channels))
        self.net = nn.Sequential(*layer)

    def forward(self, X):
        for blk in self.net:
            Y = blk(X)
            # 连接通道维度上每个块的输入和输出
            X = torch.cat((X, Y), dim=1)
        return X

blk = DenseBlock(2, 3, 10)
X = torch.randn(4, 3, 8, 8)
Y = blk(X)
Y.shape

b1 = nn.Sequential(
    nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
    nn.BatchNorm2d(64), nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

def transition_block(input_channels, num_channels):
    return nn.Sequential(
        nn.BatchNorm2d(input_channels), nn.ReLU(),
        nn.Conv2d(input_channels, num_channels, kernel_size=1),
        nn.AvgPool2d(kernel_size=2, stride=2))

# num_channels为当前的通道数
num_channels, growth_rate = 64, 32
num_convs_in_dense_blocks = [4, 4, 4, 4]
blks = []
for i, num_convs in enumerate(num_convs_in_dense_blocks):
    blks.append(DenseBlock(num_convs, num_channels, growth_rate))
    # 上一个稠密块的输出通道数
    num_channels += num_convs * growth_rate
    # 在稠密块之间添加一个转换层，使通道数量减半
    if i != len(num_convs_in_dense_blocks) - 1:
        blks.append(transition_block(num_channels, num_channels // 2))
        num_channels = num_channels // 2

net = nn.Sequential(
    b1, *blks,
    nn.BatchNorm2d(num_channels), nn.ReLU(),
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Flatten(),
    nn.Linear(num_channels, 10))

net = nn.Sequential(
    b1, *blks,
    nn.BatchNorm2d(num_channels), nn.ReLU(),
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Flatten(),
    nn.Linear(num_channels, 10))

torch.Size([4, 23, 8, 8])