In [1]:
from nnunetv2.paths import nnUNet_results, nnUNet_raw
import torch
from batchgenerators.utilities.file_and_folder_operations import join
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
from nnunetv2.imageio.mrc_reader_writer import MRCIO
import numpy as np
import mrcfile

In [3]:
def save_tomo(data, path, voxel_size=17.14):
    """
    Save a 3D numpy array as an MRC file.

    Parameters:
    - data: ndarray
        The 3D data to save.
    - path: str
        Path where the MRC file will be saved.
    - voxel_size: float
        The voxel size of the data.
    """
    with mrcfile.new(path, overwrite=True) as mrc:
        data = data.astype(np.uint8)
        mrc.set_data(data)
        mrc.voxel_size = voxel_size
        
def save_tomo_float32(data, path, voxel_size=17.14):
    """
    Save a 3D numpy array as an MRC file.

    Parameters:
    - data: ndarray
        The 3D data to save.
    - path: str
        Path where the MRC file will be saved.
    - voxel_size: float
        The voxel size of the data.
    """
    with mrcfile.new(path, overwrite=True) as mrc:
        data = data.astype(np.float32)
        mrc.set_data(data)
        mrc.voxel_size = voxel_size

In [4]:
# instantiate the nnUNetPredictor
predictor = nnUNetPredictor(
    tile_step_size=0.33,
    use_gaussian=True,
    use_mirroring=True,
    perform_everything_on_device=False,
    device=torch.device('cuda', 0),
    verbose=False,
    verbose_preprocessing=False,
    allow_tqdm=True
)
# # initializes the network architecture, loads the checkpoint
# predictor.initialize_from_trained_model_folder(
#     join(nnUNet_results, 'Dataset003_3tomo/nnUNetTrainer__nnUNetPlans__3d_fullres'),
#     use_folds=(1,),
#     checkpoint_name='checkpoint_best.pth',
# )

# initializes the network architecture, loads the checkpoint
# predictor.initialize_from_trained_model_folder(
#     # join(nnUNet_results, 'Synapse256_nnUNet/nnUNetTrainer__nnUNetPlans__3d_fullres'),
#     join(nnUNet_results, 'Dataset009_10tomo_10classes/nnUNetTrainer__nnUNetPlans__3d_fullres'),
#     use_folds=(0,),
#     checkpoint_name='checkpoint_best.pth',
# )

# initializes the network architecture, loads the checkpoint
predictor.initialize_from_trained_files(
    dataset_json_path = '/home/liushuo/Documents/data/nnUNet/nnUNet_results/Dataset009_10tomo_10classes/nnUNetTrainer__nnUNetPlans__3d_fullres/dataset.json',
    plans_json_path = '/home/liushuo/Documents/data/nnUNet/nnUNet_results/Dataset009_10tomo_10classes/nnUNetTrainer__nnUNetPlans__3d_fullres/plans.json',
    checkpoint_paths='/home/liushuo/Documents/data/nnUNet/nnUNet_results/Dataset009_10tomo_10classes/nnUNetTrainer__nnUNetPlans__3d_fullres/fold_0/checkpoint_best.pth',
)

In [4]:
# img, props = MRCIO().read_images([join(nnUNet_raw, 'Synapse_128/imagesTs/pp1033.mrc')])
img, props = MRCIO().read_images(['/media/liushuo/data1/data/CET-MAP/actin/emd_11870/emd_11870.map'])
# props = {
#     'spacing': (17.14, 17.14, 17.14),
# }

# ret = predictor.predict_single_npy_array(img, props, None, None, False)
ret1, ret2 = predictor.predict_single_npy_array(img, props, None, None, True)

100%|██████████| 392/392 [01:59<00:00,  3.28it/s]


In [5]:
save_tomo(ret1, f'/media/liushuo/data1/data/CET-MAP/actin/emd_11870/result.rec', voxel_size=props['spacing'])
# save_tomo_float32(ret2, f'/media/liushuo/data1/data/synapse_seg/pp387/ret2.mrc')