# Download model weights for CNN semantic segmentation module.

In [None]:
import torch
from tardis_em.utils.aws import get_weights_aws

# Fnet model for 3D microtubule segmentation
cnn_weights = get_weights_aws(
    network="fnet_attn", subtype="32", model="microtubules_3d"
)

# # Unet model for 3D microtubule segmentation
# cnn_weights = get_weights_aws(network='unet', subtype="32", model="microtubules_3d")
#
# # Fnet model for 3D membrane segmentation
# cnn_weights = get_weights_aws(network='fnet_attn', subtype="32", model="membrane_3d")
#
# # Unet model for 3D membrane segmentation
# cnn_weights = get_weights_aws(network='unet', subtype="32", model="membrane_3d")
#
# # Fnet model for 3D membrane segmentation
# cnn_weights = get_weights_aws(network='fnet_attn', subtype="32", model="membrane_2d")
#
# # Unet model for 3D membrane segmentation
# cnn_weights = get_weights_aws(network='unet', subtype="32", model="membrane_2d")

cnn_weights = torch.load(cnn_weights, map_location="cpu")

# Download model weights for DIST instance segmentation module.

In [2]:
from tardis_em.utils.aws import get_weights_aws
import torch

# # 3D instance segmentation module for membranes in 3D
# dist_weights = get_weights_aws(network='dist', subtype="triang", model="3d")

# 2D instance segmentation module for membranes or microtubules
dist_weights = get_weights_aws(network="dist", subtype="triang", model="2d")

dist_weights = torch.load(dist_weights, map_location="cpu")

# Build Prediction class

In [3]:
from tardis_em.utils.predictor import GeneralPredictor
import numpy as np

x = np.random.rand(128, 128, 128)

tardis = GeneralPredictor(
    predict="Microtubule",  # Must be one of ["Filament", "Membrane2D", "Membrane", "Microtubule"]
    dir_s=[x],  # Must be one of [str to directory, np.ndarray]
    binary_mask=False,  # If True, GeneralPredictor will assume that input image is Semantic Mask and will run only instance segmentation
    output_format="return_return",  # Must be "formatS_formatI" where formatS is semantic output format file and formatI is instance output format one of [".tif", ".tiff", ".mrc", ".rec", ".am", ".map", ".npy"]
    patch_size=128,  # Size of a patch ML model predict at a time
    convolution_nn="fnet_attn",  # Name of CNN model. Must be one of [unet, fnet_attn]
    cnn_threshold=0.25,  # Threshold for CNN model
    dist_threshold=0.5,  # Threshold for DIST instance model
    points_in_patch=1000,  # Number of points DIST model predict at a time [GPU constrain 1000 points needs 12 Gb GPU memory]
    predict_with_rotation=False,  # Optional rotation of image patch during CNN prediction to increase accuracy
    instances=True,  # If True, Tardis will predict both semantic and instance predictions
    device_s="cpu",  # 0-9 indicate GPU id, but also "cpu" or "mps" can be used
    debug=False,  # If True, enable debuting mode which save all intermediate files
    checkpoint=[cnn_weights, dist_weights],  # Indicate pre-trained weights
    correct_px=25,  # For numpy correct pixel size must be indicated
    amira_prefix=None,  # For microtubules prediction, Tardis can read Amira Spatial Graph to compare both predictions and select overlaying microtubules to increase precision
    filter_by_length=None,  # Optional length filter for the predicted filaments
    connect_splines=None,  # Optional filter for connecting filaments which are facing the same direction and are in the distance equal or smaller than [A]
    connect_cylinder=None,  # Optional filter for connecting filaments which are facing the same direction based on cylinder radius equal or smaller than [A]
    amira_compare_distance=None,  # Optional length filter for comparing Amira prediction with TARDIS
    amira_inter_probability=None,  # Optional filter for comparing Amira prediction with TARDIS. Defining likelihood of two microtubules being the same one [0-1]
    tardis_logo=False,  # If True, enable console display of prediction progress. If True, Tardis will run silently
)

# Run prediction

In [4]:
semantic, instance, instance_filter = tardis()

In [5]:
import matplotlib.pyplot as plt

plt.imshow(semantic[0].sum(0))

# Create dummy dataset

In [6]:
from tardis_em.utils.export_data import to_mrc

to_mrc(x.astype(np.float32), 25.0, "./test.mrc")

# Run TARDIS on file in directory

In [7]:
from tardis_em.utils.predictor import GeneralPredictor
import numpy as np

x = np.random.rand(128, 128, 128)

