In [1]:
import sys
sys.path.append('./KPConv-PyTorch/')

In [2]:
from utils.config import Config
from datasets.CellData import CellDataset, CellDataCollate, CellDataCustomBatch
from models.architectures import KPFCNN
from torch.utils.data import DataLoader
from utils.trainer import ModelTrainer
import numpy as np
import torch

In [3]:
class CellDataConfig(Config):
    """
    Override the parameters you want to modify for this dataset
    """

    ####################
    # Dataset parameters
    ####################

    # Dataset name
    dataset = 'Fluo-C3DH-A549'

    # Number of classes in the dataset (This value is overwritten by dataset class when Initializating dataset).
    num_classes = None

    # Type of task performed on this dataset (also overwritten)
    dataset_task = 'cloud_segmentation'

    # Number of CPU threads for the input pipeline
    input_threads = 0

    #########################
    # Architecture definition
    #########################

    # Define layers
    architecture = ['simple',
                    'resnetb',
                    'resnetb_strided',
                    'resnetb',
                    'resnetb',
                    'resnetb_strided',
                    'resnetb_deformable',
                    'resnetb_deformable',
                    'resnetb_deformable_strided',
                    'resnetb_deformable',
                    'resnetb_deformable',
                    'resnetb_deformable_strided',
                    'resnetb_deformable',
                    'resnetb_deformable',
                    'nearest_upsample',
                    'unary',
                    'nearest_upsample',
                    'unary',
                    'nearest_upsample',
                    'unary',
                    'nearest_upsample',
                    'unary']

    ###################
    # KPConv parameters
    ###################

    # Radius of the input sphere
    in_radius = 1.5

    # Number of kernel points
    num_kernel_points = 15

    # Size of the first subsampling grid in meter
    first_subsampling_dl = 0.08

    # Radius of convolution in "number grid cell". (2.5 is the standard value)
    conv_radius = 2.5

    # Radius of deformable convolution in "number grid cell". Larger so that deformed kernel can spread out
    deform_radius = 6.0

    # Radius of the area of influence of each kernel point in "number grid cell". (1.0 is the standard value)
    KP_extent = 1.2

    # Behavior of convolutions in ('constant', 'linear', 'gaussian')
    KP_influence = 'linear'

    # Aggregation function of KPConv in ('closest', 'sum')
    aggregation_mode = 'sum'

    # Choice of input features
    first_features_dim = 128
    in_features_dim = 1

    # Can the network learn modulations
    modulated = False

    # Batch normalization parameters
    use_batch_norm = True
    batch_norm_momentum = 0.02

    # Deformable offset loss
    # 'point2point' fitting geometry by penalizing distance from deform point to input points
    # 'point2plane' fitting geometry by penalizing distance from deform point to input point triplet (not implemented)
    deform_fitting_mode = 'point2point'
    deform_fitting_power = 1.0              # Multiplier for the fitting/repulsive loss
    deform_lr_factor = 0.1                  # Multiplier for learning rate applied to the deformations
    repulse_extent = 1.2                    # Distance of repulsion for deformed kernel points

    #####################
    # Training parameters
    #####################

    # Maximal number of epochs
    max_epoch = 500

    # Learning rate management
    learning_rate = 1e-4
    momentum = 0.98
    lr_decays = {i: 0.1 ** (1 / 150) for i in range(1, max_epoch)}
    grad_clip_norm = 100.0

    # Number of batch
    batch_num = 10

    # Number of steps per epochs
    epoch_steps = 500

    # Number of validation examples per epoch
    validation_size = 50

    # Number of epoch between each checkpoint
    checkpoint_gap = 15

    # Augmentations
    augment_scale_anisotropic = True
    augment_symmetries = [True, False, False]
    augment_rotation = 'vertical'
    augment_scale_min = 0.8
    augment_scale_max = 1.2
    augment_noise = 0.001
    augment_color = 0.8

    # The way we balance segmentation loss
    #   > 'none': Each point in the whole batch has the same contribution.
    #   > 'class': Each class has the same contribution (points are weighted according to class balance)
    #   > 'batch': Each cloud in the batch has the same contribution (points are weighted according cloud sizes)
    segloss_balance = 'none'

    # Do we need to save convergence
    saving = True
    saving_path = "./checkpoints/"

In [4]:
config = CellDataConfig()
train_dataset = CellDataset(config, set = 'training')
validation_dataset = CellDataset(config, set = 'validation')


Preparing KDTree for cloud t000, subsampled at 0.080
14.7 MB loaded in 4.6s

Preparing KDTree for cloud t002, subsampled at 0.080
14.7 MB loaded in 4.2s

Preparing KDTree for cloud t006, subsampled at 0.080
14.7 MB loaded in 5.0s

Preparing KDTree for cloud t008, subsampled at 0.080
14.7 MB loaded in 4.2s

Preparing KDTree for cloud t012, subsampled at 0.080
14.7 MB loaded in 4.5s

Preparing KDTree for cloud t013, subsampled at 0.080
14.7 MB loaded in 4.8s

Preparing KDTree for cloud t014, subsampled at 0.080
14.7 MB loaded in 4.5s

