## Training

In [None]:
from omero.gateway import BlitzGateway
import ezomero
#load dotenv for OMERO login
from dotenv import load_dotenv

from tifffile import imsave
import torch

from tifffile import imwrite

import os
import tempfile
import pandas as pd
import zipfile
import numpy as np
import datetime
import shutil


from cellpose import io, models, train


### Setup connection with OMERO

In [None]:
load_dotenv(override=True)

conn = BlitzGateway(host=os.environ.get("HOST"), username=os.environ.get("USER_NAME"), passwd=os.environ.get("PASSWORD"), secure=True)
connection_status = conn.connect()
if connection_status:
    print("Connected to OMERO Server")
else:
    print("Connection to OMERO Server Failed")
conn.c.enableKeepAlive(60)

### Get info from the dataset

In [None]:
datatype = "dataset" # "plate", "dataset", "image"
data_id = 502
nucl_channel = 0

#validate that data_id matches datatype
if datatype == "plate":
    plate = conn.getObject("Plate", data_id)
    print('Plate Name: ', plate.getName())
elif datatype == "dataset":
    dataset = conn.getObject("Dataset", data_id)
    print('Dataset Name: ', dataset.getName())
elif datatype == "image":
    image = conn.getObject("Image", data_id)
    print('Image Name: ', image.getName())

### Define output folder for training

In [None]:
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
home_dir = os.path.expanduser("~")
models_dir = os.path.join(home_dir, "cellpose_models")
os.makedirs(models_dir, exist_ok=True)
folder_name = f"cellpose-{timestamp}"
output_directory = os.path.join(models_dir, folder_name)
os.makedirs(output_directory, exist_ok=True)
output_directory = os.path.abspath(output_directory)
print(f"Output directory: {output_directory}")

### Collecting data from OMERO using the attached table

In [None]:
def get_specific_table(conn, dataset_id, table_name="cellpose_training_data"):
    """
    Find and return a specific table attached to a dataset by its name.
    
    Args:
        conn: OMERO connection
        dataset_id: ID of the dataset to search
        table_name: Name of the table file to find
        
    Returns:
        table: Table data as pandas DataFrame or list of lists
        file_ann_id: ID of the file annotation containing the table
    """
    # Get all file annotations on the dataset
    file_ann_ids = ezomero.get_file_annotation_ids(conn, "Dataset", dataset_id)
    
    # Get original file details to check names
    for ann_id in file_ann_ids:
        ann = conn.getObject("FileAnnotation", ann_id)
        if ann is None:
            continue
            
        orig_file = ann.getFile()
        if orig_file.getName() == table_name:
            try:
                table = ezomero.get_table(conn, ann_id)
                return table, ann_id
            except Exception as e:
                print(f"Found file {table_name} but failed to load as table: {e}")
                continue
                
    return None, None

In [None]:
table_name = "cellpose_training_data"
table, file_ann_id = get_specific_table(conn, data_id, table_name)
if table is not None:
    print(f"Found table {table_name} in file annotation {file_ann_id}")
    # If pandas DataFrame:
    print(table.head())
else:
    print(f"No table named {table_name} found")

In [None]:
#download table from omero, use it to collect training data
train_images = []
validate_images = []

folders = ["training_input", "training_label", "val_input", "val_label", "tmp"]	
for folder in folders:
    folder = os.path.join(output_directory,folder)
    if os.path.exists(folder) and os.path.isdir(folder):
        shutil.rmtree(folder)
    #os.makedirs(folder)

#prepare training data
train_images = table[table['train'] == True]
val_images = table[table['validate'] == True]
os.makedirs(os.path.join(output_directory, "tmp"), exist_ok=True)
training_dir = os.path.join(output_directory, "training")
os.makedirs(training_dir, exist_ok=True)

for n in range(len(train_images)):
    z_slice = train_images.iloc[n]['z_slice']
    channel = train_images.iloc[n]['channel']
    timepoint = train_images.iloc[n]['timepoint']
    image = conn.getObject('Image', int(train_images.iloc[n]['image_id']))
    pixels = image.getPrimaryPixels()
    img = pixels.getPlane(z_slice, channel, timepoint) #(z, c, t) 
    #save image to output folder
    # Normalize 16-bit to 8-bit using 0 as minimum
    img_8bit = ((img) * (255.0 / img.max())).astype(np.uint8)

    # Save as 8-bit tiff as required for cellpose training
    imwrite(os.path.join(training_dir, f"training_0000{n}_img.tif"), img_8bit)
    
    file_path = ezomero.get_file_annotation(conn, int(train_images.iloc[n]['label_id']), os.path.join(output_directory, "tmp"))
    os.rename(file_path, os.path.join(training_dir, f"training_0000{n}_masks.tif"))

val_dir = os.path.join(output_directory, "validation")
os.makedirs(val_dir, exist_ok=True)

for n in range(len(val_images)):
    image = conn.getObject('Image', int(val_images.iloc[n]['image_id']))
    pixels = image.getPrimaryPixels()
    img = pixels.getPlane(z_slice, channel, timepoint) #(z, c, t) 
    # Normalize 16-bit to 8-bit using 0 as minimum
    img_8bit = ((img) * (255.0 / img.max())).astype(np.uint8)
    #save image to output folder
    imsave(os.path.join(val_dir, f"val_0000{n}_img.tif"), img_8bit)
    file_path = ezomero.get_file_annotation(conn, int(val_images.iloc[n]['label_id']), os.path.join(output_directory, "tmp"))
    os.rename(file_path, os.path.join(val_dir, f"val_0000{n}_masks.tif"))

print("Training data succesfully saved to: ", output_directory)

### Prepare data for the training

In [None]:
output = io.load_train_test_data(training_dir, val_dir, image_filter="_img",
                                mask_filter="_masks", look_one_level_down=False)

### Running the training

In [None]:
io.logger_setup()

images, labels, image_names, test_images, test_labels, image_names_test = output
channels = [1,2] #which channels to use for training
n_epochs = 500
#Retrain a Cellpose model
model = models.CellposeModel(model_type="cyto3")
model_path, train_losses, test_losses = train.train_seg(model.net,
                            train_data=images, train_labels=labels,
                            channels=[1,2], normalize=True,
                            test_data=test_images, test_labels=test_labels,
                            weight_decay=1e-4, SGD=True, learning_rate=0.1,
                            n_epochs=n_epochs, model_name="cellpose_model")

### save model


In [None]:
print(output_directory)