In [None]:
import numpy as np
import pandas as pd
import os
import gc

import random
import cv2

import collections
import time 
from datetime import datetime
from timeit import default_timer as timer
import math
import logging
from tqdm import tqdm
from PIL import Image
from functools import partial
train_on_gpu = True
import shutil

import torch
from torch.utils.data import TensorDataset, DataLoader, Dataset
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.optim.optimizer import Optimizer, required
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau, CosineAnnealingLR
from torch.utils.data.sampler import SubsetRandomSampler, SequentialSampler, RandomSampler
from torch.nn.parallel.data_parallel import data_parallel
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

#import segmentation_models_pytorch as smp

import albumentations as albu
from albumentations import pytorch as AT

from sklearn.model_selection import train_test_split, StratifiedKFold 

# Drawing
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib.cm
from mpl_toolkits.mplot3d import Axes3D

# 1 初始设置

## 1.1 文件路径

一般来说，一个比赛会有如下几个文件/文件夹：  
1、train.csv（或类似文件）：该文件一般包括了图片名称和rle信息（run length encode），如果是多标签的话一般也会包含标签的信息；  
2、sample_submission.csv（或类似文件）：一个提交格式的样板；  
3、train_images（文件夹）：其中包含了所有训练集的图片；  
4、test_images（文件夹）：其中包含了所有测试集的图片。  
  
注意，train与test图片的大小未必是一样的，提交前要注意test的mask尺寸大小如何，这会影响最终的rle编码。

In [None]:
PROJ_FOLDER = '../' # 项目文件夹
TRAIN_IMAGES =  PROJ_FOLDER + 'train_images/' # train图片
TEST_IMAGES = PROJ_FOLDER + 'test_images/' # test图片
train = pd.read_csv(PROJ_FOLDER + 'train.csv') # train.csv
sub = pd.read_csv(PROJ_FOLDER + 'sample_submission.csv') # sample_submission.csv

In [None]:
CACHE = PROJ_FOLDER + 'img_cache/' # 缓存文件夹
CACHE_IMG = CACHE + 'image/' # 用于存放压缩后的图片
CACHE_MASK = CACHE + 'mask/' # 用于存放压缩后的mask
CACHE_SPLIT = CACHE + 'split/' # 用于存放train/valid的分割文件

for folder in [CACHE, CACHE_IMG, CACHE_MASK, CACHE_SPLIT]:
    if not os.path.exists(folder):
        os.makedirs(folder)

## 1.2 csv文件处理

In [None]:
CLASSNAME_TO_CLASSNO={
    'Fish'   : 0,
    'Flower' : 1,
    'Gravel' : 2,
    'Sugar'  : 3,
} # 类别转数字

CLASSNO_TO_CLASSNAME = {v: k for k, v in CLASSNAME_TO_CLASSNO.items()} # 数字转类别
NUM_CLASS = len(CLASSNAME_TO_CLASSNO) # 总类别数

CLASS_COLOR = [(255,255,0), (0,0,255), (0,255,0), (0,255,255)] # 类别的颜色（用于画图）

#### train

In [None]:
df = train.fillna('')
df[['image_id', 'class_name']] = df['Image_Label'].str.split('_', expand = True)
df['class_no'] = df['class_name'].map(CLASSNAME_TO_CLASSNO)
df['encoded_pixel'] = df['EncodedPixels']
df['label'] = (df['EncodedPixels'] != '').astype(np.int32)
df = df[['image_id', 'class_no', 'class_name', 'label', 'encoded_pixel']]
df.head()

#### label透视表  (train)

In [None]:
pvt_df = pd.pivot_table(df, index='image_id', columns='class_name', values='label').reset_index()
pvt_df['mix_label'] = (pvt_df['Fish'].astype(str) + pvt_df['Flower'].astype(str) 
                       + pvt_df['Gravel'].astype(str) + pvt_df['Sugar'].astype(str))
pvt_df.head()

#### test

In [None]:
test_df = sub.fillna('')
test_df[['image_id', 'class_name']] = test_df['Image_Label'].str.split('_', expand = True)
test_df['class_no'] = test_df['class_name'].map(CLASSNAME_TO_CLASSNO)
test_df['encoded_pixel'] = test_df['EncodedPixels']
test_df['label'] = (test_df['EncodedPixels'] != '').astype(np.int32)
test_df = test_df[['image_id', 'class_no', 'class_name', 'label', 'encoded_pixel']]
test_df.head()

#### label透视表  (test)

In [None]:
test_df['encoded_pixel'] = ''
test_df['label'] = 0
pvt_test_df = pd.pivot_table(test_df, index='image_id', columns='class_name', values='label').reset_index()
pvt_test_df['mix_label'] = (pvt_test_df['Fish'].astype(str) + pvt_test_df['Flower'].astype(str)
                            + pvt_test_df['Gravel'].astype(str) + pvt_test_df['Sugar'].astype(str))
pvt_test_df.head()

## 1.3 常数设置

In [None]:
PI  = np.pi # pi
INF = np.inf #  无穷大
EPS = 1e-12 # 防止除以0

In [None]:
IMAGE_RGB_MEAN = [0.485, 0.456, 0.406] # Imagenet standards
IMAGE_RGB_STD  = [0.229, 0.224, 0.225]

In [None]:
NUM_TRAIN = len(df['image_id'].unique()) # train图片数量
NUM_TEST = len(test_df['image_id'].unique()) # test图片数量

In [None]:
IMAGE_TO_MASK_SCALE = 0.25 # 图片->mask的缩放系数，此处从1400×2100 -> 350×525
MASK_WIDTH = 525
MASK_HEIGHT = 350

In [None]:
NUM_TRAIN_POS = {
    'Fish'   : (pvt_df['Fish'].sum(), pvt_df['Fish'].sum()/NUM_TRAIN),
    'Flower' : (pvt_df['Flower'].sum(), pvt_df['Flower'].sum()/NUM_TRAIN),
    'Gravel' : (pvt_df['Gravel'].sum(), pvt_df['Gravel'].sum()/NUM_TRAIN),
    'Sugar'  : (pvt_df['Sugar'].sum(), pvt_df['Sugar'].sum()/NUM_TRAIN),
} # train中各标签的数量及其比例

NUM_TEST_POS = {
    'Fish'   : (1864, 1864/NUM_TEST),
    'Flower' : (1508, 1508/NUM_TEST),
    'Gravel' : (1982, 1982/NUM_TEST),
    'Sugar'  : (2382, 2382/NUM_TEST), 
} # 根据probe估测的各标签的数量及其比例

## 1.4 随机种子设置

In [None]:
SEED = 666
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

## 1.5 显卡相关设置

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' # 多卡使用 0/1 两张卡
#os.environ['CUDA_VISIBLE_DEVICES'] = '0' # 使用 0 一张卡

# 2 辅助函数

## 2.1 rle相关

In [None]:
def run_length_decode(rle, height=350, width=525, fill_value=1):
    '''
    从rle (string)转为mask
    
    height: mask高
    width: mask宽
    '''
    mask = np.zeros((height, width), np.float32)
    if rle != '':
        mask = mask.reshape(-1)
        r = [int(r) for r in rle.split(' ')]
        r = np.array(r).reshape(-1, 2)
        for start, length in r:
            start = start - 1  # 1 index
            mask[start:(start + length)] = fill_value
        mask = mask.reshape(width, height).T
    return mask

In [None]:
def run_length_encode(mask):
    '''
    从mask转为rle
    '''
    m = mask.T.flatten()
    if m.sum() == 0:
        rle = ''
    else:
        m = np.concatenate([[0], m, [0]])
        run = np.where(m[1:] != m[:-1])[0] + 1
        run[1::2] -= run[::2]
        rle = ' '.join(str(r) for r in run)
    return rle

## 2.2 图片与mask压缩

In [None]:
def run_dump_image_to_png(input_path, output_path, out_w, out_h):
    '''
    将图片压缩为更小尺寸的png文件
    
    input_path: 原图路径
    output_path: 新png的保存路径（没有则自动新建文件夹）
    out_w: 新png的宽
    out_h：新png的高
    '''
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    
    images = os.listdir(input_path)
    images.sort()

    for image_id in tqdm(images):
        print(image_id)
        image = cv2.imread(input_path + image_id, cv2.IMREAD_COLOR)
        image = cv2.resize(image, dsize=(out_w, out_h), interpolation=cv2.INTER_LINEAR)
        image_file = output_path + '%s.png'%image_id[:-4] # 去掉末尾的.jpg
        cv2.imwrite(image_file, image)

In [None]:
def run_dump_mask_to_png(df, output_path, in_w, in_h, out_w, out_h):
    '''
    提取mask后保存为更小尺寸的png文件
    
    df: 包含图片名和rle的表（如train）
    output_path: 新mask png的保存路径（没有则自动新建文件夹）
    in_w: 原图的宽
    in_h：原图的高
    out_w: 新mask png的宽
    out_h：新mask png的高
    '''
    gb = df.groupby('image_id')
    images = list(gb.groups.keys())

    for image_id in tqdm(images):
        img_info = gb.get_group(image_id).sort_values('class_no') # 某张图片每个标签的详情
        assert(len(img_info) == NUM_CLASS)

        rle = img_info['encoded_pixel'].values.tolist()
        mask = np.array([run_length_decode(r, height=in_h, width=in_w, fill_value=1) for r in rle]) # 包含NUM_CLASS个层

        mask = mask.transpose(1, 2, 0) # 变为 h×w×layer
        mask = cv2.resize(mask, dsize=(out_w, out_h), interpolation=cv2.INTER_LINEAR) # 变为 h'×w'×layer
        mask = (mask * 255).astype(np.uint8)

        mask_file = output_path + '%s.png'%image_id[:-4] # 去掉末尾的.jpg
        print(mask_file)
        cv2.imwrite(mask_file, mask)

## 2.3 tensor转图片/mask

In [None]:
def tensor_to_image(tensor):
    '''tensor转图片'''
    image = tensor.data.cpu().numpy()
    image = image.transpose(0,2,3,1) # (batch, h, w, layer)
    image = image[...,::-1] # 多了一行这个
    return image