Preparing KDTree for cloud t019, subsampled at 0.080
14.7 MB loaded in 4.0s

Preparing KDTree for cloud t022, subsampled at 0.080
14.7 MB loaded in 4.4s

Preparing KDTree for cloud t025, subsampled at 0.080
14.7 MB loaded in 4.3s

Preparing KDTree for cloud t026, subsampled at 0.080
14.7 MB loaded in 4.1s

Preparing KDTree for cloud t027, subsampled at 0.080
14.7 MB loaded in 4.1s

Preparing KDTree for cloud t028, subsampled at 0.080
14.7 MB loaded in 4.0s

In [5]:
training_loader = DataLoader(train_dataset,
                                 batch_size=1,
                            collate_fn=CellDataCollate)

validation_loader = DataLoader(validation_dataset,
                                 batch_size=1,
                              collate_fn=CellDataCollate)

net = KPFCNN(config, train_dataset.label_values, train_dataset.ignored_labels)
trainer = ModelTrainer(net, config)

In [6]:
trainer.train(net, training_loader, validation_loader, config)

e000-i0000 => L=20.789 acc=  0% / t(ms):  28.9 1894.9 294.0)
e000-i0002 => L=20.719 acc= 14% / t(ms):  18.9 363.3 285.2)
e000-i0004 => L=20.396 acc=  0% / t(ms):  19.5 350.2 291.1)
e000-i0006 => L=21.117 acc=  0% / t(ms):  24.2 346.3 297.0)
e000-i0008 => L=20.820 acc=  0% / t(ms):  22.3 334.1 295.5)
e000-i0010 => L=20.747 acc=  6% / t(ms):  20.9 327.9 303.3)
e000-i0012 => L=22.182 acc=  0% / t(ms):  21.8 327.7 305.5)
Fluo-C3DH-A549 mean IoU = 0.0%
e001-i0000 => L=20.194 acc=  0% / t(ms): 619.8 308.4 300.0)
e001-i0002 => L=21.697 acc=  0% / t(ms):  15.6 348.4 407.9)
e001-i0004 => L=20.679 acc=  6% / t(ms):  21.0 333.5 382.4)
e001-i0006 => L=20.554 acc=  7% / t(ms):  19.7 319.5 357.1)
e001-i0008 => L=20.659 acc=  0% / t(ms):  39.9 302.7 337.0)
e001-i0010 => L=20.205 acc=  0% / t(ms):  36.6 298.0 319.3)
e001-i0012 => L=20.577 acc=  6% / t(ms):  32.4 287.9 307.7)
Fluo-C3DH-A549 mean IoU = 0.0%
e002-i0000 => L=22.471 acc=  0% / t(ms): 942.4 242.4 242.9)
e002-i0002 => L=20.089 acc= 12% / t(m

e018-i0000 => L=18.362 acc= 56% / t(ms): 519.8 287.7 287.0)
e018-i0002 => L=18.007 acc= 67% / t(ms): 107.8 281.1 273.2)
e018-i0004 => L=18.119 acc= 67% / t(ms):  90.7 281.0 275.2)
e018-i0006 => L=18.292 acc= 60% / t(ms):  76.7 281.6 275.7)
e018-i0008 => L=17.989 acc= 53% / t(ms):  66.0 293.0 274.5)
e018-i0010 => L=17.335 acc= 60% / t(ms):  56.2 293.2 279.6)
e018-i0012 => L=17.208 acc= 50% / t(ms):  48.5 292.0 277.6)
Fluo-C3DH-A549 mean IoU = 100.0%
e019-i0000 => L=18.131 acc= 67% / t(ms): 630.3 288.0 276.8)
e019-i0002 => L=17.382 acc= 67% / t(ms):  17.6 281.1 283.8)
e019-i0004 => L=17.798 acc= 58% / t(ms):  17.2 281.7 280.3)
e019-i0006 => L=18.030 acc= 33% / t(ms):  16.7 280.6 283.5)
e019-i0008 => L=18.162 acc= 64% / t(ms):  16.2 278.8 281.7)
e019-i0010 => L=18.828 acc= 67% / t(ms):  15.8 276.9 281.0)
e019-i0012 => L=17.425 acc= 38% / t(ms):  15.9 281.5 285.5)
Fluo-C3DH-A549 mean IoU = 62.5%
e020-i0000 => L=18.382 acc= 69% / t(ms): 546.1 287.8 258.3)
e020-i0002 => L=18.005 acc= 65% / t

