In [1]:
import sys
sys.path.append('./KPConv-PyTorch/')

In [2]:
from utils.config import Config
from datasets.CellData import CellDataset, CellDataCollate, CellDataCustomBatch, CellDataSampler
from models.architectures import KPFCNN
from torch.utils.data import DataLoader
from utils.trainer import ModelTrainer
import numpy as np
import torch

In [3]:
class CellDataConfig(Config):
    """
    Override the parameters you want to modify for this dataset
    """

    ####################
    # Dataset parameters
    ####################

    # Dataset name
    dataset = 'Fluo-C3DH-A549'

    # Number of classes in the dataset (This value is overwritten by dataset class when Initializating dataset).
    num_classes = None

    # Type of task performed on this dataset (also overwritten)
    dataset_task = 'cloud_segmentation'

    # Number of CPU threads for the input pipeline
    input_threads = 0

    #########################
    # Architecture definition
    #########################

    # Define layers
    architecture = ['simple',
                    'resnetb',
                    'resnetb_strided',
                    'resnetb',
                    'resnetb',
                    'resnetb_strided',
                    'resnetb_deformable',
                    'resnetb_deformable',
                    'resnetb_deformable_strided',
                    'resnetb_deformable',
                    'resnetb_deformable',
                    'resnetb_deformable_strided',
                    'resnetb_deformable',
                    'resnetb_deformable',
                    'nearest_upsample',
                    'unary',
                    'nearest_upsample',
                    'unary',
                    'nearest_upsample',
                    'unary',
                    'nearest_upsample',
                    'unary']

    ###################
    # KPConv parameters
    ###################

    # Radius of the input sphere
    in_radius = 2.0

    # Number of kernel points
    num_kernel_points = 15

    # Size of the first subsampling grid in meter
    first_subsampling_dl = 0.08

    # Radius of convolution in "number grid cell". (2.5 is the standard value)
    conv_radius = 2.5

    # Radius of deformable convolution in "number grid cell". Larger so that deformed kernel can spread out
    deform_radius = 6.0

    # Radius of the area of influence of each kernel point in "number grid cell". (1.0 is the standard value)
    KP_extent = 1.2

    # Behavior of convolutions in ('constant', 'linear', 'gaussian')
    KP_influence = 'linear'

    # Aggregation function of KPConv in ('closest', 'sum')
    aggregation_mode = 'sum'

    # Choice of input features
    first_features_dim = 128
    in_features_dim = 1

    # Can the network learn modulations
    modulated = False

    # Batch normalization parameters
    use_batch_norm = True
    batch_norm_momentum = 0.02

    # Deformable offset loss
    # 'point2point' fitting geometry by penalizing distance from deform point to input points
    # 'point2plane' fitting geometry by penalizing distance from deform point to input point triplet (not implemented)
    deform_fitting_mode = 'point2point'
    deform_fitting_power = 1.0              # Multiplier for the fitting/repulsive loss
    deform_lr_factor = 0.1                  # Multiplier for learning rate applied to the deformations
    repulse_extent = 1.2                    # Distance of repulsion for deformed kernel points

    #####################
    # Training parameters
    #####################

    # Maximal number of epochs
    max_epoch = 500

    # Learning rate management
    learning_rate = 1e-4
    momentum = 0.98
    lr_decays = {i: 0.1 ** (1 / 150) for i in range(1, max_epoch)}
    grad_clip_norm = 100.0

    # Number of batch
    batch_num = 10

    # Number of steps per epochs
    epoch_steps = 200

    # Number of validation examples per epoch
    validation_size = 50

    # Number of epoch between each checkpoint
    checkpoint_gap = 10

    # Augmentations
    augment_scale_anisotropic = True
    augment_symmetries = [True, False, False]
    augment_rotation = 'vertical'
    augment_scale_min = 0.8
    augment_scale_max = 1.2
    augment_noise = 0.001
    augment_color = 0.8

    # The way we balance segmentation loss
    #   > 'none': Each point in the whole batch has the same contribution.
    #   > 'class': Each class has the same contribution (points are weighted according to class balance)
    #   > 'batch': Each cloud in the batch has the same contribution (points are weighted according cloud sizes)
    segloss_balance = 'class'
    class_w = [0.02, 0.98]

    # Do we need to save convergence
    saving = True
    saving_path = None

In [4]:
config = CellDataConfig()
train_dataset = CellDataset(config, set = 'training')
validation_dataset = CellDataset(config, set = 'validation')

training_sampler = CellDataSampler(train_dataset)
validation_sampler = CellDataSampler(validation_dataset)


Preparing KDTree for cloud t000, subsampled at 0.080
14.7 MB loaded in 3.0s

Preparing KDTree for cloud t001_2, subsampled at 0.080
14.7 MB loaded in 3.1s

Preparing KDTree for cloud t002, subsampled at 0.080
14.7 MB loaded in 3.0s

