# Classification Task: Classification NCA

Neural Cellular Automata (NCA) offer a novel approach to image classification. Unlike conventional deep learning methods, NCA processes information locally while still capturing local patterns through intercellular communication. Moreover, its local architecture ensures lightweight storage and fast inference while also maintaining robustness to domain shifts.

Instead of generating a target image as in Growing NCA, a Classification NCA extracts features of an image to put it into a category. It does not start from a seed, but rather from an input image whose state evolves over several NCA updates.

In this notebook, we demonstrate the use of NCAs for classifiying blood cell images from the Matek-19 dataset, a medical dataset containing multiple classes of white blood cells.

## Implementation Overview

Our simplified NCA classification pipeline consists of:

1. **NCA Architecture**: Defining the architecure of the Classification NCA model
2. **Data Preparation**: Loading and preprocessing the Matek-19 dataset
3. **Loss Function**: Binary cross-entropy loss for multi-class classification
4. **Training and Evaluation Loop**: Iteratively performing update steps to extract features, which will then be used for classification.
5. **Visualization**: Evaluating and visualizing the performance of our model.

## 0. Imports & Select Device

We start by importing the necessary libraries for our next steps.

In [None]:
import torch
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import torch.utils.data as data
import torch.nn as nn

import numpy as np
import kagglehub
import os
import matplotlib.pyplot as plt
from PIL import Image

from torchvision.transforms import v2
from collections import Counter

from shutil import copytree
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
from sklearn import metrics


In [None]:
# Select device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## 1. NCA Architecture 

It is time to define the NCA model that we will be using - the MaxNCA. The MaxNCA is a Neural Cellular Automata model adapted for image classification, which works by iteratevely updating padded images to extract the channel-wise maximum, thus aggregating the features of an image. 

In [None]:
class MaxNCA(nn.Module):
    def __init__(self, channel_n=16, fire_rate=0.5, device=None, hidden_size=128, input_channels=3, init_method="standard"):
        """Neural Cellular Automata for classification using max-pooling features
        
        Args:
            channel_n: Number of channels in NCA state
            fire_rate: Probability of cell updates
            device: Device to run computations on
            hidden_size: Size of hidden layer in update network
            input_channels: Number of input channels
            init_method: Weight initialization method ('standard')
        """
        super(MaxNCA, self).__init__()
        
        # Device configuration  
        if device is not None:
            self.device = device
        else:
            self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        
        # Model parameters
        self.channel_n = channel_n
        self.input_channels = input_channels
        self.fire_rate = fire_rate

        # Two depthwise convolutional layers to capture patterns
        self.p0 = nn.Conv2d(channel_n, channel_n, kernel_size=3, stride=1, padding=1, groups = channel_n, padding_mode="reflect")
        self.p1 = nn.Conv2d(channel_n, channel_n, kernel_size=3, stride=1, padding=1, groups = channel_n, padding_mode="reflect")

        # Processes neighborhood information to determine cell updates
        self.fc0 = nn.Conv2d(channel_n*3, hidden_size, kernel_size=1)
        self.bn = nn.BatchNorm2d(hidden_size, track_running_stats=False)
        self.fc1 = nn.Conv2d(hidden_size, channel_n, kernel_size=1, bias=False)


        # Takes final NCA state and classifies it into one of 13 categories
        self.fc2 = nn.Linear(channel_n,128)
        self.fc3 = nn.Linear(128,13)

        # Initialize last layer to zero
        with torch.no_grad():
            self.fc1.weight.zero_()

    def perceive(self, x):
        """Creates perception vector

        Args:
            x: current state
        """
        z1 = self.p0(x)
        z2 = self.p1(x)
        y = torch.cat((x,z1,z2),1)
        return y

    def update(self, x_in, fire_rate):
        """Performs one NCA update step

        Args:
            x: Input state tensor [B,H,W,C]
            fire_rate: Update probability 
        """
        # Transpose for linear layers [B,C,H,W] -> [B,H,W,C]
        x = x_in.permute(0, 3, 1, 2)

        # Compute updates
        dx = self.perceive(x)
        dx = self.fc0(dx)
        dx = self.bn(dx)
        dx = F.relu(dx)
        dx = self.fc1(dx)

        # Stochastic updates
        if fire_rate is None:
            fire_rate = self.fire_rate
        stochastic = (torch.rand((dx.size(0), 1, dx.size(2), dx.size(3)), device=x.device) > fire_rate).float()
        dx = dx * stochastic

        # Apply updates
        x = x + dx
        x = x.permute(0, 2, 3, 1)

        return x

    def forward(self, x, steps=32, fire_rate=0.5):
        """Forward function, applies k NCA update steps leaving input channels unchanged
        
        Args:
            x: Input tensor [B,H,W,C]
            steps: Number of NCA updates
            fire_rate: Update probability
        """
        # NCA update steps
        for step in range(steps):
            x2 = self.update(x, fire_rate).clone()
            x = torch.concat((x[...,:self.input_channels], x2[...,self.input_channels:]), 3)
        
        # Feature Aggregation
        max=F.adaptive_max_pool2d(x.permute(0, 3, 1, 2), (1, 1))
        max = max.view(max.size(0), -1)
        
        # Classification
        out=self.fc2(max)
        out = F.relu(out)
        out =self.fc3(out)
        
        return out,x

