In [1]:
import sys
sys.path.append('./KPConv-PyTorch/')

In [2]:
from utils.config import Config
from datasets.CellData import CellDataset, CellDataCollate
from models.architectures import KPFCNN
from torch.utils.data import DataLoader
from utils.trainer import ModelTrainer
import numpy as np
import torch

In [3]:
class CellDataConfig(Config):
    """
    Override the parameters you want to modify for this dataset
    """

    ####################
    # Dataset parameters
    ####################

    # Dataset name
    dataset = 'Fluo-C3DH-A549'

    # Number of classes in the dataset (This value is overwritten by dataset class when Initializating dataset).
    num_classes = None

    # Type of task performed on this dataset (also overwritten)
    dataset_task = ''

    # Number of CPU threads for the input pipeline
    input_threads = 0

    #########################
    # Architecture definition
    #########################

    # Define layers
    architecture = ['simple',
                    'resnetb',
                    'resnetb_strided',
                    'resnetb',
                    'resnetb',
                    'resnetb_strided',
                    'resnetb_deformable',
                    'resnetb_deformable',
                    'resnetb_deformable_strided',
                    'resnetb_deformable',
                    'resnetb_deformable',
                    'resnetb_deformable_strided',
                    'resnetb_deformable',
                    'resnetb_deformable',
                    'nearest_upsample',
                    'unary',
                    'nearest_upsample',
                    'unary',
                    'nearest_upsample',
                    'unary',
                    'nearest_upsample',
                    'unary']

    ###################
    # KPConv parameters
    ###################

    # Radius of the input sphere
    in_radius = 1.5

    # Number of kernel points
    num_kernel_points = 15

    # Size of the first subsampling grid in meter
    first_subsampling_dl = 0.08

    # Radius of convolution in "number grid cell". (2.5 is the standard value)
    conv_radius = 2.5

    # Radius of deformable convolution in "number grid cell". Larger so that deformed kernel can spread out
    deform_radius = 6.0

    # Radius of the area of influence of each kernel point in "number grid cell". (1.0 is the standard value)
    KP_extent = 1.2

    # Behavior of convolutions in ('constant', 'linear', 'gaussian')
    KP_influence = 'linear'

    # Aggregation function of KPConv in ('closest', 'sum')
    aggregation_mode = 'sum'

    # Choice of input features
    first_features_dim = 128
    in_features_dim = 1

    # Can the network learn modulations
    modulated = False

    # Batch normalization parameters
    use_batch_norm = True
    batch_norm_momentum = 0.02

    # Deformable offset loss
    # 'point2point' fitting geometry by penalizing distance from deform point to input points
    # 'point2plane' fitting geometry by penalizing distance from deform point to input point triplet (not implemented)
    deform_fitting_mode = 'point2point'
    deform_fitting_power = 1.0              # Multiplier for the fitting/repulsive loss
    deform_lr_factor = 0.1                  # Multiplier for learning rate applied to the deformations
    repulse_extent = 1.2                    # Distance of repulsion for deformed kernel points

    #####################
    # Training parameters
    #####################

    # Maximal number of epochs
    max_epoch = 500

    # Learning rate management
    learning_rate = 1e-2
    momentum = 0.98
    lr_decays = {i: 0.1 ** (1 / 150) for i in range(1, max_epoch)}
    grad_clip_norm = 100.0

    # Number of batch
    batch_num = 10

    # Number of steps per epochs
    epoch_steps = 500

    # Number of validation examples per epoch
    validation_size = 50

    # Number of epoch between each checkpoint
    checkpoint_gap = 15

    # Augmentations
    augment_scale_anisotropic = True
    augment_symmetries = [True, False, False]
    augment_rotation = 'vertical'
    augment_scale_min = 0.8
    augment_scale_max = 1.2
    augment_noise = 0.001
    augment_color = 0.8

    # The way we balance segmentation loss
    #   > 'none': Each point in the whole batch has the same contribution.
    #   > 'class': Each class has the same contribution (points are weighted according to class balance)
    #   > 'batch': Each cloud in the batch has the same contribution (points are weighted according cloud sizes)
    segloss_balance = 'none'

    # Do we nee to save convergence
    saving = True
    saving_path = "./checkpoints/"

