This repo provides the PyTorch source code of our paper: mnDINO: Accurate and robust segmentation of micronuclei with vision transformer networks. The pre-trained model is publicly available on huggingface, and the dataset can be downloaded through Bioimage Archive.
The mnDINO model is specifically designed for highly efficient and accurate micronuclei segmentation in DNA-stained images across diverse experimental conditions. The model outputs both micronuclei and nuclei segmentation masks simultaneously. To accelerate future research in micronucleus (MN) biology. The dataset, code, and pre-trained model are made publicly available to facilitate future research in micronucleus (MN) biology.
Refer to tutorial notebook for example usage of mnDINO model
pip install mndinoimport torch
from mndino import mnmodel
from huggingface_hub import hf_hub_download
model_path = hf_hub_download(repo_id="CaicedoLab/mnDINO", filename="mnDINO_v1.pth")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = mnmodel.MicronucleiModel(device=device)
model.load(model_path)import skimage
import numpy as np
STEP = 32 # recommended value
PREDICTION_BATCH = 4
THRESHOLD = 0.5
im = skimage.io.imread(your_image_path)
im = np.array((im - np.min(im))/(np.max(im) - np.min(im)), dtype="float32") # normalize image
probabilities = model.predict(im, stride=1, step=STEP, batch_size=PREDICTION_BATCH)
mn_predictions = probabilities[0,:,:] > THRESHOLD
nuclei_predictions = probabilities[1,:,:] > THRESHOLDimport skimage
from mndino import evaluation
mn_gt = skimage.io.imread(your_annotated_image_path)
precision, recall = evaluation.segmentation_report(predictions=mn_predictions, gt=mn_gt, intersection_ratio=0.1, wandb_mode=False)Environment files for reproducing experiments in the manuscript is under environments folder.
Train mnDINO
python3 training_model.py --path '/scr/data/annotated_mn_datasets/' --gpu 0 --epochs 20 --batch_size 4 --loss_fn 'combined' --lr 1e-5 --scale 1.0 --gaussianMaking predictions on test set
python3 prediction.py --path '/scr/data/annotated_mn_datasets/' --test_set --gpu 0 --step 32 --batch_size 4 --prob_threshold 0.5 --iou_threshold 0.1 --scale 1Turn on --test_set if user wants to evaluate on test set, turn it off to select validation set.
Turn on --wandb_mode if user wants to show loss on Weights and Biases
MNFinder Evaluation
python3 mnfinder_prediction.py --test_path '/scr/data/annotated_mn_datasets/test/images/' --save_path '/scr/data/annotated_mn_datasets/mnfinder_predictions/' --wandb_modeCellpose Finetuned Evaluation
python3 cellpose_prediction.py --gpu 0 --train_path '/scr/data/annotated_mn_datasets/train/images/' --save_path '/scr/data/annotated_mn_datasets/cellpose_predictions/' --finetune --wandb_modeFrozen microSAM backbone (better performance)
python3 microsam_prediction.py --gpu 0 --train_path '/scr/data/microsam_data/train/' --pred_path '/scr/data/annotated_mn_datasets/test/images/' --save_path '/scr/yren/microsam_data/microsam_predictions/' --frozen --wandb_modeRetrain microSAM
python3 reformat_microsam_images.py --load_path '/scr/yren/annotated_mn_datasets/' --save_path '/scr/yren/microsam_data/'
python3 microsam_prediction.py --gpu 0 --train_path '/scr/data/microsam_data/train/' --pred_path '/scr/data/annotated_mn_datasets/test/images/' --save_path '/scr/yren/microsam_data/microsam_predictions/' --wandb_modeTurn on --frozen if user wants to use frozen microSAM backbone to make predictions
- Expected file extension of training images and nuclei masks is
.tif, the corresponding training masks is.png. Following values are tunable if retraining on non-micronucleus subcellular datasets. - Combined loss = 0.8 * subcellular loss + 0.2 * nuclei loss.
device = f"cuda:{gpu}" if torch.cuda.is_available() else 'cpu'
model = mnmodel.MicronucleiModel(
device=device,
data_dir=DIRECTORY,
patch_size=256,
scale_factor=1.0,
gaussian=True,
oversample=False # oversample option is only applied to the micronuclei dataset presented in manuscript
)
model.train(epochs=20,
batch_size=4,
learning_rate=1e-5,
loss_fn='combined',
weight_decay=1e-6,
wandb_mode=False
)
model.save(outdir=OUTPUT_DIR, model_name=MODEL_NAME)