e035-i0004 => L=17.166 acc= 93% / t(ms):  19.2 403.9 388.3)
e035-i0006 => L=16.423 acc= 88% / t(ms):  18.6 380.8 367.7)
e035-i0008 => L=17.372 acc= 78% / t(ms):  17.7 359.8 349.9)
e035-i0010 => L=16.859 acc= 80% / t(ms):  73.7 340.0 332.9)
e035-i0012 => L=16.722 acc= 93% / t(ms):  62.8 325.1 319.5)
Fluo-C3DH-A549 mean IoU = 100.0%
e036-i0000 => L=17.795 acc= 80% / t(ms): 916.2 291.3 288.5)
e036-i0002 => L=16.634 acc=100% / t(ms):  18.5 286.2 280.9)
e036-i0004 => L=16.570 acc= 93% / t(ms):  17.4 283.1 280.6)
e036-i0006 => L=16.659 acc= 88% / t(ms):  18.7 289.1 294.5)
e036-i0008 => L=16.971 acc=100% / t(ms):  18.4 284.0 289.4)
e036-i0010 => L=17.360 acc= 93% / t(ms):  18.5 285.6 285.3)
e036-i0012 => L=17.301 acc=100% / t(ms):  17.9 283.9 286.5)
Fluo-C3DH-A549 mean IoU = 0.0%
e037-i0000 => L=17.158 acc= 87% / t(ms): 679.2 369.3 372.6)
e037-i0002 => L=17.862 acc= 94% / t(ms):  22.4 376.3 562.9)
e037-i0004 => L=16.967 acc= 93% / t(ms):  23.2 389.9 536.8)
e037-i0006 => L=16.909 acc=100% / t(

e052-i0010 => L=16.303 acc=100% / t(ms):  13.7 272.5 275.2)
e052-i0012 => L=16.989 acc=100% / t(ms):  13.2 272.3 278.6)
Fluo-C3DH-A549 mean IoU = 100.0%
e053-i0000 => L=15.999 acc=100% / t(ms): 1326.0 240.6 246.9)
e053-i0002 => L=16.595 acc= 93% / t(ms):  10.7 279.2 259.8)
e053-i0004 => L=15.516 acc=100% / t(ms):  10.6 280.9 264.6)
e053-i0006 => L=16.358 acc= 79% / t(ms):  10.5 280.8 268.0)
e053-i0008 => L=16.583 acc=100% / t(ms):  11.4 283.7 269.2)
e053-i0010 => L=16.223 acc=100% / t(ms):  11.0 283.0 272.2)
e053-i0012 => L=16.528 acc=100% / t(ms):  10.8 280.6 272.5)
Fluo-C3DH-A549 mean IoU = 100.0%
e054-i0000 => L=16.731 acc=100% / t(ms): 524.9 280.9 264.1)
e054-i0002 => L=15.950 acc= 95% / t(ms):   9.2 265.9 276.0)
e054-i0004 => L=15.663 acc=100% / t(ms):  10.1 278.1 287.7)
e054-i0006 => L=16.625 acc=100% / t(ms):  10.3 285.3 283.5)
e054-i0008 => L=16.151 acc=100% / t(ms):  10.0 285.4 283.7)
e054-i0010 => L=16.182 acc=100% / t(ms):  15.0 291.1 282.9)
e054-i0012 => L=16.635 acc=100% /

Fluo-C3DH-A549 mean IoU = 100.0%
e070-i0000 => L=15.863 acc=100% / t(ms): 503.2 298.8 267.1)
e070-i0002 => L=16.014 acc=100% / t(ms):  11.1 325.7 315.1)


KeyboardInterrupt: 

In [None]:
####At time of prediction, variable train_dataset.label_values might not exist####
config = CellDataConfig()
test_dataset = CellDataset(config, set = 'test')
test_loader = DataLoader(test_dataset,
                                 batch_size=1,
                            collate_fn=CellDataCollate)
predictor_net = KPFCNN(config, test_dataset.label_values, test_dataset.ignored_labels)

#Prediction
trained_state_dict = torch.load(config.saving_path + "checkpoints/chkp_0060.tar")
predictor_net.load_state_dict(trained_state_dict["model_state_dict"])

In [15]:
for batch in test_dataset:
    print(batch)
    out = predictor_net.forward(CellDataCustomBatch(batch), config)
    print(out)
    break

[array([[-0.91076297,  0.3148507 ,  0.2296465 ],
       [-0.9128102 ,  0.31340718, -0.76485544],
       [-1.1024333 , -0.75365   ,  0.22992776],
       [-0.08197124,  0.07015906,  0.2277575 ],
       [-0.08148889,  0.07162248, -0.7669312 ],
       [-0.2713964 , -1.0003787 ,  0.22885692],
       [-0.2711405 , -0.99894017, -0.76679003],
       [ 0.1078346 ,  1.1416628 ,  0.22784108],
       [ 0.10880942,  1.141107  , -0.7660433 ],
       [ 0.56006885, -1.2435691 ,  0.22873339],
       [ 0.74919695, -0.17294763,  0.22986808],
       [ 0.74812144, -0.17233273, -0.76889205],
       [ 0.93842256,  0.8969245 ,  0.23059952],
       [-0.08114512,  0.07251382,  1.2234466 ]], dtype=float32), array([[-0.910763  ,  0.3148507 ,  0.22964653],
       [-0.9128103 ,  0.3134072 , -0.76485544],
       [-1.1024333 , -0.75364995,  0.22992766],
       [ 0.10783458,  1.1416628 ,  0.22784108],
       [-0.27139637, -1.0003787 ,  0.22885692],
       [-0.08114517,  0.07251382,  1.2234465 ],
       [-0.08197124,  

IndexError: too many indices for tensor of dimension 1