In [4]:
config = CellDataConfig()
train_dataset = CellDataset(config, set = 'training')
validation_dataset = CellDataset(config, set = 'validation')


Preparing KDTree for cloud t000, subsampled at 0.080
14.7 MB loaded in 3.4s

Preparing KDTree for cloud t002, subsampled at 0.080
14.7 MB loaded in 3.4s

Preparing KDTree for cloud t006, subsampled at 0.080
14.7 MB loaded in 3.5s

Preparing KDTree for cloud t008, subsampled at 0.080
14.7 MB loaded in 3.5s

Preparing KDTree for cloud t012, subsampled at 0.080
14.7 MB loaded in 3.4s

Preparing KDTree for cloud t013, subsampled at 0.080
14.7 MB loaded in 3.5s

Preparing KDTree for cloud t014, subsampled at 0.080
14.7 MB loaded in 3.5s

Preparing KDTree for cloud t019, subsampled at 0.080
14.7 MB loaded in 3.5s

Preparing KDTree for cloud t022, subsampled at 0.080
14.7 MB loaded in 3.5s

Preparing KDTree for cloud t025, subsampled at 0.080
14.7 MB loaded in 3.5s

Preparing KDTree for cloud t026, subsampled at 0.080
14.7 MB loaded in 3.5s

Preparing KDTree for cloud t027, subsampled at 0.080
14.7 MB loaded in 3.4s

Preparing KDTree for cloud t028, subsampled at 0.080
14.7 MB loaded in 3.5s

In [5]:
training_loader = DataLoader(train_dataset,
                                 batch_size=1,
                            collate_fn=CellDataCollate)

validation_loader = DataLoader(validation_dataset,
                                 batch_size=1,
                              collate_fn=CellDataCollate)

net = KPFCNN(config, train_dataset.label_values, train_dataset.ignored_labels)
trainer = ModelTrainer(net, config)

In [6]:
trainer.train(net, training_loader, validation_loader, config)

e000-i0000 => L=21.103 acc= 36% / t(ms):  18.9 1811.4 270.2)
e000-i0002 => L=19.251 acc= 83% / t(ms):  16.0 266.3 274.8)
e000-i0004 => L=18.893 acc=100% / t(ms):  16.0 264.2 273.7)
e000-i0006 => L=19.475 acc=100% / t(ms):  16.3 260.0 271.4)
e000-i0008 => L=15.744 acc=100% / t(ms):  16.5 256.7 271.9)
e000-i0011 => L=14.789 acc=100% / t(ms):  15.3 252.3 266.0)
e000-i0013 => L=13.887 acc=100% / t(ms):  15.4 250.3 263.8)
Fluo-C3DH-A549 mean IoU = 100.0%
e001-i0001 => L=14.891 acc=100% / t(ms):  11.0 260.3 266.8)
e001-i0003 => L=16.447 acc=100% / t(ms):  11.5 261.3 266.9)
e001-i0005 => L=18.472 acc=100% / t(ms):  11.6 257.9 265.1)
e001-i0007 => L=20.174 acc=100% / t(ms):  12.1 254.1 265.4)
e001-i0009 => L=20.821 acc=100% / t(ms):  12.6 253.3 265.6)
e001-i0011 => L=19.270 acc=100% / t(ms):  12.8 252.8 263.6)
e001-i0013 => L=18.095 acc=100% / t(ms):  11.7 252.0 266.7)
Fluo-C3DH-A549 mean IoU = 0.0%
e002-i0000 => L=17.835 acc=100% / t(ms): 478.6 296.2 265.3)
e002-i0002 => L=18.213 acc=100% / t

KeyboardInterrupt: 