Preparing KDTree for cloud t006, subsampled at 0.080
14.7 MB loaded in 2.9s

Preparing KDTree for cloud t008, subsampled at 0.080
14.7 MB loaded in 3.0s

Preparing KDTree for cloud t008_2, subsampled at 0.080
14.7 MB loaded in 3.1s

Preparing KDTree for cloud t011_2, subsampled at 0.080
14.7 MB loaded in 2.9s

Preparing KDTree for cloud t013, subsampled at 0.080
14.7 MB loaded in 3.0s

Preparing KDTree for cloud t013_2, subsampled at 0.080
14.7 MB loaded in 3.0s

Preparing KDTree for cloud t014, subsampled at 0.080
14.7 MB loaded in 3.2s

Preparing KDTree for cloud t015, subsampled at 0.080
14.7 MB loaded in 3.1s

Preparing KDTree for cloud t015_2, subsampled at 0.080
14.7 MB loaded in 3.1s

Preparing KDTree for cloud t019, subsampled at 0.080
14.7 MB load

In [5]:
training_loader = DataLoader(train_dataset,
                            collate_fn=CellDataCollate,
                            sampler=training_sampler,)

validation_loader = DataLoader(validation_dataset,
                              collate_fn=CellDataCollate,
                              sampler=training_sampler,)

training_sampler.calibration(training_loader)
validation_sampler.calibration(validation_loader)

net = KPFCNN(config, train_dataset.label_values, train_dataset.ignored_labels)
trainer = ModelTrainer(net, config)


Starting Calibration (use verbose=True for more details)
Calibration done in 0.0s


Starting Calibration (use verbose=True for more details)
Calibration done in 0.0s



In [None]:
trainer.train(net, training_loader, validation_loader, config)

e000-i0000 => L=19.339 acc= 50% / t(ms):  20.9 1873.9 249.9)
e000-i0003 => L=19.406 acc= 49% / t(ms):  29.0 220.6 239.6)
e000-i0006 => L=19.344 acc= 43% / t(ms):  24.8 216.8 236.8)
e000-i0009 => L=19.172 acc= 56% / t(ms):  23.8 212.6 235.8)
e000-i0012 => L=19.151 acc= 56% / t(ms):  21.5 215.4 234.0)
e000-i0015 => L=18.605 acc= 51% / t(ms):  20.4 207.8 236.2)
e000-i0018 => L=18.601 acc= 73% / t(ms):  20.5 205.6 233.9)
e000-i0021 => L=18.443 acc= 63% / t(ms):  20.7 209.3 229.4)
e000-i0024 => L=18.512 acc= 70% / t(ms):  19.3 211.4 232.9)
e000-i0027 => L=18.290 acc= 71% / t(ms):  18.3 211.4 233.7)
e000-i0030 => L=18.235 acc= 70% / t(ms):  17.6 211.2 229.6)
e000-i0033 => L=18.110 acc= 79% / t(ms):  17.0 209.8 227.8)
e000-i0036 => L=18.522 acc= 73% / t(ms):  16.7 207.8 227.2)
e000-i0039 => L=17.721 acc= 88% / t(ms):  16.5 207.3 227.2)
e000-i0042 => L=17.910 acc= 91% / t(ms):  18.9 204.6 225.4)
e000-i0045 => L=17.927 acc= 96% / t(ms):  22.3 205.3 224.1)
e000-i0048 => L=17.577 acc= 94% / t(ms)

e001-i0138 => L=12.642 acc= 85% / t(ms):  18.7 232.2 249.6)
e001-i0140 => L=11.882 acc=100% / t(ms):  19.7 239.4 258.6)
e001-i0142 => L=13.047 acc= 72% / t(ms):  20.5 248.3 265.8)
e001-i0144 => L=12.354 acc= 99% / t(ms):  21.1 249.8 271.7)
e001-i0146 => L=12.712 acc= 86% / t(ms):  20.1 255.2 276.5)
e001-i0148 => L=11.990 acc=100% / t(ms):  22.2 256.6 284.7)
e001-i0150 => L=11.871 acc=100% / t(ms):  22.5 259.4 287.2)
e001-i0152 => L=11.835 acc=100% / t(ms):  22.6 265.0 287.4)
e001-i0154 => L=13.209 acc= 71% / t(ms):  22.7 270.9 286.3)
e001-i0156 => L=11.745 acc=100% / t(ms):  24.3 268.2 286.7)
e001-i0158 => L=11.921 acc=100% / t(ms):  25.6 262.4 286.8)
e001-i0161 => L=11.829 acc=100% / t(ms):  22.9 252.0 274.0)
e001-i0164 => L=12.620 acc= 89% / t(ms):  22.5 243.6 262.0)
e001-i0167 => L=12.012 acc=100% / t(ms):  20.6 235.6 256.2)
e001-i0170 => L=11.898 acc= 99% / t(ms):  20.8 232.7 249.4)
e001-i0173 => L=11.752 acc=100% / t(ms):  19.4 226.4 247.4)
e001-i0176 => L=12.002 acc=100% / t(ms):

