# Med-NCA: Robust and Lightweight Segmentation with Neural Cellular Automata 
### John Kalkhof, Camila González, Anirban Mukhopadhyay
__https://arxiv.org/pdf/2302.03473.pdf__



***

## __The Backbone Model__
<div>
<img src="src/images/backbone_model_MedNCA.png" width="200"/>
</div>

## _1. Imports_

In [None]:
import torch
from src.datasets.Nii_Gz_Dataset_3D import Dataset_NiiGz_3D
from src.models.Model_BackboneNCA import BackboneNCA
from src.losses.LossFunctions import DiceBCELoss
from src.utils.Experiment import Experiment
from src.agents.Agent_NCA import Agent_NCA

## _2. Configure experiment_
- __AutoReload__
    - If an experiment already exists in _model\_path_ the config will __always__ be overwritten with the existing one
    - Additionally if the model has been saved previously, this state will be reloaded
- Download _hippocampus_ data from 'http://medicaldecathlon.com/' and adapt 'img_path' and 'label_path'

In [None]:
config = [{
    # Basic
    'img_path': r"image_path",
    'label_path': r"label_path",
    'model_path': r'Models/Backbone2D_Run1',
    'device':"cuda:0",
    'unlock_CPU': True,
    # Optimizer
    'lr': 16e-4,
    'lr_gamma': 0.9999,
    'betas': (0.5, 0.5),
    # Training
    'save_interval': 10,
    'evaluate_interval': 10,
    'n_epoch': 1000,
    'batch_size': 48,
    # Model
    'channel_n': 16,        # Number of CA state channels
    'inference_steps': 64,
    'cell_fire_rate': 0.5,
    'input_channels': 1,
    'output_channels': 1,
    'hidden_size': 128,
    # Data
    'input_size': (64, 64),
    'data_split': [0.7, 0, 0.3], 
}
]

## _3. Choose architecture, dataset and training agent_

- _Dataset\_Nii\_Gz\_3D_ loads 3D files. If you pass a _slice_ it will be split along the according axis.

In [None]:
dataset = Dataset_NiiGz_3D(slice=2)
device = torch.device(config[0]['device'])
ca = BackboneNCA(config[0]['channel_n'], config[0]['cell_fire_rate'], device, hidden_size=config[0]['hidden_size'], input_channels=config[0]['input_channels']).to(device)
agent = Agent_NCA(ca)
exp = Experiment(config, dataset, ca, agent)
dataset.set_experiment(exp)
exp.set_model_state('train')
data_loader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=exp.get_from_config('batch_size'))

loss_function = DiceBCELoss() 

## _4. Run training_

In [None]:
agent.train(data_loader, loss_function)

## _5. Evaluate test data_

In [None]:
agent.getAverageDiceScore()