# ResNet18 实现

![pic](resnet18.jpg)

In [1]:
%matplotlib inline
import torch
from torch import nn
from torch.nn import functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.optim as optim
from tqdm import tqdm
import os
print(torch.__version__)

1.10.1


In [2]:
#prepare dataset and preprocessing
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    # 至少要加上下面这句ToTensor
    transforms.ToTensor(),
    # ciaf10固有均值标准差
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    # 至少要加上下面这句ToTensor
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

## 加载数据集

In [2]:
# minist数据集
# 正常来说,train_set这个类是需要自己定义的,但是在官方数据集中已经给定义好了
data_dir="D:/data/image/"
if not os.path.exists(data_dir):
    # 尝试mac的文件夹
    data_dir="~/data"
    if not os.path.exists(data_dir):
            raise FileExistsError("data source not exist!")
    
print("data source",data_dir)
train_set=datasets.CIFAR10(root=data_dir,
                        transform=transform_train,
                        train=True,
                        download=True)

val_set=datasets.CIFAR10(root=data_dir,
                        transform=transform_test,
                        train=False,
                        download=True)
print('train data',len(train_set))
print('val data',len(train_set))

data source D:/data/image/


NameError: name 'transform_train' is not defined

In [15]:
epoch_total=30
batch=64
lr=0.01
device=torch.device('cuda'if torch.cuda.is_available() else 'cpu')
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
print('use device:',device)

use device: cuda


In [5]:
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch,
                                         shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch,
                                         shuffle=False, num_workers=4)

In [6]:
# Residual 残差块
class Residual(nn.Module):
    def __init__(self,input_channel,out_channel,kersize=3,use_1x1=False,strides=1):
        super().__init__()
        # 第一层是输入->输出 ,可能会有深度变化
        self.conv1=nn.Conv2d(input_channel,out_channel,kersize,padding=1,stride=strides)
        # 第二层就是输出->输出,没有深度变化,也没有跨距变化
        self.conv2=nn.Conv2d(out_channel,out_channel,kersize,padding=1,stride=1)
        self.bn1=nn.BatchNorm2d(out_channel)
        self.bn2=nn.BatchNorm2d(out_channel)
        if(use_1x1):
            self.short=nn.Conv2d(input_channel,out_channel,kernel_size=(1,1),stride=strides)
        else:
            self.short=None
    def forward(self,x):
        y=self.conv1(x)
        y=self.bn1(y)
        y=F.relu(y)
        y=self.conv2(y)
        y=self.bn2(y)
        if(self.short):
            x=self.short(x)
        y=y+x
        F.relu(y)
        return y
        
        

标准resnet使用的是224x224的图片

CIFAR10图片为3x32x32能跑,但是因为图片太小,最后的卷积可能效果不好

可以尝试稍微改动下模型,比如修改第一层的输入跨距

In [7]:
# 标准reset 18
class ResetNet18(nn.Module):
    def __init__(self,input_channel,out_label):
        super().__init__()
        self.bn1=nn.Sequential(
            nn.Conv2d(input_channel,64,7,stride=2,padding=3),
            nn.BatchNorm2d(64),nn.ReLU(),
            nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
        )
        self.bn2=nn.Sequential(*self.res_block(64,64,2,bfirst=True))
        self.bn3=nn.Sequential(*self.res_block(64,128,2))
        self.bn4=nn.Sequential(*self.res_block(128,256,2))
        self.bn5=nn.Sequential(*self.res_block(256,512,2))
        self.full=nn.Sequential(nn.AdaptiveAvgPool2d((1,1)),nn.Flatten(), nn.Linear(512, out_label))
        
    def res_block(self,input_channel,out_channel,num_block,bfirst=False):
        blk=[]
        for i in range(num_block):
            if(i==0 and not bfirst):
                # 一般第一个块需要降维
                # 降维则是wh缩减为一半,深度增加为2倍
                blk.append(Residual(input_channel,out_channel,use_1x1=True,strides=2))
            else:
                # 第一层比较特殊,输入输出相同宽度,所以直接可都用out_channel变量
                blk.append(Residual(out_channel,out_channel))
        # print("make block ",blk)
        return blk

    def forward(self,x):

        y=self.bn1(x)
        y=self.bn2(y)
        y=self.bn3(y)
        y=self.bn4(y)
        y=self.bn5(y)
        y=self.full(y)
        return y

In [8]:
# 送一个数据进去实验一下
X=torch.rand(4,3,32,32)
net=ResetNet18(3,10)
out=net(X)
print(out.shape)

torch.Size([4, 10])


In [9]:
# 送到设备上
net.to(device);

In [10]:
criterion=nn.CrossEntropyLoss()
# 0.9倍当前的梯度+0.1倍上次的梯度
optimizer=optim.SGD(net.parameters(),lr=lr,momentum=0.9)

In [11]:
# 测试一下 dataloader
iterator=iter(train_loader)
data,label=next(iterator)
print(data.size(),label.size())

torch.Size([64, 3, 32, 32]) torch.Size([64])


