In [1]:
import sys
sys.path.append('./KPConv-PyTorch/')

In [2]:
from utils.config import Config
from datasets.CellData import CellDataset, CellDataCollate, CellDataCustomBatch, CellDataSampler
from models.architectures import KPFCNN
from torch.utils.data import DataLoader
from utils.trainer import ModelTrainer
import numpy as np
import torch

In [3]:
class CellDataConfig(Config):
    """
    Override the parameters you want to modify for this dataset
    """

    ####################
    # Dataset parameters
    ####################

    # Dataset name
    dataset = 'Fluo-C3DH-A549'

    # Number of classes in the dataset (This value is overwritten by dataset class when Initializating dataset).
    num_classes = None

    # Type of task performed on this dataset (also overwritten)
    dataset_task = 'cloud_segmentation'

    # Number of CPU threads for the input pipeline
    input_threads = 0

    #########################
    # Architecture definition
    #########################

    # Define layers
    architecture = ['simple',
                    'resnetb',
                    'resnetb_strided',
                    'resnetb',
                    'resnetb',
                    'resnetb_strided',
                    'resnetb_deformable',
                    'resnetb_deformable',
                    'resnetb_deformable_strided',
                    'resnetb_deformable',
                    'resnetb_deformable',
                    'resnetb_deformable_strided',
                    'resnetb_deformable',
                    'resnetb_deformable',
                    'nearest_upsample',
                    'unary',
                    'nearest_upsample',
                    'unary',
                    'nearest_upsample',
                    'unary',
                    'nearest_upsample',
                    'unary']

    ###################
    # KPConv parameters
    ###################

    # Radius of the input sphere
    in_radius = 1.5

    # Number of kernel points
    num_kernel_points = 15

    # Size of the first subsampling grid in meter
    first_subsampling_dl = 0.03

    # Radius of convolution in "number grid cell". (2.5 is the standard value)
    conv_radius = 2.5

    # Radius of deformable convolution in "number grid cell". Larger so that deformed kernel can spread out
    deform_radius = 6.0

    # Radius of the area of influence of each kernel point in "number grid cell". (1.0 is the standard value)
    KP_extent = 1.0

    # Behavior of convolutions in ('constant', 'linear', 'gaussian')
    KP_influence = 'linear'

    # Aggregation function of KPConv in ('closest', 'sum')
    aggregation_mode = 'sum'

    # Choice of input features
    first_features_dim = 128
    in_features_dim = 1

    # Can the network learn modulations
    modulated = False

    # Batch normalization parameters
    use_batch_norm = True
    batch_norm_momentum = 0.02

    # Deformable offset loss
    # 'point2point' fitting geometry by penalizing distance from deform point to input points
    # 'point2plane' fitting geometry by penalizing distance from deform point to input point triplet (not implemented)
    deform_fitting_mode = 'point2point'
    deform_fitting_power = 1.0              # Multiplier for the fitting/repulsive loss
    deform_lr_factor = 0.1                  # Multiplier for learning rate applied to the deformations
    repulse_extent = 1.2                    # Distance of repulsion for deformed kernel points

    #####################
    # Training parameters
    #####################

    # Maximal number of epochs
    max_epoch = 200

    # Learning rate management
    learning_rate = 1e-5
    momentum = 0.98
    lr_decays = {i: 0.1 ** (1 / 150) for i in range(1, max_epoch)}
    grad_clip_norm = 100.0

    # Number of batch
    batch_num = 10

    # Number of steps per epochs
    epoch_steps = 100

    # Number of validation examples per epoch
    validation_size = 50

    # Number of epoch between each checkpoint
    checkpoint_gap = 5

    # Augmentations
    augment_scale_anisotropic = True
    augment_symmetries = [True, False, False]
    augment_rotation = 'vertical'
    augment_scale_min = 0.8
    augment_scale_max = 1.2
    augment_noise = 0.001
    augment_color = 0.8

    # The way we balance segmentation loss
    #   > 'none': Each point in the whole batch has the same contribution.
    #   > 'class': Each class has the same contribution (points are weighted according to class balance)
    #   > 'batch': Each cloud in the batch has the same contribution (points are weighted according cloud sizes)
    segloss_balance = 'none'

    # Do we need to save convergence
    saving = True
    saving_path = None

