In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models

In [2]:
import numpy as np
import os
import json
import time
import copy

# 1. Define the network

In [34]:
class ResBlock(nn.Module):
    """
    Resnet block
    """
    def __init__(self, ch_in, ch_out, stride=1):
        """
        :param ch_in
        :param ch_out
        """
        super(ResBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(ch_out)
        self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(ch_out)
        
        self.extra = nn.Sequential()
        if (ch_out != ch_in)or(stride!=1):
            #print('ch_out != ch_in')
            # [b, ch_in, h, w] => [b, ch_out, h, w]            
            self.extra = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
                nn.BatchNorm2d(ch_out)
            )
    
    def forward(self, x):
        """
        :param x: [b, ch, h, w]
        :return:
        """
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        # short cut
        # extra module: [b, ch_in, h, w] => [b, ch_out, h, w]
        # element-wise add:
        #print('outshape:',out.shape)
        #print('extrashape:', self.extra(x).shape)
        out = self.extra(x) + out
        out = F.relu(out)
        
        return out
    
class ResNet(nn.Module):
    
    def __init__(self):
        super(ResNet, self).__init__()
        
        # pre-produce layer
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 18, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(18)
        )
        # followed 4 blocks
        ## [b, 64, h, w] => [b, 128, h, w]
        self.blk1 = ResBlock(18, 36, stride=2)
        ## [b, 128, h, w] => [b, 256, h, w]
        self.blk2 = ResBlock(36, 72, stride=2)
        ## [b, 256, h, w] => [b, 512, h, w]
        self.blk3 = ResBlock(72, 144, stride=2)
        ## [b, 512, h, w] => [b, 1024, h, w]
        self.blk4 = ResBlock(144, 144, stride=2)
        
        self.outlayer = nn.Sequential(
            nn.Linear(144,10),
            nn.LogSoftmax(dim=1)
                                     )

    def forward(self, x):
        """
        :param: x:
        :return:
        """
        x = F.relu(self.conv1(x))
        
        ## [b, 64, h, w] => [b, 1024, h, w]
        x = self.blk1(x)
        x = self.blk2(x)
        x = self.blk3(x)
        x = self.blk4(x)
        
        # print('after conv:', x.shape) # [b, 512, 2, 2]
        ## [b, 512, h, w] => [b, 512, 1, 1]
        x = F.adaptive_avg_pool2d(x, [1,1])
        # print('after pool:', x.shape)
        
        x = x.view(x.size(0), -1)  # flatten
        x = self.outlayer(x)
        
        return x

## 1.1 parameters

In [4]:
model_ = ResNet()
model_.parameters

