# Med-NCA: Robust and Lightweight Segmentation with Neural Cellular Automata 
### John Kalkhof, Camila González, Anirban Mukhopadhyay
__https://arxiv.org/pdf/2302.03473.pdf__



***

## __The Backbone Model__
<div>
<img src="src/images/model_MedNCA.png" width="600"/>
</div>

## _1. Imports_

In [1]:
import torch
from src.datasets.Nii_Gz_Dataset_3D import Dataset_NiiGz_3D
from src.models.Model_BackboneNCA import BackboneNCA
from src.losses.LossFunctions import DiceBCELoss
from src.utils.Experiment import Experiment
from src.agents.Agent_Med_NCA import Agent_Med_NCA

## _2. Configure experiment_
- __AutoReload__
    - If an experiment already exists in _model\_path_ the config will __always__ be overwritten with the existing one
    - Additionally if the model has been saved previously, this state will be reloaded
- Download _prostate_ data from 'http://medicaldecathlon.com/' and adapt 'img_path' and 'label_path'

In [5]:
import shutil
model_path = r'Models/Med_NCA_knee'

config = [{
    'img_path': r"/home/alvin/UltrAi/ai-pocus/state_of_the_art_models/nnunet/knee_cartilage/phase_info_experiments/nnUNet_raw/Dataset070_Clarius_L15/imagesTr/",
    'label_path': r"/home/alvin/UltrAi/ai-pocus/state_of_the_art_models/nnunet/knee_cartilage/phase_info_experiments/nnUNet_raw/Dataset070_Clarius_L15/labelsTr/",
    # 'img_path': r"/home/alvin/UltrAi/Datasets/raw_datasets/other_datasets/Task04_Hippocampus/Task04_Hippocampus_tiny/imagesTr/",
    # 'label_path': r"/home/alvin/UltrAi/Datasets/raw_datasets/other_datasets/Task04_Hippocampus/Task04_Hippocampus_tiny/labelsTr/",
    'model_path': model_path,
    'device':"cuda:0",
    'unlock_CPU': True,
    # Optimizer
    'lr': 1e-4,
    'lr_gamma': 0.9999,
    'betas': (0.5, 0.5),
    # Training
    'save_interval': 10,
    'evaluate_interval': 10,
    'n_epoch': 1000,
    'batch_size': 10,
    # Model
    'channel_n': 32,        # Number of CA state channels
    'inference_steps': 64,
    'cell_fire_rate': 0.5,
    'input_channels': 1,
    'output_channels': 1,
    'hidden_size': 128,
    'train_model':1,
    # Data
    'input_size': [(64, 64), (256, 256)],
    'data_split': [0.7, 0, 0.3], 
}
]

## _3. Choose architecture, dataset and training agent_

- _Dataset\_Nii\_Gz\_3D_ loads 3D files. If you pass a _slice_ it will be split along the according axis.

In [6]:
dataset = Dataset_NiiGz_3D(slice=2)
device = torch.device(config[0]['device'])
ca1 = BackboneNCA(config[0]['channel_n'], config[0]['cell_fire_rate'], device, hidden_size=config[0]['hidden_size'], input_channels=config[0]['input_channels']).to(device)
ca2 = BackboneNCA(config[0]['channel_n'], config[0]['cell_fire_rate'], device, hidden_size=config[0]['hidden_size'], input_channels=config[0]['input_channels']).to(device)
ca = [ca1, ca2]
agent = Agent_Med_NCA(ca)
exp = Experiment(config, dataset, ca, agent)
dataset.set_experiment(exp)
exp.set_model_state('train')
data_loader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=exp.get_from_config('batch_size'))

loss_function = DiceBCELoss() 

Models/Med_NCA_knee/data_split.dt
Models/Med_NCA_knee/models/epoch_200
Reload State 200


## _4. Run training_

In [4]:
agent.train(data_loader, loss_function)

Epoch: 201


100%|██████████| 12/12 [00:10<00:00,  1.12it/s]


201 loss = 0.17912832523385683
Epoch: 202


100%|██████████| 12/12 [00:10<00:00,  1.17it/s]


202 loss = 0.1306118443608284
Epoch: 203


100%|██████████| 12/12 [00:10<00:00,  1.16it/s]


203 loss = 0.1578909351179997
Epoch: 204


100%|██████████| 12/12 [00:10<00:00,  1.16it/s]


204 loss = 0.19274791392187277
Epoch: 205


100%|██████████| 12/12 [00:10<00:00,  1.17it/s]


205 loss = 0.14738769177347422
Epoch: 206


100%|██████████| 12/12 [00:10<00:00,  1.18it/s]


206 loss = 0.1548586953431368
Epoch: 207


100%|██████████| 12/12 [00:10<00:00,  1.15it/s]


207 loss = 0.17553421296179295
Epoch: 208


100%|██████████| 12/12 [00:10<00:00,  1.18it/s]


208 loss = 0.16236415691673756
Epoch: 209


100%|██████████| 12/12 [00:10<00:00,  1.18it/s]


209 loss = 0.17442767570416132
Epoch: 210


100%|██████████| 12/12 [00:10<00:00,  1.18it/s]