Validation : 212.0% (timings : 29.10 113.81)
Validation : 226.0% (timings : 29.75 114.99)
Validation : 240.0% (timings : 31.08 115.88)
Validation : 256.0% (timings : 29.65 111.07)
Validation : 276.0% (timings : 26.61 101.35)
Validation : 296.0% (timings : 24.91 92.32)
Validation : 316.0% (timings : 22.60 90.84)
Validation : 336.0% (timings : 21.38 88.59)
Validation : 356.0% (timings : 21.68 87.10)
Validation : 376.0% (timings : 22.25 85.01)
Validation : 396.0% (timings : 21.29 84.49)
Confusion Matrix
****************
[[5.0752119e+05 4.3738286e+03]
 [1.1915121e+04 4.7787918e+02]]
[0.9689029  0.02850147]
Fluo-C3DH-A549 mean IoU = 49.9%
e003-i0000 => L=11.704 acc=100% / t(ms): 25753.8 248.8 218.7)
e003-i0003 => L=11.708 acc= 97% / t(ms):  15.6 232.3 232.8)
e003-i0006 => L=11.654 acc=100% / t(ms):  15.6 227.5 232.2)
e003-i0009 => L=11.790 acc=100% / t(ms):  15.6 226.0 231.2)
e003-i0012 => L=11.509 acc=100% / t(ms):  17.2 225.8 229.6)
e003-i0015 => L=11.621 acc=100% / t(ms):  18.3 223.8 227

e004-i0062 => L=11.545 acc= 99% / t(ms):  23.1 271.2 290.7)
e004-i0064 => L=11.969 acc= 99% / t(ms):  24.7 271.0 293.5)
e004-i0066 => L=11.406 acc= 92% / t(ms):  29.2 271.1 295.1)
e004-i0068 => L=11.459 acc= 94% / t(ms):  28.2 283.6 301.6)
e004-i0070 => L=11.516 acc=100% / t(ms):  27.2 282.7 293.0)
e004-i0073 => L=11.345 acc= 89% / t(ms):  24.1 267.2 277.1)
e004-i0076 => L=11.347 acc= 91% / t(ms):  23.2 255.6 262.5)
e004-i0079 => L=11.303 acc=100% / t(ms):  21.1 244.7 256.2)
e004-i0082 => L=11.531 acc= 91% / t(ms):  20.9 242.0 250.2)
e004-i0085 => L=11.483 acc= 99% / t(ms):  20.7 236.9 245.9)
e004-i0088 => L=11.450 acc= 94% / t(ms):  20.7 231.2 241.5)
e004-i0091 => L=11.228 acc= 96% / t(ms):  20.8 229.8 241.4)
e004-i0094 => L=11.538 acc=100% / t(ms):  19.4 225.9 237.2)
e004-i0097 => L=11.530 acc=100% / t(ms):  19.6 224.5 235.2)
e004-i0100 => L=12.415 acc= 84% / t(ms):  21.3 223.9 237.9)
e004-i0102 => L=11.516 acc= 99% / t(ms):  20.3 225.7 243.2)
e004-i0104 => L=11.881 acc= 94% / t(ms):

e005-i0140 => L=11.405 acc= 89% / t(ms):  24.2 274.4 292.2)
e005-i0142 => L=11.196 acc= 90% / t(ms):  24.0 276.0 293.1)
e005-i0144 => L=11.183 acc= 94% / t(ms):  22.4 278.2 293.8)
e005-i0146 => L=11.337 acc= 93% / t(ms):  21.1 278.3 292.8)
e005-i0148 => L=11.415 acc= 94% / t(ms):  21.5 270.3 283.8)
e005-i0151 => L=11.032 acc=100% / t(ms):  23.0 255.8 267.8)
e005-i0154 => L=11.747 acc= 92% / t(ms):  22.4 250.1 255.2)
e005-i0157 => L=11.355 acc= 90% / t(ms):  20.6 240.0 249.3)
e005-i0160 => L=11.335 acc= 87% / t(ms):  20.8 240.0 244.4)
e005-i0163 => L=11.384 acc= 96% / t(ms):  20.8 231.1 240.6)
e005-i0166 => L=13.095 acc= 81% / t(ms):  19.4 229.7 237.9)
e005-i0169 => L=11.398 acc= 97% / t(ms):  22.6 229.1 233.9)
e005-i0172 => L=11.392 acc= 97% / t(ms):  20.7 230.3 234.0)
e005-i0175 => L=11.368 acc= 92% / t(ms):  22.0 228.0 231.4)
e005-i0178 => L=11.358 acc= 97% / t(ms):  20.3 227.3 232.9)
e005-i0181 => L=11.290 acc= 94% / t(ms):  19.0 227.5 233.5)
e005-i0184 => L=11.248 acc= 94% / t(ms):