In [1]:
import sys
sys.path.append('../')

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import cv2

import torch

from dataset.dataset_unet import prepare_trainset
from model.model_unet import UNetResNet34

%matplotlib inline

## image size vs batch size

In [3]:
device = 'cuda:0'
def test_hyperparameters(BATCH_SIZE=16, IMG_SIZE=256, device=device):
    train_dl, val_dl = prepare_trainset(BATCH_SIZE=BATCH_SIZE, 
                                        NUM_WORKERS=8, 
                                        SEED=2019, 
                                        IMG_SIZE=IMG_SIZE, debug=True)

    for i, (images, masks) in enumerate(train_dl):  # 一次读取一个Bath
        images = images.to(device=device, dtype=torch.float)
        masks = masks.to(device=device, dtype=torch.float)
        if i==0:
            break

    print(images.size(), masks.size())
    
    ##
    net = UNetResNet34(debug=False)
    net = net.to(device=device)
    
    logit = net(images)
    print('Pass')

In [4]:
test_hyperparameters(BATCH_SIZE=8, IMG_SIZE=256, device='cuda:0')#'cpu'#'cuda:0'

Count of trainset (for training):  900
Count of validset (for training):  200
torch.Size([8, 1, 256, 256]) torch.Size([8, 1, 256, 256])
Pass


## 作业：自己尝试各种超参数组合
- epoch, 10? 200?
- early stop round, 1? 20?
- learning rate, 10? 0.00001?
- batch size
- image size, 128, 256, 512, 768, 1024, ...
- 使用BCE、Focal Loss、Dice Loss

## 学习率方案

In [1]:
import torch
import torchvision

In [6]:
net = torchvision.models.resnet34(pretrained=False)
train_params = filter(lambda p: p.requires_grad, net.parameters())

optimizer = torch.optim.SGD(train_params, momentum=0.9, weight_decay=0.0001, lr=0.02)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.3)
#scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5, 30], gamma=0.1)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

In [7]:
for epoch in range(30):
    #train
    print('lr: %.4f, epoch: %d'%(scheduler.get_lr()[0], epoch))
    scheduler.step()

lr: 0.0200, epoch: 0
lr: 0.0196, epoch: 1
lr: 0.0194, epoch: 2
lr: 0.0192, epoch: 3
lr: 0.0190, epoch: 4
lr: 0.0188, epoch: 5
lr: 0.0186, epoch: 6
lr: 0.0185, epoch: 7
lr: 0.0183, epoch: 8
lr: 0.0181, epoch: 9
lr: 0.0179, epoch: 10
lr: 0.0177, epoch: 11
lr: 0.0176, epoch: 12
lr: 0.0174, epoch: 13
lr: 0.0172, epoch: 14
lr: 0.0170, epoch: 15
lr: 0.0169, epoch: 16
lr: 0.0167, epoch: 17
lr: 0.0165, epoch: 18
lr: 0.0164, epoch: 19
lr: 0.0162, epoch: 20
lr: 0.0160, epoch: 21
lr: 0.0159, epoch: 22
lr: 0.0157, epoch: 23
lr: 0.0156, epoch: 24
lr: 0.0154, epoch: 25
lr: 0.0152, epoch: 26
lr: 0.0151, epoch: 27
lr: 0.0149, epoch: 28
lr: 0.0148, epoch: 29


In [66]:
#ReduceLROnPlateau
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', 
                                                       factor=0.5, patience=4,
                                                       verbose=True, threshold=0.0001, 
                                                       threshold_mode='rel', cooldown=0, 
                                                       min_lr=0, eps=1e-08)

metric = 0.5
for epoch in range(30):
    #train
    #print('lr: %.4f, epoch: %d'%(scheduler.get_lr()[0], epoch))
    print('metric: ', metric)
    if np.random.rand()<0.2:
        metric += 0.02
    scheduler.step(metric)

metric:  0.5
metric:  0.52
metric:  0.52
metric:  0.52
metric:  0.54
Epoch    34: reducing learning rate of group 0 to 6.2500e-04.
metric:  0.56
metric:  0.5800000000000001
metric:  0.5800000000000001
metric:  0.6000000000000001
metric:  0.6000000000000001
metric:  0.6200000000000001
metric:  0.6200000000000001
metric:  0.6200000000000001
metric:  0.6200000000000001
metric:  0.6200000000000001
Epoch    44: reducing learning rate of group 0 to 3.1250e-04.
metric:  0.6200000000000001
metric:  0.6200000000000001
metric:  0.6200000000000001
metric:  0.6200000000000001
metric:  0.6200000000000001
metric:  0.6400000000000001
metric:  0.6400000000000001
metric:  0.6600000000000001
metric:  0.6600000000000001
metric:  0.6800000000000002
metric:  0.6800000000000002
metric:  0.6800000000000002
metric:  0.7000000000000002
metric:  0.7000000000000002
metric:  0.7000000000000002
