# Med-GrapherNCA: Evaluation & Comparison
### Dice, IoU, Pseudo-Ensemble Variance across all configurations

This notebook loads trained models and evaluates them on the ISIC 2018 test set.
It reproduces Tables 1-3 from the paper.

---

## 0. Google Colab Setup: Mount Drive & Download Dataset

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os

# Setup project directory in Google Drive
PROJECT_ROOT = '/content/drive/MyDrive/Experiments/Grapher_NCA'
DATASET_DIR = os.path.join(PROJECT_ROOT, 'datasets', 'ISIC2018')
MODELS_DIR = os.path.join(PROJECT_ROOT, 'Models')

os.makedirs(DATASET_DIR, exist_ok=True)
os.makedirs(MODELS_DIR, exist_ok=True)

print(f'Project root: {PROJECT_ROOT}')
print(f'Dataset dir:  {DATASET_DIR}')
print(f'Models dir:   {MODELS_DIR}')

In [None]:
%%bash -s "$DATASET_DIR"
DATASET_DIR=$1

# Download ISIC 2018 Task 1 Training Data (images + masks)
if [ ! -d "$DATASET_DIR/ISIC2018_Task1-2_Training_Input" ]; then
    echo "Downloading ISIC 2018 training images..."
    wget -q --show-progress -O "$DATASET_DIR/train_input.zip" \
        "https://isic-archive.s3.amazonaws.com/challenges/2018/ISIC2018_Task1-2_Training_Input.zip"
    echo "Extracting training images..."
    unzip -q "$DATASET_DIR/train_input.zip" -d "$DATASET_DIR"
    rm "$DATASET_DIR/train_input.zip"
else
    echo "Training images already exist, skipping download."
fi

if [ ! -d "$DATASET_DIR/ISIC2018_Task1_Training_GroundTruth" ]; then
    echo "Downloading ISIC 2018 training ground truth..."
    wget -q --show-progress -O "$DATASET_DIR/train_gt.zip" \
        "https://isic-archive.s3.amazonaws.com/challenges/2018/ISIC2018_Task1_Training_GroundTruth.zip"
    echo "Extracting training ground truth..."
    unzip -q "$DATASET_DIR/train_gt.zip" -d "$DATASET_DIR"
    rm "$DATASET_DIR/train_gt.zip"
else
    echo "Training ground truth already exists, skipping download."
fi

# Download ISIC 2018 Task 1 Test Data (images + masks)
if [ ! -d "$DATASET_DIR/ISIC2018_Task1-2_Test_Input" ]; then
    echo "Downloading ISIC 2018 test images..."
    wget -q --show-progress -O "$DATASET_DIR/test_input.zip" \
        "https://isic-archive.s3.amazonaws.com/challenges/2018/ISIC2018_Task1-2_Test_Input.zip"
    echo "Extracting test images..."
    unzip -q "$DATASET_DIR/test_input.zip" -d "$DATASET_DIR"
    rm "$DATASET_DIR/test_input.zip"
else
    echo "Test images already exist, skipping download."
fi

if [ ! -d "$DATASET_DIR/ISIC2018_Task1_Test_GroundTruth" ]; then
    echo "Downloading ISIC 2018 test ground truth..."
    wget -q --show-progress -O "$DATASET_DIR/test_gt.zip" \
        "https://isic-archive.s3.amazonaws.com/challenges/2018/ISIC2018_Task1_Test_GroundTruth.zip"
    echo "Extracting test ground truth..."
    unzip -q "$DATASET_DIR/test_gt.zip" -d "$DATASET_DIR"
    rm "$DATASET_DIR/test_gt.zip"
else
    echo "Test ground truth already exists, skipping download."
fi

echo "\nDataset contents:"
ls -la "$DATASET_DIR"

In [None]:
import os, sys

# Clone repo from GitHub into Colab's /content (fast, always fresh)
REPO_DIR = '/content/grapher-nca/M3D-NCA'
if not os.path.isdir('/content/grapher-nca'):
    os.system('git clone https://github.com/AvniMittal13/grapher-nca.git /content/grapher-nca')
else:
    os.system('git -C /content/grapher-nca pull')

if REPO_DIR not in sys.path:
    sys.path.insert(0, REPO_DIR)

# Install dependencies
os.system('pip install -q torchio==0.18.82 nibabel tensorboard')

print(f'Repo dir: {REPO_DIR}')
print(f'sys.path OK: {REPO_DIR in sys.path}')

## 1. Imports

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from src.datasets.ISIC_Dataset import ISIC2018_Dataset
from src.models.Model_BackboneNCA import BackboneNCA
from src.models.Model_GrapherNCA_M1 import GrapherNCA_M1
from src.models.Model_GrapherNCA_M2 import GrapherNCA_M2
from src.losses.LossFunctions import DiceLoss, IoULoss
from src.utils.Experiment import Experiment
from src.agents.Agent_Med_NCA import Agent_Med_NCA