In [None]:
def tensor_to_mask(tensor):
    '''tensor转mask'''
    mask = tensor.data.cpu().numpy()
    mask = mask.transpose(0,2,3,1) # (batch, h, w, layer)
    return mask

## 2.4 各种增扩函数

In [None]:
def do_flip_lr(image, mask):
    '''对图片和对应的mask进行左右翻转'''
    image = cv2.flip(image, 1)
    mask  = cv2.flip(mask, 1)
    return image, mask

In [None]:
def do_flip_ud(image, mask):
    '''对图片和对应的mask进行上下翻转'''
    image = cv2.flip(image, 0)
    mask = cv2.flip(mask, 0)
    return image, mask

In [None]:
def do_random_crop(image, mask, w, h, mw, mh):
    '''
    对图片和对应的mask进行随机裁切。
    若图片本身尺寸小于w×h，则不裁切。
    若大于w×h，则随机裁切一块w×h的区域，对应的mask裁切相应区域，而后将其缩放至mw×mh的大小。
    '''
    height, width = image.shape[:2]
    height_mask, width_mask = mask.shape[:2]
    mask = cv2.resize(mask, dsize=(width, height), interpolation=cv2.INTER_LINEAR) # mask缩放至图片大小
    
    # 若尺寸足够，裁切一块w×h的区域
    x, y = 0, 0
    if width > w:
        x = np.random.choice(width - w)
    if height > h:
        y = np.random.choice(height - h)
    image = image[y:y+h, x:x+w]
    mask  = mask [y:y+h, x:x+w]

    mask  = cv2.resize(mask, dsize=(mw, mh), interpolation=cv2.INTER_LINEAR)
    return image, mask

In [None]:
def do_random_crop_rescale(image, mask, w, h):
    '''
    对图片和对应的mask进行随机裁切后分别缩放至图片及mask的原尺寸。
    '''
    height, width = image.shape[:2]
    height_mask, width_mask = mask.shape[:2]
    mask = cv2.resize(mask, dsize=(width, height), interpolation=cv2.INTER_LINEAR) # mask缩放至图片大小

    x, y= 0, 0
    if width > w:
        x = np.random.choice(width - w)
    if height > h:
        y = np.random.choice(height - h)
    image = image[y:y+h, x:x+w]
    mask = mask[y:y+h, x:x+w]

    if (w,h) != (width,height):
        image = cv2.resize(image, dsize=(width,height), interpolation=cv2.INTER_LINEAR)
        mask  = cv2.resize(mask, dsize=(width_mask, height_mask), interpolation=cv2.INTER_LINEAR)

    return image, mask

In [None]:
def do_random_crop_rotate_rescale(image, mask, mode=['rotate','scale','shift']):
    '''
    对图片和对应的mask进行随机裁切。随后可选择旋转(rotate)、缩放(scale)或是移动(shift)。
    再分别缩放至图片及mask的原尺寸。
    '''
    height, width = image.shape[:2]
    height_mask, width_mask = mask.shape[:2]
    mask = cv2.resize(mask, dsize=(width,height), interpolation=cv2.INTER_LINEAR) # mask缩放至图片大小

    dangle = 0 # 旋转角度
    dscale_x, dscale_y = 0, 0 # 缩放系数
    dshift_x, dshift_y = 0, 0 # 移动系数

    if 'rotate' in mode:
        dangle = np.random.uniform(-30, 30) # 随机旋转±30°
    if 'scale' in mode:
        dscale_x, dscale_y = np.random.uniform(-1, 1, 2)*0.15 # 随机缩放图片长宽的±15% （x, y不同比例）
    if 'shift' in mode:
        dshift_x, dshift_y = np.random.uniform(-1, 1, 2)*0.10 # 随机移动图片长宽的±10% （x, y不同比例）

    cos = np.cos(dangle / 180 * PI)
    sin = np.sin(dangle / 180 * PI)
    sx, sy = 1 + dscale_x, 1 + dscale_y 
    tx, ty = dshift_x * width, dshift_y * height

    src = np.array([[-width/2,-height/2], [ width/2,-height/2], [ width/2, height/2], [-width/2, height/2]], np.float32)
    src = src * [sx,sy]
    x = (src * [cos,-sin]).sum(1) + width/2 + tx
    y = (src * [sin, cos]).sum(1) + height/2 + ty
    src = np.column_stack([x,y])

    dst = np.array([[0,0], [width,0], [width, height], [0,height]])
    s = src.astype(np.float32)
    d = dst.astype(np.float32)
    transform = cv2.getPerspectiveTransform(s, d)

    image = cv2.warpPerspective(image, transform, (width, height), flags=cv2.INTER_LINEAR, 
                                borderMode=cv2.BORDER_CONSTANT, borderValue=(0,0,0))
    mask = cv2.warpPerspective(mask, transform, (width, height), flags=cv2.INTER_LINEAR, 
                               borderMode=cv2.BORDER_CONSTANT, borderValue=(0,0,0,0))
    mask = cv2.resize(mask, dsize=(width_mask, height_mask), 
                      interpolation=cv2.INTER_LINEAR)

    return image, mask

## 2.5 激活函数

#### Mish

In [None]:
def mish(x):
    return x * torch.tanh(F.softplus(x))

In [None]:
class Mish(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, inputs):
        return mish(inputs)

#### Swish

In [None]:
def swish(x):
    return x * F.sigmoid(x)

In [None]:
class Swish(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, inputs):
        return swish(inputs)

# 3 压缩图片与mask

In [None]:
# 压缩储存图片
run_dump_image_to_png(TRAIN_IMAGES, CACHE_IMG + 'train_1050x700/', 1050, 700)
run_dump_image_to_png(TEST_IMAGES, CACHE_IMG + 'test_1050x700/', 1050, 700)

# 压缩储存mask
run_dump_mask_to_png(df, CACHE_MASK, 2100, 1400, 525, 350)

# 4 分割验证集

此处按照透视表中选择了按照label透视表中的`mix_label`进行分层取样划出3折，方法有待商榷。

In [None]:
skf = StratifiedKFold(n_splits=3, random_state=SEED, shuffle=True)
for i, (trn_index, val_index) in enumerate(skf.split(pvt_df, pvt_df['mix_label'])):
    train_ids = pvt_df['image_id'].iloc[trn_index].values # train 
    valid_ids = pvt_df['image_id'].iloc[val_index].values # valid
    np.save(CACHE_SPLIT + 'train_fold_a%d.npy'%i, train_ids)
    np.save(CACHE_SPLIT + 'valid_fold_a%d.npy'%i, valid_ids) # 保存分割

同样保存一个test的留作后续使用。

In [None]:
test_ids = test_df['image_id'].unique()
np.save(CACHE_SPLIT + 'test.npy', test_ids)

# 5 定义Dataset与数据加载方式

## 5.1 Dataset

Dataset主要用于读取数据，并每次输出单个数据。

In [None]:
class CloudDataset(Dataset): # 继承自torch的Dataset
    '''
    一个自定义的Dataset。一个Dataset至少要包括三个函数：
    __init__    用于初始化；
    __len__     用于定义数据量；
    __getitem__ 用于取数据；
    
    df: label透视表，用于获得每个mask对应的label
    split: 分割文件(.npy)的名称
    mode: Dataset的类型，指明是'train'还是'test'
    augment: 增扩
    '''
    def __init__(self, df, split, mode='train', augment=None):
        '''定义一些内部变量'''
        self.df = df
        self.split = split
        self.mode = mode
        self.augment = augment
        
        if mode == 'train': # train与test对应不同的文件夹
            self.folder = CACHE_IMG + 'train_1050x700/' 
        else:
            self.folder = CACHE_IMG + 'test_1050x700/'
            
        self.img_id = list(np.concatenate([np.load(CACHE_SPLIT + f , allow_pickle=True) for f in split]))
        self.num_img = len(self.img_id)
        
    def __len__(self):
        '''定义Dataset的长度'''
        return self.num_img
      
    def __getitem__(self, index):
        '''定义如何取数据'''
        image_id = self.img_id[index]
        image = cv2.imread('%s%s.png'%(self.folder, image_id[:-4]), cv2.IMREAD_COLOR)
        
        if self.mode == 'train':
            mask = cv2.imread(CACHE_MASK + '%s.png'%(image_id[:-4]), cv2.IMREAD_UNCHANGED)
        else:
            mask_w = MASK_WIDTH
            mask_h = MASK_HEIGHT
            mask = np.zeros((int(mask_h), int(mask_w), 4), np.uint8)
            
        image = image.astype(np.float32) / 255
        mask = mask.astype(np.float32) / 255
        label = self.df.loc[self.df['image_id'] == image_id][list(CLASSNAME_TO_CLASSNO.keys())]
        
        if self.augment is None:
            return image, label, mask, image_id # 可以返回多种内容
        else:
            return self.augment(image, label, mask, image_id) # 返回图、标签、mask和图名

## 5.2 collate_fn

Dataset中的\__getitem__定义了每个数据怎么取，而collate_fn定义了每个batch中的单个数据取出来之后有什么批量操作。collate_fn之后将作为DataLoader的参数。以下这个collate_fn仅对每个batch做数据拼接。

In [None]:
def null_collate(batch):
    '''
    定义一个collate_fn。Dataset中的__getitem__定义了每个数据怎么取，而collate_fn定义了
    每个batch中的单个数据取出来之后有什么批量操作。
    
    batch: 传入一个batch的数据
    '''
    batch_size = len(batch) 

    input_img = [] # 图
    truth_label = [] # 标签
    truth_mask  = [] # mask，三者与Dataset中的__getitem__的输出一致
    img_id = [] # 图名

    for b in range(batch_size):
        input_img.append(batch[b][0]) # 图
        truth_label.append(batch[b][1]) # 标签
        truth_mask.append(batch[b][2]) # mask
        img_id.append(batch[b][3]) # 图名，顺序与Dataset中的__getitem__的输出一致
    
    # 拼接图、标签、mask与图名
    input_img = np.stack(input_img) # (batch, h, w, layer)
    input_img = input_img[...,::-1].copy() # shape不变
    input_img = input_img.transpose(0,3,1,2) # (batch, layer， h, w)

    truth_mask = np.stack(truth_mask)
    truth_mask = truth_mask.transpose(0,3,1,2) # (batch, layer， h, w)

    truth_label = np.stack(truth_label)
    img_id = np.stack(img_id)

    # 转为tensor
    input_img = torch.from_numpy(input_img).float()
    truth_label = torch.from_numpy(truth_label).float()
    truth_mask = torch.from_numpy(truth_mask).float()

    # 对batch内的每个图根据mask来标定每个标签是0还是1
    if 1:
        m = truth_mask.view(batch_size, NUM_CLASS, -1).sum(-1) # (batch, NUM_CLASS)
        truth_label = (m > 0).float() # 

    return input_img, truth_label, truth_mask, img_id # 顺序与Dataset中的__getitem__的输出一致

