# 源码阅读

## 核心源码: 4大类U-net模型定义+预训练权重

In [56]:
from torch import nn
from torch.nn import functional as F
import torch
from torchvision import models
import torchvision

"""
This script has been taken (and modified) from :
https://github.com/ternaus/TernausNet

@ARTICLE{arXiv:1801.05746,
         author = {V. Iglovikov and A. Shvets},
          title = {TernausNet: U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentation},
        journal = {ArXiv e-prints},
         eprint = {1801.05746}, 
           year = 2018
        }
"""

# --------------------------------- basic building bloc ---------------------------------
def conv3x3(in_, out):
    return nn.Conv2d(in_, out, 3, padding=1)


class ConvRelu(nn.Module):
    def __init__(self, in_, out):
        super().__init__()
        self.conv = conv3x3(in_, out)
        self.activation = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.activation(x)
        return x


class NoOperation(nn.Module):
    def forward(self, x):
        return x


class DecoderBlock(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super().__init__()

        self.block = nn.Sequential(
            ConvRelu(in_channels, middle_channels),
            nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.block(x)


class UNet11(nn.Module):
    def __init__(self, num_classes=1, num_filters=32, pretrained=False):
        """
        :param num_classes:
        :param num_filters:
        :param pretrained:
            False - no pre-trained network is used
            True  - encoder is pre-trained with VGG11
        """
        super().__init__()
        self.pool = nn.MaxPool2d(2, 2)

        self.encoder = models.vgg11(pretrained=pretrained).features # !!!!!!! 获取预训练权重 并赋值给下面几层 !!!!!!!!

        self.relu = self.encoder[1]
        self.conv1 = self.encoder[0]
        self.conv2 = self.encoder[3]
        self.conv3s = self.encoder[6]
        self.conv3 = self.encoder[8]
        self.conv4s = self.encoder[11]
        self.conv4 = self.encoder[13]
        self.conv5s = self.encoder[16]
        self.conv5 = self.encoder[18]

        self.center = DecoderBlock(num_filters * 8 * 2, num_filters * 8 * 2, num_filters * 8)
        self.dec5 = DecoderBlock(num_filters * (16 + 8), num_filters * 8 * 2, num_filters * 8)
        self.dec4 = DecoderBlock(num_filters * (16 + 8), num_filters * 8 * 2, num_filters * 4)
        self.dec3 = DecoderBlock(num_filters * (8 + 4), num_filters * 4 * 2, num_filters * 2)
        self.dec2 = DecoderBlock(num_filters * (4 + 2), num_filters * 2 * 2, num_filters)
        self.dec1 = ConvRelu(num_filters * (2 + 1), num_filters)

        self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)

    def forward(self, x):
        conv1 = self.relu(self.conv1(x))
        conv2 = self.relu(self.conv2(self.pool(conv1)))
        conv3s = self.relu(self.conv3s(self.pool(conv2)))
        conv3 = self.relu(self.conv3(conv3s))
        conv4s = self.relu(self.conv4s(self.pool(conv3)))
        conv4 = self.relu(self.conv4(conv4s))
        conv5s = self.relu(self.conv5s(self.pool(conv4)))
        conv5 = self.relu(self.conv5(conv5s))

        center = self.center(self.pool(conv5))

        dec5 = self.dec5(torch.cat([center, conv5], 1))  # 取浅层特征和上采样特征进行合并
        dec4 = self.dec4(torch.cat([dec5, conv4], 1))
        dec3 = self.dec3(torch.cat([dec4, conv3], 1))
        dec2 = self.dec2(torch.cat([dec3, conv2], 1))
        dec1 = self.dec1(torch.cat([dec2, conv1], 1))
        return self.final(dec1)


def unet11(pretrained=False, **kwargs):
    """
    pretrained:
            False - no pre-trained network is used
            True  - encoder is pre-trained with VGG11
            carvana - all weights are pre-trained on
                Kaggle: Carvana dataset https://www.kaggle.com/c/carvana-image-masking-challenge
    """
    model = UNet11(pretrained=pretrained, **kwargs)

    if pretrained == 'carvana':
        state = torch.load('TernausNet.pt')
        model.load_state_dict(state['model'])
    return model


class DecoderBlockV2(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):
        super(DecoderBlockV2, self).__init__()
        self.in_channels = in_channels

        if is_deconv:
            """
                Paramaters for Deconvolution were chosen to avoid artifacts, following
                link https://distill.pub/2016/deconv-checkerboard/
            """

            self.block = nn.Sequential(
                ConvRelu(in_channels, middle_channels),
                nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,
                                   padding=1),
                nn.ReLU(inplace=True)
            )
        else:
            self.block = nn.Sequential(
                nn.Upsample(scale_factor=2, mode='bilinear'),
                ConvRelu(in_channels, middle_channels),
                ConvRelu(middle_channels, out_channels),
            )

    def forward(self, x):
        return self.block(x)