## 2. Helper Functions

In [None]:
def create_model(model_code, config, device):
    """Create a model instance from a code string."""
    if model_code == 'b1':
        return BackboneNCA(
            config['channel_n'], config['cell_fire_rate'], device,
            hidden_size=config['hidden_size'], input_channels=config['input_channels']
        ).to(device)
    elif model_code == 'm1':
        return GrapherNCA_M1(
            config['channel_n'], config['cell_fire_rate'], device,
            hidden_size=config['hidden_size'], input_channels=config['input_channels'],
            k=9
        ).to(device)
    elif model_code == 'm2':
        return GrapherNCA_M2(
            config['channel_n'], config['cell_fire_rate'], device,
            hidden_size=config['hidden_size'], input_channels=config['input_channels'],
            k=9, patch_size=4
        ).to(device)
    else:
        raise ValueError(f'Unknown model code: {model_code}')


def compute_metrics(agent, useSigmoid=True):
    """Compute Dice and IoU on the test set."""
    dice_loss_fn = DiceLoss(useSigmoid=useSigmoid)
    iou_loss_fn = IoULoss(useSigmoid=useSigmoid)

    dataset = agent.exp.dataset
    agent.exp.set_model_state('test')
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)

    dice_scores = []
    iou_scores = []

    with torch.no_grad():
        for i, data in enumerate(dataloader):
            data = agent.prepare_data(data, eval=True)
            outputs, targets = agent.get_outputs(data, full_img=True, tag='eval')
            dice = 1 - dice_loss_fn(outputs[..., 0], targets[..., 0], smooth=0).item()
            iou = 1 - iou_loss_fn(outputs[..., 0], targets[..., 0], smooth=0).item()
            dice_scores.append(dice)
            iou_scores.append(iou)

    agent.exp.set_model_state('train')
    return {
        'dice_mean': np.mean(dice_scores),
        'dice_std': np.std(dice_scores),
        'iou_mean': np.mean(iou_scores),
        'iou_std': np.std(iou_scores),
    }

## 3. Configuration

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

base_config = {
    'img_path': os.path.join(DATASET_DIR, 'ISIC2018_Task1-2_Training_Input'),
    'label_path': os.path.join(DATASET_DIR, 'ISIC2018_Task1_Training_GroundTruth'),
    'device': str(device),
    'unlock_CPU': True,
    'lr': 1e-4,
    'lr_gamma': 0.9999,
    'betas': (0.5, 0.5),
    'save_interval': 10,
    'evaluate_interval': 10,
    'n_epoch': 1000,
    'batch_size': 8,
    'channel_n': 64,
    'inference_steps': 10,
    'cell_fire_rate': 0.5,
    'input_channels': 3,
    'output_channels': 1,
    'hidden_size': 512,
    'data_split': [0.7, 0, 0.3],
}

# Single-level experiments
single_experiments = ['b1', 'm1', 'm2']

# Multi-level experiments
multi_experiments = ['b1b1', 'm1b1', 'm1m1', 'm2b1', 'm2m2', 'm1m2']

## 4. Evaluate Single-Level Models (Table 2)

In [None]:
single_results = {}

for model_type in single_experiments:
    print(f'\n=== Evaluating single-level: {model_type} ===')
    config = base_config.copy()
    config['model_path'] = os.path.join(MODELS_DIR, f'GrapherNCA_single_{model_type}')
    config['input_size'] = [(256, 256)]
    config['train_model'] = 0

    dataset = ISIC2018_Dataset(input_channels=config['input_channels'])
    ca = [create_model(model_type, config, device)]
    agent = Agent_Med_NCA(ca)
    exp = Experiment([config], dataset, ca, agent)
    dataset.set_experiment(exp)

    metrics = compute_metrics(agent)
    single_results[model_type] = metrics

    params = sum(p.numel() for p in ca[0].parameters())
    print(f'  Params: {params}')
    print(f'  Dice: {metrics["dice_mean"]:.4f} +/- {metrics["dice_std"]:.4f}')
    print(f'  IoU:  {metrics["iou_mean"]:.4f} +/- {metrics["iou_std"]:.4f}')

# Display as table
df_single = pd.DataFrame(single_results).T
df_single.columns = ['Dice (mean)', 'Dice (std)', 'IoU (mean)', 'IoU (std)']
print('\n--- Table 2: Single-Level Results ---')
print(df_single.round(4).to_string())

## 5. Evaluate Multi-Level Models (Table 3)

In [None]:
multi_results = {}

