# M3D-NCA: Robust 3D Segmentation with Built-in Quality Control
### John Kalkhof, Anirban Mukhopadhyay
__TBD__



***

## __The Backbone Model__
<div>
<img src="src/images/model_M3DNCA.png" width="600"/>
</div>

## _1. Imports_

In [1]:
import torch
from src.datasets.Nii_Gz_Dataset_3D import Dataset_NiiGz_3D
from src.models.Model_BasicNCA3D import BasicNCA3D
from src.losses.LossFunctions import DiceFocalLoss
from src.utils.Experiment import Experiment
from src.agents.Agent_M3D_NCA import Agent_M3D_NCA

## _2. Configure experiment_
- __AutoReload__
    - If an experiment already exists in _model\_path_ the config will __always__ be overwritten with the existing one
    - Additionally if the model has been saved previously, this state will be reloaded
- Download _prostate_ data from 'http://medicaldecathlon.com/' and adapt 'img_path' and 'label_path'

In [2]:
config = [{
    'img_path': r"/home/jkalkhof_locale/Documents/Data/Task04_Hippocampus/train/imagesTr/",
    'label_path': r"/home/jkalkhof_locale/Documents/Data/Task04_Hippocampus/train/labelsTr/",
    'name': r'M3D_NCA_Run6',
    'device':"cuda:0",
    'unlock_CPU': True,
    # Optimizer
    'lr': 16e-4,
    'lr_gamma': 0.9999,
    'betas': (0.9, 0.99),
    # Training
    'save_interval': 10,
    'evaluate_interval': 10,
    'n_epoch': 3000,
    'batch_duplication': 1,
    # Model
    'channel_n': 16,        # Number of CA state channels
    'inference_steps': [10, 10],
    'cell_fire_rate': 0.5,
    'batch_size': 4,
    'input_channels': 1,
    'output_channels': 1,
    'hidden_size': 64,
    'train_model':1,
    # Data
    'input_size': [(16, 16, 13),(64, 64, 52)], # 
    'scale_factor': 4,
    'data_split': [0.7, 0, 0.3], 
    'keep_original_scale': True,
    'rescale': True,
}
]

## _3. Choose architecture, dataset and training agent_

- _Dataset\_Nii\_Gz\_3D_ loads 3D files. If you pass a _slice_ it will be split along the according axis.

In [3]:
dataset = Dataset_NiiGz_3D()
device = torch.device(config[0]['device'])
ca1 = BasicNCA3D(config[0]['channel_n'], config[0]['cell_fire_rate'], device, hidden_size=config[0]['hidden_size'], kernel_size=7, input_channels=config[0]['input_channels']).to(device)
ca2 = BasicNCA3D(config[0]['channel_n'], config[0]['cell_fire_rate'], device, hidden_size=config[0]['hidden_size'], kernel_size=3, input_channels=config[0]['input_channels']).to(device)
ca = [ca1, ca2]
agent = Agent_M3D_NCA(ca)
exp = Experiment(config, dataset, ca, agent)
dataset.set_experiment(exp)
exp.set_model_state('train')
data_loader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=exp.get_from_config('batch_size'))

loss_function = DiceFocalLoss() 

Datasplit-> train entries: 139, val entries: 0, test entries: 59
Datasplit-> train entries: 139, val entries: 0, test entries: 59


## _4. Run training_

In [4]:
agent.train(data_loader, loss_function)