## 2. Data Preparation

### 2.1 Dataloader

Now we need to define a dataset that handles our classification task. The WBC_Dataset is just what we need to manage the white blood cell (WBC) images of the Matek-19 dataset.

In [None]:

class WBC_Dataset(data.Dataset):
    """PyTorch Dataset for White Blood Cell (WBC) images with augmentation support"""
    def __init__(self,image_paths,labels,resize=None, augment=False, dataset="AML"):
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.image_paths=image_paths
        self.labels=labels
        self.resize=resize # Number of pixels in one dimension (square image)
        self.augment=augment
        self.transforms = v2.Compose([
            v2.RandomRotation([0,360]),
            v2.RandomHorizontalFlip(p=0.5),
            ])
        
        self.norm = v2.Compose([v2.ToTensor(), v2.Normalize(mean=[0.82069695, 0.7281261, 0.836143],std=[0.16157213, 0.2490039, 0.09052657])])
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self,idx):
        image = Image.open(self.image_paths[idx])
        if self.resize is not None:
            image=image.resize((self.resize,self.resize))
        image=np.array(image)[:,:,0:3]
        if self.augment:
            image = self.transforms(image)
            
        image=self.norm(image)
        label=torch.zeros(13)
        label[self.labels[idx]]=1

        return image.permute(1,2,0), label

### 2.2 Downloading the dataset

We will be using the Matek-19, a public dataset containing over 18,000 annotated blood cells from 200 individuals. Half of those subjects are affected by AML.

