In [150]:
%load_ext autoreload
%autoreload 2
import os
from hydra import initialize, compose
from network_module import Net
from hydra.utils import instantiate
from datamodule import PolypGenDataset
import torch
from lightning.pytorch import loggers
import numpy as np
from monai import metrics as mm
from visualize_results import save_visualization_grid

os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [151]:
print("Loading multiple models for ensemble...")

# Define your model configurations and checkpoint paths
model_configs = [
    {"config_name": "config",                "checkpoint": "../logs_drive/efficientnet-b4_32/version_2/model_checkpoint.ckpt"},
    {"config_name": "config_attention_unet", "checkpoint": "../logs_drive/attention_unet/version_1/model_checkpoint.ckpt"},
    {"config_name": "config_segresnet",      "checkpoint": "../logs_drive/segresnet/version_1/model_checkpoint.ckpt"},
]

models = []
# Load each model with its respective config
for model_config in model_configs:
    with initialize(config_path="config", version_base=None):
        cfg = compose(config_name=model_config["config_name"])
        checkpoint_path = model_config["checkpoint"]
        model = instantiate(cfg.model.object)
        net = Net.load_from_checkpoint(
            checkpoint_path,
            model=model,
            criterion=instantiate(cfg.criterion),
            optimizer=cfg.optimizer,
            lr=cfg.lr,
            scheduler=cfg.scheduler,
        )
        # net = net.load_from_checkpoint(checkpoint_path)
        net.eval()  # Set to evaluation mode
        models.append(net)

# Load dataset (using the last config for simplicity)
dataset = PolypGenDataset(batch_size=cfg.batch_size, img_size=cfg.img_size)
trainer = instantiate(cfg.trainer)




Loading multiple models for ensemble...


💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [152]:
# Get predictions from all models
all_models_prediction_outputs = []
for i, model in enumerate(models):
    print(f"Getting predictions from model {i+1}/{len(models)}")
    prediction_outputs = trainer.predict(model, dataset)
    all_models_prediction_outputs.append(prediction_outputs)

# Implement voting system
# ensemble_predictions = ensemble_voting(all_models_predictions)
voting_batch_probs = []
for single_model_prediction_outputs in all_models_prediction_outputs:
    
    batch_probs = [batch['probabilities'] for batch in single_model_prediction_outputs]
    voting_batch_probs.append(torch.cat(batch_probs))
    print(f"inside for: {voting_batch_probs[-1].shape}, type: {type(voting_batch_probs[-1])}")
voting_probs = torch.stack(voting_batch_probs)
print(f"After stacking all models: {voting_probs.shape}, type: {type(voting_probs)}")

# Soft voting: average probabilities, then threshold
ensemble_probs = voting_probs.mean(dim=0)
print(f"After averaging: {ensemble_probs.shape}, type: {type(ensemble_probs)}")

# You can still threshold at the end if you need binary output
ensemble_predictions = (ensemble_probs > 0.5).float()
print(f"After final thresholding: {ensemble_predictions.shape}, type: {type(ensemble_predictions)}")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Getting predictions from model 1/3
Predicting DataLoader 0: 100%|██████████| 6/6 [00:16<00:00,  0.36it/s]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Getting predictions from model 2/3
Predicting DataLoader 0: 100%|██████████| 6/6 [02:40<00:00,  0.04it/s]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Getting predictions from model 3/3
Predicting DataLoader 0: 100%|██████████| 6/6 [00:05<00:00,  1.11it/s]
inside for: torch.Size([88, 1, 512, 512]), type: <class 'torch.Tensor'>
inside for: torch.Size([88, 1, 512, 512]), type: <class 'torch.Tensor'>
inside for: torch.Size([88, 1, 512, 512]), type: <class 'torch.Tensor'>
After stacking all models: torch.Size([3, 88, 1, 512, 512]), type: <class 'torch.Tensor'>
After averaging: torch.Size([88, 1, 512, 512]), type: <class 'torch.Tensor'>
After final thresholding: torch.Size([88, 1, 512, 512]), type: <class 'torch.Tensor'>


In [153]:
single_model_prediction_outputs = all_models_prediction_outputs[0]

batch_targets = [batch['masks'] for batch in single_model_prediction_outputs]

# targets_list.append(torch.cat(batch_targets))

targets = torch.cat(batch_targets)

In [154]:
# Initialize metrics
get_dice = mm.DiceMetric(include_background=False, reduction="mean")
get_iou = mm.MeanIoU(include_background=False, reduction="mean")
get_accuracy = mm.ConfusionMatrixMetric(include_background=False, metric_name="accuracy")
get_recall = mm.ConfusionMatrixMetric(include_background=False, metric_name="sensitivity")
get_precision = mm.ConfusionMatrixMetric(include_background=False, metric_name="precision")

# Ensure predictions are long type for metrics
predictions = ensemble_predictions.long()
targets = targets.long()


In [155]:
# Compute metrics
get_dice(predictions, targets)
get_iou(predictions, targets)
get_accuracy(predictions, targets)
get_recall(predictions, targets)
get_precision(predictions, targets)

# Get results
dice = get_dice.aggregate()[0].item()
iou = get_iou.aggregate()[0].item()
accuracy = get_accuracy.aggregate()[0].item()
recall = get_recall.aggregate()[0].item()
precision = get_precision.aggregate()[0].item()
f1 = 2 * (precision * recall) / (precision + recall + 1e-8)
f2 = 5 * (precision * recall) / (4 * precision + recall + 1e-8)

ensemble_metrics = {
    'test_dice': dice,
    'test_iou': iou,
    'test_accuracy': accuracy,
    'test_recall': recall,
    'test_precision': precision,
    'test_f1': f1,
    'test_f2': f2
}

In [156]:
ensemble_metrics

{'test_dice': 0.7933390140533447,
 'test_iou': 0.7223753333091736,
 'test_accuracy': 0.9763011336326599,
 'test_recall': 0.8135696649551392,
 'test_precision': 0.8791412711143494,
 'test_f1': 0.8450854188064955,
 'test_f2': 0.8258896190632325}

In [157]:
logger = loggers.TensorBoardLogger("../logs_drive/", name='ensemble')

# Log metrics to TensorBoard

# for metric_name, metric_value in ensemble_metrics.items():
#     logger.experiment.add_scalar(f'ensemble/{metric_name}', metric_value, 0)

# Log hyperparameters and metrics together
logger.log_hyperparams(
    params={
        'ensemble_method': 'average',  # or 'majority'
        'num_models': len(models),
        'model_configs': [config['config_name'] for config in model_configs]
    },
    metrics=ensemble_metrics
)

print("Metrics logged to TensorBoard!")
print(f"View logs at: {logger.log_dir}")


Metrics logged to TensorBoard!
View logs at: ../logs_drive/ensemble\version_4


In [158]:
%reload_ext tensorboard
%tensorboard --logdir ../logs_drive/

Reusing TensorBoard on port 6007 (pid 14760), started 2:25:04 ago. (Use '!kill 14760' to kill it.)

In [159]:
ensemble_predictions.shape

torch.Size([88, 1, 512, 512])

In [160]:
test_input = all_models_prediction_outputs[0]
ensemble_batches = torch.split(ensemble_predictions, cfg.batch_size, dim=0)

for i, batch in enumerate(test_input):
    batch.pop('probabilities', None)
    batch['predictions'] = ensemble_batches[i]


In [161]:
save_visualization_grid(prediction_outputs, logger.log_dir)