210 loss = 0.21009765937924385
Evaluate model
__________________________ CASE 0 __________________________
__________________________ CASE 1 __________________________
115, 0.9012390375137329, 
__________________________ CASE 2 __________________________
116, 0.8847361207008362, 
__________________________ CASE 3 __________________________
117, 0.8872445821762085, 
__________________________ CASE 4 __________________________
118, 0.8918919563293457, 
__________________________ CASE 5 __________________________
119, 0.9192392826080322, 
__________________________ CASE 6 __________________________
120, 0.9149386882781982, 
__________________________ CASE 7 __________________________
121, 0.8771764039993286, 
__________________________ CASE 8 __________________________
122, 0.9382771253585815, 
__________________________ CASE 9 __________________________
123, 0.9241727590560913, 
__________________________ CASE 10 __________________________
124, 0.8707868456840515, 
______________________

100%|██████████| 12/12 [00:10<00:00,  1.18it/s]


211 loss = 0.16312946379184723
Epoch: 212


100%|██████████| 12/12 [00:10<00:00,  1.18it/s]


212 loss = 0.14592345990240574
Epoch: 213


100%|██████████| 12/12 [00:10<00:00,  1.16it/s]


213 loss = 0.18202736725409827
Epoch: 214


100%|██████████| 12/12 [00:10<00:00,  1.17it/s]


214 loss = 0.11450925779839356
Epoch: 215


100%|██████████| 12/12 [00:10<00:00,  1.19it/s]


215 loss = 0.16057221094767252
Epoch: 216


100%|██████████| 12/12 [00:10<00:00,  1.17it/s]


216 loss = 0.20543734667201838
Epoch: 217


100%|██████████| 12/12 [00:10<00:00,  1.19it/s]


217 loss = 0.15263236251970133
Epoch: 218


100%|██████████| 12/12 [00:10<00:00,  1.19it/s]


218 loss = 0.12565883497397104
Epoch: 219


100%|██████████| 12/12 [00:10<00:00,  1.18it/s]


219 loss = 0.19325709342956543
Epoch: 220


100%|██████████| 12/12 [00:10<00:00,  1.18it/s]


220 loss = 0.19485810647408167
Evaluate model
__________________________ CASE 0 __________________________
__________________________ CASE 1 __________________________
115, 0.9077740907669067, 
__________________________ CASE 2 __________________________
116, 0.9085399508476257, 
__________________________ CASE 3 __________________________
117, 0.8822962641716003, 
__________________________ CASE 4 __________________________
118, 0.8859429955482483, 
__________________________ CASE 5 __________________________
119, 0.9435458183288574, 
__________________________ CASE 6 __________________________
120, 0.8823243379592896, 
__________________________ CASE 7 __________________________
121, 0.8277417421340942, 
__________________________ CASE 8 __________________________
122, 0.9242832660675049, 
__________________________ CASE 9 __________________________
123, 0.9436987042427063, 
__________________________ CASE 10 __________________________
124, 0.8731076717376709, 
______________________

100%|██████████| 12/12 [00:10<00:00,  1.19it/s]


221 loss = 0.17588034893075624
Epoch: 222


100%|██████████| 12/12 [00:10<00:00,  1.19it/s]


222 loss = 0.23978903206686178
Epoch: 223


100%|██████████| 12/12 [00:10<00:00,  1.17it/s]


223 loss = 0.18629958791037401
Epoch: 224


 83%|████████▎ | 10/12 [00:09<00:01,  1.09it/s]

## _5. Evaluate test data_

In [5]:
agent.getAverageDiceScore()

__________________________ CASE 0 __________________________
__________________________ CASE 1 __________________________
115, 0.9128252863883972, 
__________________________ CASE 2 __________________________
116, 0.848017692565918, 
__________________________ CASE 3 __________________________
117, 0.90008145570755, 
__________________________ CASE 4 __________________________
118, 0.8997448682785034, 
__________________________ CASE 5 __________________________
119, 0.9263120293617249, 
__________________________ CASE 6 __________________________
120, 0.8876810669898987, 
__________________________ CASE 7 __________________________
121, 0.8779782652854919, 
__________________________ CASE 8 __________________________
122, 0.9372177720069885, 
__________________________ CASE 9 __________________________
123, 0.9387393593788147, 
__________________________ CASE 10 __________________________
124, 0.8587614297866821, 
__________________________ CASE 11 __________________________
125, 0.89

{0: {'115': 0.9128252863883972,
  '116': 0.848017692565918,
  '117': 0.90008145570755,
  '118': 0.8997448682785034,
  '119': 0.9263120293617249,
  '120': 0.8876810669898987,
  '121': 0.8779782652854919,
  '122': 0.9372177720069885,
  '123': 0.9387393593788147,
  '124': 0.8587614297866821,
  '125': 0.8988624215126038,
  '126': 0.9269989132881165,
  '127': 0.6640500426292419,
  '128': 0.9495993256568909,
  '129': 0.93318110704422,
  '130': 0.9504987001419067,
  '131': 0.9143036007881165,
  '132': 0.9162552952766418,
  '133': 0.9279139041900635,
  '134': 0.9305689334869385,
  '135': 0.9309068918228149,
  '136': 0.8607228398323059,
  '137': 0.8529163002967834,
  '138': 0.7913591265678406,
  '139': 0.848166286945343,
  '140': 0.9257413744926453,
  '141': 0.8127416968345642,
  '142': 0.7649997472763062,
  '143': 0.8837792277336121,
  '144': 0.8422486186027527,
  '145': 0.9533175826072693,
  '146': 0.8722153306007385,
  '147': 0.924808919429779,
  '148': 0.9406454563140869,
  '149': 0.9118192