## 5.3 Train的增扩（在线增扩）

In [None]:
def train_augment(image, label, mask, image_id):
    '''
    该增扩仅针对train，valid与test无需增扩。此外，如果需要全体resize的话，valid与test也要加resize增扩。
    参数顺序与Dataset中的__getitem__的输出一致
    '''

    if np.random.rand() > 0.5: # 50%概率左右翻转
        image, mask = do_flip_lr(image, mask)
    if np.random.rand() > 0.5: # 50%概率上下翻转
        image, mask = do_flip_ud(image, mask)

    image, mask = random.choice([
        lambda image, mask : (image, mask),
        lambda image, mask : do_random_crop_rescale(image, mask, w=925, h=630),
        lambda image, mask : do_random_crop_rotate_rescale(image, mask, mode=['rotate']),
    ])(image, mask) # 随机选择三种增扩中的一个（第一个是不变）


    return image, label, mask, image_id # 顺序与Dataset中的__getitem__的输出一致

## 5.4 创建DataLoader

In [None]:
batch_size_train = 4
batch_size_test = 4

In [None]:
# Train
train_dataset = CloudDataset(
    df = pvt_df, 
    split = ['train_fold_a0.npy',],
    mode = 'train',
    augment = train_augment,
)

train_loader  = DataLoader(
    train_dataset,
    sampler = RandomSampler(train_dataset), # 随机采样，指定了sampler，那么shuffle参数必须为False
    batch_size = batch_size_train,
    drop_last = True, # 如果最后一个batch的数据量小于batch_size，是否抛弃这部分数据
    num_workers = 4, # 多线程
    pin_memory = True,
    collate_fn = null_collate 
)

# Valid
valid_dataset = CloudDataset(
    df = pvt_df,
    split = ['valid_fold_a0.npy',],
    mode = 'train',
    augment = None,
)

valid_loader = DataLoader(
    valid_dataset,
    sampler = SequentialSampler(valid_dataset), # 顺序采样
    batch_size  = batch_size_test,
    drop_last   = False,
    num_workers = 4,
    pin_memory  = True,
    collate_fn  = null_collate
)

同样为test创建Dataset和DataLoader：

In [None]:
# Test
test_dataset = CloudDataset(
    df = pvt_test_df,
    split = ['test.npy',],
    mode = 'test',
    augment = None,
)

test_loader = DataLoader(
    test_dataset,
    sampler = SequentialSampler(test_dataset), # 顺序采样
    batch_size  = batch_size_test,
    drop_last   = False,
    num_workers = 4,
    pin_memory  = True,
    collate_fn  = null_collate
)

#### 测试一下

In [None]:
verify_dataset = CloudDataset(
    df = pvt_df, 
    split = ['train_fold_a0.npy',],
    mode = 'train',
    augment = train_augment,
)

verify_loader  = DataLoader(
    verify_dataset,
    sampler = RandomSampler(train_dataset), # 随机采样，指定了sampler，那么shuffle参数必须为False
    batch_size = batch_size_train,
    drop_last = True, # 如果最后一个batch的数据量小于batch_size，是否抛弃这部分数据
    num_workers = 4, # 多线程
    pin_memory = True,
    collate_fn = null_collate 
)

for t, (input_img, truth_label, truth_mask, image_id) in enumerate(verify_loader):
    if t < 5:
        print('----t=%d---'%t)
        print('')
        print('input', input_img.shape)
        print('truth_label', truth_label.shape)
        print('truth_mask ', truth_mask.shape)
        print('image_id ', image_id)
        print('')
    else:
        break

# 6 模型

## 6.1 Encoder部分

#### 下载预训练权重文件

In [None]:
torchvision.models.resnet34(pretrained=True) # pretrained=True即可下载权重文件

In [None]:
PRETRAIN_FILE = '/root/.cache/torch/checkpoints/resnet34-333f7ec4.pth' # 此处输入下载文件时的保存地址

#### Encoder组件

In [None]:
class ConvBn2d(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size=3, padding=1, stride=1):
        super(ConvBn2d, self).__init__()
        self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, padding=padding, stride=stride, bias=False)
        self.bn   = nn.BatchNorm2d(out_channel, eps=1e-5)

    def forward(self,x):
        x = self.conv(x)
        x = self.bn(x)
        return x

In [None]:
class BasicBlock(nn.Module): # bottleneck type C
    def __init__(self, in_channel, channel, out_channel, stride=1, is_shortcut=False):
        super(BasicBlock, self).__init__()
        self.is_shortcut = is_shortcut

        self.conv_bn1 = ConvBn2d(in_channel, channel, kernel_size=3, padding=1, stride=stride)
        self.conv_bn2 = ConvBn2d(channel, out_channel, kernel_size=3, padding=1, stride=1)

        if is_shortcut:
            self.shortcut = ConvBn2d(in_channel, out_channel, kernel_size=1, padding=0, stride=stride)

    def forward(self, x):
        #z = F.relu(self.conv_bn1(x),inplace=True)
        z = mish(self.conv_bn1(x))
        z = self.conv_bn2(z)

        if self.is_shortcut:
            x = self.shortcut(x)

        z += x
        #z = F.relu(z,inplace=True)
        z = mish(z)
        return z

#### Backbone

