In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as pyplot

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
#hyper parameter
epoch_num = 80
batch_size = 100
learning_rate = 0.001

In [4]:
#image preprocessing modules
transform = transforms.Compose([
    transforms.Pad(4), #填充
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32), #随即裁剪
    transforms.ToTensor()
])

In [5]:
#data download
train_dataset = torchvision.datasets.CIFAR10(root='../../data/',
                                             train = True,
                                             download=True,
                                             transform=transforms.ToTensor())

test_dataset = torchvision.datasets.CIFAR10(root='../../data/',
                                             train = False,
                                             transform=transforms.ToTensor())

#data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size = batch_size,
                                           shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                           batch_size = batch_size,
                                           shuffle = False)

Files already downloaded and verified


In [6]:
#3x3 Conv
def conv3x3(in_channels,out_channels,stride = 1): #无偏置的3x3卷积
    return nn.Conv2d(in_channels=in_channels,
                     out_channels=out_channels,
                     kernel_size=3,
                     stride=stride,
                     padding = 1,
                     bias = False)

In [7]:
# Residual Block
class ResidualBlock(nn.Module):
    def __init__(self,in_channels,out_channels,stride=1,downsample=None):
        super(ResidualBlock,self).__init__()
        self.conv1 = conv3x3(in_channels,out_channels,stride)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)  #inplace改变输入数据
        self.conv2 = conv3x3(out_channels,out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample
        
    def forward(self,x):
        residual = x    #skip connection
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample:       #维度变化就要做downsample来对齐
            residual = self.downsample(x)
        out +=residual
        out = self.relu(out)
        return out

In [37]:
#ResNet
class ResNet(nn.Module):
    def __init__(self,block,layers,output_size = 10):
        super(ResNet,self).__init__()
        self.in_channels = 16
        self.conv = conv3x3(3,16)
        self.bn = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self.make_layer(block,16,layers[0])
        self.layer2 = self.make_layer(block,32,layers[1],2)
        self.layer3 = self.make_layer(block,64,layers[2],2)
        self.avg_pool = nn.AvgPool2d(8)
        self.fc = nn.Linear(64,output_size)
    
    def make_layer(self,block,out_channels,blocks,stride=1):
        downsample = None
        if (stride !=1) or (self.in_channels !=out_channels):
            downsample = nn.Sequential(
            conv3x3(self.in_channels,out_channels,stride=stride),
            nn.BatchNorm2d(out_channels)
            )
        
        layers = []
        layers.append(block(self.in_channels,out_channels,stride,downsample))
        self.in_channels = out_channels
        for i in range(1,blocks):
            layers.append(block(out_channels,out_channels))
        return nn.Sequential(*layers)
    
    def forward(self,x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0),-1)
        out = self.fc(out)
        return out

        

In [38]:
model = ResNet(ResidualBlock,[2,2,2]).to(device)

In [39]:
#criterion and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)

In [40]:
def update_lr(optimizer,lr): #动态调整学习率
    for para_group in optimizer.param_groups:
        para_group['lr'] = lr

In [41]:
#train model
total_step = len(train_loader)
cur_lr = learning_rate
for epoch in range(epoch_num):
    for i,(images,labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        
        #forward
        predict = model(images)
        loss = criterion(predict,labels)
        
        #backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print("Epoch={}/{},step={}/{},loss={:.4f}".format(epoch+1,epoch_num,i+1,total_step,loss.item()) )
    
    if(epoch+1) % 20 == 0:
        cur_lr /= 3
        update_lr(optimizer,cur_lr)
    

Epoch=1/80,step=100/500,loss=1.6379
Epoch=1/80,step=200/500,loss=1.5276
Epoch=1/80,step=300/500,loss=1.3089
Epoch=1/80,step=400/500,loss=1.2329
Epoch=1/80,step=500/500,loss=1.1730
Epoch=2/80,step=100/500,loss=0.9029
Epoch=2/80,step=200/500,loss=0.8060
Epoch=2/80,step=300/500,loss=1.1179
Epoch=2/80,step=400/500,loss=0.8464
Epoch=2/80,step=500/500,loss=0.9274
Epoch=3/80,step=100/500,loss=0.9064
Epoch=3/80,step=200/500,loss=0.7588
Epoch=3/80,step=300/500,loss=0.7585
Epoch=3/80,step=400/500,loss=0.7004
Epoch=3/80,step=500/500,loss=0.7804
Epoch=4/80,step=100/500,loss=0.7884
Epoch=4/80,step=200/500,loss=0.7157
Epoch=4/80,step=300/500,loss=0.7448
Epoch=4/80,step=400/500,loss=0.7118
Epoch=4/80,step=500/500,loss=0.5978
Epoch=5/80,step=100/500,loss=0.6569
Epoch=5/80,step=200/500,loss=0.8102
Epoch=5/80,step=300/500,loss=0.6100
Epoch=5/80,step=400/500,loss=0.7112
Epoch=5/80,step=500/500,loss=0.5863
Epoch=6/80,step=100/500,loss=0.4922
Epoch=6/80,step=200/500,loss=0.6558
Epoch=6/80,step=300/500,loss

Epoch=45/80,step=300/500,loss=0.0028
Epoch=45/80,step=400/500,loss=0.0012
Epoch=45/80,step=500/500,loss=0.0007
Epoch=46/80,step=100/500,loss=0.0004
Epoch=46/80,step=200/500,loss=0.0005
Epoch=46/80,step=300/500,loss=0.0011
Epoch=46/80,step=400/500,loss=0.0018
Epoch=46/80,step=500/500,loss=0.0006
Epoch=47/80,step=100/500,loss=0.0019
Epoch=47/80,step=200/500,loss=0.0008
Epoch=47/80,step=300/500,loss=0.0026
Epoch=47/80,step=400/500,loss=0.0015
Epoch=47/80,step=500/500,loss=0.0005
Epoch=48/80,step=100/500,loss=0.0009
Epoch=48/80,step=200/500,loss=0.0017
Epoch=48/80,step=300/500,loss=0.0041
Epoch=48/80,step=400/500,loss=0.0013
Epoch=48/80,step=500/500,loss=0.0043
Epoch=49/80,step=100/500,loss=0.0017
Epoch=49/80,step=200/500,loss=0.0016
Epoch=49/80,step=300/500,loss=0.0010
Epoch=49/80,step=400/500,loss=0.0006
Epoch=49/80,step=500/500,loss=0.0004
Epoch=50/80,step=100/500,loss=0.0008
Epoch=50/80,step=200/500,loss=0.0064
Epoch=50/80,step=300/500,loss=0.0015
Epoch=50/80,step=400/500,loss=0.0053
E

In [43]:
#test model
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    
    for images,labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        
        pred = model(images)
        _,pred_pos = torch.max(pred.data,1)
        
        correct += (pred_pos == labels).sum().item()
        total += labels.shape[0]
    
    print('accuracy is:{:.4f}'.format( (correct/total)   ))
        
    

accuracy is:0.8162


In [44]:
#save model
torch.save(model.state_dict(),'resnet.ckpt')