for combination in multi_experiments:
    print(f'\n=== Evaluating multi-level: {combination} ===')
    config = base_config.copy()
    config['model_path'] = os.path.join(MODELS_DIR, f'GrapherNCA_multi_{combination}')
    config['input_size'] = [(64, 64), (256, 256)]
    config['train_model'] = 1

    dataset = ISIC2018_Dataset(input_channels=config['input_channels'])
    level0_code = combination[:2]
    level1_code = combination[2:]
    ca1 = create_model(level0_code, config, device)
    ca2 = create_model(level1_code, config, device)
    ca = [ca1, ca2]

    agent = Agent_Med_NCA(ca)
    exp = Experiment([config], dataset, ca, agent)
    dataset.set_experiment(exp)

    metrics = compute_metrics(agent)
    multi_results[combination] = metrics

    total_params = sum(p.numel() for m in ca for p in m.parameters())
    print(f'  Total Params: {total_params}')
    print(f'  Dice: {metrics["dice_mean"]:.4f} +/- {metrics["dice_std"]:.4f}')
    print(f'  IoU:  {metrics["iou_mean"]:.4f} +/- {metrics["iou_std"]:.4f}')

# Display as table
df_multi = pd.DataFrame(multi_results).T
df_multi.columns = ['Dice (mean)', 'Dice (std)', 'IoU (mean)', 'IoU (std)']
print('\n--- Table 3: Multi-Level Results ---')
print(df_multi.round(4).to_string())

## 6. Pseudo-Ensemble Evaluation

In [None]:
# Run pseudo-ensemble on the best multi-level model
combination = 'm1b1'

config = base_config.copy()
config['model_path'] = os.path.join(MODELS_DIR, f'GrapherNCA_multi_{combination}')
config['input_size'] = [(64, 64), (256, 256)]
config['train_model'] = 1

dataset = ISIC2018_Dataset(input_channels=config['input_channels'])
level0_code = combination[:2]
level1_code = combination[2:]
ca1 = create_model(level0_code, config, device)
ca2 = create_model(level1_code, config, device)
ca = [ca1, ca2]

agent = Agent_Med_NCA(ca)
exp = Experiment([config], dataset, ca, agent)
dataset.set_experiment(exp)

print(f'Pseudo-ensemble evaluation for {combination}:')
agent.getAverageDiceScore(pseudo_ensemble=True, showResults=True)

## 7. Visualization: Input / Prediction / Variance

In [None]:
def visualize_samples(agent, sample_indices=[0, 5, 10], n_passes=10):
    """Visualize input, prediction, and variance map for selected samples."""
    dataset = agent.exp.dataset
    agent.exp.set_model_state('test')
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)

    with torch.no_grad():
        for i, data in enumerate(dataloader):
            if i not in sample_indices:
                continue
            data = agent.prepare_data(data, eval=True)
            data_id, inputs, targets = data

            # Multiple forward passes for variance
            preds = []
            for _ in range(n_passes):
                outputs, _ = agent.get_outputs(data, full_img=True, tag='viz')
                preds.append(torch.sigmoid(outputs).detach().cpu().numpy())
            preds = np.stack(preds, axis=0)
            mean_pred = np.mean(preds, axis=0)
            variance = np.std(preds, axis=0)

            # Plot
            fig, axes = plt.subplots(1, 4, figsize=(20, 5))

            # Input image
            inp = inputs[0].detach().cpu().numpy()
            if inp.shape[-1] >= 3:
                axes[0].imshow(inp[..., :3])
            else:
                axes[0].imshow(inp[..., 0], cmap='gray')
            axes[0].set_title('Input')
            axes[0].axis('off')

            # Ground truth
            axes[1].imshow(targets[0, ..., 0].detach().cpu().numpy(), cmap='gray')
            axes[1].set_title('Ground Truth')
            axes[1].axis('off')

            # Mean prediction
            axes[2].imshow(mean_pred[0, ..., 0], cmap='Purples')
            axes[2].set_title('Mean Prediction')
            axes[2].axis('off')

            # Variance map
            im = axes[3].imshow(variance[0, ..., 0], cmap='hot')
            axes[3].set_title('Variance Map')
            axes[3].axis('off')
            plt.colorbar(im, ax=axes[3], fraction=0.046, pad=0.04)

            plt.suptitle(f'Sample {i}')
            plt.tight_layout()
            plt.show()

    agent.exp.set_model_state('train')


visualize_samples(agent, sample_indices=[0, 5, 10, 20])

## 8. Summary Table (Table 1: Parameter Comparison)

In [None]:
# Parameter counts for all model types
param_table = {}
cfg = base_config.copy()

for code in ['b1', 'm1', 'm2']:
    m = create_model(code, cfg, device)
    param_table[code] = sum(p.numel() for p in m.parameters())
    del m

print('--- Table 1: Parameter Counts ---')
for code, count in param_table.items():
    size_kb = count * 4 / 1024  # float32 = 4 bytes
    print(f'{code}: {count:,} params ({size_kb:.1f} KB)')

print('\n--- Multi-level totals ---')
for combo in multi_experiments:
    l0, l1 = combo[:2], combo[2:]
    total = param_table[l0] + param_table[l1]
    size_kb = total * 4 / 1024
    print(f'{combo}: {total:,} params ({size_kb:.1f} KB)')