Epoch: 1
Dataset size: 35
1 loss = 1.1589499459122166
Epoch: 2
Dataset size: 35
2 loss = 0.8889974101501352
Epoch: 3
Dataset size: 35
3 loss = 0.7027766764163971
Epoch: 4
Dataset size: 35
4 loss = 0.653698081479353
Epoch: 5
Dataset size: 35
5 loss = 0.7054048614842551
Epoch: 6
Dataset size: 35
6 loss = 0.6348320245742798
Epoch: 7
Dataset size: 35
7 loss = 0.5549518768276487
Epoch: 8
Dataset size: 35
8 loss = 0.4955295609576362
Epoch: 9
Dataset size: 35
9 loss = 0.49494366475514
Epoch: 10
Dataset size: 35
10 loss = 0.494265129461008
Evaluate model
_hippocampus_374.nii.gz_0
torch.Size([1, 64, 64, 52, 1]) torch.Size([1, 64, 64, 52, 1])
, 0.6673606038093567
(64, 64, 3)
_hippocampus_326.nii.gz_0
torch.Size([1, 64, 64, 52, 1]) torch.Size([1, 64, 64, 52, 1])
, 0.6256745457649231
(64, 64, 3)
_hippocampus_373.nii.gz_0
torch.Size([1, 64, 64, 52, 1]) torch.Size([1, 64, 64, 52, 1])
, 0.668007493019104
(64, 64, 3)
_hippocampus_181.nii.gz_0
torch.Size([1, 64, 64, 52, 1]) torch.Size([1, 64, 64, 52, 1

Tracking a matplotlib object using "aim.Figure" might not behave as expected.In such cases, consider tracking with "aim.Image".


Model saved
Epoch: 11
Dataset size: 35
11 loss = 0.47621300731386457
Epoch: 12
Dataset size: 35
12 loss = 0.5227104829890388
Epoch: 13
Dataset size: 35
13 loss = 0.511947950019556
Epoch: 14
Dataset size: 35
14 loss = 0.4589899072752279
Epoch: 15
Dataset size: 35
15 loss = 0.4113287572349821
Epoch: 16
Dataset size: 35
16 loss = 0.5775810431908158
Epoch: 17
Dataset size: 35
17 loss = 0.42763129051993876
Epoch: 18
Dataset size: 35
18 loss = 0.44337648153305054
Epoch: 19
Dataset size: 35
19 loss = 0.4939519938300638
Epoch: 20
Dataset size: 35
20 loss = 0.44622260638896155
Evaluate model
_hippocampus_374.nii.gz_0
torch.Size([1, 64, 64, 52, 1]) torch.Size([1, 64, 64, 52, 1])
, 0.808322012424469
(64, 64, 3)
_hippocampus_326.nii.gz_0
torch.Size([1, 64, 64, 52, 1]) torch.Size([1, 64, 64, 52, 1])
, 0.7724165916442871
(64, 64, 3)
_hippocampus_373.nii.gz_0
torch.Size([1, 64, 64, 52, 1]) torch.Size([1, 64, 64, 52, 1])
, 0.805558443069458
(64, 64, 3)
_hippocampus_181.nii.gz_0
torch.Size([1, 64, 64, 

Tracking a matplotlib object using "aim.Figure" might not behave as expected.In such cases, consider tracking with "aim.Image".


_hippocampus_109.nii.gz_0
torch.Size([1, 64, 64, 52, 1]) torch.Size([1, 64, 64, 52, 1])
, 0.7968423366546631
(64, 64, 3)
_hippocampus_288.nii.gz_0
torch.Size([1, 64, 64, 52, 1]) torch.Size([1, 64, 64, 52, 1])
, 0.7959915995597839
(64, 64, 3)
_hippocampus_320.nii.gz_0
torch.Size([1, 64, 64, 52, 1]) torch.Size([1, 64, 64, 52, 1])
, 0.7304364442825317
(64, 64, 3)
Average Dice Loss 3d: 0, 0.7697630918632119
Standard Deviation 3d: 0, 0.03320077213154725
{'_hippocampus_374.nii.gz_0': 0.808322012424469, '_hippocampus_326.nii.gz_0': 0.7724165916442871, '_hippocampus_373.nii.gz_0': 0.805558443069458, '_hippocampus_181.nii.gz_0': 0.7795662879943848, '_hippocampus_330.nii.gz_0': 0.7047051191329956, '_hippocampus_052.nii.gz_0': 0.8047559261322021, '_hippocampus_300.nii.gz_0': 0.7989926934242249, '_hippocampus_046.nii.gz_0': 0.7803579568862915, '_hippocampus_146.nii.gz_0': 0.8156078457832336, '_hippocampus_358.nii.gz_0': 0.7698864340782166, '_hippocampus_006.nii.gz_0': 0.8276785016059875, '_hippoca

Tracking a matplotlib object using "aim.Figure" might not behave as expected.In such cases, consider tracking with "aim.Image".


_hippocampus_320.nii.gz_0
torch.Size([1, 64, 64, 52, 1]) torch.Size([1, 64, 64, 52, 1])
, 0.7185304760932922
(64, 64, 3)
Average Dice Loss 3d: 0, 0.7778012752532959
Standard Deviation 3d: 0, 0.036669244191744076
{'_hippocampus_374.nii.gz_0': 0.8028759956359863, '_hippocampus_326.nii.gz_0': 0.801069974899292, '_hippocampus_373.nii.gz_0': 0.8177461624145508, '_hippocampus_181.nii.gz_0': 0.789074718952179, '_hippocampus_330.nii.gz_0': 0.7631024122238159, '_hippocampus_052.nii.gz_0': 0.7571545839309692, '_hippocampus_300.nii.gz_0': 0.7805323600769043, '_hippocampus_046.nii.gz_0': 0.8046528697013855, '_hippocampus_146.nii.gz_0': 0.8211931586265564, '_hippocampus_358.nii.gz_0': 0.7614007592201233, '_hippocampus_006.nii.gz_0': 0.8242071866989136, '_hippocampus_166.nii.gz_0': 0.8032206892967224, '_hippocampus_232.nii.gz_0': 0.7715351581573486, '_hippocampus_386.nii.gz_0': 0.7827970385551453, '_hippocampus_057.nii.gz_0': 0.7541376948356628, '_hippocampus_095.nii.gz_0': 0.8268132209777832, '_hip

## _5. Evaluate test data_

In [None]:
agent.getAverageDiceScore()