<bound method Module.parameters of ResNet(
  (conv1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (blk1): ResBlock(
    (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (extra): Sequential(
      (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (blk2): ResBlock(
    (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(

In [77]:
_params = list(model_.parameters())
k = 0
for i in _params:
    l = 1
    for j in i.size():
        l *= j
    k+=l
print('total params:',k)

total params: 780562


# 2. Train

## 2.1 Preparation

In [5]:
train_on_gpu = torch.cuda.is_available()
device = torch.device('cuda:0' if train_on_gpu else 'cpu')
device

device(type='cpu')

In [39]:
# data preparation

data_dir = '../data/'
train_dir = data_dir + 'train/'
valid_dir = data_dir + 'valid/'

data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((32,32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406],
                             std=[0.229,0.224,0.225])
    ]),
    'valid': transforms.Compose([
        transforms.Resize((32,32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406],
                             std=[0.229,0.224,0.225])
    ])
}

In [40]:
batch_size = 16

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'valid']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train','valid']}

class_names = image_datasets['train'].classes

In [41]:
model_ft = ResNet().to(device)
criterion = nn.CrossEntropyLoss().to(device)
#criterion = nn.NLLLoss().to(device)
optimizer_ft = optim.Adam(model_ft.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)


## 2.2 the train model

In [46]:
def train_model(model, device, dataloaders, criterion, optimizer, num_epochs=25, filename='outupt1.pth'):
    since = time.time()
    best_acc = 0
    model.to(device)
    #print(model)
    
    # process records
    val_acc_history = []
    train_acc_history = []
    train_losses = []
    valid_losses = []
    LRs = [optimizer.param_groups[0]['lr']]
    
    best_model_wts = copy.deepcopy(model.state_dict())
    
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs-1))
        print('-' *10)
        
        for phase in ['train', 'valid']:
            if phase == 'train':
                print('###training###')
                model.train()
            else:
                print('###validating###')
                model.eval()
            
            running_loss = 0.
            running_correct = 0.
            
            #bb = 0
            for inputs, labels in dataloaders[phase]:
                
                inputs, labels = inputs.to(device), labels.to(device)
 
                with torch.set_grad_enabled(phase=='train'):
                    # outputs: [b, 10]
                    # label: [b]
                    # lodd: tensor scalar
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)

                    if phase=='train':
                        #print('training batch:', bb)
                        #bb+=1
                        # backprop
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                # calculate the loss
                #print('loss:%.3f'%loss.item())
                running_loss += loss.item() * inputs.size(0)
                running_correct += torch.sum(preds==labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_correct.double() / len(dataloaders[phase].dataset)

            time_elapsed = time.time()- since
            print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            if phase=='valid' and epoch_acc>best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                state = {
                    'state_dict': model.state_dict(),
                    'best_acc': best_acc,
                    'optimizer': optimizer.state_dict()
                }
                torch.save(state, filename)
            
            if phase=='valid':
                val_acc_history.append(epoch_acc)
                valid_losses.append(epoch_loss)
                scheduler.step(epoch_loss)
            if phase=='train':
                train_acc_history.append(epoch_acc)
                train_losses.append(epoch_loss)
            
        print('Optimizer learning rate: {:.7f}'.format(optimizer.param_groups[0]['lr']))
        LRs.append(optimizer.param_groups[0]['lr'])
        print()
        
    time_elapsed = time.time()- since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))
    
    model.load_state_dict(best_model_wts)
    return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs

## 2.3 train

In [49]:
model1, valid_acc, train_acc, valid_loss, train_loss, LRs1 = train_model(
                                    model_ft, device, dataloaders, criterion, optimizer_ft, num_epochs=2)

Epoch 0/1
----------
###training###
Time elapsed 2m 6s
train Loss: 0.0603 Acc: 0.9796
###validating###
Time elapsed 2m 10s
valid Loss: 1.0598 Acc: 0.7870
Optimizer learning rate: 0.0010000

Epoch 1/1
----------
###training###




Time elapsed 4m 24s
train Loss: 0.0583 Acc: 0.9791
###validating###
Time elapsed 4m 28s
valid Loss: 1.0319 Acc: 0.7884
Optimizer learning rate: 0.0010000

Training complete in 4m 28s
Best val Acc: 0.788365


# 3. Results display

In [None]:
epoch_num = len(train_loss)
plt.plot(np.arange(epoch_num), train_loss, marker='o', label='Train')
plt.plot(np.arange(epoch_num), valid_loss, marker='o', label='Valid')
plt.xticks(np.arange(0,epoch_num,2), [int(xx) for xx in np.arange(0,epoch_num,2)])
plt.xlabel('Epoch', fontsize=14)
plt.ylabel('Loss', fontsize=14)
plt.legend()

In [None]:
epoch_num = len(train_loss)
plt.plot(np.arange(epoch_num), train_acc, marker='o', label='Train')
plt.plot(np.arange(epoch_num), valid_acc, marker='o', label='Valid')
plt.xticks(np.arange(0,epoch_num,2), [int(xx) for xx in np.arange(0,epoch_num,2)])
plt.xlabel('Epoch', fontsize=14)
plt.ylabel('Accuracy', fontsize=14)
plt.legend()