# Unsupervised Training of Neural Cellular Automata on Edge Devices
### John Kalkhof, Amin Ranem, Anirban Mukhopadhyay
__https://arxiv.org/pdf/2407.18114__



***

## _1. Imports_

In [None]:
import torch
from src.datasets.Nii_Gz_Dataset import Nii_Gz_Dataset
from src.models.Model_MedNCA import MedNCA
from src.agents.Agent_MedNCA_Simple import MedNCAAgent
from src.losses.LossFunctions import DiceFocalLoss, WeightedDiceBCELoss
from src.utils.Experiment import Experiment
from src.datasets.Nii_Gz_Dataset_customPath import Dataset_NiiGz_customPath
from src.models.Model_MedNCA_finetune import MedNCA_finetune
from src.agents.Agent_Med_NCA_Simple_finetuning import Agent_Med_NCA_finetuning

## _2. Configure experiment_
- __AutoReload__
    - If an experiment already exists in _model\_path_ the config will __always__ be overwritten with the existing one
    - Additionally if the model has been saved previously, this state will be reloaded

In [None]:

config = [{
    'img_path': r"image_path",
    'label_path': r"label_path",
    'name': r'Med_NCA_Run1_pretraining', #12 or 13, 54 opt, 
    'device':"cuda:0",
    'unlock_CPU': True,
    # Optimizer
    'lr': 16e-4,
    'lr_gamma': 0.9999,#0.9999,
    'betas': (0.9, 0.99),
    # Training
    'save_interval': 500,
    'evaluate_interval': 1501,
    'n_epoch': 1500,
    'batch_duplication': 1,
    # Model
    'channel_n': 16,        # Number of CA state channels
    'inference_steps': [20, 20],
    'cell_fire_rate': 0.5,
    'batch_size': 16,
    'input_channels': 1,
    'output_channels': 1,
    'hidden_size': 128,
    'train_model':1,
    # Data
    'input_size': [(64, 64), (256, 256)] ,
    'scale_factor': 4,
    'data_split': [1.0, 0, 0.0], 
    'keep_original_scale': False,
    'rescale': True,
}
]


## _3. Choose architecture, dataset and training agent_

- Nii_Gz_Dataset loads 2D files. If you pass store it will be loaded into RAM for faster training.

In [None]:

dataset = Nii_Gz_Dataset()#store=True)
device = torch.device(config[0]['device'])
ca = MedNCA(channel_n=16, fire_rate=0.5, steps=50, device = "cuda:0", hidden_size=128, input_channels=1, output_channels=1, batch_duplication=1).to("cuda:0")
agent = MedNCAAgent(ca)
exp = Experiment(config, dataset, ca, agent)
dataset.set_experiment(exp)
exp.set_model_state('train')
data_loader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=exp.get_from_config('batch_size'))

loss_function = DiceFocalLoss() 
print(sum(p.numel() for p in ca.parameters() if p.requires_grad))


## _4. Run training_

In [None]:
agent.train(data_loader, loss_function)

# eval after its done
#agent.getAverageDiceScore(pseudo_ensemble=False)

## _5. Evaluate and generate variance and mean predictions for finetuning_

In [None]:

# Generate mean and variance maps - will be saved in a new parent folder /variance and /mean
dataset_varMean = Dataset_NiiGz_customPath(resize=True, slice=2, size=(256, 256), imagePath=r"image_path", labelPath=r"label_path")
agent.getAverageDiceScore(pseudo_ensemble=True, dataset=dataset_varMean, save_meanVariance=True)   


## _6. Unsupervised Adaptation_

Define config of unsupervised model


In [None]:
config = [{
    'img_path': r"image_path",
    'label_path': r"label_path",
    'name': r'Med_NCA_Run1_unsupervised', #12 or 13, 54 opt, 
    'pretrained': r'Med_NCA_Run1_pretraining', #12 or 13, 54 opt, 
    'device':"cuda:0",
    'unlock_CPU': True,
    # Optimizer
    'lr': 3e-6,
    'lr_gamma': 0.9999,#0.9999,
    'betas': (0.9, 0.99),
    # Training
    'save_interval': 50,
    'evaluate_interval': 501,
    'n_epoch': 100,
    'batch_duplication': 1,
    # Model
    'channel_n': 16,        # Number of CA state channels
    'inference_steps': [20, 20],
    'cell_fire_rate': 0.5,
    'batch_size': 8,
    'input_channels': 1,
    'output_channels': 1,
    'hidden_size': 128,
    'train_model':1,
    # Data
    'input_size': [(64, 64), (256, 256)] ,
    'scale_factor': 4,
    'data_split': [1.0, 0, 0.0], 
    'keep_original_scale': False,
    'rescale': True,
}
]

## _7. Choose architecture, dataset and training agent_

In [None]:
dataset = Nii_Gz_Dataset()#store=True)
device = torch.device(config[0]['device'])
ca = MedNCA_finetune(channel_n=16, fire_rate=0.5, steps=50, device = "cuda:0", hidden_size=128, input_channels=1, output_channels=1, batch_duplication=1).to("cuda:0")
agent = Agent_Med_NCA_finetuning(ca)
exp = Experiment(config, dataset, ca, agent)
dataset.set_experiment(exp)
exp.set_model_state('train')
data_loader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=exp.get_from_config('batch_size'))

loss_function = WeightedDiceBCELoss() 

## _8. Run unsupervised training_

In [None]:
agent.train(data_loader, loss_function)

## _9. Test_

In [None]:
dataset = Dataset_NiiGz_customPath(resize=True, slice=2, size=(256, 256), imagePath=r"image_path", labelPath=r"label_path")
dataset.exp = exp
agent.getAverageDiceScore(pseudo_ensemble=True, dataset=dataset)