In [4]:
config = CellDataConfig()
train_dataset = CellDataset(config, set = 'training')
validation_dataset = CellDataset(config, set = 'validation')

training_sampler = CellDataSampler(train_dataset)
validation_sampler = CellDataSampler(validation_dataset)


Preparing KDTree for cloud t000, subsampled at 0.030
14.7 MB loaded in 3.4s

Preparing KDTree for cloud t002, subsampled at 0.030
14.7 MB loaded in 3.2s

Preparing KDTree for cloud t006, subsampled at 0.030
14.7 MB loaded in 3.1s

Preparing KDTree for cloud t008, subsampled at 0.030
14.7 MB loaded in 3.4s

Preparing KDTree for cloud t012, subsampled at 0.030
14.7 MB loaded in 3.6s

Preparing KDTree for cloud t013, subsampled at 0.030
14.7 MB loaded in 3.3s

Preparing KDTree for cloud t014, subsampled at 0.030
14.7 MB loaded in 3.4s

Preparing KDTree for cloud t019, subsampled at 0.030
14.7 MB loaded in 3.4s

Preparing KDTree for cloud t022, subsampled at 0.030
14.7 MB loaded in 3.1s

Preparing KDTree for cloud t025, subsampled at 0.030
14.7 MB loaded in 3.1s

Preparing KDTree for cloud t026, subsampled at 0.030
14.7 MB loaded in 3.2s

Preparing KDTree for cloud t027, subsampled at 0.030
14.7 MB loaded in 2.9s

Preparing KDTree for cloud t028, subsampled at 0.030
14.7 MB loaded in 2.9s

In [5]:
training_loader = DataLoader(train_dataset,
                                 batch_size=1,
                            collate_fn=CellDataCollate,
                            sampler=training_sampler,
                             pin_memory=True)

validation_loader = DataLoader(validation_dataset,
                                 batch_size=1,
                              collate_fn=CellDataCollate,
                              sampler=training_sampler,
                             pin_memory=True)

training_sampler.calibration(training_loader, verbose=True)
validation_sampler.calibration(validation_loader, verbose=True)

net = KPFCNN(config, train_dataset.label_values, train_dataset.ignored_labels)
trainer = ModelTrainer(net, config)


Starting Calibration (use verbose=True for more details)

Previous calibration found:
Check batch limit dictionary
[92m"potentials_1.500_0.030_10": -1399[0m
Check neighbors limit dictionary
[92m"0.030_0.075": 1[0m
[92m"0.060_0.150": 1[0m
[92m"0.120_0.720": 1[0m
[92m"0.240_1.440": 10[0m
[92m"0.480_2.880": 18[0m
Calibration done in 0.0s


Starting Calibration (use verbose=True for more details)

Previous calibration found:
Check batch limit dictionary
[92m"potentials_1.500_0.030_10": -1399[0m
Check neighbors limit dictionary
[92m"0.030_0.075": 1[0m
[92m"0.060_0.150": 1[0m
[92m"0.120_0.720": 1[0m
[92m"0.240_1.440": 10[0m
[92m"0.480_2.880": 18[0m
Calibration done in 0.0s



In [6]:
trainer.train(net, training_loader, validation_loader, config)

e000-i0000 => L=38.635 acc=  0% / t(ms):  18.0 1706.3 363.2)
e000-i0003 => L=38.313 acc= 63% / t(ms):   9.5 225.6 256.5)
e000-i0006 => L=39.096 acc= 59% / t(ms):   9.2 221.5 256.5)
e000-i0009 => L=37.565 acc= 47% / t(ms):   8.7 218.7 250.8)
e000-i0012 => L=38.205 acc= 62% / t(ms):   8.5 218.3 253.6)
e000-i0015 => L=37.679 acc= 50% / t(ms):   8.2 220.2 257.8)
e000-i0018 => L=38.028 acc= 67% / t(ms):   8.2 221.1 254.3)
e000-i0021 => L=37.059 acc= 53% / t(ms):   8.0 220.4 250.8)
e000-i0023 => L=37.192 acc= 88% / t(ms):   8.0 223.2 253.7)
e000-i0026 => L=37.744 acc= 88% / t(ms):   8.1 220.5 251.7)
e000-i0029 => L=37.421 acc= 79% / t(ms):   8.3 219.0 253.3)
e000-i0032 => L=36.747 acc= 73% / t(ms):   8.0 220.2 250.8)
e000-i0035 => L=35.682 acc=100% / t(ms):   7.9 219.5 251.2)
e000-i0038 => L=35.932 acc=100% / t(ms):   8.4 221.2 256.8)
e000-i0041 => L=35.809 acc=100% / t(ms):   8.4 217.0 252.9)
e000-i0044 => L=35.575 acc=100% / t(ms):   8.1 217.7 249.9)
e000-i0046 => L=34.482 acc=100% / t(ms)

Validation : 114.0% (timings : 8.66 98.36)
Validation : 132.0% (timings : 8.64 100.45)
Validation : 150.0% (timings : 8.69 102.07)
Validation : 168.0% (timings : 8.83 102.50)
Validation : 188.0% (timings : 9.32 102.27)
Fluo-C3DH-A549 mean IoU = 48.8%
e003-i0000 => L=19.920 acc=100% / t(ms): 11699.0 316.3 275.8)
e003-i0002 => L=19.840 acc=100% / t(ms):  15.1 266.4 270.6)
e003-i0004 => L=20.093 acc=100% / t(ms):  23.6 259.8 271.6)
e003-i0006 => L=19.758 acc=100% / t(ms):  20.6 255.9 268.6)
e003-i0009 => L=19.961 acc=100% / t(ms):  16.9 247.8 266.1)
e003-i0011 => L=19.733 acc=100% / t(ms):  15.0 247.7 264.1)
e003-i0013 => L=19.456 acc=100% / t(ms):  13.7 245.9 266.4)
e003-i0015 => L=19.760 acc=100% / t(ms):  12.5 246.2 262.7)
e003-i0017 => L=19.790 acc=100% / t(ms):  11.7 245.7 261.5)
e003-i0019 => L=19.606 acc=100% / t(ms):  11.1 244.8 261.8)
e003-i0022 => L=19.564 acc=100% / t(ms):  10.4 242.1 265.9)
e003-i0024 => L=19.495 acc=100% / t(ms):   9.8 243.8 261.8)
e003-i0026 => L=19.486 acc=