tardis = GeneralPredictor(
    predict="Microtubule",  # Must be one of ["Filament", "Membrane2D", "Membrane", "Microtubule"]
    dir_s=".",  # Must be one of [str to directory, np.ndarray]
    binary_mask=False,  # If True, GeneralPredictor will assume that input image is Semantic Mask and will run only instance segmentation
    output_format="return_return",  # Must be "formatS_formatI" where formatS is semantic output format file and formatI is instance output format one of [".tif", ".tiff", ".mrc", ".rec", ".am", ".map", ".npy"]
    patch_size=128,  # Size of a patch ML model predict at a time
    convolution_nn="fnet_attn",  # Name of CNN model. Must be one of [unet, fnet_attn]
    cnn_threshold=0.25,  # Threshold for CNN model
    dist_threshold=0.5,  # Threshold for DIST instance model
    points_in_patch=1000,  # Number of points DIST model predict at a time [GPU constrain 1000 points needs 12 Gb GPU memory]
    predict_with_rotation=False,  # Optional rotation of image patch during CNN prediction to increase accuracy
    instances=True,  # If True, Tardis will predict both semantic and instance predictions
    device_s="cpu",  # 0-9 indicate GPU id, but also "cpu" or "mps" can be used
    debug=False,  # If True, enable debuting mode which save all intermediate files
    checkpoint=[cnn_weights, dist_weights],  # Indicate pre-trained weights
    correct_px=25,  # For numpy correct pixel size must be indicated
    amira_prefix=None,  # For microtubules prediction, Tardis can read Amira Spatial Graph to compare both predictions and select overlaying microtubules to increase precision
    filter_by_length=None,  # Optional length filter for the predicted filaments
    connect_splines=None,  # Optional filter for connecting filaments which are facing the same direction and are in the distance equal or smaller than [A]
    connect_cylinder=None,  # Optional filter for connecting filaments which are facing the same direction based on cylinder radius equal or smaller than [A]
    amira_compare_distance=None,  # Optional length filter for comparing Amira prediction with TARDIS
    amira_inter_probability=None,  # Optional filter for comparing Amira prediction with TARDIS. Defining likelihood of two microtubules being the same one [0-1]
    tardis_logo=False,  # If True, enable console display of prediction progress. If True, Tardis will run silently
)

# Run Prediction

In [8]:
semantic, instance, instance_filter = tardis()

In [9]:
import matplotlib.pyplot as plt

plt.imshow(semantic[0].sum(0))

# Run TARDIS on file in directory and save file as .mrc and .csv

In [10]:
from tardis_em.utils.predictor import GeneralPredictor
import numpy as np

x = np.random.rand(128, 128, 128)

tardis = GeneralPredictor(
    predict="Microtubule",  # Must be one of ["Filament", "Membrane2D", "Membrane", "Microtubule"]
    dir_s=".",  # Must be one of [str to directory, np.ndarray]
    binary_mask=False,  # If True, GeneralPredictor will assume that input image is Semantic Mask and will run only instance segmentation
    output_format="mrc_csv",  # Must be "formatS_formatI" where formatS is semantic output format file and formatI is instance output format one of [".tif", ".tiff", ".mrc", ".rec", ".am", ".map", ".npy"]
    patch_size=128,  # Size of a patch ML model predict at a time
    convolution_nn="fnet_attn",  # Name of CNN model. Must be one of [unet, fnet_attn]
    cnn_threshold=0.25,  # Threshold for CNN model
    dist_threshold=0.5,  # Threshold for DIST instance model
    points_in_patch=1000,  # Number of points DIST model predict at a time [GPU constrain 1000 points needs 12 Gb GPU memory]
    predict_with_rotation=False,  # Optional rotation of image patch during CNN prediction to increase accuracy
    instances=True,  # If True, Tardis will predict both semantic and instance predictions
    device_s="cpu",  # 0-9 indicate GPU id, but also "cpu" or "mps" can be used
    debug=False,  # If True, enable debuting mode which save all intermediate files
    checkpoint=[cnn_weights, dist_weights],  # Indicate pre-trained weights
    correct_px=25,  # For numpy correct pixel size must be indicated
    amira_prefix=None,  # For microtubules prediction, Tardis can read Amira Spatial Graph to compare both predictions and select overlaying microtubules to increase precision
    filter_by_length=None,  # Optional length filter for the predicted filaments
    connect_splines=None,  # Optional filter for connecting filaments which are facing the same direction and are in the distance equal or smaller than [A]
    connect_cylinder=None,  # Optional filter for connecting filaments which are facing the same direction based on cylinder radius equal or smaller than [A]
    amira_compare_distance=None,  # Optional length filter for comparing Amira prediction with TARDIS
    amira_inter_probability=None,  # Optional filter for comparing Amira prediction with TARDIS. Defining likelihood of two microtubules being the same one [0-1]
    tardis_logo=False,  # If True, enable console display of prediction progress. If True, Tardis will run silently
)
tardis()  # This will save output in './Prediction/' folder only if prediction was successful