In [12]:
def train():
    net.train()
    train_total_len=len(train_set)
    # 需要注意这里写的是train_set的长度,如果写错成train_loader,返回的是数据集一共有多少个batch
    with tqdm(total=train_total_len,desc=f'Train:') as pbar:
        for idx,data in enumerate(train_loader):
            data,label=data
            data,label=data.to(device),label.to(device)
            optimizer.zero_grad()
            # forward
            outputs=net(data)
            loss=criterion(outputs,label)
            loss.backward()
            optimizer.step()
            # 更新进度条
            pbar.update(batch)


In [42]:
def test():
    # eval 模式下,dropout失效,bn层参数采用之前训练的,不更新
    net.eval()
    val_loss=0
    correct=0
    total_num=len(val_set)
    with tqdm(total=total_num,desc=f'Validation:') as pbar:
        # 不计算损失,这样速度更快
        with torch.no_grad():
            for idx,data in enumerate(val_loader):
                x,y=data
                x,y=x.to(device),y.to(device)
                optimizer.zero_grad()
                y_pre=net(x)
                val_loss+=criterion(y_pre,y).item()
                # max 第一个返回的是元素值,第二个为索引值
                # 求第一个维度的max,因此结果返回的是batch维度的max
                # 返回是一个第一个元素为值,第二个元素为idx的tuple
                pred=torch.max(y_pre,dim=1)[1]
                
                # pred维度为batch,每个元素为索引
                correct+=pred.eq(y).sum().item()
                # 更新进度条
                pbar.update(batch)
        # 格式化打印直接有% 带f%这种{:.2f%}是错的格式,format这种有点坑
        print("test loss {},accuracy {:.2%}".format(val_loss,correct/total_num))
        
            
        

In [43]:
%%time
# 跑一个epoch大概十分钟
# 实在太慢,就不在mac上运行了
for epoch in range(epoch_total):
    print('epoch:',epoch)
    train()
    test()

epoch: 0


Train:: 50048it [00:17, 2844.88it/s]                                                                                                 
Validation:: 10048it [00:05, 1764.95it/s]                                                                                            


test loss 131.11734211444855,accuracy 71.03%
epoch: 1


Train:: 50048it [00:17, 2832.30it/s]                                                                                                 
Validation:: 10048it [00:05, 1765.54it/s]                                                                                            


test loss 118.028868496418,accuracy 73.67%
epoch: 2


Train:: 50048it [00:17, 2877.27it/s]                                                                                                 
Validation:: 10048it [00:05, 1772.82it/s]                                                                                            


test loss 111.20981541275978,accuracy 75.50%
epoch: 3


Train:: 50048it [00:17, 2877.41it/s]                                                                                                 
Validation:: 10048it [00:05, 1776.84it/s]                                                                                            


test loss 109.04911902546883,accuracy 75.77%
epoch: 4


Train:: 50048it [00:17, 2872.27it/s]                                                                                                 
Validation:: 10048it [00:05, 1757.99it/s]                                                                                            


test loss 105.23417779803276,accuracy 76.97%
epoch: 5


Train:: 50048it [00:17, 2795.65it/s]                                                                                                 
Validation:: 10048it [00:05, 1774.23it/s]                                                                                            


test loss 108.92613685131073,accuracy 75.85%
epoch: 6


Train:: 50048it [00:17, 2857.32it/s]                                                                                                 
Validation:: 10048it [00:05, 1776.47it/s]                                                                                            


test loss 101.0473915040493,accuracy 78.03%
epoch: 7


Train:: 50048it [00:17, 2880.31it/s]                                                                                                 
Validation:: 10048it [00:05, 1779.78it/s]                                                                                            


test loss 99.22572460770607,accuracy 78.08%
epoch: 8


Train:: 50048it [00:17, 2856.65it/s]                                                                                                 
Validation:: 10048it [00:05, 1778.98it/s]                                                                                            


test loss 103.96464204788208,accuracy 76.93%
epoch: 9


Train:: 50048it [00:17, 2880.21it/s]                                                                                                 
Validation:: 10048it [00:05, 1774.17it/s]                                                                                            


test loss 94.17370548844337,accuracy 79.40%
epoch: 10


Train:: 50048it [00:17, 2869.93it/s]                                                                                                 
Validation:: 10048it [00:05, 1773.21it/s]                                                                                            


test loss 97.02327623963356,accuracy 78.67%
epoch: 11


Train:: 50048it [00:17, 2888.36it/s]                                                                                                 
Validation:: 10048it [00:05, 1757.90it/s]                                                                                            


test loss 92.05853220820427,accuracy 80.33%
epoch: 12


Train:: 50048it [00:17, 2881.15it/s]                                                                                                 
Validation:: 10048it [00:05, 1775.94it/s]                                                                                            


test loss 89.88909649848938,accuracy 80.33%
epoch: 13


Train:: 50048it [00:17, 2877.28it/s]                                                                                                 
Validation:: 10048it [00:05, 1770.70it/s]                                                                                            


test loss 93.46342650055885,accuracy 79.04%
epoch: 14