e006-i0009 => L=18.640 acc=100% / t(ms):  10.3 247.1 249.7)
e006-i0011 => L=18.731 acc=100% / t(ms):  10.2 243.4 254.5)
e006-i0013 => L=18.964 acc=100% / t(ms):  10.1 254.8 258.8)
e006-i0016 => L=18.891 acc=100% / t(ms):   9.9 251.9 255.3)
e006-i0019 => L=18.618 acc=100% / t(ms):   9.3 245.5 253.5)
e006-i0022 => L=18.630 acc=100% / t(ms):   8.6 242.5 250.1)
e006-i0025 => L=18.573 acc=100% / t(ms):   8.4 234.2 248.0)
e006-i0028 => L=18.547 acc=100% / t(ms):   8.0 232.5 245.0)
e006-i0031 => L=18.843 acc=100% / t(ms):   7.9 229.7 248.0)
e006-i0033 => L=18.476 acc=100% / t(ms):   8.0 230.3 257.2)
e006-i0035 => L=18.642 acc=100% / t(ms):   8.2 235.6 266.7)
e006-i0037 => L=18.781 acc=100% / t(ms):   8.3 242.3 276.4)
e006-i0039 => L=18.453 acc=100% / t(ms):   8.4 252.1 282.7)
e006-i0041 => L=18.638 acc=100% / t(ms):   8.6 256.8 289.9)
e006-i0043 => L=18.607 acc=100% / t(ms):   8.4 253.0 286.3)
e006-i0046 => L=18.646 acc=100% / t(ms):   8.2 245.2 274.4)
e006-i0049 => L=18.862 acc=100% / t(ms):

