In [1]:
import os
import datetime
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import DataLoader

from nets.deeplabv3_plus import DeepLab

from utils.utils_training import get_lr_scheduler, set_optimizer_lr
from utils.utils_logs import LossHistory
from utils.dataloader import DeeplabDataset
from utils.utils_fit import fit_one_epoch

## load model and logs

In [2]:
Cuda = True
num_classes = 2
model_path = "model/deeplabv3+_model.pth"
dataset_path = 'weizmann_horse_db'
save_dir = 'logs'

input_shape = [512, 512]
def weights_init(net):
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and classname.find('Conv') != -1:
            torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find('BatchNorm2d') != -1:
            torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
            torch.nn.init.constant_(m.bias.data, 0.0)
    net.apply(init_func)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DeepLab(num_classes=num_classes)
weights_init(model)

# 根据预训练权重的Key和模型的Key进行加载
model_dict      = model.state_dict()
pretrained_dict = torch.load(model_path, map_location = device)

load_key, no_load_key, temp_dict = [], [], {}
for k, v in pretrained_dict.items():
    if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v):
        temp_dict[k] = v
        load_key.append(k)
    else:
        no_load_key.append(k)
model_dict.update(temp_dict)
model.load_state_dict(model_dict)

time_str = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S')
log_dir = os.path.join(save_dir, "loss_" + str(time_str))
loss_history = LossHistory(log_dir, model, input_shape=input_shape)
model_train = model.train()

if Cuda:
    model_train = torch.nn.DataParallel(model)
    cudnn.benchmark = True
    model_train = model_train.cuda()
print('The numbers of keys which fail loading:',len(no_load_key))

The numbers of keys which fial loading: 0


## dataset

In [3]:
Epoch = 10
batch_size = 4
def collate_function(batch):
    images = []
    pngs = []
    seg_labels = []
    for img, png, labels in batch:
        images.append(img)
        pngs.append(png)
        seg_labels.append(labels)
    images = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
    pngs = torch.from_numpy(np.array(pngs)).long()
    seg_labels = torch.from_numpy(np.array(seg_labels)).type(torch.FloatTensor)
    return images, pngs, seg_labels

with open(os.path.join(dataset_path, "datasets/train.txt"),"r") as f:
        train_lines = f.readlines()
with open(os.path.join(dataset_path, "datasets/val.txt"),"r") as f:
        val_lines = f.readlines()
num_train = len(train_lines)
num_val   = len(val_lines)

train_dataset = DeeplabDataset(train_lines, input_shape, num_classes, True, dataset_path)
val_dataset = DeeplabDataset(val_lines, input_shape, num_classes, False, dataset_path)

train_set = DataLoader(train_dataset, shuffle = True, batch_size = batch_size, pin_memory=True,
                 drop_last = True, collate_fn = collate_function)
val_set = DataLoader(val_dataset  , shuffle = True, batch_size = batch_size, pin_memory=True,
                     drop_last = True, collate_fn = collate_function)
# 判断每一个epoch的长度
epoch_step = num_train // batch_size
epoch_step_val = num_val // batch_size

## optimizer and learning rate

In [4]:
momentum = 0.9
weight_decay = 1e-4

Init_lr = 7e-3
Min_lr = Init_lr * 0.01

nbs  = 16
lr_limit_max = 1e-1
lr_limit_min = 5e-4
# 根据batch_size，自适应调整学习率
Init_lr_fit = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max)
Min_lr_fit = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2)
# SGD优化器
optimizer = optim.SGD(model.parameters(), Init_lr_fit, momentum = momentum, nesterov=True, weight_decay = weight_decay)
# 获得学习率下降的公式
lr_scheduler_func = get_lr_scheduler(Init_lr_fit, Min_lr_fit, Epoch)

## train

In [5]:
for epoch in range(Epoch):
    set_optimizer_lr(optimizer, lr_scheduler_func, epoch)
    fit_one_epoch(model_train, model, loss_history, optimizer, epoch, Epoch, epoch_step, epoch_step_val,
                  train_set, val_set, Cuda, num_classes, save_dir)
loss_history.writer.close()

Epoch(train) 1/10: 100%|█████████████████████████████████████████████| 69/69 [00:30<00:00,  2.27it/s, train_loss=0.118]
Epoch(valid) 1/10: 100%|██████████████████████████████████████████████| 12/12 [00:02<00:00,  5.09it/s, val_loss=0.0991]
Epoch(train) 2/10: 100%|█████████████████████████████████████████████| 69/69 [00:25<00:00,  2.72it/s, train_loss=0.128]
Epoch(valid) 2/10: 100%|██████████████████████████████████████████████| 12/12 [00:02<00:00,  4.95it/s, val_loss=0.0984]
Epoch(train) 3/10: 100%|█████████████████████████████████████████████| 69/69 [00:25<00:00,  2.70it/s, train_loss=0.118]
Epoch(valid) 3/10: 100%|██████████████████████████████████████████████| 12/12 [00:02<00:00,  5.00it/s, val_loss=0.0912]
Epoch(train) 4/10: 100%|█████████████████████████████████████████████| 69/69 [00:25<00:00,  2.70it/s, train_loss=0.126]
Epoch(valid) 4/10: 100%|██████████████████████████████████████████████| 12/12 [00:02<00:00,  4.88it/s, val_loss=0.0936]
Epoch(train) 5/10: 100%|████████████████