# Example: Train Stardist Model

This example file shows how to segment images with the stardist model using a Google Colab runtime. Google Colab is optimized for using files saved to a Google Drive. The images to be segmented should therefore be uploaded to a Google Drive, and this file opened in a Google Colab runtime.

## Setting up

*   Install dependencies
*   Mount Google Drive
*   Import dependencies

In [None]:
%pip install stardist
%pip install czifile

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import numpy as np
import os
from skimage import io
from stardist.models import StarDist3D
from csbdeep.utils import normalize
import glob
import czifile

## Config

Set up input and output paths and define segmentation function. The base directory of the model, as well as the file extension of the images to be segmented, to be used must also be defined

In [None]:
input_path = "/content/drive/path/to/images"
output_path ="/content/drive/path/to/save"
model_path = '/content/drive/path/to//models'
file_extension = ".czi"

image_files = glob.glob(f"{folder_path}/*{file_extension}")
print(image_files)

model = StarDist3D(None, name='Hippocampus9.1', basedir=model_path)

def segment(img_path):
    if '.czi' in img_path:
      new_image = czifile.imread(img_path)
      print(new_image.shape)
      new_image = np.squeeze(new_image)
      print(new_image.shape)

    # Transpose the array to the desired shape
      new_image = np.transpose(new_image, (1,2,3,0))
      print(new_image.shape)
    else:
      new_image = io.imread(img_path)
    if new_image.shape[-1] == 4:
        normalized = normalize(new_image[:,:,:,3])
    else:
        normalized = normalize(new_image[:,:,:,2])

    labels, _ = model.predict_instances(normalized, n_tiles=(10,10,4))

    directory, filename = os.path.split(img_path)
    without_extension, extension = os.path.splitext(filename)
    mask_file_name = f"{without_extension}_mask.tif"
    mask_path = os.path.join(output_path, mask_file_name)

    io.imsave(mask_path, labels)


## Segment

In [None]:
for image_path in image_files:
    segment(image_path)