e009-i0026 => L=18.522 acc=100% / t(ms):   7.3 223.9 247.6)
e009-i0029 => L=18.480 acc=100% / t(ms):   7.4 224.4 248.9)
e009-i0031 => L=18.469 acc=100% / t(ms):   7.3 227.2 250.4)
e009-i0033 => L=18.699 acc=100% / t(ms):   7.3 231.6 255.6)
e009-i0035 => L=18.688 acc=100% / t(ms):   7.5 237.5 265.5)
e009-i0037 => L=18.517 acc=100% / t(ms):   7.8 248.4 277.8)
e009-i0039 => L=18.535 acc=100% / t(ms):   8.4 259.6 288.3)
e009-i0041 => L=18.622 acc=100% / t(ms):   8.7 268.8 298.3)
e009-i0043 => L=18.647 acc=100% / t(ms):   9.0 275.4 302.9)
e009-i0045 => L=18.579 acc=100% / t(ms):   9.2 276.8 306.1)
e009-i0047 => L=18.498 acc=100% / t(ms):   9.1 271.7 301.0)
e009-i0050 => L=18.525 acc=100% / t(ms):   8.9 261.7 287.8)
e009-i0053 => L=18.646 acc=100% / t(ms):   8.6 252.4 282.1)
e009-i0056 => L=18.502 acc=100% / t(ms):   8.2 245.4 276.3)
e009-i0059 => L=18.441 acc=100% / t(ms):   8.1 240.1 269.1)
e009-i0061 => L=18.439 acc=100% / t(ms):   8.0 239.6 267.5)
e009-i0063 => L=18.852 acc=100% / t(ms):

e012-i0034 => L=18.726 acc=100% / t(ms):   7.7 237.0 250.0)
e012-i0037 => L=18.448 acc=100% / t(ms):   7.6 232.5 248.3)
e012-i0040 => L=18.423 acc=100% / t(ms):   7.5 231.2 248.1)
e012-i0043 => L=18.539 acc=100% / t(ms):   7.4 227.9 246.5)
e012-i0046 => L=18.327 acc=100% / t(ms):  10.1 222.9 247.3)
e012-i0049 => L=18.495 acc=100% / t(ms):   9.3 220.3 249.7)
e012-i0052 => L=18.448 acc=100% / t(ms):   8.9 222.4 247.4)
e012-i0055 => L=18.367 acc=100% / t(ms):   8.5 224.5 246.2)
e012-i0058 => L=18.505 acc=100% / t(ms):  12.5 221.4 249.5)
e012-i0060 => L=18.634 acc=100% / t(ms):  11.5 223.4 253.9)
e012-i0062 => L=19.199 acc=100% / t(ms):  10.7 226.5 259.0)
e012-i0064 => L=18.503 acc=100% / t(ms):  10.5 235.9 267.5)
e012-i0066 => L=18.641 acc=100% / t(ms):  10.3 244.0 278.8)
e012-i0068 => L=18.442 acc=100% / t(ms):  10.1 248.8 283.6)
e012-i0070 => L=18.466 acc=100% / t(ms):  10.0 253.4 293.4)
e012-i0072 => L=18.573 acc=100% / t(ms):   9.8 262.3 301.7)
e012-i0074 => L=18.695 acc=100% / t(ms):

e015-i0043 => L=18.743 acc=100% / t(ms):   9.8 217.4 248.7)
e015-i0045 => L=18.551 acc=100% / t(ms):   9.3 221.3 250.1)
e015-i0048 => L=18.642 acc=100% / t(ms):   9.0 221.0 247.0)
e015-i0051 => L=18.276 acc=100% / t(ms):   8.6 221.6 245.6)
e015-i0054 => L=18.489 acc=100% / t(ms):   8.4 222.9 246.3)
e015-i0057 => L=18.608 acc=100% / t(ms):   8.0 222.6 247.4)
e015-i0060 => L=18.305 acc=100% / t(ms):   7.8 222.5 245.5)
e015-i0063 => L=18.450 acc=100% / t(ms):   7.8 219.5 245.3)
e015-i0066 => L=18.523 acc=100% / t(ms):   7.6 217.8 244.3)
e015-i0069 => L=18.653 acc=100% / t(ms):   7.5 217.8 244.4)
e015-i0072 => L=18.314 acc=100% / t(ms):   7.5 218.7 245.8)
e015-i0075 => L=18.772 acc=100% / t(ms):   7.6 220.5 249.2)
e015-i0077 => L=18.871 acc=100% / t(ms):   7.5 226.4 255.5)
e015-i0079 => L=18.582 acc=100% / t(ms):   7.6 233.5 264.4)
e015-i0081 => L=18.483 acc=100% / t(ms):   7.9 241.8 273.5)
e015-i0083 => L=18.710 acc=100% / t(ms):   8.2 249.6 281.4)
e015-i0085 => L=18.545 acc=100% / t(ms):

KeyboardInterrupt: 