Check GPU avaiability (Runtime -> change runtime type -> hardware escalator -> T4 GPU)

In [None]:
import torch
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))

Import environment, packages, load paths, data and prepare the files

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Adjust the paths to your image, mask, model folders (/content/drive/MyDrive/...)

In [None]:
drive_data_image_dir = "/content/drive/MyDrive/cellpose/images_to_train/tiff"
drive_data_mask_dir = "/content/drive/MyDrive/cellpose/training_material"
drive_model_dir = "/content/drive/MyDrive/cellpose/model"

In [None]:
# Clone/update repository
! rm -rf /content/Image_analysis
! git clone https://github.com/Nerita21/Image_analysis.git

In [None]:
# Install packages from requirements only if missing
import subprocess
import sys
import importlib.util
import os

def ensure_requirements(req_file="requirements.txt"):
    if not os.path.exists(req_file):
        # Try to find requirements.txt in the cloned repository directory
        repo_dir = "/content/Image_analysis"
        if os.path.exists(os.path.join(repo_dir, req_file)):
            req_file = os.path.join(repo_dir, req_file)
        else:
            print(f"Error: requirements.txt not found at {req_file} or {os.path.join(repo_dir, 'requirements.txt')}")
            return

    with open(req_file) as f:
        for line in f:
            pkg = line.strip()
            if not pkg or pkg.startswith("#"):
                continue

            # Pip typically expects '==' for exact version pinning.
            # If the requirement uses '=', try to convert it to '=='
            # if it looks like a simple package=version string and not a URL or path.
            processed_pkg = pkg
            if '=' in pkg and '==' not in pkg and not any(x in pkg for x in ['/', 'git+', '#egg=']):
                parts = pkg.split('=', 1)
                if len(parts) == 2 and parts[1]: # ensure there's a version part
                    processed_pkg = f"{parts[0]}=={parts[1]}"

            name = pkg.split("==")[0].split(">=")[0].split("<=")[0].split("=")[0]
            if importlib.util.find_spec(name) is None:
                print(f"Installing {processed_pkg}...")
                subprocess.check_call([sys.executable, "-m", "pip", "install", processed_pkg])
            else:
                print(f"{name} already installed.")

ensure_requirements()


In [None]:
# Import required packages:
from cellpose import models, utils
import napari
import matplotlib.pyplot as plt
from skimage import io
import os
import tifffile
import numpy as np

In [None]:
# Now you can import your packages
import sys
sys.path.insert(0, '/content/Image_analysis/src')   # adjust path if you cloned elsewhere
from utils import load_config
config, base_dir = load_config()

from package import (train_cellpose_model)

Run model training manually after adjusting the inputs

In [None]:
if __name__ == "__main__":
    # Train model
    try:
        model = train_cellpose_model(
            image_dir=drive_data_image_dir,
            mask_dir=drive_data_mask_dir,
            model_name="noisyFISH_cyto", # rename for your choice
            channels=[1, 0],  # Only cyto channel (single channel)
            n_epochs=100,
        )
        print("\n Training complete!")
    except Exception as e:
        print(f"\n Training failed: {e}")

    save_model_path = os.path.join(drive_model_dir, "noisyFISH_cyto.pth")
    # Save the trained model
    try:
        model.save_model(save_model_path)
        print(f"Model successfully saved to {save_model_path}")
    except Exception as e:
        print(f"Failed to save model: {e}")