## Training

In [None]:
from tifffile import imsave
import torch

from tifffile import imwrite

import os
import tempfile
import pandas as pd
import zipfile
import numpy as np
import datetime
import shutil

from torch_em.util.debug import check_loader
from torch_em.data import MinInstanceSampler
from torch_em.util.util import get_random_colors

import micro_sam.training as sam_training
from micro_sam.sample_data import fetch_tracking_example_data, fetch_tracking_segmentation_data
from micro_sam.automatic_segmentation import get_predictor_and_segmenter, automatic_instance_segmentation

### Get info from the dataset

### Define in/output folder for training
Select a folder on you local computer. The folder needs to contain the following folders:
  
**root_folder/**  
→ training_input/  - images used for training  
→ training_label/  - matching label images, should have the same file names  
→ val_input/       - images used for validation  
→ val_label/       - matching validation images  
→ model/           - empty folder which will contain the model we will train  

In [None]:
output_directory = os.path.abspath("C:\\Users\\mwpaul\\micro-sam_models\\micro-sam-20250207_095503")
print(f"Output directory: {output_directory}")

## Training parameters

In [None]:
# The name of the checkpoint/model name. The checkpoints will be stored in './checkpoints/<checkpoint_name>'
checkpoint_name = "sam"
n_objects_per_batch = 2  # the number of objects per batch that will be sampled
device = "cuda" if torch.cuda.is_available() else "cpu" # the device/GPU used for training
n_epochs = 100  # how long we train (in epochs)
print('running on: ', device)
# The model_type determines which base model is used to initialize the weights that are finetuned.
# We use vit_b here because it can be trained faster. Note that vit_h usually yields higher quality results.
model_type = "vit_l"

### Prepare data loader for the training
This will return a few examples to make sure data is properly organized

In [None]:
val_input_dir = os.path.join(output_directory, "val_input")
os.makedirs(val_input_dir, exist_ok=True)
val_label_dir = os.path.join(output_directory, "val_label")
os.makedirs(val_label_dir, exist_ok=True) 

batch_size = 2  # training batch size
patch_shape = (1, 512, 512)  # the size of patches for training
# Load images from multiple files in folder via pattern (here: all tif files)
raw_key, label_key = "*.tif", "*.tif"

# Train an additional convolutional decoder for end-to-end automatic instance segmentation
# NOTE 1: It's important to have densely annotated-labels while training the additional convolutional decoder.
# NOTE 2: In case you do not have labeled images, we recommend using `micro-sam` annotator tools to annotate as many objects as possible per image for best performance.
train_instance_segmentation = True

# NOTE: The dataloader internally takes care of adding label transforms: i.e. used to convert the ground-truth
# labels to the desired instances for finetuning Segment Anythhing, or, to learn the foreground and distances
# to the object centers and object boundaries for automatic segmentation.

# There are cases where our inputs are large and the labeled objects are not evenly distributed across the image.
# For this we use samplers, which ensure that valid inputs are chosen subjected to the paired labels.
# The sampler chosen below makes sure that the chosen inputs have atleast one foreground instance, and filters out small objects.
sampler = MinInstanceSampler(min_size=25)  # NOTE: The choice of 'min_size' value is paired with the same value in 'min_size' filter in 'label_transform'.

train_loader = sam_training.default_sam_loader(
    raw_paths=training_input_dir,
    raw_key=raw_key,
    label_paths=training_label_dir,
    label_key=label_key,
    with_segmentation_decoder=train_instance_segmentation,
    patch_shape=patch_shape,
    batch_size=batch_size,
    is_seg_dataset=True,
    #rois=train_roi,
    shuffle=True,
    raw_transform=sam_training.identity,
    sampler=sampler,
)

val_loader = sam_training.default_sam_loader(
    raw_paths=val_input_dir,
    raw_key=raw_key,
    label_paths=val_label_dir,
    label_key=label_key,
    with_segmentation_decoder=train_instance_segmentation,
    patch_shape=patch_shape,
    batch_size=batch_size,
    is_seg_dataset=True,
    #rois=val_roi,
    shuffle=True,
    raw_transform=sam_training.identity,
    sampler=sampler,
)
check_loader(train_loader, 1, plt=True)
check_loader(val_loader, 1, plt=True)

### Running the training

In [None]:
sam_training.train_sam(
    name=checkpoint_name,
    save_root=os.path.join(output_directory, "models"),
    model_type=model_type,
    train_loader=train_loader,
    val_loader=val_loader,
    n_epochs=n_epochs,
    #checkpoint_path='C:\\...\\models\\checkpoints\\sam\\best.pt', #can be used to train further
    n_objects_per_batch=n_objects_per_batch,
    with_segmentation_decoder=train_instance_segmentation,
    device=device,
)

### location of the model


In [None]:
print(output_directory)