# ResNet Net

In [1]:
import time
import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import torch.nn as nn
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms

In [2]:
# Image Preprocessing
transform = transforms.Compose([
    transforms.Scale(40),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32),
    transforms.ToTensor()])

# CIFAR-10 Dataset
train_dataset = CIFAR10(root='data/cifar10',
                              train=True,
                              transform=transform,
                              download=False)

test_dataset = CIFAR10(root='data/cifar10',
                             train=False,
                             transform=transforms.ToTensor())

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=100,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=100,
                                          shuffle=False)

In [3]:
# 定义残差网络
class ResidualBlock(nn.Module):
    def __init__(self,in_channels,out_channels,stride=1):
        super(ResidualBlock,self).__init__()
        self.conv1 = nn.Conv2d(in_channels,out_channels,kernel_size=(3,3),stride=stride,padding=1,bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels,out_channels,kernel_size=(3,3),stride=1,padding=1,bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        # 下采样
        self.shortcut = nn.Sequential()
        if in_channels!=out_channels or stride!=1:
            self.shortcut = nn.Sequential(
                                nn.Conv2d(in_channels,out_channels,stride=stride,kernel_size=(1,1),bias=False),
                                nn.BatchNorm2d(out_channels),
                                         )
    def forward(self,x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

In [4]:
class CIFARResNet18(nn.Module):
    def __init__(self,num_classes=10):
        super(CIFARResNet18,self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3,out_channels=64,stride=1,kernel_size=(3,3),padding=1,bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        
        self.stage1 = self._create_stage(64, 64, stride=1)
        self.stage2 = self._create_stage(64, 128, stride=2)
        self.stage3 = self._create_stage(128, 256, stride=2)
        self.stage4 = self._create_stage(256,512, stride=2)
        
        self.linear = nn.Linear(512,num_classes)
        
        
    def _create_stage(self, in_channels,out_channels,stride):
        return nn.Sequential(
            ResidualBlock(in_channels,out_channels,stride=stride),
            ResidualBlock(out_channels,out_channels,1)
        )
    
    def forward(self,x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.stage1(out)
        out = self.stage2(out)
        out = self.stage3(out)
        out = self.stage4(out)
        out = F.avg_pool2d(out,4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

In [5]:
net = CIFARResNet18().cuda()

In [6]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

In [7]:
for epoch in range(10):
    losses = []
    start = time.time()
    for batch_index,(inputs,targets) in enumerate(train_loader):
        inputs, targets = Variable(inputs).cuda(), Variable(targets).cuda()
        
        optimizer.zero_grad()
        
        pred = net(inputs)
        loss = loss_fn(pred,targets)
        
        loss.backward()
        optimizer.step()
        losses.append(loss.data[0])
        if batch_index % 10==0:
            print("Epoch: %d [%d/%d] Loss : %.3f" % (epoch,batch_index,len(train_loader), np.mean(losses)))
    end = time.time()
    print('Epoch: %d Loss : %.3f Time : %.3f seconds' % (epoch, np.mean(losses), end-start))
    # eval
    
    net.eval()
    total = 0
    correct = 0
    
    for batch_index,(inputs,targets) in enumerate(test_loader):
        inputs, targets = Variable(inputs).cuda(), Variable(targets).cuda()
        pred = net(inputs)
        _, predicted = torch.max(pred.data, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()
    print('accuarcy: %.3f'%(100.0* correct/total))

Epoch: 0 [0/500] Loss : 2.280
Epoch: 0 [10/500] Loss : 2.180
Epoch: 0 [20/500] Loss : 2.083


KeyboardInterrupt: 

In [9]:
for epoch in range(200):
    losses = []
    # Train
    start = time.time()
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.cuda(), targets.cuda()

        optimizer.zero_grad()
        inputs, targets = Variable(inputs), Variable(targets)
        outputs = net(inputs)
        loss = loss_fn(outputs, targets)
        loss.backward()
        optimizer.step()
        losses.append(loss.data[0])
        if batch_idx%10 == 0:
            print('Epoch : %d Loss : %.3f' % (epoch, np.mean(losses)))
    end = time.time()

    print('Epoch : %d Loss : %.3f Time : %.3f seconds ' % (epoch, np.mean(losses), end - start))
    # Evaluate
    net.eval()
    total = 0
    correct = 0
    for batch_idx, (inputs, targets) in enumerate(test_loader):
        inputs, targets = inputs.cuda(), targets.cuda()

        inputs, targets = Variable(inputs, volatile=True), Variable(targets, volatile=True)
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()

    print('Epoch : %d Test Acc : %.3f' % (epoch, 100.*correct/total))
    print('--------------------------------------------------------------')
    net.train()

Epoch : 0 Loss : 1.096
Epoch : 0 Loss : 0.985
Epoch : 0 Loss : 0.991
Epoch : 0 Loss : 1.024
Epoch : 0 Loss : 1.028
Epoch : 0 Loss : 1.020
Epoch : 0 Loss : 1.021
Epoch : 0 Loss : 1.017
Epoch : 0 Loss : 1.020
Epoch : 0 Loss : 1.014
Epoch : 0 Loss : 1.018
Epoch : 0 Loss : 1.011
Epoch : 0 Loss : 1.005
Epoch : 0 Loss : 1.001
Epoch : 0 Loss : 0.998
Epoch : 0 Loss : 0.995
Epoch : 0 Loss : 0.987
Epoch : 0 Loss : 0.987
Epoch : 0 Loss : 0.989
Epoch : 0 Loss : 0.985
Epoch : 0 Loss : 0.982
Epoch : 0 Loss : 0.979
Epoch : 0 Loss : 0.978
Epoch : 0 Loss : 0.974
Epoch : 0 Loss : 0.971
Epoch : 0 Loss : 0.970
Epoch : 0 Loss : 0.967
Epoch : 0 Loss : 0.965
Epoch : 0 Loss : 0.964
Epoch : 0 Loss : 0.963
Epoch : 0 Loss : 0.962
Epoch : 0 Loss : 0.962
Epoch : 0 Loss : 0.960
Epoch : 0 Loss : 0.958
Epoch : 0 Loss : 0.955
Epoch : 0 Loss : 0.953
Epoch : 0 Loss : 0.949
Epoch : 0 Loss : 0.946
Epoch : 0 Loss : 0.945
Epoch : 0 Loss : 0.943
Epoch : 0 Loss : 0.939
Epoch : 0 Loss : 0.936
Epoch : 0 Loss : 0.935
Epoch : 0 L

KeyboardInterrupt: 