Train:: 50048it [00:17, 2890.68it/s]                                                                                                 
Validation:: 10048it [00:05, 1780.89it/s]                                                                                            


test loss 91.83916383981705,accuracy 80.36%
epoch: 15


Train:: 50048it [00:17, 2849.86it/s]                                                                                                 
Validation:: 10048it [00:05, 1786.70it/s]                                                                                            


test loss 89.07499961555004,accuracy 81.10%
epoch: 16


Train:: 50048it [00:17, 2893.26it/s]                                                                                                 
Validation:: 10048it [00:05, 1788.82it/s]                                                                                            


test loss 87.15135458111763,accuracy 81.01%
epoch: 17


Train:: 50048it [00:17, 2881.57it/s]                                                                                                 
Validation:: 10048it [00:05, 1783.84it/s]                                                                                            


test loss 87.75441151857376,accuracy 81.05%
epoch: 18


Train:: 50048it [00:17, 2893.99it/s]                                                                                                 
Validation:: 10048it [00:05, 1790.73it/s]                                                                                            


test loss 86.96920646727085,accuracy 81.44%
epoch: 19


Train:: 50048it [00:17, 2881.26it/s]                                                                                                 
Validation:: 10048it [00:05, 1782.51it/s]                                                                                            


test loss 87.10530969500542,accuracy 81.41%
epoch: 20


Train:: 50048it [00:17, 2898.15it/s]                                                                                                 
Validation:: 10048it [00:05, 1767.80it/s]                                                                                            


test loss 87.43953198194504,accuracy 81.17%
epoch: 21


Train:: 50048it [00:17, 2878.69it/s]                                                                                                 
Validation:: 10048it [00:05, 1773.96it/s]                                                                                            


test loss 82.48810589313507,accuracy 82.27%
epoch: 22


Train:: 50048it [00:17, 2866.43it/s]                                                                                                 
Validation:: 10048it [00:05, 1770.92it/s]                                                                                            


test loss 79.35491527616978,accuracy 83.31%
epoch: 23


Train:: 50048it [00:17, 2877.74it/s]                                                                                                 
Validation:: 10048it [00:05, 1782.78it/s]                                                                                            


test loss 85.7786630988121,accuracy 81.40%
epoch: 24


Train:: 50048it [00:17, 2860.76it/s]                                                                                                 
Validation:: 10048it [00:05, 1779.94it/s]                                                                                            


test loss 85.35498857498169,accuracy 82.25%
epoch: 25


Train:: 50048it [00:17, 2879.91it/s]                                                                                                 
Validation:: 10048it [00:05, 1779.74it/s]                                                                                            


test loss 85.08455780148506,accuracy 82.41%
epoch: 26


Train:: 50048it [00:17, 2864.42it/s]                                                                                                 
Validation:: 10048it [00:05, 1761.22it/s]                                                                                            


test loss 79.68626298010349,accuracy 83.10%
epoch: 27


Train:: 50048it [00:17, 2862.10it/s]                                                                                                 
Validation:: 10048it [00:05, 1761.51it/s]                                                                                            


test loss 80.57792191207409,accuracy 83.06%
epoch: 28


Train:: 50048it [00:17, 2812.04it/s]                                                                                                 
Validation:: 10048it [00:05, 1777.58it/s]                                                                                            


test loss 81.04605334997177,accuracy 82.99%
epoch: 29


Train:: 50048it [00:17, 2867.09it/s]                                                                                                 
Validation:: 10048it [00:05, 1769.03it/s]                                                                                            

test loss 84.70116354525089,accuracy 82.19%
Wall time: 11min 33s





In [44]:
%%time
# 跑一个epoch大概十分钟
# 实在太慢,就不在mac上运行了
for epoch in range(5):
    print('epoch:',epoch)
    train()
    test()

epoch: 0


Train:: 50048it [00:17, 2884.26it/s]                                                                                                 
Validation:: 10048it [00:05, 1799.54it/s]                                                                                            


test loss 80.20117615163326,accuracy 83.44%
epoch: 1


Train:: 50048it [00:17, 2861.28it/s]                                                                                                 
Validation:: 10048it [00:05, 1796.50it/s]                                                                                            


test loss 84.87753988802433,accuracy 82.93%
epoch: 2


Train:: 50048it [00:17, 2896.24it/s]                                                                                                 
Validation:: 10048it [00:05, 1797.15it/s]                                                                                            


test loss 79.7335135936737,accuracy 83.75%
epoch: 3


Train:: 50048it [00:17, 2844.85it/s]                                                                                                 
Validation:: 10048it [00:05, 1790.54it/s]                                                                                            


test loss 84.12095852196217,accuracy 83.32%
epoch: 4


Train:: 50048it [00:17, 2884.75it/s]                                                                                                 
Validation:: 10048it [00:05, 1798.75it/s]                                                                                            

test loss 80.80921545624733,accuracy 83.31%
Wall time: 1min 55s





In [46]:
torch.save(net.state_dict(),"loss80_accuracy83.pth")