class AlbuNet(nn.Module):
    """
        UNet (https://arxiv.org/abs/1505.04597) with Resnet34(https://arxiv.org/abs/1512.03385) encoder

        Proposed by Alexander Buslaev: https://www.linkedin.com/in/al-buslaev/

        """

    def __init__(self, num_classes=1, num_filters=32, pretrained=False, is_deconv=False):
        """
        :param num_classes:
        :param num_filters:
        :param pretrained:
            False - no pre-trained network is used
            True  - encoder is pre-trained with resnet34
        :is_deconv:
            False: bilinear interpolation is used in decoder
            True: deconvolution is used in decoder
        """
        super().__init__()
        self.num_classes = num_classes

        self.pool = nn.MaxPool2d(2, 2)

        self.encoder = torchvision.models.resnet34(pretrained=pretrained)

        self.relu = nn.ReLU(inplace=True)

        self.conv1 = nn.Sequential(self.encoder.conv1,
                                   self.encoder.bn1,
                                   self.encoder.relu,
                                   self.pool)

        self.conv2 = self.encoder.layer1

        self.conv3 = self.encoder.layer2

        self.conv4 = self.encoder.layer3

        self.conv5 = self.encoder.layer4

        self.center = DecoderBlockV2(512, num_filters * 8 * 2, num_filters * 8, is_deconv)

        self.dec5 = DecoderBlockV2(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
        self.dec4 = DecoderBlockV2(256 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
        self.dec3 = DecoderBlockV2(128 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
        self.dec2 = DecoderBlockV2(64 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2, is_deconv)
        self.dec1 = DecoderBlockV2(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
        self.dec0 = ConvRelu(num_filters, num_filters)
        self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)

    def forward(self, x):
        conv1 = self.conv1(x)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)
        conv4 = self.conv4(conv3)
        conv5 = self.conv5(conv4)

        center = self.center(self.pool(conv5))

        dec5 = self.dec5(torch.cat([center, conv5], 1))

        dec4 = self.dec4(torch.cat([dec5, conv4], 1))
        dec3 = self.dec3(torch.cat([dec4, conv3], 1))
        dec2 = self.dec2(torch.cat([dec3, conv2], 1))
        dec1 = self.dec1(dec2)
        dec0 = self.dec0(dec1)

        return self.final(dec0)


class UNetVGG16(nn.Module):
    """PyTorch U-Net model using VGG16 encoder.

    UNet: https://arxiv.org/abs/1505.04597
    VGG: https://arxiv.org/abs/1409.1556
    Proposed by Vladimir Iglovikov and Alexey Shvets: https://github.com/ternaus/TernausNet

    Args:
            num_classes (int): Number of output classes.
            num_filters (int, optional): Number of filters in the last layer of decoder. Defaults to 32.
            dropout_2d (float, optional): Probability factor of dropout layer before output layer. Defaults to 0.2.
            pretrained (bool, optional):
                False - no pre-trained weights are being used.
                True  - VGG encoder is pre-trained on ImageNet.
                Defaults to False.
            is_deconv (bool, optional):
                False: bilinear interpolation is used in decoder.
                True: deconvolution is used in decoder.
                Defaults to False.

    """

    def __init__(self, num_classes=1, num_filters=32, dropout_2d=0.2, pretrained=False, is_deconv=False):
        super().__init__()
        self.num_classes = num_classes
        self.dropout_2d = dropout_2d

        self.pool = nn.MaxPool2d(2, 2)

        self.encoder = torchvision.models.vgg16(pretrained=pretrained).features

        self.relu = nn.ReLU(inplace=True)

        self.conv1 = nn.Sequential(self.encoder[0],
                                   self.relu,
                                   self.encoder[2],
                                   self.relu)

        self.conv2 = nn.Sequential(self.encoder[5],
                                   self.relu,
                                   self.encoder[7],
                                   self.relu)

        self.conv3 = nn.Sequential(self.encoder[10],
                                   self.relu,
                                   self.encoder[12],
                                   self.relu,
                                   self.encoder[14],
                                   self.relu)

        self.conv4 = nn.Sequential(self.encoder[17],
                                   self.relu,
                                   self.encoder[19],
                                   self.relu,
                                   self.encoder[21],
                                   self.relu)

        self.conv5 = nn.Sequential(self.encoder[24],
                                   self.relu,
                                   self.encoder[26],
                                   self.relu,
                                   self.encoder[28],
                                   self.relu)

        self.center = DecoderBlockV2(512, num_filters * 8 * 2, num_filters * 8, is_deconv)

        self.dec5 = DecoderBlockV2(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
        self.dec4 = DecoderBlockV2(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
        self.dec3 = DecoderBlockV2(256 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
        self.dec2 = DecoderBlockV2(128 + num_filters * 2, num_filters * 2 * 2, num_filters, is_deconv)
        self.dec1 = ConvRelu(64 + num_filters, num_filters)
        self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)

    def forward(self, x):
        conv1 = self.conv1(x)
        conv2 = self.conv2(self.pool(conv1))
        conv3 = self.conv3(self.pool(conv2))
        conv4 = self.conv4(self.pool(conv3))
        conv5 = self.conv5(self.pool(conv4))

        center = self.center(self.pool(conv5))

        dec5 = self.dec5(torch.cat([center, conv5], 1))

        dec4 = self.dec4(torch.cat([dec5, conv4], 1))
        dec3 = self.dec3(torch.cat([dec4, conv3], 1))
        dec2 = self.dec2(torch.cat([dec3, conv2], 1))
        dec1 = self.dec1(torch.cat([dec2, conv1], 1))

        return self.final(F.dropout2d(dec1, p=self.dropout_2d))


class UNetResNet(nn.Module):
    """PyTorch U-Net model using ResNet(34, 101 or 152) encoder.

    UNet: https://arxiv.org/abs/1505.04597
    ResNet: https://arxiv.org/abs/1512.03385
    Proposed by Alexander Buslaev: https://www.linkedin.com/in/al-buslaev/

    Args:
            encoder_depth (int): Depth of a ResNet encoder (34, 101 or 152).
            num_classes (int): Number of output classes.
            num_filters (int, optional): Number of filters in the last layer of decoder. Defaults to 32.
            dropout_2d (float, optional): Probability factor of dropout layer before output layer. Defaults to 0.2.
            pretrained (bool, optional):
                False - no pre-trained weights are being used.
                True  - ResNet encoder is pre-trained on ImageNet.
                Defaults to False.
            is_deconv (bool, optional):
                False: bilinear interpolation is used in decoder.
                True: deconvolution is used in decoder.
                Defaults to False.

    """

    def __init__(self, encoder_depth, num_classes, num_filters=32, dropout_2d=0.2,
                 pretrained=False, is_deconv=False):
        super().__init__()
        self.num_classes = num_classes
        self.dropout_2d = dropout_2d

        if encoder_depth == 34:
            self.encoder = torchvision.models.resnet34(pretrained=pretrained)
            bottom_channel_nr = 512
        elif encoder_depth == 101:
            self.encoder = torchvision.models.resnet101(pretrained=pretrained)
            bottom_channel_nr = 2048
        elif encoder_depth == 152:
            self.encoder = torchvision.models.resnet152(pretrained=pretrained)
            bottom_channel_nr = 2048
        else:
            raise NotImplementedError('only 34, 101, 152 version of Resnet are implemented')

        self.pool = nn.MaxPool2d(2, 2)

        self.relu = nn.ReLU(inplace=True)

        self.conv1 = nn.Sequential(self.encoder.conv1,
                                   self.encoder.bn1,
                                   self.encoder.relu,
                                   self.pool)

        self.conv2 = self.encoder.layer1

        self.conv3 = self.encoder.layer2

        self.conv4 = self.encoder.layer3

        self.conv5 = self.encoder.layer4

        self.center = DecoderBlockV2(bottom_channel_nr, num_filters * 8 * 2, num_filters * 8, is_deconv)
        self.dec5 = DecoderBlockV2(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
        self.dec4 = DecoderBlockV2(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8,
                                   is_deconv)
        self.dec3 = DecoderBlockV2(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2,
                                   is_deconv)
        self.dec2 = DecoderBlockV2(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2,
                                   is_deconv)
        self.dec1 = DecoderBlockV2(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
        self.dec0 = ConvRelu(num_filters, num_filters)
        self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)

    def forward(self, x):
        conv1 = self.conv1(x)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)
        conv4 = self.conv4(conv3)
        conv5 = self.conv5(conv4)

        pool = self.pool(conv5)
        center = self.center(pool)

        dec5 = self.dec5(torch.cat([center, conv5], 1))

        dec4 = self.dec4(torch.cat([dec5, conv4], 1))
        dec3 = self.dec3(torch.cat([dec4, conv3], 1))
        dec2 = self.dec2(torch.cat([dec3, conv2], 1))
        dec1 = self.dec1(dec2)
        dec0 = self.dec0(dec1)

        return self.final(F.dropout2d(dec0, p=self.dropout_2d))


### Step 1：处理原始数据并保存（只运行1次）

In [19]:
# ------------------ 处理原始数据并保存 ----------------------
import os
import numpy as np
import pandas as pd
import tqdm

class ArgParser(object):
    def __init__(self):
        self.train_images_dir="TrainData"
        self.test_images_dir="TestData"
        self.depths_filepath="depths.csv"
        self.metadata_filepath="Results/metadata.csv"
        
PARAMS = ArgParser()

def generate_metadata(train_images_dir, test_images_dir, depths_filepath):
    depths = pd.read_csv(depths_filepath)

    metadata = {}
    print(os.path.join(train_images_dir, 'images'))
    
    for filename in (os.listdir(os.path.join(train_images_dir, 'images'))):
        image_filepath = os.path.join(train_images_dir, 'images', filename)
        mask_filepath = os.path.join(train_images_dir, 'masks', filename)
        image_id = filename.split('.')[0]
        depth = depths[depths['id'] == image_id]['z'].values[0]

        metadata.setdefault('file_path_image', []).append(image_filepath)
        metadata.setdefault('file_path_mask', []).append(mask_filepath)
        metadata.setdefault('is_train', []).append(1)
        metadata.setdefault('id', []).append(image_id)
        metadata.setdefault('z', []).append(depth)

    for filename in (os.listdir(os.path.join(test_images_dir, 'images'))):
        image_filepath = os.path.join(test_images_dir, 'images', filename)
        image_id = filename.split('.')[0]
        depth = depths[depths['id'] == image_id]['z'].values[0]

        metadata.setdefault('file_path_image', []).append(image_filepath)
        metadata.setdefault('file_path_mask', []).append(None)
        metadata.setdefault('is_train', []).append(0)
        metadata.setdefault('id', []).append(image_id)
        metadata.setdefault('z', []).append(depth)
    
    print("here .... ")
    return pd.DataFrame(metadata)  # 直接将DataFrame中的数据写到csv文件中

def prepare_metadata():
    #LOGGER.info('creating metadata')
    meta = generate_metadata(train_images_dir=PARAMS.train_images_dir,   # 训练数据存放路径
                             test_images_dir=PARAMS.test_images_dir,     # 测试数据存放路径
                             depths_filepath=PARAMS.depths_filepath      # 深度数据存放路径
                             )
    meta.to_csv(PARAMS.metadata_filepath, index=None)

# 只运行一次就ok
# prepare_metadata()

### Step 2：训练数据阶段

In [55]:
from attrdict import AttrDict
from src.augmentation import intensity_seq, affine_seq, pad_to_fit_net, crop_seq, test_time_augmentation_transform, \
    test_time_augmentation_inverse_transform

class ArgParser(object):
    def __init__(self):
        # PreProcess
        self.train_images_dir="trainData"
        self.test_images_dir="testData"
        self.depths_filepath="depths.csv"
        self.metadata_filepath="Results/metadata.csv"
        
        # Training 
        self.experiment_dir="Results" 
        self.clone_experiment_dir_from='' 
        self.overwrite=1
        self.SEED = 1234
        self.shuffle=1
        self.n_cv_splits=10 # 10折交叉验证
        self.dev_mode_size=20
        self.num_workers=2
        self.num_threads=4
        
        # General parameters
        self.image_h = 128
        self.image_w = 128
        self.image_channels = 3
        
        self.epochs_nr=10000
        self.batch_size_train=64
        self.batch_size_inference=64
        self.lr=0.0001
        self.momentum=0.9
        self.gamma=0.95
        self.patience= 20
        self.validation_metric_name='iout'
        self.minimize_validation_metric= 0
        
        self.loader_mode="resize"
        self.pad_method="symmetric"
        
        self.image_source="disk"
        self.target_format = 'png'
        self.pin_memory = 1
        
        # U-Net parameters
        self.unet_output_channels=2  # --------------- 确实是2通道 --------------
        self.unet_activation='sigmoid'
        self.encoder="ResNet152"

        # U-Net from scratch parameters
        self.nr_unet_outputs=1
        self.n_filters=16
        self.conv_kernel= 3
        self.pool_kernel= 3
        self.pool_stride= 2
        self.repeat_blocks= 4

        # Loss
        self.dice_weight=5.0
        self.bce_weight=1.0
        
        # Regularization
        self.use_batch_norm=1
        self.l2_reg_conv=0.0001
        self.l2_reg_dense=0.0
        self.dropout_conv= 0.1
        self.dropout_dense= 0.0

        # Postprocessing
        self.threshold_masks= 0.5
        self.tta_aggregation_method= "mean"

        
PARAMS = ArgParser()

MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]

SEED = 1234

ID_COLUMNS = ['id']
X_COLUMNS = ['file_path_image']
Y_COLUMNS = ['file_path_mask']
DEPTH_COLUMN = ['z']
ORIGINAL_SIZE = (101, 101)

GLOBAL_CONFIG = {'exp_root': PARAMS.experiment_dir,
                 'num_workers': PARAMS.num_workers,
                 'num_classes': 2,
                 'img_H-W': (PARAMS.image_h, PARAMS.image_w),
                 'batch_size_train': PARAMS.batch_size_train,
                 'batch_size_inference': PARAMS.batch_size_inference,
                 'loader_mode': PARAMS.loader_mode,
                 }

TRAINING_CONFIG = {'epochs': PARAMS.epochs_nr,
                   'shuffle': True,
                   'batch_size': PARAMS.batch_size_train,
                   }

SOLUTION_CONFIG = AttrDict({
    'env': {'experiment_dir': PARAMS.experiment_dir},
    'execution': GLOBAL_CONFIG,
    'xy_splitter': {
        'unet': {'x_columns': X_COLUMNS,
                 'y_columns': Y_COLUMNS,
                 },
    },
    'reader': {
        'unet': {'x_columns': X_COLUMNS,
                 'y_columns': Y_COLUMNS,
                 },
    },
    'loaders': {'crop_and_pad': {'dataset_params': {'h': PARAMS.image_h,
                                                    'w': PARAMS.image_w,
                                                    'pad_method': PARAMS.pad_method,
                                                    'image_source': PARAMS.image_source,
                                                    'divisor': 64,
                                                    'target_format': PARAMS.target_format,
                                                    'MEAN': MEAN,
                                                    'STD': STD
                                                    },
                                 'loader_params': {'training': {'batch_size': PARAMS.batch_size_train,
                                                                'shuffle': True,
                                                                'num_workers': PARAMS.num_workers,
                                                                'pin_memory': PARAMS.pin_memory
                                                                },
                                                   'inference': {'batch_size': PARAMS.batch_size_inference,
                                                                 'shuffle': False,
                                                                 'num_workers': PARAMS.num_workers,
                                                                 'pin_memory': PARAMS.pin_memory
                                                                 },
                                                   },

                                 'augmentation_params': {'image_augment_train': intensity_seq,
                                                         'image_augment_with_target_train': crop_seq(
                                                             crop_size=(PARAMS.image_h, PARAMS.image_w)),
                                                         'image_augment_inference': pad_to_fit_net(64,
                                                                                                   PARAMS.pad_method),
                                                         'image_augment_with_target_inference': pad_to_fit_net(64,
                                                                                                               PARAMS.pad_method)
                                                         },
                                 },
                'crop_and_pad_tta': {'dataset_params': {'h': PARAMS.image_h,
                                                        'w': PARAMS.image_w,
                                                        'pad_method': PARAMS.pad_method,
                                                        'image_source': PARAMS.image_source,
                                                        'divisor': 64,
                                                        'target_format': PARAMS.target_format,
                                                        'MEAN': MEAN,
                                                        'STD': STD
                                                        },
                                     'loader_params': {'training': {'batch_size': PARAMS.batch_size_train,
                                                                    'shuffle': True,
                                                                    'num_workers': PARAMS.num_workers,
                                                                    'pin_memory': PARAMS.pin_memory
                                                                    },
                                                       'inference': {'batch_size': PARAMS.batch_size_inference,
                                                                     'shuffle': False,
                                                                     'num_workers': PARAMS.num_workers,
                                                                     'pin_memory': PARAMS.pin_memory
                                                                     },
                                                       },

                                     'augmentation_params': {
                                         'image_augment_inference': pad_to_fit_net(64, PARAMS.pad_method),
                                         'image_augment_with_target_inference': pad_to_fit_net(64,
                                                                                               PARAMS.pad_method),
                                         'tta_transform': test_time_augmentation_transform
                                     },
                                     },
                'resize': {'dataset_params': {'h': PARAMS.image_h,
                                              'w': PARAMS.image_w,
                                              'pad_method': PARAMS.pad_method,
                                              'image_source': PARAMS.image_source,
                                              'divisor': 64,
                                              'target_format': PARAMS.target_format,
                                              'MEAN': MEAN,
                                              'STD': STD
                                              },
                           'loader_params': {'training': {'batch_size': PARAMS.batch_size_train,
                                                          'shuffle': True,
                                                          'num_workers': PARAMS.num_workers,
                                                          'pin_memory': PARAMS.pin_memory
                                                          },
                                             'inference': {'batch_size': PARAMS.batch_size_inference,
                                                           'shuffle': False,
                                                           'num_workers': PARAMS.num_workers,
                                                           'pin_memory': PARAMS.pin_memory
                                                           },
                                             },

                           'augmentation_params': {'image_augment_train': intensity_seq,
                                                   'image_augment_with_target_train': affine_seq
                                                   },
                           },
                'resize_tta': {'dataset_params': {'h': PARAMS.image_h,
                                                  'w': PARAMS.image_w,
                                                  'pad_method': PARAMS.pad_method,
                                                  'image_source': PARAMS.image_source,
                                                  'divisor': 64,
                                                  'target_format': PARAMS.target_format,
                                                  'MEAN': MEAN,
                                                  'STD': STD
                                                  },
                               'loader_params': {'training': {'batch_size': PARAMS.batch_size_train,
                                                              'shuffle': True,
                                                              'num_workers': PARAMS.num_workers,
                                                              'pin_memory': PARAMS.pin_memory
                                                              },
                                                 'inference': {'batch_size': PARAMS.batch_size_inference,
                                                               'shuffle': False,
                                                               'num_workers': PARAMS.num_workers,
                                                               'pin_memory': PARAMS.pin_memory
                                                               },
                                                 },

                               'augmentation_params': {'tta_transform': test_time_augmentation_transform
                                                       },
                               },
                },
    'model': {
        'unet': {
            'architecture_config': {'model_params': {'n_filters': PARAMS.n_filters,
                                                     'conv_kernel': PARAMS.conv_kernel,
                                                     'pool_kernel': PARAMS.pool_kernel,
                                                     'pool_stride': PARAMS.pool_stride,
                                                     'repeat_blocks': PARAMS.repeat_blocks,
                                                     'batch_norm': PARAMS.use_batch_norm,
                                                     'dropout': PARAMS.dropout_conv,
                                                     'in_channels': PARAMS.image_channels,
                                                     'out_channels': PARAMS.unet_output_channels,
                                                     'nr_outputs': PARAMS.nr_unet_outputs,
                                                     'encoder': PARAMS.encoder,
                                                     'activation': PARAMS.unet_activation,
                                                     'dice_weight': PARAMS.dice_weight,
                                                     'bce_weight': PARAMS.bce_weight,
                                                     },
                                    'optimizer_params': {'lr': PARAMS.lr,
                                                         },
                                    'regularizer_params': {'regularize': True,
                                                           'weight_decay_conv2d': PARAMS.l2_reg_conv,
                                                           },
                                    'weights_init': {'function': 'xavier',
                                                     },
                                    },
            'training_config': TRAINING_CONFIG,
            'callbacks_config': {'model_checkpoint': {
                'filepath': os.path.join(GLOBAL_CONFIG['exp_root'], 'checkpoints', 'unet', 'best.torch'),
                'epoch_every': 1,
                'metric_name': PARAMS.validation_metric_name,
                'minimize': PARAMS.minimize_validation_metric},
                'lr_scheduler': {'gamma': PARAMS.gamma,
                                 'epoch_every': 1},
                'training_monitor': {'batch_every': 0,
                                     'epoch_every': 1},
                'experiment_timing': {'batch_every': 0,
                                      'epoch_every': 1},
                'validation_monitor': {'epoch_every': 1,
                                       'data_dir': PARAMS.train_images_dir,
                                       'loader_mode': PARAMS.loader_mode},
                'neptune_monitor': {'model_name': 'unet',
                                    'image_nr': 4,
                                    'image_resize': 0.2},
                'early_stopping': {'patience': PARAMS.patience,
                                   'metric_name': PARAMS.validation_metric_name,
                                   'minimize': PARAMS.minimize_validation_metric},
            }
        },
    },
    'tta_generator': {'flip_ud': False,
                      'flip_lr': True,
                      'rotation': False,
                      'color_shift_runs': 4},
    'tta_aggregator': {'tta_inverse_transform': test_time_augmentation_inverse_transform,
                       'method': PARAMS.tta_aggregation_method,
                       'nthreads': PARAMS.num_threads
                       },
    'thresholder': {'threshold_masks': PARAMS.threshold_masks,
                    },
})

print("Done!")
type(SOLUTION_CONFIG)

Done!


attrdict.dictionary.AttrDict

In [None]:
PRETRAINED_NETWORKS = {'VGG11': {'model': UNet11,
                                 'model_config': {'pretrained': True},
                                 'init_weights': False},
                       'VGG16': {'model': UNetVGG16,
                                 'model_config': {'pretrained': True,
                                                  'dropout_2d': 0.0, 'is_deconv': True},
                                 'init_weights': False},
                       'AlbuNet': {'model': AlbuNet,
                                   'model_config': {'pretrained': True, 'is_deconv': True},
                                   'init_weights': False},
                       'ResNet34': {'model': UNetResNet,
                                    'model_config': {'encoder_depth': 34,
                                                     'num_filters': 32, 'dropout_2d': 0.0,
                                                     'pretrained': True, 'is_deconv': True,
                                                     },
                                    'init_weights': False},
                       'ResNet101': {'model': UNetResNet,
                                     'model_config': {'encoder_depth': 101,
                                                      'num_filters': 32, 'dropout_2d': 0.0,
                                                      'pretrained': True, 'is_deconv': True,
                                                      },
                                     'init_weights': False},
                       'ResNet152': {'model': UNetResNet,
                                     'model_config': {'encoder_depth': 152,
                                                      'num_filters': 32, 'dropout_2d': 0.2,
                                                      'pretrained': True, 'is_deconv': True,
                                                      },
                                     'init_weights': False}
                       }

In [4]:
#####################################
#    src/pipeline_manager.py        #
#####################################

import os
import shutil

import numpy as np
import pandas as pd
from sklearn.externals import joblib

from .metrics import intersection_over_union, intersection_over_union_thresholds
from . import pipeline_config as cfg

from .utils import NeptuneContext, init_logger, read_masks, create_submission, \
    generate_metadata, set_seed, KFoldBySortedValue, clean_memory

class ArgParser(object):
    def __init__(self):
        # PreProcess
        self.train_images_dir="trainData"
        self.test_images_dir="testData"
        self.depths_filepath="depths.csv"
        self.metadata_filepath="Results/metadata.csv"
        
        # Training 
        self.experiment_dir="Results" 
        self.clone_experiment_dir_from='' 
        self.overwrite=1
        self.SEED = 1234
        self.shuffle=1
        self.n_cv_splits=10 # 10折交叉验证
        self.dev_mode_size=20
        
        
LOGGER = init_logger()
CTX = NeptuneContext()
PARAMS = ArgParser()
set_seed(cfg.SEED)


class PipelineManager:
    def prepare_metadata(self):
        prepare_metadata()

    def train(self, pipeline_name, dev_mode):
        train(pipeline_name, dev_mode)

    def evaluate(self, pipeline_name, dev_mode):
        evaluate(pipeline_name, dev_mode)

    def predict(self, pipeline_name, dev_mode, submit_predictions):
        predict(pipeline_name, dev_mode, submit_predictions)

    def train_evaluate_cv(self, pipeline_name, dev_mode):
        train_evaluate_cv(pipeline_name, dev_mode)


def train(pipeline_name, dev_mode):
    #LOGGER.info('training')
    print("training ...")
    
    if PARAMS.clone_experiment_dir_from != '':
        if os.path.exists(PARAMS.experiment_dir):
            shutil.rmtree(PARAMS.experiment_dir)  # 递归删除该目录
        shutil.copytree(PARAMS.clone_experiment_dir_from, PARAMS.experiment_dir)  # 拷贝文件 source -->target
    
    # 判断是否覆盖结果
    if bool(PARAMS.overwrite) and os.path.isdir(PARAMS.experiment_dir):
        shutil.rmtree(PARAMS.experiment_dir)
    
    # 读入metadata
    meta = pd.read_csv(PARAMS.metadata_filepath) 
    # 索引所有训练数据
    meta_train = meta[meta['is_train'] == 1]

    cv = KFoldBySortedValue(n_splits=PARAMS.n_cv_splits, shuffle=PARAMS.shuffle, random_state=PARAMS.SEED)
    
    for train_idx, valid_idx in cv.split(meta_train[['z']].values.reshape(-1)):
        break

    meta_train_split, meta_valid_split = meta_train.iloc[train_idx], meta_train.iloc[valid_idx]

    if dev_mode:
        meta_train_split = meta_train_split.sample(PARAMS.dev_mode_size, random_state=PARAMS.SEED)
        meta_valid_split = meta_valid_split.sample(int(PARAMS.dev_mode_size / 2), random_state=PARAMS.SEED)

    
    data = {'input': {'meta': meta_train_split},
            'callback_input': {'meta_valid': meta_valid_split} 
           }
    PIPELINES = {'unet': unet, 'unet_tta': unet_tta } # 先确实选择好选用哪种模型
    
    
    pipeline = PIPELINES[pipeline_name](config=cfg.SOLUTION_CONFIG, train_mode=True)
    pipeline.clean_cache()
    pipeline.fit_transform(data)
    pipeline.clean_cache()


def evaluate(pipeline_name, dev_mode):
    LOGGER.info('evaluating')

    if PARAMS.clone_experiment_dir_from != '':
        if os.path.exists(PARAMS.experiment_dir):
            shutil.rmtree(PARAMS.experiment_dir)
        shutil.copytree(PARAMS.clone_experiment_dir_from, PARAMS.experiment_dir)

    meta = pd.read_csv(PARAMS.metadata_filepath)
    meta_train = meta[meta['is_train'] == 1]

    cv = KFoldBySortedValue(n_splits=PARAMS.n_cv_splits, shuffle=PARAMS.shuffle, random_state=cfg.SEED)
    for train_idx, valid_idx in cv.split(meta_train[cfg.DEPTH_COLUMN].values.reshape(-1)):
        break

    meta_valid_split = meta_train.iloc[valid_idx]
    y_true = read_masks(meta_valid_split[cfg.Y_COLUMNS[0]].values)

    if dev_mode:
        meta_valid_split = meta_valid_split.sample(PARAMS.dev_mode_size, random_state=cfg.SEED)

    
    data = {'input': {'meta': meta_valid_split,},
            
            'callback_input': {'meta_valid': None}
            }
    
    
    
    pipeline = PIPELINES[pipeline_name](config=cfg.SOLUTION_CONFIG, train_mode=False)
    pipeline.clean_cache()
    output = pipeline.transform(data)
    pipeline.clean_cache()
    y_pred = output['y_pred']

    LOGGER.info('Calculating IOU and IOUT Scores')
    iou_score = intersection_over_union(y_true, y_pred)
    LOGGER.info('IOU score on validation is {}'.format(iou_score))
    CTX.channel_send('IOU', 0, iou_score)

    iout_score = intersection_over_union_thresholds(y_true, y_pred)
    LOGGER.info('IOUT score on validation is {}'.format(iout_score))
    CTX.channel_send('IOUT', 0, iout_score)

    results_filepath = os.path.join(PARAMS.experiment_dir, 'validation_results.pkl')
    LOGGER.info('Saving validation results to {}'.format(results_filepath))
    joblib.dump((meta_valid_split, y_true, y_pred), results_filepath)


def make_submission(submission_filepath):
    LOGGER.info('Making Kaggle submit...')
    os.system('kaggle competitions submit -c tgs-salt-identification-challenge -f {} -m {}'.format(submission_filepath,
                                                                                                   PARAMS.kaggle_message))
    LOGGER.info('Kaggle submit completed')


def predict(pipeline_name, submit_predictions, dev_mode):
    LOGGER.info('predicting')

    if PARAMS.clone_experiment_dir_from != '':
        if os.path.exists(PARAMS.experiment_dir):
            shutil.rmtree(PARAMS.experiment_dir)
        shutil.copytree(PARAMS.clone_experiment_dir_from, PARAMS.experiment_dir)

    meta = pd.read_csv(PARAMS.metadata_filepath)
    meta_test = meta[meta['is_train'] == 0]

    if dev_mode:
        meta_test = meta_test.sample(PARAMS.dev_mode_size, random_state=cfg.SEED)

    data = {'input': {'meta': meta_test,
                      },
            'callback_input': {'meta_valid': None
                               }
            }

    pipeline = PIPELINES[pipeline_name](config=cfg.SOLUTION_CONFIG, train_mode=False)
    pipeline.clean_cache()
    output = pipeline.transform(data)
    pipeline.clean_cache()
    y_pred = output['y_pred']

    submission = create_submission(meta_test, y_pred)

    submission_filepath = os.path.join(PARAMS.experiment_dir, 'submission.csv')

    submission.to_csv(submission_filepath, index=None, encoding='utf-8')
    LOGGER.info('submission saved to {}'.format(submission_filepath))
    LOGGER.info('submission head \n\n{}'.format(submission.head()))

    if submit_predictions:
        make_submission(submission_filepath)


def train_evaluate_cv(pipeline_name, dev_mode):
    LOGGER.info('training')

    if PARAMS.clone_experiment_dir_from != '':
        if os.path.exists(PARAMS.experiment_dir):
            shutil.rmtree(PARAMS.experiment_dir)
        shutil.copytree(PARAMS.clone_experiment_dir_from, PARAMS.experiment_dir)

    if bool(PARAMS.overwrite) and os.path.isdir(PARAMS.experiment_dir):
        shutil.rmtree(PARAMS.experiment_dir)

    if dev_mode:
        meta = pd.read_csv(PARAMS.metadata_filepath, nrows=PARAMS.dev_mode_size)
    else:
        meta = pd.read_csv(PARAMS.metadata_filepath)
    meta_train = meta[meta['is_train'] == 1]

    cv = KFoldBySortedValue(n_splits=PARAMS.n_cv_splits, shuffle=PARAMS.shuffle, random_state=cfg.SEED)

    fold_iou, fold_iout = [], []
    for fold_id, (train_idx, valid_idx) in enumerate(cv.split(meta_train[cfg.DEPTH_COLUMN].values.reshape(-1))):
        train_data_split, valid_data_split = meta_train.iloc[train_idx], meta_train.iloc[valid_idx]

        LOGGER.info('Started fold {}'.format(fold_id))
        iou, iout, _, _ = _fold_fit_evaluate_loop(train_data_split,
                                                  valid_data_split,
                                                  fold_id,
                                                  pipeline_name)

        LOGGER.info('Fold {} IOU {}'.format(fold_id, iou))
        CTX.channel_send('Fold {} IOU'.format(fold_id), 0, iou)
        LOGGER.info('Fold {} IOUT {}'.format(fold_id, iout))
        CTX.channel_send('Fold {} IOUT'.format(fold_id), 0, iout)

        fold_iou.append(iou)
        fold_iout.append(iout)

    iou_mean, iou_std = np.mean(fold_iou), np.std(fold_iou)
    iout_mean, iout_std = np.mean(fold_iout), np.std(fold_iout)

    LOGGER.info('IOU mean {}, IOU std {}'.format(iou_mean, iou_std))
    CTX.channel_send('IOU', 0, iou_mean)
    CTX.channel_send('IOU STD', 0, iou_std)

    LOGGER.info('IOUT mean {}, IOUT std {}'.format(iout_mean, iout_std))
    CTX.channel_send('IOUT', 0, iout_mean)
    CTX.channel_send('IOUT STD', 0, iout_std)


def _fold_fit_evaluate_loop(train_data_split, valid_data_split, fold_id, pipeline_name):
    train_pipe_input = {'input': {'meta': train_data_split
                                  },
                        'callback_input': {'meta_valid': valid_data_split
                                           }
                        }

    valid_pipe_input = {'input': {'meta': valid_data_split
                                  },
                        'callback_input': {'meta_valid': None
                                           }
                        }

    config = cfg.SOLUTION_CONFIG
    model_name = pipeline_name
    config['model'][model_name]['callbacks_config']['neptune_monitor']['model_name'] = '{}_{}'.format(model_name,
                                                                                                      fold_id)
    pipeline = PIPELINES[pipeline_name](config=config, train_mode=True, suffix='_fold_{}'.format(fold_id))
    LOGGER.info('Start pipeline fit and transform on train')
    pipeline.clean_cache()
    pipeline.fit_transform(train_pipe_input)
    pipeline.clean_cache()

    pipeline = PIPELINES[pipeline_name](config=config, train_mode=False, suffix='_fold_{}'.format(fold_id))
    LOGGER.info('Start pipeline transform on valid')
    pipeline.clean_cache()
    output_valid = pipeline.transform(valid_pipe_input)
    pipeline.clean_cache()

    y_pred = output_valid['y_pred']
    y_true = read_masks(valid_data_split[cfg.Y_COLUMNS[0]].values)

    iou = intersection_over_union(y_true, y_pred)
    iout = intersection_over_union_thresholds(y_true, y_pred)

    return iou, iout, y_pred, pipeline


ModuleNotFoundError: No module named '__main__.pipelines'; '__main__' is not a package

In [None]:
#####################################
#            src/utils.py           #
#####################################

import logging
import os
import pathlib
import random
import sys
import time
from itertools import chain
from collections import Iterable

from deepsense import neptune
import numpy as np
import pandas as pd
import torch
from PIL import Image
import matplotlib.pyplot as plt
from attrdict import AttrDict
from tqdm import tqdm
from pycocotools import mask as cocomask
from sklearn.model_selection import BaseCrossValidator
from steppy.base import BaseTransformer
import yaml
from imgaug import augmenters as iaa
import imgaug as ia

NEPTUNE_CONFIG_PATH = str(pathlib.Path(__file__).resolve().parents[1] / 'configs' / 'neptune.yaml')


# Alex Martelli's 'Borg'
# http://python-3-patterns-idioms-test.readthedocs.io/en/latest/Singleton.html
class _Borg:
    _shared_state = {}

    def __init__(self):
        self.__dict__ = self._shared_state


class NeptuneContext(_Borg):
    def __init__(self, fallback_file=NEPTUNE_CONFIG_PATH):
        _Borg.__init__(self)

        self.ctx = neptune.Context()
        self.fallback_file = fallback_file
        self.params = self._read_params()
        self.numeric_channel = neptune.ChannelType.NUMERIC
        self.image_channel = neptune.ChannelType.IMAGE
        self.text_channel = neptune.ChannelType.TEXT

    def channel_send(self, *args, **kwargs):
        self.ctx.channel_send(*args, **kwargs)

    def _read_params(self):
        if self.ctx.params.__class__.__name__ == 'OfflineContextParams':
            params = self._read_yaml().parameters
        else:
            params = self.ctx.params
        return params

    def _read_yaml(self):
        with open(self.fallback_file) as f:
            config = yaml.load(f)
        return AttrDict(config)


def init_logger():
    logger = logging.getLogger('salt-detection')
    logger.setLevel(logging.INFO)
    message_format = logging.Formatter(fmt='%(asctime)s %(name)s >>> %(message)s',
                                       datefmt='%Y-%m-%d %H-%M-%S')

    # console handler for validation info
    ch_va = logging.StreamHandler(sys.stdout)
    ch_va.setLevel(logging.INFO)

    ch_va.setFormatter(fmt=message_format)

    # add the handlers to the logger
    logger.addHandler(ch_va)

    return logger


def get_logger():
    return logging.getLogger('salt-detection')


def decompose(labeled):
    nr_true = labeled.max()
    masks = []
    for i in range(1, nr_true + 1):
        msk = labeled.copy()
        msk[msk != i] = 0.
        msk[msk == i] = 255.
        masks.append(msk)

    if not masks:
        return [labeled]
    else:
        return masks


def create_submission(meta, predictions):
    output = []
    for image_id, mask in zip(meta['id'].values, predictions):
        rle_encoded = ' '.join(str(rle) for rle in run_length_encoding(mask))
        output.append([image_id, rle_encoded])

    submission = pd.DataFrame(output, columns=['id', 'rle_mask']).astype(str)
    return submission


def encode_rle(predictions):
    return [run_length_encoding(mask) for mask in predictions]


def read_masks(masks_filepaths):
    masks = []
    for mask_filepath in tqdm(masks_filepaths):
        mask = Image.open(mask_filepath)
        mask = np.asarray(mask.convert('L').point(lambda x: 0 if x < 128 else 1)).astype(np.uint8)
        masks.append(mask)
    return masks


def read_images(filepaths):
    images = []
    for filepath in filepaths:
        image = np.array(Image.open(filepath))
        images.append(image)
    return images


def run_length_encoding(x):
    # https://www.kaggle.com/c/data-science-bowl-2018/discussion/48561#
    bs = np.where(x.T.flatten())[0]

    rle = []
    prev = -2
    for b in bs:
        if (b > prev + 1): rle.extend((b + 1, 0))
        rle[-1] += 1
        prev = b

    if len(rle) != 0 and rle[-1] + rle[-2] == x.size:
        rle[-2] = rle[-2] - 1

    return rle


def run_length_decoding(mask_rle, shape):
    """
    Based on https://www.kaggle.com/msl23518/visualize-the-stage1-test-solution and modified
    Args:
        mask_rle: run-length as string formatted (start length)
        shape: (height, width) of array to return

    Returns:
        numpy array, 1 - mask, 0 - background

    """
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[1] * shape[0], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape((shape[1], shape[0])).T


def generate_metadata(train_images_dir, test_images_dir, depths_filepath):
    
    depths = pd.read_csv(depths_filepath) # pandas读入深度数据文件

    metadata = {} # 字典
    for filename in tqdm(os.listdir(os.path.join(train_images_dir, 'images'))):
        image_filepath = os.path.join(train_images_dir, 'images', filename)  # 生成每张图片的文件路径
        mask_filepath = os.path.join(train_images_dir, 'masks', filename)    # 生成每张图片的标注图片的文件路径
        image_id = filename.split('.')[0]  # 截取 xxx.png 的 xxx 作为图片ID
        depth = depths[depths['id'] == image_id]['z'].values[0]  # 提取该图片的深度数值(标签为 'z'， 单位：英尺)
        
        # 思路不错    先用个大字典存储原始训练和测试数据信息    根据is_train标志区分每张图片是用于训练还是测试
        metadata.setdefault('file_path_image', []).append(image_filepath)   # 初始化为空列表 然后添加数据 高效技巧！！！！！！！！！！！
        metadata.setdefault('file_path_mask', []).append(mask_filepath)
        metadata.setdefault('is_train', []).append(1)
        metadata.setdefault('id', []).append(image_id)
        metadata.setdefault('z', []).append(depth)

    for filename in tqdm(os.listdir(os.path.join(test_images_dir, 'images'))):
        image_filepath = os.path.join(test_images_dir, 'images', filename)
        image_id = filename.split('.')[0]
        depth = depths[depths['id'] == image_id]['z'].values[0]

        metadata.setdefault('file_path_image', []).append(image_filepath)
        metadata.setdefault('file_path_mask', []).append(None)
        metadata.setdefault('is_train', []).append(0)
        metadata.setdefault('id', []).append(image_id)
        metadata.setdefault('z', []).append(depth)

    return pd.DataFrame(metadata)


def sigmoid(x):
    return 1. / (1 + np.exp(-x))


def softmax(X, theta=1.0, axis=None):
    """
    https://nolanbconaway.github.io/blog/2017/softmax-numpy
    Compute the softmax of each element along an axis of X.

    Parameters
    ----------
    X: ND-Array. Probably should be floats.
    theta (optional): float parameter, used as a multiplier
        prior to exponentiation. Default = 1.0
    axis (optional): axis to compute values along. Default is the
        first non-singleton axis.

    Returns an array the same size as X. The result will sum to 1
    along the specified axis.
    """

    # make X at least 2d
    y = np.atleast_2d(X)

    # find axis
    if axis is None:
        axis = next(j[0] for j in enumerate(y.shape) if j[1] > 1)

    # multiply y against the theta parameter,
    y = y * float(theta)

    # subtract the max for numerical stability
    y = y - np.expand_dims(np.max(y, axis=axis), axis)

    # exponentiate y
    y = np.exp(y)

    # take the sum along the specified axis
    ax_sum = np.expand_dims(np.sum(y, axis=axis), axis)

    # finally: divide elementwise
    p = y / ax_sum

    # flatten if X was 1D
    if len(X.shape) == 1: p = p.flatten()

    return p


def from_pil(*images):
    images = [np.array(image) for image in images]
    if len(images) == 1:
        return images[0]
    else:
        return images


def to_pil(*images):
    images = [Image.fromarray((image).astype(np.uint8)) for image in images]
    if len(images) == 1:
        return images[0]
    else:
        return images


def make_apply_transformer(func, output_name='output', apply_on=None):
    class StaticApplyTransformer(BaseTransformer):
        def transform(self, *args, **kwargs):
            self.check_input(*args, **kwargs)

            if not apply_on:
                iterator = zip(*args, *kwargs.values())
            else:
                iterator = zip(*args, *[kwargs[key] for key in apply_on])

            output = []
            for func_args in tqdm(iterator, total=self.get_arg_length(*args, **kwargs)):
                output.append(func(*func_args))
            return {output_name: output}

        @staticmethod
        def check_input(*args, **kwargs):
            if len(args) and len(kwargs) == 0:
                raise Exception('Input must not be empty')

            arg_length = None
            for arg in chain(args, kwargs.values()):
                if not isinstance(arg, Iterable):
                    raise Exception('All inputs must be iterable')
                arg_length_loc = None
                try:
                    arg_length_loc = len(arg)
                except:
                    pass
                if arg_length_loc is not None:
                    if arg_length is None:
                        arg_length = arg_length_loc
                    elif arg_length_loc != arg_length:
                        raise Exception('All inputs must be the same length')

        @staticmethod
        def get_arg_length(*args, **kwargs):
            arg_length = None
            for arg in chain(args, kwargs.values()):
                if arg_length is None:
                    try:
                        arg_length = len(arg)
                    except:
                        pass
                if arg_length is not None:
                    return arg_length

    return StaticApplyTransformer()


def rle_from_binary(prediction):
    prediction = np.asfortranarray(prediction)
    return cocomask.encode(prediction)


def binary_from_rle(rle):
    return cocomask.decode(rle)


def get_segmentations(labeled):
    nr_true = labeled.max()
    segmentations = []
    for i in range(1, nr_true + 1):
        msk = labeled == i
        segmentation = rle_from_binary(msk.astype('uint8'))
        segmentation['counts'] = segmentation['counts'].decode("UTF-8")
        segmentations.append(segmentation)
    return segmentations


def get_crop_pad_sequence(vertical, horizontal):
    top = int(vertical / 2)
    bottom = vertical - top
    right = int(horizontal / 2)
    left = horizontal - right
    return (top, right, bottom, left)


def get_list_of_image_predictions(batch_predictions):
    image_predictions = []
    for batch_pred in batch_predictions:
        image_predictions.extend(list(batch_pred))
    return image_predictions


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


class ImgAug:
    def __init__(self, augmenters):
        if not isinstance(augmenters, list):
            augmenters = [augmenters]
        self.augmenters = augmenters
        self.seq_det = None

    def _pre_call_hook(self):
        seq = iaa.Sequential(self.augmenters)
        seq = reseed(seq, deterministic=True)
        self.seq_det = seq

    def transform(self, *images):
        images = [self.seq_det.augment_image(image) for image in images]
        if len(images) == 1:
            return images[0]
        else:
            return images

    def __call__(self, *args):
        self._pre_call_hook()
        return self.transform(*args)


def get_seed():
    seed = int(time.time()) + int(os.getpid())
    return seed


def reseed(augmenter, deterministic=True):
    augmenter.random_state = ia.new_random_state(get_seed())
    if deterministic:
        augmenter.deterministic = True

    for lists in augmenter.get_children_lists():
        for aug in lists:
            aug = reseed(aug, deterministic=True)
    return augmenter


# BaseCrossValidator 来自scikit learn
class KFoldBySortedValue(BaseCrossValidator): 
    def __init__(self, n_splits=3, shuffle=False, random_state=None):
        self.n_splits = n_splits
        self.shuffle = shuffle
        self.random_state = random_state

    def _iter_test_indices(self, X, y=None, groups=None):
        n_samples = X.shape[0]
        indices = np.arange(n_samples)

        sorted_idx_vals = sorted(zip(indices, X), key=lambda x: x[1])
        indices = [idx for idx, val in sorted_idx_vals]

        for split_start in range(self.n_splits):
            split_indeces = indices[split_start::self.n_splits]
            yield split_indeces

    def get_n_splits(self, X=None, y=None, groups=None):
        return self.n_splits


def plot_list(images=[], labels=[]):
    n_img = len(images)
    n_lab = len(labels)
    n = n_lab + n_img
    fig, axs = plt.subplots(1, n, figsize=(16, 12))
    for i, image in enumerate(images):
        axs[i].imshow(image)
        axs[i].set_xticks([])
        axs[i].set_yticks([])
    for j, label in enumerate(labels):
        axs[n_img + j].imshow(label, cmap='nipy_spectral')
        axs[n_img + j].set_xticks([])
        axs[n_img + j].set_yticks([])
    plt.show()


def clean_memory():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()


In [3]:
import click
from src.pipeline_manager import PipelineManager

pipeline_manager = PipelineManager()


@click.group()
def main():
    pass


@main.command()
def prepare_metadata():
    pipeline_manager.prepare_metadata()


@main.command()
@click.option('-p', '--pipeline_name', help='pipeline to be trained', required=True)
@click.option('-d', '--dev_mode', help='if true only a small sample of data will be used', is_flag=True, required=False)
def train(pipeline_name, dev_mode):
    pipeline_manager.train(pipeline_name, dev_mode)


@main.command()
@click.option('-p', '--pipeline_name', help='pipeline to be trained', required=True)
@click.option('-d', '--dev_mode', help='if true only a small sample of data will be used', is_flag=True, required=False)
def evaluate(pipeline_name, dev_mode):
    pipeline_manager.evaluate(pipeline_name, dev_mode)


@main.command()
@click.option('-p', '--pipeline_name', help='pipeline to be trained', required=True)
@click.option('-s', '--submit_predictions', help='submit predictions if true', is_flag=True, required=False)
@click.option('-d', '--dev_mode', help='if true only a small sample of data will be used', is_flag=True, required=False)
def predict(pipeline_name, submit_predictions, dev_mode):
    pipeline_manager.predict(pipeline_name, submit_predictions, dev_mode)


@main.command()
@click.option('-p', '--pipeline_name', help='pipeline to be trained', required=True)
@click.option('-s', '--submit_predictions', help='submit predictions if true', is_flag=True, required=False)
@click.option('-d', '--dev_mode', help='if true only a small sample of data will be used', is_flag=True, required=False)
def train_evaluate_predict(pipeline_name, submit_predictions, dev_mode):
    pipeline_manager.train(pipeline_name, dev_mode)
    pipeline_manager.evaluate(pipeline_name, dev_mode)
    pipeline_manager.predict(pipeline_name, submit_predictions, dev_mode)


@main.command()
@click.option('-p', '--pipeline_name', help='pipeline to be trained', required=True)
@click.option('-d', '--dev_mode', help='if true only a small sample of data will be used', is_flag=True, required=False)
def train_evaluate(pipeline_name, dev_mode):
    pipeline_manager.train(pipeline_name, dev_mode)
    pipeline_manager.evaluate(pipeline_name, dev_mode)


@main.command()
@click.option('-p', '--pipeline_name', help='pipeline to be trained', required=True)
@click.option('-s', '--submit_predictions', help='submit predictions if true', is_flag=True, required=False)
@click.option('-d', '--dev_mode', help='if true only a small sample of data will be used', is_flag=True, required=False)
def evaluate_predict(pipeline_name, submit_predictions, dev_mode):
    pipeline_manager.evaluate(pipeline_name, dev_mode)
    pipeline_manager.predict(pipeline_name, submit_predictions, dev_mode)


@main.command()
@click.option('-p', '--pipeline_name', help='pipeline to be trained', required=True)
@click.option('-d', '--dev_mode', help='if true only a small sample of data will be used', is_flag=True, required=False)
def train_evaluate_cv(pipeline_name, dev_mode):
    pipeline_manager.train_evaluate_cv(pipeline_name, dev_mode)


if __name__ == "__main__":
    main()

ModuleNotFoundError: No module named 'pycocotools'