In [None]:
class ResNet34(nn.Module):

    def __init__(self, num_class=1000 ):
        super(ResNet34, self).__init__()

        self.block0  = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, padding=3, stride=2, bias=True),
            nn.BatchNorm2d(64),
            #nn.ReLU(inplace=True),
            Mish(),
        )
        self.block0[0].bias.data.fill_(0.0)

        self.block1  = nn.Sequential(
             nn.MaxPool2d(kernel_size=3, padding=1, stride=2),
             BasicBlock( 64, 64, 64, stride=1, is_shortcut=False,),
          * [BasicBlock( 64, 64, 64, stride=1, is_shortcut=False,) for i in range(1,3)],
        )
        self.block2  = nn.Sequential(
             BasicBlock( 64,128,128, stride=2, is_shortcut=True, ),
          * [BasicBlock(128,128,128, stride=1, is_shortcut=False,) for i in range(1,4)],
        )
        self.block3  = nn.Sequential(
             BasicBlock(128,256,256, stride=2, is_shortcut=True, ),
          * [BasicBlock(256,256,256, stride=1, is_shortcut=False,) for i in range(1,6)],
        )
        self.block4 = nn.Sequential(
             BasicBlock(256,512,512, stride=2, is_shortcut=True, ),
          * [BasicBlock(512,512,512, stride=1, is_shortcut=False,) for i in range(1,3)],
        )
        self.logit = nn.Linear(512,num_class)

    def forward(self, x):
        batch_size = len(x)

        x = self.block0(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = F.adaptive_avg_pool2d(x,1).reshape(batch_size,-1)
        logit = self.logit(x)
        return logit

## 6.2 加载Encoder权重

In [None]:
CONVERSION=[
 'block0.0.weight',	(64, 3, 7, 7),	 'conv1.weight',	(64, 3, 7, 7),
 'block0.1.weight',	(64,),	 'bn1.weight',	(64,),
 'block0.1.bias',	(64,),	 'bn1.bias',	(64,),
 'block0.1.running_mean',	(64,),	 'bn1.running_mean',	(64,),
 'block0.1.running_var',	(64,),	 'bn1.running_var',	(64,),
 'block1.1.conv_bn1.conv.weight',	(64, 64, 3, 3),	 'layer1.0.conv1.weight',	(64, 64, 3, 3),
 'block1.1.conv_bn1.bn.weight',	(64,),	 'layer1.0.bn1.weight',	(64,),
 'block1.1.conv_bn1.bn.bias',	(64,),	 'layer1.0.bn1.bias',	(64,),
 'block1.1.conv_bn1.bn.running_mean',	(64,),	 'layer1.0.bn1.running_mean',	(64,),
 'block1.1.conv_bn1.bn.running_var',	(64,),	 'layer1.0.bn1.running_var',	(64,),
 'block1.1.conv_bn2.conv.weight',	(64, 64, 3, 3),	 'layer1.0.conv2.weight',	(64, 64, 3, 3),
 'block1.1.conv_bn2.bn.weight',	(64,),	 'layer1.0.bn2.weight',	(64,),
 'block1.1.conv_bn2.bn.bias',	(64,),	 'layer1.0.bn2.bias',	(64,),
 'block1.1.conv_bn2.bn.running_mean',	(64,),	 'layer1.0.bn2.running_mean',	(64,),
 'block1.1.conv_bn2.bn.running_var',	(64,),	 'layer1.0.bn2.running_var',	(64,),
 'block1.2.conv_bn1.conv.weight',	(64, 64, 3, 3),	 'layer1.1.conv1.weight',	(64, 64, 3, 3),
 'block1.2.conv_bn1.bn.weight',	(64,),	 'layer1.1.bn1.weight',	(64,),
 'block1.2.conv_bn1.bn.bias',	(64,),	 'layer1.1.bn1.bias',	(64,),
 'block1.2.conv_bn1.bn.running_mean',	(64,),	 'layer1.1.bn1.running_mean',	(64,),
 'block1.2.conv_bn1.bn.running_var',	(64,),	 'layer1.1.bn1.running_var',	(64,),
 'block1.2.conv_bn2.conv.weight',	(64, 64, 3, 3),	 'layer1.1.conv2.weight',	(64, 64, 3, 3),
 'block1.2.conv_bn2.bn.weight',	(64,),	 'layer1.1.bn2.weight',	(64,),
 'block1.2.conv_bn2.bn.bias',	(64,),	 'layer1.1.bn2.bias',	(64,),
 'block1.2.conv_bn2.bn.running_mean',	(64,),	 'layer1.1.bn2.running_mean',	(64,),
 'block1.2.conv_bn2.bn.running_var',	(64,),	 'layer1.1.bn2.running_var',	(64,),
 'block1.3.conv_bn1.conv.weight',	(64, 64, 3, 3),	 'layer1.2.conv1.weight',	(64, 64, 3, 3),
 'block1.3.conv_bn1.bn.weight',	(64,),	 'layer1.2.bn1.weight',	(64,),
 'block1.3.conv_bn1.bn.bias',	(64,),	 'layer1.2.bn1.bias',	(64,),
 'block1.3.conv_bn1.bn.running_mean',	(64,),	 'layer1.2.bn1.running_mean',	(64,),
 'block1.3.conv_bn1.bn.running_var',	(64,),	 'layer1.2.bn1.running_var',	(64,),
 'block1.3.conv_bn2.conv.weight',	(64, 64, 3, 3),	 'layer1.2.conv2.weight',	(64, 64, 3, 3),
 'block1.3.conv_bn2.bn.weight',	(64,),	 'layer1.2.bn2.weight',	(64,),
 'block1.3.conv_bn2.bn.bias',	(64,),	 'layer1.2.bn2.bias',	(64,),
 'block1.3.conv_bn2.bn.running_mean',	(64,),	 'layer1.2.bn2.running_mean',	(64,),
 'block1.3.conv_bn2.bn.running_var',	(64,),	 'layer1.2.bn2.running_var',	(64,),
 'block2.0.conv_bn1.conv.weight',	(128, 64, 3, 3),	 'layer2.0.conv1.weight',	(128, 64, 3, 3),
 'block2.0.conv_bn1.bn.weight',	(128,),	 'layer2.0.bn1.weight',	(128,),
 'block2.0.conv_bn1.bn.bias',	(128,),	 'layer2.0.bn1.bias',	(128,),
 'block2.0.conv_bn1.bn.running_mean',	(128,),	 'layer2.0.bn1.running_mean',	(128,),
 'block2.0.conv_bn1.bn.running_var',	(128,),	 'layer2.0.bn1.running_var',	(128,),
 'block2.0.conv_bn2.conv.weight',	(128, 128, 3, 3),	 'layer2.0.conv2.weight',	(128, 128, 3, 3),
 'block2.0.conv_bn2.bn.weight',	(128,),	 'layer2.0.bn2.weight',	(128,),
 'block2.0.conv_bn2.bn.bias',	(128,),	 'layer2.0.bn2.bias',	(128,),
 'block2.0.conv_bn2.bn.running_mean',	(128,),	 'layer2.0.bn2.running_mean',	(128,),
 'block2.0.conv_bn2.bn.running_var',	(128,),	 'layer2.0.bn2.running_var',	(128,),
 'block2.0.shortcut.conv.weight',	(128, 64, 1, 1),	 'layer2.0.downsample.0.weight',	(128, 64, 1, 1),
 'block2.0.shortcut.bn.weight',	(128,),	 'layer2.0.downsample.1.weight',	(128,),
 'block2.0.shortcut.bn.bias',	(128,),	 'layer2.0.downsample.1.bias',	(128,),
 'block2.0.shortcut.bn.running_mean',	(128,),	 'layer2.0.downsample.1.running_mean',	(128,),
 'block2.0.shortcut.bn.running_var',	(128,),	 'layer2.0.downsample.1.running_var',	(128,),
 'block2.1.conv_bn1.conv.weight',	(128, 128, 3, 3),	 'layer2.1.conv1.weight',	(128, 128, 3, 3),
 'block2.1.conv_bn1.bn.weight',	(128,),	 'layer2.1.bn1.weight',	(128,),
 'block2.1.conv_bn1.bn.bias',	(128,),	 'layer2.1.bn1.bias',	(128,),
 'block2.1.conv_bn1.bn.running_mean',	(128,),	 'layer2.1.bn1.running_mean',	(128,),
 'block2.1.conv_bn1.bn.running_var',	(128,),	 'layer2.1.bn1.running_var',	(128,),
 'block2.1.conv_bn2.conv.weight',	(128, 128, 3, 3),	 'layer2.1.conv2.weight',	(128, 128, 3, 3),
 'block2.1.conv_bn2.bn.weight',	(128,),	 'layer2.1.bn2.weight',	(128,),
 'block2.1.conv_bn2.bn.bias',	(128,),	 'layer2.1.bn2.bias',	(128,),
 'block2.1.conv_bn2.bn.running_mean',	(128,),	 'layer2.1.bn2.running_mean',	(128,),
 'block2.1.conv_bn2.bn.running_var',	(128,),	 'layer2.1.bn2.running_var',	(128,),
 'block2.2.conv_bn1.conv.weight',	(128, 128, 3, 3),	 'layer2.2.conv1.weight',	(128, 128, 3, 3),
 'block2.2.conv_bn1.bn.weight',	(128,),	 'layer2.2.bn1.weight',	(128,),
 'block2.2.conv_bn1.bn.bias',	(128,),	 'layer2.2.bn1.bias',	(128,),
 'block2.2.conv_bn1.bn.running_mean',	(128,),	 'layer2.2.bn1.running_mean',	(128,),
 'block2.2.conv_bn1.bn.running_var',	(128,),	 'layer2.2.bn1.running_var',	(128,),
 'block2.2.conv_bn2.conv.weight',	(128, 128, 3, 3),	 'layer2.2.conv2.weight',	(128, 128, 3, 3),
 'block2.2.conv_bn2.bn.weight',	(128,),	 'layer2.2.bn2.weight',	(128,),
 'block2.2.conv_bn2.bn.bias',	(128,),	 'layer2.2.bn2.bias',	(128,),
 'block2.2.conv_bn2.bn.running_mean',	(128,),	 'layer2.2.bn2.running_mean',	(128,),
 'block2.2.conv_bn2.bn.running_var',	(128,),	 'layer2.2.bn2.running_var',	(128,),
 'block2.3.conv_bn1.conv.weight',	(128, 128, 3, 3),	 'layer2.3.conv1.weight',	(128, 128, 3, 3),
 'block2.3.conv_bn1.bn.weight',	(128,),	 'layer2.3.bn1.weight',	(128,),
 'block2.3.conv_bn1.bn.bias',	(128,),	 'layer2.3.bn1.bias',	(128,),
 'block2.3.conv_bn1.bn.running_mean',	(128,),	 'layer2.3.bn1.running_mean',	(128,),
 'block2.3.conv_bn1.bn.running_var',	(128,),	 'layer2.3.bn1.running_var',	(128,),
 'block2.3.conv_bn2.conv.weight',	(128, 128, 3, 3),	 'layer2.3.conv2.weight',	(128, 128, 3, 3),
 'block2.3.conv_bn2.bn.weight',	(128,),	 'layer2.3.bn2.weight',	(128,),
 'block2.3.conv_bn2.bn.bias',	(128,),	 'layer2.3.bn2.bias',	(128,),
 'block2.3.conv_bn2.bn.running_mean',	(128,),	 'layer2.3.bn2.running_mean',	(128,),
 'block2.3.conv_bn2.bn.running_var',	(128,),	 'layer2.3.bn2.running_var',	(128,),
 'block3.0.conv_bn1.conv.weight',	(256, 128, 3, 3),	 'layer3.0.conv1.weight',	(256, 128, 3, 3),
 'block3.0.conv_bn1.bn.weight',	(256,),	 'layer3.0.bn1.weight',	(256,),
 'block3.0.conv_bn1.bn.bias',	(256,),	 'layer3.0.bn1.bias',	(256,),
 'block3.0.conv_bn1.bn.running_mean',	(256,),	 'layer3.0.bn1.running_mean',	(256,),
 'block3.0.conv_bn1.bn.running_var',	(256,),	 'layer3.0.bn1.running_var',	(256,),
 'block3.0.conv_bn2.conv.weight',	(256, 256, 3, 3),	 'layer3.0.conv2.weight',	(256, 256, 3, 3),
 'block3.0.conv_bn2.bn.weight',	(256,),	 'layer3.0.bn2.weight',	(256,),
 'block3.0.conv_bn2.bn.bias',	(256,),	 'layer3.0.bn2.bias',	(256,),
 'block3.0.conv_bn2.bn.running_mean',	(256,),	 'layer3.0.bn2.running_mean',	(256,),
 'block3.0.conv_bn2.bn.running_var',	(256,),	 'layer3.0.bn2.running_var',	(256,),
 'block3.0.shortcut.conv.weight',	(256, 128, 1, 1),	 'layer3.0.downsample.0.weight',	(256, 128, 1, 1),
 'block3.0.shortcut.bn.weight',	(256,),	 'layer3.0.downsample.1.weight',	(256,),
 'block3.0.shortcut.bn.bias',	(256,),	 'layer3.0.downsample.1.bias',	(256,),
 'block3.0.shortcut.bn.running_mean',	(256,),	 'layer3.0.downsample.1.running_mean',	(256,),
 'block3.0.shortcut.bn.running_var',	(256,),	 'layer3.0.downsample.1.running_var',	(256,),
 'block3.1.conv_bn1.conv.weight',	(256, 256, 3, 3),	 'layer3.1.conv1.weight',	(256, 256, 3, 3),
 'block3.1.conv_bn1.bn.weight',	(256,),	 'layer3.1.bn1.weight',	(256,),
 'block3.1.conv_bn1.bn.bias',	(256,),	 'layer3.1.bn1.bias',	(256,),
 'block3.1.conv_bn1.bn.running_mean',	(256,),	 'layer3.1.bn1.running_mean',	(256,),
 'block3.1.conv_bn1.bn.running_var',	(256,),	 'layer3.1.bn1.running_var',	(256,),
 'block3.1.conv_bn2.conv.weight',	(256, 256, 3, 3),	 'layer3.1.conv2.weight',	(256, 256, 3, 3),
 'block3.1.conv_bn2.bn.weight',	(256,),	 'layer3.1.bn2.weight',	(256,),
 'block3.1.conv_bn2.bn.bias',	(256,),	 'layer3.1.bn2.bias',	(256,),
 'block3.1.conv_bn2.bn.running_mean',	(256,),	 'layer3.1.bn2.running_mean',	(256,),
 'block3.1.conv_bn2.bn.running_var',	(256,),	 'layer3.1.bn2.running_var',	(256,),
 'block3.2.conv_bn1.conv.weight',	(256, 256, 3, 3),	 'layer3.2.conv1.weight',	(256, 256, 3, 3),
 'block3.2.conv_bn1.bn.weight',	(256,),	 'layer3.2.bn1.weight',	(256,),
 'block3.2.conv_bn1.bn.bias',	(256,),	 'layer3.2.bn1.bias',	(256,),
 'block3.2.conv_bn1.bn.running_mean',	(256,),	 'layer3.2.bn1.running_mean',	(256,),
 'block3.2.conv_bn1.bn.running_var',	(256,),	 'layer3.2.bn1.running_var',	(256,),
 'block3.2.conv_bn2.conv.weight',	(256, 256, 3, 3),	 'layer3.2.conv2.weight',	(256, 256, 3, 3),
 'block3.2.conv_bn2.bn.weight',	(256,),	 'layer3.2.bn2.weight',	(256,),
 'block3.2.conv_bn2.bn.bias',	(256,),	 'layer3.2.bn2.bias',	(256,),
 'block3.2.conv_bn2.bn.running_mean',	(256,),	 'layer3.2.bn2.running_mean',	(256,),
 'block3.2.conv_bn2.bn.running_var',	(256,),	 'layer3.2.bn2.running_var',	(256,),
 'block3.3.conv_bn1.conv.weight',	(256, 256, 3, 3),	 'layer3.3.conv1.weight',	(256, 256, 3, 3),
 'block3.3.conv_bn1.bn.weight',	(256,),	 'layer3.3.bn1.weight',	(256,),
 'block3.3.conv_bn1.bn.bias',	(256,),	 'layer3.3.bn1.bias',	(256,),
 'block3.3.conv_bn1.bn.running_mean',	(256,),	 'layer3.3.bn1.running_mean',	(256,),
 'block3.3.conv_bn1.bn.running_var',	(256,),	 'layer3.3.bn1.running_var',	(256,),
 'block3.3.conv_bn2.conv.weight',	(256, 256, 3, 3),	 'layer3.3.conv2.weight',	(256, 256, 3, 3),
 'block3.3.conv_bn2.bn.weight',	(256,),	 'layer3.3.bn2.weight',	(256,),
 'block3.3.conv_bn2.bn.bias',	(256,),	 'layer3.3.bn2.bias',	(256,),
 'block3.3.conv_bn2.bn.running_mean',	(256,),	 'layer3.3.bn2.running_mean',	(256,),
 'block3.3.conv_bn2.bn.running_var',	(256,),	 'layer3.3.bn2.running_var',	(256,),
 'block3.4.conv_bn1.conv.weight',	(256, 256, 3, 3),	 'layer3.4.conv1.weight',	(256, 256, 3, 3),
 'block3.4.conv_bn1.bn.weight',	(256,),	 'layer3.4.bn1.weight',	(256,),
 'block3.4.conv_bn1.bn.bias',	(256,),	 'layer3.4.bn1.bias',	(256,),
 'block3.4.conv_bn1.bn.running_mean',	(256,),	 'layer3.4.bn1.running_mean',	(256,),
 'block3.4.conv_bn1.bn.running_var',	(256,),	 'layer3.4.bn1.running_var',	(256,),
 'block3.4.conv_bn2.conv.weight',	(256, 256, 3, 3),	 'layer3.4.conv2.weight',	(256, 256, 3, 3),
 'block3.4.conv_bn2.bn.weight',	(256,),	 'layer3.4.bn2.weight',	(256,),
 'block3.4.conv_bn2.bn.bias',	(256,),	 'layer3.4.bn2.bias',	(256,),
 'block3.4.conv_bn2.bn.running_mean',	(256,),	 'layer3.4.bn2.running_mean',	(256,),
 'block3.4.conv_bn2.bn.running_var',	(256,),	 'layer3.4.bn2.running_var',	(256,),
 'block3.5.conv_bn1.conv.weight',	(256, 256, 3, 3),	 'layer3.5.conv1.weight',	(256, 256, 3, 3),
 'block3.5.conv_bn1.bn.weight',	(256,),	 'layer3.5.bn1.weight',	(256,),
 'block3.5.conv_bn1.bn.bias',	(256,),	 'layer3.5.bn1.bias',	(256,),
 'block3.5.conv_bn1.bn.running_mean',	(256,),	 'layer3.5.bn1.running_mean',	(256,),
 'block3.5.conv_bn1.bn.running_var',	(256,),	 'layer3.5.bn1.running_var',	(256,),
 'block3.5.conv_bn2.conv.weight',	(256, 256, 3, 3),	 'layer3.5.conv2.weight',	(256, 256, 3, 3),
 'block3.5.conv_bn2.bn.weight',	(256,),	 'layer3.5.bn2.weight',	(256,),
 'block3.5.conv_bn2.bn.bias',	(256,),	 'layer3.5.bn2.bias',	(256,),
 'block3.5.conv_bn2.bn.running_mean',	(256,),	 'layer3.5.bn2.running_mean',	(256,),
 'block3.5.conv_bn2.bn.running_var',	(256,),	 'layer3.5.bn2.running_var',	(256,),
 'block4.0.conv_bn1.conv.weight',	(512, 256, 3, 3),	 'layer4.0.conv1.weight',	(512, 256, 3, 3),
 'block4.0.conv_bn1.bn.weight',	(512,),	 'layer4.0.bn1.weight',	(512,),
 'block4.0.conv_bn1.bn.bias',	(512,),	 'layer4.0.bn1.bias',	(512,),
 'block4.0.conv_bn1.bn.running_mean',	(512,),	 'layer4.0.bn1.running_mean',	(512,),
 'block4.0.conv_bn1.bn.running_var',	(512,),	 'layer4.0.bn1.running_var',	(512,),
 'block4.0.conv_bn2.conv.weight',	(512, 512, 3, 3),	 'layer4.0.conv2.weight',	(512, 512, 3, 3),
 'block4.0.conv_bn2.bn.weight',	(512,),	 'layer4.0.bn2.weight',	(512,),
 'block4.0.conv_bn2.bn.bias',	(512,),	 'layer4.0.bn2.bias',	(512,),
 'block4.0.conv_bn2.bn.running_mean',	(512,),	 'layer4.0.bn2.running_mean',	(512,),
 'block4.0.conv_bn2.bn.running_var',	(512,),	 'layer4.0.bn2.running_var',	(512,),
 'block4.0.shortcut.conv.weight',	(512, 256, 1, 1),	 'layer4.0.downsample.0.weight',	(512, 256, 1, 1),
 'block4.0.shortcut.bn.weight',	(512,),	 'layer4.0.downsample.1.weight',	(512,),
 'block4.0.shortcut.bn.bias',	(512,),	 'layer4.0.downsample.1.bias',	(512,),
 'block4.0.shortcut.bn.running_mean',	(512,),	 'layer4.0.downsample.1.running_mean',	(512,),
 'block4.0.shortcut.bn.running_var',	(512,),	 'layer4.0.downsample.1.running_var',	(512,),
 'block4.1.conv_bn1.conv.weight',	(512, 512, 3, 3),	 'layer4.1.conv1.weight',	(512, 512, 3, 3),
 'block4.1.conv_bn1.bn.weight',	(512,),	 'layer4.1.bn1.weight',	(512,),
 'block4.1.conv_bn1.bn.bias',	(512,),	 'layer4.1.bn1.bias',	(512,),
 'block4.1.conv_bn1.bn.running_mean',	(512,),	 'layer4.1.bn1.running_mean',	(512,),
 'block4.1.conv_bn1.bn.running_var',	(512,),	 'layer4.1.bn1.running_var',	(512,),
 'block4.1.conv_bn2.conv.weight',	(512, 512, 3, 3),	 'layer4.1.conv2.weight',	(512, 512, 3, 3),
 'block4.1.conv_bn2.bn.weight',	(512,),	 'layer4.1.bn2.weight',	(512,),
 'block4.1.conv_bn2.bn.bias',	(512,),	 'layer4.1.bn2.bias',	(512,),
 'block4.1.conv_bn2.bn.running_mean',	(512,),	 'layer4.1.bn2.running_mean',	(512,),
 'block4.1.conv_bn2.bn.running_var',	(512,),	 'layer4.1.bn2.running_var',	(512,),
 'block4.2.conv_bn1.conv.weight',	(512, 512, 3, 3),	 'layer4.2.conv1.weight',	(512, 512, 3, 3),
 'block4.2.conv_bn1.bn.weight',	(512,),	 'layer4.2.bn1.weight',	(512,),
 'block4.2.conv_bn1.bn.bias',	(512,),	 'layer4.2.bn1.bias',	(512,),
 'block4.2.conv_bn1.bn.running_mean',	(512,),	 'layer4.2.bn1.running_mean',	(512,),
 'block4.2.conv_bn1.bn.running_var',	(512,),	 'layer4.2.bn1.running_var',	(512,),
 'block4.2.conv_bn2.conv.weight',	(512, 512, 3, 3),	 'layer4.2.conv2.weight',	(512, 512, 3, 3),
 'block4.2.conv_bn2.bn.weight',	(512,),	 'layer4.2.bn2.weight',	(512,),
 'block4.2.conv_bn2.bn.bias',	(512,),	 'layer4.2.bn2.bias',	(512,),
 'block4.2.conv_bn2.bn.running_mean',	(512,),	 'layer4.2.bn2.running_mean',	(512,),
 'block4.2.conv_bn2.bn.running_var',	(512,),	 'layer4.2.bn2.running_var',	(512,),
 'logit.weight',	(1000, 512),	 'fc.weight',	(1000, 512),
 'logit.bias',	(1000,),	 'fc.bias',	(1000,),
]

In [None]:
def absorb_rgb_normalisation_to_conv(weight, bias, rgb_mean=IMAGE_RGB_MEAN, rgb_std=IMAGE_RGB_STD ):
    '''加载权重时需要调用，作用尚不明确……'''
    out, c, h, w = weight.shape
    u = torch.from_numpy(np.array(rgb_mean,np.float32).reshape(1,3,1,1)).to(weight.device)
    s = torch.from_numpy(np.array(rgb_std ,np.float32).reshape(1,3,1,1)).to(weight.device)

    norm_weight = weight/s
    norm_bias = -u*weight/s
    norm_bias = norm_bias.sum(dim=[1,2,3]) + bias.to(weight.device)

    return norm_weight, norm_bias

In [None]:
def load_pretrain(net, skip=[], pretrain_file=PRETRAIN_FILE, conversion=CONVERSION, is_print=True):
    '''加载预训练权重'''
    print('Load pretrain_file: %s'%pretrain_file)

    pretrain_state_dict = torch.load(pretrain_file, map_location=lambda storage, loc: storage)
    state_dict = net.state_dict()

    i = 0
    conversion = np.array(conversion, dtype=object).reshape(-1,4)
    for key, _ , pretrain_key, _ in conversion:
        if any(s in key for s in ['.num_batches_tracked', ] + skip):
            continue

        if is_print: # 打印信息
            print('\t\t', '%-48s  %-24s  <---  %-32s  %-24s'%(
                key, str(state_dict[key].shape),
                pretrain_key, 
                str(pretrain_state_dict[pretrain_key].shape),
            ))
        i = i+1

        state_dict[key] = pretrain_state_dict[pretrain_key]

    if 1:
        state_dict['block0.0.weight'], state_dict['block0.0.bias'] =(
            absorb_rgb_normalisation_to_conv(state_dict['block0.0.weight'], state_dict['block0.0.bias']))

    net.load_state_dict(state_dict)
    print('')
    print('len(pretrain_state_dict.keys()) = %d'%len(pretrain_state_dict.keys()))
    print('len(state_dict.keys()) = %d'%len(state_dict.keys()))
    print('loaded = %d'%i)
    print('')

#### 测试一下

In [None]:
net = ResNet34()
load_pretrain(net, is_print=True)

## 6.3 Decoder部分

#### Decoder组件

In [None]:
def resize_like(x, reference, mode='bilinear'):
    '''
    用于将x缩放成reference同样尺寸（对齐feature map）
    '''
    if x.shape[2:] !=  reference.shape[2:]:
        if mode=='bilinear':
            x = F.interpolate(x, size=reference.shape[2:],mode='bilinear',align_corners=False)
        if mode=='nearest':
            x = F.interpolate(x, size=reference.shape[2:],mode='nearest')
    return x

In [None]:
class SCSEModule(nn.Module):
    '''
    SCSE Attention模块（spatial/channel sqeeze & excitation）
    '''
    def __init__(self, ch, re=16):
        super().__init__()
        self.cSE = nn.Sequential(nn.AdaptiveAvgPool2d(1),
                                 nn.Conv2d(ch, ch//re, 1),
                                 nn.ReLU(inplace=True),
                                 nn.Conv2d(ch//re, ch, 1),
                                 nn.Sigmoid()
                                )
        self.sSE = nn.Sequential(nn.Conv2d(ch, ch, 1),
                                 nn.Sigmoid())

    def forward(self, x):
        return x * self.cSE(x) + x * self.sSE(x)

#### Decoder

In [None]:
class Decode(nn.Module):
    def __init__(self, in_channel, channel, out_channel):
        super(Decode, self).__init__()
        
        self.attention1 = SCSEModule(in_channel)
        self.attention2 = SCSEModule(out_channel)
        
        self.top = nn.Sequential(
            nn.Conv2d(in_channel, channel, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(channel),
            nn.ReLU(inplace=True),
            #nn.Dropout(0.1),
            
            nn.Conv2d(channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True),
            #nn.Dropout(0.1),

            # nn.Conv2d(out_channel//2, out_channel, kernel_size=1, stride=1, padding=0, bias=False),
            # BatchNorm2d(out_channel),
            # nn.ReLU(inplace=True), #Swish(), #
        )

    def forward(self, x):
        x = torch.cat(x, 1) # 拼接[xn, resized x]
        x = self.attention1(x) # 可以去掉
        x = self.top(x)
        x = self.attention2(x) # 可以去掉
        return x

## 6.4 完整模型

In [None]:
class Net(nn.Module):
    '''
    Unet作为Architecture，Resnet34作为Backbone
    '''
    def load_pretrain(self, skip=['logit.'], is_print=True):
        '''
        加载预训练权重
        '''
        load_pretrain(self, skip, pretrain_file=PRETRAIN_FILE, conversion=CONVERSION, is_print=is_print)

    def __init__(self, num_class=4 ):
        super(Net, self).__init__()
        e = ResNet34() # 加载Backbone，提取各个组件后丢弃
        self.block0 = e.block0
        self.block1 = e.block1
        self.block2 = e.block2
        self.block3 = e.block3
        self.block4 = e.block4
        e = None  # dropped

        self.center= nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2),
            ConvBn2d(512, 1024),
            nn.ELU(inplace=True),
            ConvBn2d(1024, 512),
        )

        self.decode1 =  Decode(512+512,512,256) # Decode的输入是[xn, resized x]，所以维度翻番
        self.decode2 =  Decode(256+256,256,128)
        self.decode3 =  Decode(128+128,128, 64)
        self.decode4 =  Decode( 64+ 64,128, 64)
        self.decode5 =  Decode( 64+ 64, 64, 64)

        self.logit = nn.Conv2d(64, num_class, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        batch_size,C,H,W = x.shape

        x0 = self.block0(x)
        x1 = self.block1(x0)
        x2 = self.block2(x1)
        x3 = self.block3(x2)
        x4 = self.block4(x3)
        
        x  = self.center(x4)

        x = self.decode1([x4, resize_like(x,x4)])          #; print('d1',x.size())
        x = self.decode2([x3, resize_like(x,x3)])          #; print('d2',x.size())
        x = self.decode3([x2, resize_like(x,x2)])          #; print('d3',x.size())
        x = self.decode4([x1, resize_like(x,x1)])          #; print('d4',x.size())
        x = self.decode5([x0, resize_like(x,x0)])          #; print('d5',x.size())

        logit = self.logit(x)

        probability_mask  = torch.sigmoid(logit)
        probability_label = F.adaptive_max_pool2d(probability_mask,1).view(batch_size,-1)
        return probability_label, probability_mask

#### 测试一下

In [None]:
def run_check_net():
    batch_size = 1
    C, H, W = 3, 384, 576

    input_img = np.random.uniform(-1, 1, (batch_size,C, H, W ))
    input_img = np.random.uniform(-1, 1, (batch_size,C, H, W ))
    input_img = torch.from_numpy(input_img).float().cuda()

    net = Net().cuda()
    net.eval()

    with torch.no_grad():
        probability_label, probability_mask = net(input_img)

    print('')
    print('input: ',input_img.shape)
    print('probability_label: ',probability_label.shape)
    print('probability_mask: ',probability_mask.shape)

In [None]:
run_check_net()

## 6.5 损失函数

此处采用了BCE Loss

In [None]:
def criterion(probability_label, probability_mask, truth_label, truth_mask):

    # label
    p = torch.clamp(probability_label, 1e-7, 1-1e-7) # 限制p的范围
    t = truth_label
    loss_label = - t*torch.log(p) - 2*(1-t)*torch.log(1-p)
    loss_label = loss_label.mean()

    # mask (此处似乎是多余的)
    w = probability_label.detach().view(-1,4,1,1)
    p = torch.clamp(probability_mask, 1e-7, 1-1e-7) # 限制p的范围
    t = truth_mask

    loss_mask = F.binary_cross_entropy(probability_mask, truth_mask, reduction='mean')

    return loss_label, loss_mask

## 6.6 评价指标

In [None]:
def metric (probability_label, probability_mask, truth_label, truth_mask, use_reject=True):
    '''
    probability_label: 各个标签的概率
    probability_mask: mask上各个点的概率
    truth_label: 实际的标签
    truth_mask: 实际的mask
    use_reject: mask算出有正像素点的样本是否要求label预测也为正（即同时满足label和mask为正），默认True
    '''

    threshold_label = 0.50 # 标签概率阈值（大于此为正样本）
    threshold_mask = 0.50 # mask各个点的概率阈值（大于此为正样本点）
    threshold_size = 1 # mask上图形的最小正像素数量阈值（小于此则定为空mask）

    with torch.no_grad():
        # label
        batch_size, num_class = truth_label.shape

        probability = probability_label.view(batch_size, num_class)
        truth = truth_label.view(batch_size, num_class)

        p = (probability > threshold_label).float() # 预测标签为正的样本，(batch, layer)
        t = (truth > 0.5).float() # 实际标签为正的样本
        num_tp = t.sum(0) # 正样本的数量
        num_tn = batch_size - num_tp # 负样本的数量

        tp = ((p + t) == 2).float()  # True positives
        tn = ((p + t) == 0).float()  # True negatives
        tn = tn.sum(0) # TN数量
        tp = tp.sum(0) # TP数量

        select = p # 暂存
        
        # mask
        batch_size, num_class, H, W = truth_mask.shape

        probability = probability_mask.view(batch_size, num_class, -1) # 每个图层展平，(batch, layer, h×w)
        truth = truth_mask.view(batch_size, num_class, -1)

        p = (probability > threshold_mask).float() # 预测为正的像素点
        t = (truth > 0.5).float() # 实际为正的像素点

        t_sum = t.sum(-1) # 每个图，各个layer，实际正像素点的总数 (batch, layer)
        p_sum = p.sum(-1) # 每个图，各个layer，预测正像素点的总数 (batch, layer)

        neg_index = (t_sum == 0).float() # 实际每个图，哪些layer是空的（一个正像素点都没有）
        pos_index = 1 - neg_index # 实际每个图，哪些layer是至少有一个正像素点的

        # get subset
        if use_reject:
            neg_index = neg_index * select # 预测标签为0，或实际mask为空，(batch, layer)
            pos_index = pos_index * select # 预测标签为1，且实际mask不为空

        num_dn = neg_index.sum(0) # 各个标签的负样本总数
        num_dp = pos_index.sum(0) # 各个标签的正样本总数

        # 这段后半段逻辑不太明白
        dn = (p_sum < threshold_size).float() # 每个图，各个layer，预测mask正像素点的总数小于数量阈值，判为空(batch, layer)
        dp = 2 * (p*t).sum(-1) / ((p+t).sum(-1) + EPS) # 通过mask计算dice
        dn = (dn * neg_index).sum(0) # mask各种原因判为空，或预测标签为0，或实际mask为空
        dp = (dp * pos_index).sum(0) # dice为正，预测标签为1，且实际mask不为空

        # 拼接所有metrics
        all_metrics = torch.cat([
            tn, tp, num_tn, num_tp,
            dn, dp, num_dn, num_dp,
        ])
        all_metrics = all_metrics.data.cpu().numpy().reshape(-1, num_class)
        tn, tp, num_tn, num_tp, dn, dp, num_dn, num_dp = all_metrics

    return tn, tp, num_tn, num_tp, dn, dp, num_dn, num_dp

## 6.7 Scheduler

In [None]:
class NullScheduler():
    '''
    固定学习率，默认0.01
    '''
    def __init__(self, lr=0.01 ):
        super(NullScheduler, self).__init__()
        self.lr = lr
        self.cycle = 0

    def __call__(self, time):
        return self.lr

    def __str__(self):
        string = 'NullScheduler\n' + 'lr=%0.5f '%(self.lr)
        return string

In [None]:
def adjust_learning_rate(optimizer, lr):
    '''
    用于设置optimizer的学习率
    '''
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [None]:
def get_learning_rate(optimizer):
    '''
    用于获取optimizer的学习率
    '''
    lr=[]
    for param_group in optimizer.param_groups:
        lr +=[ param_group['lr'] ]

    assert(len(lr)==1) # 假定只有一个param_group
    lr = lr[0]

    return lr

## 6.8 验证 (validation)

In [None]:
def do_valid(net, valid_loader, out_dir=None):
    '''
    net: 你的模型
    valid_loader: 验证集的DataLoader
    out_dir: 用于输出结果，但蛙神把相关代码注释掉了，所以并没有用
    '''
    valid_loss = np.zeros(2 + 4*NUM_CLASS, np.float32) # 2个loss (label, mask)，每个标签有tn, tp, dn, tp，共计18个
    valid_num  = np.zeros_like(valid_loss)

    for t, (input_img, truth_label, truth_mask, image_id) in enumerate(valid_loader):
        batch_size = len(input_img)

        net.eval()
        input_img = input_img.cuda()
        truth_label = truth_label.cuda()
        truth_mask  = truth_mask.cuda()

        with torch.no_grad():
            probability_label, probability_mask = data_parallel(net, input_img) # 此处似乎可以自动多卡运行
            probability_mask = resize_like(probability_mask, truth_mask, mode='bilinear') # 将mask缩放至图片相同尺寸

            loss_label, loss_mask = criterion(probability_label, probability_mask, truth_label, truth_mask)
            tn, tp, num_tn, num_tp, dn, dp, num_dn, num_dp = metric(probability_label, probability_mask, 
                                                                    truth_label, truth_mask)
        
        l = np.array([loss_label.item()*batch_size, loss_mask.item()*batch_size, *tn, *tp, *dn, *dp ])
        n = np.array([batch_size, batch_size, *num_tn, *num_tp, *num_dn, *num_dp])
        valid_loss += l
        valid_num += n

        #==========
        #dump results for debug
        if 0:
            image = tensor_to_image(input_img)
            truth_mask = tensor_to_mask(truth_mask)
            probability_mask = tensor_to_mask(probability_mask)
            truth_label = truth_label.data.cpu().numpy()
            probability_label = probability_label.data.cpu().numpy()
            
            """
            for b in range(batch_size):
                image_id = infor[b].image_id
                result = draw_predict_result(
                    image[b], truth_label[b], truth_mask[b], probability_label[b], probability_mask[b])

                image_show('result',result,resize=0.5)
                cv2.imwrite(out_dir +'/valid/%s.png'%image_id[:-4], result)
                cv2.waitKey(1)
                pass
            """
        #==========

        #print(valid_loss)
        print('\r %8d /%d'%(valid_num[0], len(valid_loader.dataset)), end='', flush=True)
        pass  #-- end of one data loader --
    
    assert(valid_num[0] == len(valid_loader.dataset))
    valid_loss = valid_loss / (valid_num+1e-8) # 算平均loss

    #------
    test_pos_ratio = np.array(
        [NUM_TEST_POS[c][0]/NUM_TEST for c in list(CLASSNAME_TO_CLASSNO.keys())]
    )
    test_neg_ratio = 1 - test_pos_ratio # 这是probe推断出的正负样本比例

    tn, tp, dn, dp = valid_loss[2:].reshape(-1, NUM_CLASS)
    
    # 以下这两个指标很迷，理解不能……
    kaggle = test_neg_ratio*tn + test_neg_ratio*(1 - tn)*dn + test_pos_ratio*tp*dp
    kaggle = kaggle.mean()

    kaggle1 = test_neg_ratio*tn + test_pos_ratio*tp
    kaggle1 = kaggle1.mean()

    return valid_loss, (kaggle, kaggle1)

# 7 设置Logger

保存一个Logger用于记录训练的情况，方便比对，也防止万一模型崩了还能留下点信息。

In [None]:
logger = logging.getLogger('torch_training')
logger.setLevel(logging.DEBUG)

logger_path = PROJ_FOLDER + 'log/'

if not os.path.exists(logger_path):
    os.makedirs(logger_path) # 建立log文件夹

fh = logging.FileHandler(logger_path + 'local_cv.log')  # 建立log文件
fh.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') # 设置log格式
fh.setFormatter(formatter)
logger.addHandler(fh)

In [None]:
IDENTIFIER   = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') # 时间戳

In [None]:
def time_to_str(t, mode='min'):
    '''
    时间转字符串
    '''
    if mode=='min':
        t  = int(t)/60
        hr = t//60
        min = t%60
        return '%2d hr %02d min'%(hr,min)

    elif mode=='sec':
        t   = int(t)
        min = t//60
        sec = t%60
        return '%2d min %02d sec'%(min,sec)

    else:
        raise NotImplementedError

# 8 训练模型 

#### 设置模型保存路径

In [None]:
checkpoint_dir = PROJ_FOLDER + 'models/' + 'resnet34-unet-fold_a0_attention/'

if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

炼丹本体

In [None]:
def run_train():
    
    # 设置初始checkpoint（可以拿训练到一半的模型来继续训练）
    initial_checkpoint = None
    #initial_checkpoint = '/resnet34-fpn1-fold_a2/checkpoint/00024000_model.pth'
    
    # 参数设置
    scheduler = NullScheduler(lr=0.01) # scheduler选择
    iter_accum = 4 # 梯度累积
    batch_size = 6

    # 初始输出
    for f in ['checkpoint', 'train', 'valid']: 
        os.makedirs(checkpoint_dir + f, exist_ok=True)

    logger.info('\n--- [START %s] %s\n\n' % (IDENTIFIER, '-' * 64))
    logger.info('\n')
    logger.info('\tSEED         = %u\n' % SEED)
    logger.info('\tPROJECT_PATH = %s\n' % PROJECT_PATH)
    logger.info('\tout_dir      = %s\n' % checkpoint_dir)
    logger.info('\n')


    # Dataset
    logger.info('** Dataset setting **\n')
    assert(len(train_dataset) >= batch_size)
    logger.info('batch_size = %d\n'%(batch_size))
    logger.info('\n')

    # Net 
    logger.info('** Net setting **\n')
    net = Net().cuda()
    logger.info('\tinitial_checkpoint = %s\n' % initial_checkpoint)

    if initial_checkpoint is not None: # 不为空就载入模型的参数
        state_dict = torch.load(initial_checkpoint, map_location=lambda storage, loc: storage)
        net.load_state_dict(state_dict,strict=True)
        
    else: # 为空则载入预训练参数
        net.load_pretrain(is_print=False)

    logger.info('net = %s\n'%(type(net)))
    logger.info('\n')

    # optimiser 
    optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=scheduler(0), 
                                momentum=0.0, weight_decay=0.0)
    #optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()),lr=scheduler(0))
    #optimizer = torch.optim.RMSprop(net.parameters(), lr =0.0005, alpha = 0.95)

    num_iters   = 3000*1000 # 每个epoch共 n/batch_size 个iteration，倒算一共有(num_iters * batch_size / n)个epoch
    iter_smooth = 50
    iter_log    = 100 # 刷新打印信息和log
    iter_valid  = 250 # 更新loss和metrics
    iter_save   = [0, num_iters-1] + list(range(0, num_iters, 500)) # 每500个iteration保存一次

    start_iter  = 0
    start_epoch = 0
    rate        = 0
    
    if initial_checkpoint is not None:
        initial_optimizer = initial_checkpoint.replace('_model.pth','_optimizer.pth')
        if os.path.exists(initial_optimizer):
            checkpoint  = torch.load(initial_optimizer)
            start_iter  = checkpoint['iter' ]
            start_epoch = checkpoint['epoch']
            #optimizer.load_state_dict(checkpoint['optimizer'])
        pass

    logger.info('optimizer\n  %s\n'%(optimizer))
    logger.info('scheduler\n  %s\n'%(scheduler))
    logger.info('\n')

    ########################### 开始炼丹 ###########################
    logger.info('** Start training here! **\n')
    logger.info('   batch_size=%d,  iter_accum=%d\n'%(batch_size,iter_accum))
    logger.info('                    |------------------------------------------------- VALID------------------------------------------------------|---------------------- TRAIN/BATCH -----------------\n')
    logger.info('rate    iter  epoch | kaggle      | loss                tn0,1,2,3 : tp0,1,2,3                     dn0,1,2,3 : dp0,1,2,3           | loss        dn : dp0,1,2,3           | time        \n')
    logger.info('---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n')
              #0.00000  28.0* 32.6 | 0.604,0.750 | 0.85,0.29  0.73 0.92 0.83 0.68: 0.54 0.70 0.60 0.84  0.00 0.00 0.00 0.00: 0.53 0.62 0.57 0.64 | 0.00,0.00  0.00: 0.00 0.00 0.00 0.00 | 0 hr 00 min

    def message(rate, iters, epoch, kaggle, valid_loss, train_loss, batch_loss, mode='print'):
        '''
        打印信息
        '''
        if mode==('print'):
            asterisk = ' '
            loss = batch_loss
        if mode==('log'):
            asterisk = '*' if iters in iter_save else ' '
            loss = train_loss

        text = \
            '%0.5f %5.1f%s %4.1f | '%(rate, iters/1000, asterisk, epoch,) +\
            '%0.3f,%0.3f | '%(*kaggle,) +\
            '%4.2f,%4.2f  %0.2f %0.2f %0.2f %0.2f: %0.2f %0.2f %0.2f %0.2f  %0.2f %0.2f %0.2f %0.2f: %0.2f %0.2f %0.2f %0.2f | '%(*valid_loss,) +\
            '%4.2f,%4.2f  %0.2f: %0.2f %0.2f %0.2f %0.2f |'%(*loss,) +\
            '%s' % (time_to_str((timer() - start_timer),'min'))

        return text

    #----
    kaggle = (0, 0)
    valid_loss = np.zeros(18,np.float32)
    train_loss = np.zeros( 7,np.float32)
    batch_loss = np.zeros_like(valid_loss)
    iters = 0
    i = 0

    start_timer = timer()
    while  iters < num_iters: # 只要num_iters没到就会一直训练下去
        sum_train_loss = np.zeros_like(train_loss)
        sum_train = np.zeros_like(train_loss)

        optimizer.zero_grad()
        for t, (input_img, truth_label, truth_mask, image_id) in enumerate(train_loader):
            batch_size = len(input_img)
            iters  = i + start_iter
            epoch = (iters-start_iter) * batch_size / len(train_dataset) + start_epoch

            if (iters % iter_valid == 0): # 每iter_valid轮计算一次valid_loss和kaggle
                valid_loss, kaggle = do_valid(net, valid_loader, checkpoint_dir) # 最后一个参数暂时没什么意义
                pass

            if (iters % iter_log == 0): # 每iter_log轮刷新打印信息和log
                print('\r', end='', flush=True)
                logger.info(message(rate, iters, epoch, kaggle, valid_loss, train_loss, batch_loss, mode='log'))
                logger.info('\n')

            if iters in iter_save: # iter_save轮保存模型
                torch.save({
                    #'optimizer': optimizer.state_dict(),
                    'iter': iters,
                    'epoch': epoch, # epoch 0, iters 0 不保存
                }, checkpoint_dir +'checkpoint/%08d_optimizer.pth'%(iters))
                
                if iters != start_iter: 
                    torch.save(net.state_dict(), checkpoint_dir +'checkpoint/%08d_model.pth'%(iters))
                    pass

            # 学习率调整 
            lr = scheduler(iters)
            if lr < 0: 
                break
            adjust_learning_rate(optimizer, lr)
            rate = get_learning_rate(optimizer)

            # one iteration update  -------------
            #net.set_mode('train',is_freeze_bn=True)

            net.train()
            input_img = input_img.cuda()
            truth_label = truth_label.cuda()
            truth_mask = truth_mask.cuda()

            probability_label, probability_mask = data_parallel(net, input_img)
            probability_mask = resize_like(probability_mask, truth_mask, mode='bilinear')

            loss_label, loss_mask = criterion(probability_label, probability_mask, truth_label, truth_mask)

            ((loss_mask )/iter_accum).backward() # 回传loss
            
            if (iters % iter_accum) == 0: # 梯度累积
                optimizer.step()
                optimizer.zero_grad()

            # 打印信息
            tn, tp, num_tn, num_tp, dn, dp, num_dn, num_dp = metric(probability_label, probability_mask, 
                                                                    truth_label, truth_mask, False) # 注意这里用的False，不是默认的True

            l = np.array([loss_label.item()*batch_size, loss_mask.item()*batch_size, dn.sum(), *dp])
            n = np.array([batch_size, batch_size, num_dn.sum(), *num_dp ])
            batch_loss = l / (n + 1e-8)
            sum_train_loss += l
            sum_train += n
            
            if iters%iter_smooth == 0: # 此处不太理解
                train_loss = sum_train_loss / (sum_train+EPS)
                sum_train_loss[...] = 0
                sum_train[...]      = 0

            print('\r', end='', flush=True)
            print(message(rate, iters, epoch, kaggle, valid_loss, train_loss, batch_loss, mode='print'), 
                  end='',flush=True)
            i += 1 # 下一个iteration

            # debug
            if 1:
                for di in range(3):
                    if (iters+di)%500==0:

                        image = tensor_to_image(input_img)
                        truth_mask = tensor_to_mask(truth_mask)
                        probability_mask = tensor_to_mask(probability_mask)
                        truth_label = truth_label.data.cpu().numpy()
                        probability_label = probability_label.data.cpu().numpy()
        pass
    pass

    logger.info('\n')

In [None]:
run_train()

#### 清理显存和内存

In [None]:
torch.cuda.empty_cache()
gc.collect()

# 9 预测

In [None]:
def do_evaluate_segmentation(net, test_dataset, augment=[], out_dir=None):
    test_num  = 0
    test_id   = []
    test_probability_label = [] # 8bit
    test_probability_mask  = [] # 8bit
    test_truth_label = [] # 8bit
    test_truth_mask  = [] # 8bit

    start_timer = timer()
    for t, (input_img, truth_label, truth_mask, image_id) in enumerate(test_loader):

        batch_size, C, H, W = input_img.shape
        input_img = input_img.cuda()

        with torch.no_grad():
            net.eval()
            
            # 以下针对测试集增扩，可以将测试集的图片也进行翻转或者旋转，然后将所有增扩的图片概率取平均
            # 比如原图1张，左右翻转1张，上下翻转1张，则分别预测后概率取三者平均
            num_augment = 0
            if 1: # 原图
                p_label, p_mask  =  data_parallel(net, input_img)
                p_mask = resize_like(p_mask, truth_mask, mode='bilinear')

                probability_mask = p_mask
                probability_label = p_label
                num_augment += 1

            if 'flip_lr' in augment:
                p_label, p_mask = data_parallel(net, torch.flip(input_img, dims=[3]))
                p_mask = resize_like(torch.flip(p_mask, dims=[3]), truth_mask, mode='bilinear')

                probability_mask += p_mask
                probability_label += p_label
                num_augment += 1

            if 'flip_ud' in augment:
                p_label, p_mask = data_parallel(net, torch.flip(input_img, dims=[2]))
                p_mask = resize_like(torch.flip(p_mask, dims=[2]), truth_mask, mode='bilinear')

                probability_mask  += p_mask
                probability_label += p_label
                num_augment+=1

            if 'rotate' in augment:
                p_label, p_mask   = data_parallel(net, torch.flip(torch.flip(input, dims=[2]), dims=[3]))
                p_mask = resize_like(torch.flip(torch.flip(p_mask,dims=[2]), dims=[3]), truth_mask, mode='bilinear')

                probability_mask += p_mask
                probability_label += p_label
                num_augment += 1
            
            probability_mask  = probability_mask / num_augment
            probability_label = probability_label / num_augment

        batch_size = len(input_img)
        truth_label = truth_label.data.cpu().numpy().astype(np.uint8)
        truth_mask = truth_mask.data.cpu().numpy().astype(np.uint8)
        probability_mask = (probability_mask.data.cpu().numpy()*255).astype(np.uint8)
        probability_label = (probability_label.data.cpu().numpy()*255).astype(np.uint8)

        test_id.extend([i for i in image_id])
        test_truth_label.append(truth_label)
        test_truth_mask.append(truth_mask)
        test_probability_label.append(probability_label)
        test_probability_mask.append(probability_mask)
        test_num += batch_size

        print('\r %4d / %4d  %s'%(
             test_num, len(test_loader.dataset), time_to_str((timer() - start_timer), 'min')
        ), end='', flush=True)

    assert(test_num == len(test_loader.dataset))
    print('')

    start_timer = timer()
    test_truth_label = np.concatenate(test_truth_label)
    test_truth_mask  = np.concatenate(test_truth_mask)
    test_probability_label = np.concatenate(test_probability_label)
    test_probability_mask = np.concatenate(test_probability_mask)
    print(time_to_str((timer() - start_timer), 'sec'))

    return test_id, test_truth_label, test_truth_mask, test_probability_label, test_probability_mask

#### 加载模型

In [None]:
initial_checkpoint = checkpoint_dir + 'checkpoint/' +'00001000_optimizer.pth'
net = Net().cuda()
net.load_state_dict(torch.load(initial_checkpoint, map_location=lambda storage, loc: storage), strict=True)

#### 预测

In [None]:
image_id, truth_label, truth_mask, probability_label, probability_mask = (
    do_evaluate_segmentation(net, test_dataset, augment=['null', 'flip_lr'])
)

In [None]:
# 各个阈值设定
threshold_label = [ 0.50, 0.50, 0.50, 0.50,]
threshold_mask  = [ 0.50, 0.50, 0.50, 0.50,]
threshold_size  = [ 1, 1, 1, 1,]

In [None]:
predict_label = probability_label > (np.array(threshold_label)*255).astype(np.uint8).reshape(1,4)
predict_mask  = probability_mask > (np.array(threshold_mask_pixel)*255).astype(np.uint8).reshape(1,4,1,1)

In [None]:
# mask转成rle
image_id_class_id = []
encoded_pixel = []
for b in range(len(image_id)):
    for c in range(NUM_CLASS):
        image_id_class_id.append(image_id[b]+'_%s'%(CLASSNO_TO_CLASSNAME[c]))

        if predict_label[b,c]==0:
            rle=''
        else:
            rle = run_length_encode(predict_mask[b,c])
        encoded_pixel.append(rle)

#### 写提交文件

In [None]:
csv_file = 'sub_'+'_'.join(time.ctime().split(' '))+'.csv'
submit = pd.DataFrame(zip(image_id_class_id, encoded_pixel), columns=['Image_Label', 'EncodedPixels'])
submit.to_csv(csv_file, index=False)