You can find the dataset we used here: [Matek-19 Dataset](https://www.kaggle.com/datasets/inhvnnhn/matek-19-dataset)

In [None]:
# Download dataset using Kaggle Hub API
download_path = kagglehub.dataset_download("inhvnnhn/matek-19-dataset")
# Copy dataset to a working directory
copytree(os.path.join(download_path, "Matek-19 Dataset"), 
                "INSERT/input/Matek19",
                dirs_exist_ok=True)
matek19_path = "INSERT/input/Matek19" # Set dataset path for later use

### 2.3 Data Preparation

We prepare our dataset by creating iterable batches of processed images and labels from the raw dataset.

In [None]:
# White Blood Cell (WBC) class names
CLASSES = ['basophil','eosinophil','erythroblast','myeloblast','promyelocyte','myelocyte','metamyelocyte','neutrophil_banded','neutrophil_segmented','monocyte','lymphocyte_typical','lymphocyte_atypical','smudge_cell']


def plot_distribution(labels, output_path):
    """Visualize and save class distribution of the dataset

    Args:
        labels: List of integer class labels
        output_pat: Path to save the distribution plot
    """
    counts =[Counter(labels)[i] for i in range(13)] # Count samples per class
    plt.figure(figsize=(10, 5))
    plt.bar(CLASSES, counts)
    plt.xticks(rotation=45)
    plt.xlabel("Class")
    plt.ylabel("Number of Samples")
    plt.savefig(output_path)
    plt.close()

def get_data_AML(data_path,show_distribution=True):
    """Loads WBC image paths and labels from AML dataset

    Args:
        data_path: Path to dataset directory
        show_distribution: Whether to plot class distribution
    """
    image_paths = []
    labels = []

    for dirs in os.listdir(data_path):
        folder_path = os.path.join(data_path, dirs)
        for file in os.listdir(folder_path):
            if file.endswith('.jpg') or file.endswith('.tiff'):
                image_path = os.path.join(folder_path, file)
                
                if "BAS" in file:
                    label = 0
                elif "EBO" in file:
                    label = 2
                elif "EOS" in file:
                    label = 1
                elif "KSC" in file:
                    label = 12
                elif "LYA" in file:
                    label = 11
                elif "LYT" in file:
                    label = 10
                elif "MMZ" in file:
                    label = 6
                elif "MOB" in file:
                    label = 9
                elif "MON" in file:
                    label = 9
                elif "MYB" in file:
                    label = 5
                elif "MYO" in file:
                    label = 3
                elif "NGB" in file:
                    label = 7
                elif "NGS" in file:
                    label = 8
                elif "PMB" in file:
                    label = 4
                elif "PMO" in file:
                    label = 4
                labels.append(label)
                image_paths.append(image_path)
        
    if show_distribution==True:
        plot_distribution(labels, "INSERT/output/data_distribution_AML.png")

    return image_paths, labels

def get_weights(y):
    """Calculate class weights for imbalanced dataset handling

    Args:
        y: List of integer class labels
    """
    class_counts = Counter(y)
    class_counts = np.array([class_counts[i] for i in range(15)])
    class_weights = 1/(class_counts+0.001) # Avoid division by zero
    sample_weights = [class_weights[i] for i in y]
    return sample_weights

In [None]:
def prepare_data(data_path):
    """Prepares WBC image data for training"""

    # Load image paths and labels from dataset
    x, y = get_data_AML(data_path, show_distribution=False)
    x_np = np.asarray(x)
    y_np = np.asarray(y)

    # Split data into training and validation sets
    x_train, x_test, y_train, y_test = train_test_split(x_np, y_np, test_size=0.20, random_state=2)

    # Create Pytorch datasets
    train_dataset = WBC_Dataset(x_train, y_train, augment=True, resize=64, dataset="AML")
    val_dataset = WBC_Dataset(x_test, y_test, resize=64, dataset="AML")

    # Create sampler for balanced sampling
    sampler = data.WeightedRandomSampler(weights=get_weights(y_train), num_samples=len(train_dataset), replacement=True)

    # Create data loaders for efficient batch loading
    train_loader = data.DataLoader(train_dataset,sampler=sampler,batch_size=32)
    val_loader = data.DataLoader(val_dataset, batch_size=1)

    return train_loader, val_loader

## 3. Loss Function

It is time to implement the loss function. We implement the Binary Cross-Entropy loss function to classify the images into different categories.

In [None]:
class BCELoss(torch.nn.Module):
    """Binary Cross-Entropy Loss with optional sigmoid activation"""
    def __init__(self, useSigmoid = True):
        """
        Args:
            useSigmoid: Whether to use sigmoid
        """
        self.useSigmoid = useSigmoid
        super(BCELoss, self).__init__()

    def forward(self, input, target, smooth=1):
        """Forward function

        Args:
            input: input array
            target: target array
            smooth: Smoothing value
        """
        input = torch.sigmoid(input)       
        input = torch.flatten(input) 
        target = torch.flatten(target)

        BCE = torch.nn.functional.binary_cross_entropy(input.float(), target.float(), reduction='mean')
        return BCE

## 4. Training and Evaluation Loop

![alt text](../assets/model_graphic.svg)[Source](https://github.com/marrlab/WBC-NCA/blob/main/src/images/model_graphic.svg)

The training process follows four main steps:
1. **Image Padding**: We pad the images to the desired amount of channels. These additional hidden channels are initialized to zero.

2. **NCA Update Steps**: The model performs k iterative NCA update steps to extract features.

3. **Feature Aggregation**: The model takes each channels maximum to condense the evolved state into a compact feature vector.

4. **Classification**: Lastly, this feature vector is used as an input for our neural network, which will then predict the class of the image.

Here we have combined the training and evaluation of the model into a single loop. For the training we use an Adam optimizer with a learning rate of 0.0004, following the setup in the original paper.

Evaluation is performed after each epoch to compute validation loss, accuracy, and the F1 score.

In [None]:
def train_nca():
    """Trains and evaluates MaxNCA model"""

    # 1. Model Initilization
    model=MaxNCA(channel_n=16, hidden_size=128, device=device)
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.0004) #Adam optimizer
    loss_f = BCELoss() #Binary Cross-Entropy

    # 2. Data Preparation
    train_loader, val_loader = prepare_data(matek19_path)

    # Initialize metrics storage
    train_loss_total = np.zeros(5)
    val_loss_total = np.zeros(5)
    accuracy_list = []
    f1_list = []
    
    # 3. Training & Evaluation Loop (5 Epochs)
    for epoch in range(5):
        # 3.1 Training Phase
        model.train()
        train_loss = 0

        for batch in train_loader:
            # Load data to device
            inputs, targets = batch
            inputs, targets = inputs.to(device), targets.to(device)

            # Prepare NCA initial state  
            seed = torch.zeros(inputs.shape[0], 64, 64, 16).to(device)
            seed[..., :3] = inputs  #RGB Padding

            # Forward pass with 12 NCA updates
            output,_ = model(seed, steps=12, fire_rate=0.5)

            # Compute and backpropagate loss
            optimizer.zero_grad()
            loss = loss_f(output.to(device),targets.to(device))
            loss.backward()
            optimizer.step()   

            train_loss += loss.item()
        # Calculate average training loss
        train_loss_avg = train_loss/len(train_loader)
        train_loss_total[epoch]=train_loss_avg

        
        # 3.2 Validation Phase
        model.eval()
        val_loss = 0
        all_preds = []
        all_labels = []

        with torch.no_grad():
            for batch in val_loader:
                # Load data
                inputs, targets = batch
                inputs, targets = inputs.to(device), targets.to(device)

                # Prepare and run NCA
                seed = torch.zeros(inputs.shape[0], 64, 64, 16).to(device)
                seed[..., :3] = inputs  #RGB Padding
                output,_ = model(seed, steps=12, fire_rate=0.5)

                # Calculate validation loss
                val_loss += loss_f(output.to(device),targets.to(device))

                # Store predictions and labels for metrics
                preds = torch.argmax(output, dim=1)
                labels = torch.argmax(targets, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        # Calculate validation metrics
        val_loss_avg = val_loss/len(val_loader)
        val_loss_total[epoch] = val_loss_avg.item()

        # Compute accuracy and F1 score
        accuracy = accuracy_score(all_labels, all_preds)
        f1 = f1_score(all_labels, all_preds, average='weighted')
        accuracy_list.append(accuracy)
        f1_list.append(f1)
        
        # 3.3 Showcase training and validation results
        print(f"Epoch {epoch+1}/{5} | "
              f"Train Loss: {train_loss_avg:.4f} | "
              f"Val Loss: {val_loss_avg:.4f} | "
              f"Accuracy: {accuracy:.4f} | "
              f"F1 Score: {f1:.4f}")
        
        
    
    
    # 4. Save model
    model_dir = "INSERT/models"
    model_path = os.path.join(model_dir, "MaxNCA.pth")
    torch.save(model.state_dict(), model_path)
    print(f"Model saved to {model_path}")

    return model, train_loss_total, val_loss_total, accuracy_list, f1_list


### Finally, run the model!

In [None]:
if __name__ == "__main__":

    model, train_loss_total, val_loss_total, accuracy_list, f1_list= train_nca()

## 5. Visualization

This step is optional, however it may help understanding and analyzing the results of the model.
We plot the following metrics:
* Training loss & validation loss
* Accuracy score & F1 score
* Confusion matrix

In [None]:
def plot_loss(train,val,dataset):
    """Plot training and validation loss over epochs"""
    fig=plt.figure()
    plt.rcParams['figure.figsize'] = [5, 5]
    plt.plot(train, label="Train")
    plt.plot(val, label="Val")
    plt.legend()
    plt.xlabel('Epoch')
    plt.ylabel('Loss')    
    plt.savefig("INSERT/output/loss_plot_"+dataset+".png")
    return

In [None]:
# Visualize training and validation loss
plot_loss(train_loss_total,val_loss_total,"MaxNCA - Train & Val Loss")

In [None]:
def plot_metrics(accuracy_list,f1_list,dataset):
    """Plot accuracy and F1 score over epochs"""
    fig=plt.figure()
    plt.rcParams['figure.figsize'] = [5, 5]
    plt.plot(accuracy_list, label="Accuracy")
    plt.plot(f1_list, label="F1 Score")
    plt.legend()
    plt.xlabel('Epoch')
    plt.ylabel('Score')
    
    plt.savefig("INSERT/output/loss_plot_"+dataset+".png")

In [None]:
# Visualize accuracy and F1 score
plot_metrics(accuracy_list, f1_list,"MaxNCA - Accuracy & F1-Score")

In [None]:
def confusionMatrix(model, test_loader, steps, dataset_name):
    """Compute and display confusion matrix"""
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs,targets in test_loader:
            seed = torch.zeros(inputs.shape[0], 64, 64, 16).to(model.device)
            seed[..., :3] = inputs

            outputs, _ = model(seed, steps=steps, fire_rate=0.5)
            preds = torch.argmax(outputs, dim=1)
            labels = torch.argmax(targets, dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())


    confusion_matrix = metrics.confusion_matrix(all_labels, all_preds)
    cm_display = metrics.ConfusionMatrixDisplay(confusion_matrix = confusion_matrix,  display_labels=[str(i) for i in range(13)])

    cm_display.plot()
    plt.title(f"Confusion Matrix - {dataset_name}")
    plt.savefig(f"/INSERT/output/confusion_matrix_{dataset_name}.png")
    plt.show()

In [None]:
_, val_loader = prepare_data(matek19_path)

# Visualize confusion matrix
confusionMatrix(model, val_loader, steps=12, dataset_name="matek19")