In [1]:
# 复现RESNET 18（普通）以及RESNET 50 （有bottleneck）
# RESNET 18：https://towardsdev.com/implement-resnet-with-pytorch-a9fb40a77448
# RESNET 50：https://github.com/liao2000/ML-Notebook/blob/main/ResNet/ResNet_PyTorch.ipynb
# 训练：15.【代码】ResNet代码详解（下载见附件）.pdf

In [2]:
import torch
import torch.nn as nn
from torchsummary import summary

In [3]:
# ! pip install torchsummary

In [4]:
# summary(model, (3, 224, 224))

In [5]:
# 残差块
class ResBlock(nn.Module):
    def __init__(self,need_downsample,in_channels,out_channels):
        super().__init__()
        # 是否需要下采样
        self.need_downsample = need_downsample
        
        if need_downsample:
            # 需要下采样
            self.conv1 = nn.Conv2d(in_channels,out_channels,3,2,1)    
            # 短路连接
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels,out_channels,1,2,0),
                nn.BatchNorm2d(out_channels),
            )

            
        else:
            # 不需要下采样
            self.conv1 = nn.Conv2d(in_channels,out_channels,3,1,1)
            self.shortcut = nn.Sequential()

        
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu1 = nn.ReLU()
            
        # conv2
        self.conv2 = nn.Conv2d(out_channels,out_channels,3,1,1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu2 = nn.ReLU()
        
        self.relu3 = nn.ReLU()
        
    def forward(self,x):
        shortcut = self.shortcut(x)
        
        # conv1
        
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        
        # conv2
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        
        # shortcut
        x= x + shortcut
        
        x = self.relu3(x)
        
        return x
            
            
        

In [6]:
resblock = ResBlock(True,128,256)

In [7]:
# resblock

In [8]:
summary(resblock, (128, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 256, 14, 14]          33,024
       BatchNorm2d-2          [-1, 256, 14, 14]             512
            Conv2d-3          [-1, 256, 14, 14]         295,168
       BatchNorm2d-4          [-1, 256, 14, 14]             512
              ReLU-5          [-1, 256, 14, 14]               0
            Conv2d-6          [-1, 256, 14, 14]         590,080
       BatchNorm2d-7          [-1, 256, 14, 14]             512
              ReLU-8          [-1, 256, 14, 14]               0
              ReLU-9          [-1, 256, 14, 14]               0
Total params: 919,808
Trainable params: 919,808
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.38
Forward/backward pass size (MB): 3.45
Params size (MB): 3.51
Estimated Total Size (MB): 7.34
-------------------------------------------

In [9]:
class Resnet18(nn.Module):
    """
    搭建一个简单的残差网络：RESNET18
    输入：224 x 224 x 3 RGB 彩图
    输出：1000类
    
    """
    def __init__(self,num_classes):
        super().__init__()
        # layer 0
        self.layer_0 = nn.Sequential(
            
            nn.Conv2d(in_channels=3,out_channels=64,kernel_size=7,stride=2,padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3,stride=2,padding=1),
        
        )
        # 每层通道数量
        channels_list = [64,128,256,512]
        
        # layer 1
        self.layer_1 = nn.Sequential(
            # 不做下采样
            ResBlock(False,channels_list[0],channels_list[0]),
            # 不做下采样            
            ResBlock(False,channels_list[0],channels_list[0]),            
        )
        
        
        
        # layer 2        
        self.layer_2 = nn.Sequential(
            # 做下采样
            ResBlock(True,channels_list[0],channels_list[1]),
            # 不做下采样            
            ResBlock(False,channels_list[1],channels_list[1]),            
        )        
        # layer 3
        self.layer_3 = nn.Sequential(
            # 做下采样
            ResBlock(True,channels_list[1],channels_list[2]),
            # 不做下采样            
            ResBlock(False,channels_list[2],channels_list[2]),            
        )          
        # layer 4
        self.layer_4 = nn.Sequential(
            # 做下采样
            ResBlock(True,channels_list[2],channels_list[3]),
            # 不做下采样            
            ResBlock(False,channels_list[3],channels_list[3]),            
        )              
        
        # AAP
        self.aap = nn.AdaptiveAvgPool2d((1, 1))
        # flatten
        self.flatten = nn.Flatten(start_dim=1)
        # FC
        self.fc = nn.Linear(channels_list[3],num_classes)

    def forward(self,x):
        x = self.layer_0(x)
        x = self.layer_1(x)
        x = self.layer_2(x)
        x = self.layer_3(x)
        x = self.layer_4(x)  
        
        x = self.aap(x)
        
        x = self.flatten(x)

        x = self.fc(x)
        
        return x
        

In [10]:
res18 = Resnet18(10)

In [11]:
# 128, 28, 28

In [12]:
summary(res18, (3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,472
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]          36,928
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,928
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
             ReLU-11           [-1, 64, 56, 56]               0
         ResBlock-12           [-1, 64, 56, 56]               0
           Conv2d-13           [-1, 64, 56, 56]          36,928
      BatchNorm2d-14           [-1, 64,

In [13]:
import torch
import torchvision
import torchvision.transforms as transforms

In [14]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 4

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)


Files already downloaded and verified
Files already downloaded and verified


In [15]:
# net =  resnet18()
optimizer = torch.optim.SGD(res18.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

In [16]:
for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = res18(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        print(loss.item())
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

print('Finished Training')

2.4621968269348145
2.001394510269165
2.814085006713867
4.27172327041626
4.138913631439209
3.9370455741882324
3.2597336769104004
3.7382962703704834
5.162281513214111
1.8991966247558594
3.9916529655456543
4.0868425369262695
3.115840435028076
3.5323524475097656
4.601861476898193
4.268013000488281
4.621629238128662
4.277301788330078
3.0212669372558594
1.6083602905273438
6.310171127319336
4.5327558517456055
3.5432825088500977
3.9358792304992676
2.2979254722595215
3.536844253540039
5.200863838195801
3.5441975593566895
4.058243751525879
1.8265321254730225
2.1588070392608643
3.75004243850708
3.500304698944092
6.4565348625183105
2.3028018474578857
4.4429168701171875
5.207690238952637
4.159211158752441
3.9497222900390625
3.0078885555267334
3.938462257385254
3.743095636367798
3.2609426975250244
3.691668748855591
3.9230716228485107
4.083228588104248
3.1736409664154053
2.2232346534729004
3.015775680541992
3.6325669288635254
3.5292906761169434
2.296161413192749
4.413114070892334
4.186079025268555
4.

KeyboardInterrupt: 