# M3D-NCA: Robust 3D Segmentation with Built-in Quality Control
### John Kalkhof, Anirban Mukhopadhyay
__TBD__



***

## __The Backbone Model__
<div>
<img src="src/images/model_M3DNCA.png" width="600"/>
</div>

## _1. Imports_

In [None]:
import torch
from src.datasets.Nii_Gz_Dataset_3D import Dataset_NiiGz_3D
from src.models.Model_BasicNCA3D import BasicNCA3D
from src.losses.LossFunctions import DiceFocalLoss
from src.utils.Experiment import Experiment
from src.agents.Agent_M3D_NCA import Agent_M3D_NCA

## _2. Configure experiment_
- __AutoReload__
    - If an experiment already exists in _model\_path_ the config will __always__ be overwritten with the existing one
    - Additionally if the model has been saved previously, this state will be reloaded
- Download _prostate_ data from 'http://medicaldecathlon.com/' and adapt 'img_path' and 'label_path'

In [None]:
config = [{
    'img_path': r"image_path",
    'label_path': r"label_path",
    'name': r'M3D_NCA_Run1',
    'device':"cuda:0",
    'unlock_CPU': True,
    # Optimizer
    'lr': 16e-4,
    'lr_gamma': 0.9999,
    'betas': (0.9, 0.99),
    # Training
    'save_interval': 10,
    'evaluate_interval': 10,
    'n_epoch': 3000,
    'batch_duplication': 1,
    # Model
    'channel_n': 16,        # Number of CA state channels
    'inference_steps': [10, 10],
    'cell_fire_rate': 0.5,
    'batch_size': 4,
    'input_channels': 1,
    'output_channels': 1,
    'hidden_size': 64,
    'train_model':1,
    # Data
    'input_size': [(16, 16, 13),(64, 64, 52)], # 
    'scale_factor': 4,
    'data_split': [0.7, 0, 0.3], 
    'keep_original_scale': True,
    'rescale': True,
}
]

## _3. Choose architecture, dataset and training agent_

- _Dataset\_Nii\_Gz\_3D_ loads 3D files. If you pass a _slice_ it will be split along the according axis.

In [None]:
dataset = Dataset_NiiGz_3D()
device = torch.device(config[0]['device'])
ca1 = BasicNCA3D(config[0]['channel_n'], config[0]['cell_fire_rate'], device, hidden_size=config[0]['hidden_size'], kernel_size=7, input_channels=config[0]['input_channels']).to(device)
ca2 = BasicNCA3D(config[0]['channel_n'], config[0]['cell_fire_rate'], device, hidden_size=config[0]['hidden_size'], kernel_size=3, input_channels=config[0]['input_channels']).to(device)
ca = [ca1, ca2]
agent = Agent_M3D_NCA(ca)
exp = Experiment(config, dataset, ca, agent)
dataset.set_experiment(exp)
exp.set_model_state('train')
data_loader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=exp.get_from_config('batch_size'))

loss_function = DiceFocalLoss() 

## _4. Run training_

In [None]:
agent.train(data_loader, loss_function)

## _5. Evaluate test data_

In [None]:
agent.getAverageDiceScore()