# **BioImage Model Zoo Example notebook**

---

This notebook provides examples of how to load pretrained deep learning models from [BioImage Model Zoo](https://bioimage.io), use them to process new images, and finetune them.

## **1. Install key dependencies**
---
<font size = 4>


### **1.1. Install required dependencies**

---

In [None]:
#@markdown ##Play to install dependencies
#@markdown #### DO NOT RESTART THE SESSION UNTIL THE CELL FINISHES RUNNING
#@markdown #### This may take few minutes

!pip install -q bioimageio.core==0.6.8
!pip install -q matplotlib==3.9.0
!pip install -q imageio==2.31.2
!pip install -q numpy==1.23.5
!pip install -q torch==2.2.0
!pip install -q onnxruntime==1.18.0
!pip install -q pooch==1.8
!pip install -q marshmallow==3.21.3

### **1.2. Connect to your Google Drive to access training data**

---

In [None]:
#@markdown ## Run this cell to connect Google Drive

from google.colab import drive
drive.mount('/content/drive')

### **1.3. Load BioImageIO dependencies**

---

In [None]:
#@markdown ##Play to load the dependencies and functions

# If you'd rather read the warning messages, please comment the follwing two lines.
import warnings
warnings.filterwarnings("ignore")

# Load BioImage Model Zoo packages
from bioimageio.core.digest_spec import create_sample_for_model
from bioimageio.core import predict, create_prediction_pipeline, load_description, test_model

import bioimageio.spec
from bioimageio.spec import save_bioimageio_package
from bioimageio.spec.utils import load_array, download
from bioimageio.spec.model.v0_5 import (ModelDescr, ArchitectureFromFileDescr, Author, CiteEntry,
                                        Version, Doi, HttpUrl, LicenseId,
                                        WeightsDescr, PytorchStateDictWeightsDescr, TorchscriptWeightsDescr,
                                        InputTensorDescr, OutputTensorDescr, TensorId, LinkedDataset,
                                        FileDescr, IntervalOrRatioDataDescr, Identifier, SizeReference,
                                        BatchAxis, ChannelAxis, SpaceInputAxis, SpaceOutputAxis)


# Load other packages
import matplotlib.pyplot as plt
import numpy as np
import pooch
import json
import os

from IPython.display import display, Markdown
from imageio import imwrite as imsave
from imageio import imread
from pathlib import Path
from ruyaml import YAML

# URL pointing the file with the collection from the BIoimage Model Zoo
COLLECTION_URL = "https://raw.githubusercontent.com/bioimage-io/collection-bioimage-io/gh-pages/collection.json"

# Download the colection
collection_path = Path(pooch.retrieve(COLLECTION_URL, known_hash=None))
with collection_path.open() as f:
    collection = json.load(f)

# Get all the URLs of the models in the downloaded collection
model_urls = [entry["rdf_source"] for entry in collection["collection"] if entry["type"] == "model"]

# Download the rdf.yaml files from all the folders
yaml = YAML(typ="safe")
model_rdfs = [yaml.load(Path(pooch.retrieve(mu, known_hash=None))) for mu in model_urls]

# Get only the models that have "pytorch_state_dict" weights
pytorch_models = [rdf for rdf in model_rdfs if "pytorch_state_dict" in rdf["weights"]]


## **2. Inspect a model from the BioImage Model Zoo**

---

Here we will guide you through the basic functionalities of the BioImageIO Python package to interact with the content in the BioImage Model Zoo.

First, you can obtain a list of the available Bioimage Model Zoo models with PyTorch architecture.

In [None]:
#@markdown ##Check the models that can be loaded for PyTorch

# Print all the PyTorch models ("pytorch_state_dict" weights) on the Bioimage Model Zoo
print('List of models for PyTorch:\n')
for model in pytorch_models:
    print(f"{model['name']}\n - {model['config']['bioimageio']['nickname']}\n - {model['config']['bioimageio']['doi']}")

### **2.1. Load the resource description specifications of the model**

---

To load the model of your choice, you only need to fill one of the fields in the cell below and leave the rest empty. If more than one is filled the first one will be used.

<font size = 3>**`BMZ_MODEL_ID`**: Unique identifier of the model to load in the BioImage Model Zoo, e.g., `impartial-shrimp`. These identifiers are available in each model card in the zoo.

OR

<font size = 3>**`BMZ_MODEL_DOI`**: Model DOIs can also be used to load the models.

OR

<font size = 3>**`BMZ_MODEL_URL`**: URL to the main Zenodo repository as well as to the `rdf.yaml` file containing the resource description specifications can also be used to load models.

In [None]:
#@markdown ##Load the model description with one of these options

# "affable-shark"
BMZ_MODEL_ID = "affable-shark" #@param {type:"string"}
# "10.5281/zenodo.5764892"
BMZ_MODEL_DOI = "" #@param {type:"string"}
# "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/affable-shark/1.1/files/rdf.yaml"
BMZ_MODEL_URL = "" #@param {type:"string"}

#####
# Load the model description from one of the provided options

if BMZ_MODEL_ID != "":
    model = load_description(BMZ_MODEL_ID)  # TODO: load from bioimageio id
    print(f"The model '{model.name}' with ID '{BMZ_MODEL_ID}' has been correctly loaded.")
elif BMZ_MODEL_DOI != "":
    model = load_description(BMZ_MODEL_DOI)
    print(f"The model '{model.name}' with DOI '{BMZ_MODEL_DOI}' has been correctly loaded.")
elif BMZ_MODEL_URL != "":
    model = load_description(BMZ_MODEL_URL)
    print(f"The model '{model.name}' with URL '{BMZ_MODEL_URL}' has been correctly loaded.")
else:
    print('Please specify a model ID, DOI or URL')

### **2.2. Discover the different components and features of the model**

---

#### Print information about the model

In [None]:
#@markdown ##Print information about the model

print(f"The model '{model.name} {model.id_emoji}' had the following properties and metadata:")
print()
print(f" Description: {model.description}")
print(f" Tags: {', '.join(model.tags)}")
print(f" Model ID: {model.id}")
print()
print(f" The authors of the model are:")
for author in model.authors:
    print(f"  - {author.name}, with GitHub user: @{author.github_user}")
print(f" The maintainers of the modes are:")
for maintainer in model.maintainers:
    print(f"  - {maintainer.name}, with GitHub user: @{maintainer.github_user}")
print()
print(f" License: {model.license}")
print()
print(f" If you use this model, you are expected to cite:")
for citation in model.cite:
    doi_text = ""
    if citation.doi is not None:
        doi_text = f" from DOI (https://doi.org/{citation.doi})"

    url_text = ""
    if citation.url is not None:
        if citation.doi is not None:
            url_text = f" or URL ({citation.url})"
        else:
            url_text = f" from URL ({citation.url})"
    print(f"  - For the {citation.text},{doi_text}{url_text}")
if model.git_repo is not None:
    print()
    print(f" GitHub repository: {model.git_repo}")
print()
print(f" Covers of the model '{model.name}' are: ")
for cover in model.covers:
    cover_data = imread(download(cover).path)
    plt.figure(figsize=(10, 10))
    plt.imshow(cover_data)
    plt.xticks([])
    plt.yticks([])
    plt.show()
print()
print(f" Further documentation (taken from {model.documentation.absolute()}):")
print(' -'*60)
display(Markdown(open(download(model.documentation).path).read()))
print(' -'*60)

#### Inspect the weights, and expected inputs and outputs

In [None]:
#@markdown ##Inspect the weights, and expected inputs and outputs

print("Available weight formats for this model:", ", ".join(model.weights.model_fields_set))
print("Pytorch state dict weights are stored at:", model.weights.pytorch_state_dict.source.absolute())
print()

# or what inputs the model expects
print(f"The model requires {len(model.inputs)} input(s) with the following features:")
for inp in model.inputs:
    print(" - Input with axes:", ([i.id for i in inp.axes]))
    print(" - Minimum shape:", ([s.min if type(s) is bioimageio.spec.model.v0_5.ParameterizedSize else s for s in inp.shape]))
    print(" - Step:", ([s.step if type(s) is bioimageio.spec.model.v0_5.ParameterizedSize else s for s in inp.shape]))
    print()
    print(f"It is expected to be pre-processed with:")
    for prep in inp.preprocessing:
        print(f" - '{prep.id}' with arguments:")
        for prerp_arg in prep.kwargs:
           print(f"    - {prerp_arg[0]}={prerp_arg[1]}")
print()

# and what the model outputs are
print(f"The model gives {len(model.outputs)} output(s) with the following features:")
for out in model.outputs:
    print(" - Output with axes:", ([o.id for o in out.axes]) )
    print(" - Minimum shape:", ([s if type(s) is bioimageio.spec.model.v0_5.SizeReference else s for s in out.shape]))
    print(" - Step:", ([s.step if type(s) is bioimageio.spec.model.v0_5.ParameterizedSize else s for s in out.shape]))
    print()
    print(f"It is expected to be post-processed with:")
    for postp in out.postprocessing:
        print(f" - '{postp.id}' with arguments:")
        for postp_arg in postp.kwargs:
           print(f"    - {postp_arg[0]}={postp_arg[1]}")
    #print(f"The output image has a halo of : {out.halo}")


#### Inspect the test images

In [None]:
#@markdown ##Inspect the test images

# Inspect the test input images provided with the model
print(f"The model provides {len(model.inputs)} test input image(s) :")
for test_im in model.get_input_test_arrays():
    test_input = np.squeeze(test_im)
    if len(test_input.shape)>2:
        print(f"The test input image has shape {test_input.shape}, so it will not displayed")
    else:
        fig, ax = plt.subplots()
        im = ax.imshow(test_input, cmap="afmhot")
        fig.colorbar(im)
        ax.set_axis_off()
        plt.show()

# Inspect the test output images provided with the model
print(f"The model provides {len(model.outputs)} test output image(s) :")
for test_im in model.get_output_test_arrays():
    test_output = np.squeeze(test_im)
    if len(test_output.shape)>2:
        print(f"The test output image has shape {test_output.shape}, so it will not displayed")
    else:
        fig, ax = plt.subplots()
        im = ax.imshow(test_output, cmap="afmhot")
        fig.colorbar(im)
        ax.set_axis_off()
        plt.show()

## **3. Test the model**

---

Both the model format and the deployment of the model can be tested.

By running the following cell you can check that
- The model follows the format of the BioImage Model Zoo correctly (static validation)
- It actually produces the output that is expected to produce (dynamic validation). This is done by running a prediction for the test input images and checking that they agree with the given test output(s).

The running time depends on the resources available (e.g., GPU acceleration).

In [None]:
#@markdown ##Check if the model passes the test

# Test the description of the model
test_result = test_model(model)

# 'test_model()' returns a ValidationSummary object with an attribute status which can be 'passed'/'failed' and more detailed information
if test_result.status == "failed":
    print("model test:", test_result.name)
    if len(test_result.errors) > 1:
        print("The model test failed with many errors. We will only show the first one:")
    else:
        print("The model test failed with:")
    for error in test_result.errors[:1]:
        # Allowing a good indexation in the message
        error_msg = error.msg.split('\n')
        print(f" - {error_msg[0]}")
        for line in error_msg[1:]:
            print(f"   {line}")
else:
    print("model test:", test_result.name)
    assert test_result.status == "passed", f"Something went wrong, test_result.status is {test_result.status} and should be 'passed'."
    print("The model passed the test.")
    print()

# Get the versions of the Bioimage.IO packages used for this test
bioimageio_spec_version = ""
bioimageio_core_version = ""
for package in test_result.env:
    if package['name'] == 'bioimageio.spec':
        bioimageio_spec_version = package['version']
    if package['name'] == 'bioimageio.core':
        bioimageio_core_version = package['version']

print()
print(f"The model was tested using:")
print(f" - 'bioimageio_spec_version': '{bioimageio_spec_version}'")
print(f" - 'bioimageio_core_version': '{bioimageio_core_version}'")

## **4. Use the model with new images**

---

### **4.1. Process the example input array**

---

In [None]:
#@markdown ##Process the example input within the model

# Load the example image for this model
input_paths = {ipt.id: download(ipt.test_tensor).path for ipt in model.inputs}
# The prediction pipeline expects a Sample object from bioimageio.core
input_sample = create_sample_for_model(
    model=model, inputs=input_paths, sample_id="my_demo_sample"
)

# "devices" can be used to run prediction on a gpu instead of the cpu
devices = None
# "weight_format" to specify which weight format to use in case the model contains different weight formats
weight_format = None

# The prediction pipeline combines preprocessing, prediction and postprocessing.
# It should always be used for prediction with a bioimageio model.
prediction_pipeline = create_prediction_pipeline(
    model, devices=devices, weight_format=weight_format
)

# The prediction pipeline call expects the same number of inputs as the number of inputs required by the model
# In this case, the model just expects a single input. In case you have multiple inputs use:
# prediction = pred_pipeline(input1, input2, ...)
# or, if you have the inputs in a list or tuple
# prediction = pred_pipeline(*inputs)
# The call returns a list of output tensors, corresponding to the output tensors of the model
prediction = prediction_pipeline.predict_sample_without_blocking(input_sample)

# Convert both input and output sample tensors into a NumPy format to plot them
input_sample_tensor = input_sample.members["input0"].data
input_sample_tensor = np.squeeze(input_sample_tensor)
prediction_tensor = prediction.members["output0"].data
prediction_tensor = np.squeeze(prediction_tensor)

# Plot the input and output images
if len(prediction_tensor.shape)>2:
    subplot_n = prediction_tensor.shape[0]
    plt.figure(figsize=(15,10))
    plt.subplot(1,1+subplot_n,1)
    plt.imshow(input_sample_tensor, cmap="afmhot")
    plt.axis('off')
    plt.title("Input image to process")

    for i in range(subplot_n):
        plt.subplot(1,1+subplot_n,i+2)
        plt.imshow(prediction_tensor[i], cmap="afmhot")
        plt.axis('off')
        plt.title("Processed image")
    plt.show()
else:
    plt.figure(figsize=(5,5))
    plt.subplot(1,2,1)
    plt.imshow(input_sample_tensor, cmap="afmhot")
    plt.axis('off')
    plt.title("Input image to process")
    plt.subplot(1,2,2)
    plt.imshow(prediction_tensor, cmap="afmhot")
    plt.axis('off')
    plt.title("Processed image")
    plt.show()

### **4.2. Process a single image and save the result**

---

The BioImageIO core library is equipped with the utility function `predict` to run predictions on an image stored in disk. It accepts most common image formats (`.tif`, `.png`) as well as `npy` fileformat as inputs, and the output prediction can be stored in a local `Results_folder` directory.

Provide the path to the image to be processed in `Image_path` or run it with the example image.

In [None]:
#@markdown ##Indicate the path to the image

# You might want to use the test images (like on section 4.1.)
use_test_image = False #@param {type:"boolean"}

#@markdown ### If you have an image in a folder to segment, copy the path to it here:
Image_path = ""  #@param {type:"string"}
#@markdown ### Indicate where to save the output of the model:s
Results_folder = ""  #@param {type:"string"}

# Create the result/output folder
os.makedirs(Results_folder, exist_ok=True)

if use_test_image:
  # Download and take the input paths from the model description
  input_paths = {ipt.id: download(ipt.test_tensor).path for ipt in model.inputs}
else:
  # Load the paths to the input images
  input_paths = {"input0": Path(Image_path)}

# The prediction pipeline expects a Sample object from bioimageio.core
input_sample = create_sample_for_model(
    model=model, inputs=input_paths, sample_id="my_demo_sample"
)

# Use the predict function with the defined Samples
prediction = predict(model=model, inputs=input_sample)

# Convert both input and output sample tensors into a NumPy format to save and plot them
input_sample_tensor = input_sample.members["input0"].data
input_sample_tensor = np.squeeze(input_sample_tensor)
prediction_tensor = prediction.members["output0"].data
prediction_tensor = np.squeeze(prediction_tensor)

# Save the output results
filename = Image_path.split(os.sep)[-1]
name, extension = filename.split('.')
for i, f in enumerate(prediction_tensor):
  imsave(uri=os.path.join(Results_folder, f"{name}_{i}.{extension}"), im=f)
print(f"Predicted images correctly saved on: {Results_folder}")

# Plot the input image and the predicted results
plt.figure(figsize=(15,15))
plt.subplot(1,1+prediction_tensor.shape[0],1)
plt.imshow(np.squeeze(input_sample_tensor), cmap="afmhot")
plt.axis('off')
plt.title("Input image to process")

for pred_idx, pred in enumerate(prediction_tensor):
    plt.subplot(1,1+prediction_tensor.shape[0],pred_idx+2)
    plt.imshow(pred, cmap="afmhot")
    plt.axis('off')
    plt.title(f"Processed image {pred_idx+1}")
plt.show()



### **4.3. Process all images stored in a directory**

---
It is possible to provide a list of images to analyse and run the prediction for each automatically. 

In this example `tiling`, or `blocking`, strategies can be enabled. This divides the image into smaller patches, each of which is processed independently, and then rejoined.

`Tile_size` must be equal or lower than the input image size.

In [None]:
#@markdown ##Indicate a directory with images to analyse and a directory to save the images

Image_path = "" #@param {type:"string"}
Results_folder = "" #@param {type:"string"}

Tiling = True #@param {type:"boolean"}
Tile_size = 256 #@param {type:"integer"}

# Create the result/output folder
#os.makedirs(Results_folder, exist_ok=True)

# Calculate the blocksize parameter from the Tile_size
if Tiling:
    inp = model.inputs[0]
    # Get the minimum block size and the block step values from the model description RDF
    min_block_size = max([s.min for s in inp.shape if type(s) is bioimageio.spec.model.v0_5.ParameterizedSize])
    block_step = max([s.step for s in inp.shape if type(s) is bioimageio.spec.model.v0_5.ParameterizedSize])
    blocksize_parameter = (Tile_size - min_block_size) // block_step
        

# Get the list of images to analyse and the same list to save the images
input_path_list = [os.path.join(Image_path,f) for f in sorted(os.listdir(Image_path)) if not f.startswith('.')]

print(f"Provided directory contains {len(input_path_list)} images to predict.")

for in_idx, in_path in enumerate(input_path_list):

    # Apparently the implementation requires to provide an input sample of the same shape and
    # with the same id as the one the model already ahs, for that reason we need to process images
    # one by one.
    input_paths_dict = {f"input0": Path(in_path)}

    # The prediction pipeline expects a Sample object from bioimageio.core
    input_sample = create_sample_for_model(
        model=model, inputs=input_paths_dict, sample_id="my_demo_sample"
    )

    # Use the predict function with the defined Samples, with or without tiling
    if not Tiling:
        prediction = predict(model=model, inputs=input_sample)
    else:
        prediction = predict(model=model, inputs=input_sample, blocksize_parameter=blocksize_parameter)

    # Convert both input and output sample tensors into a NumPy format to save and plot them
    input_sample_tensor = input_sample.members["input0"].data
    input_sample_tensor = np.squeeze(input_sample_tensor)
    prediction_tensor = prediction.members["output0"].data
    prediction_tensor = np.squeeze(prediction_tensor)

    # Save the output results
    filename = in_path.split(os.sep)[-1]
    name, extension = filename.split('.')
    for pred_idx, pred in enumerate(prediction_tensor):
        imsave(uri=os.path.join(Results_folder, f"{name}_{pred_idx}.{extension}"), im=pred)

    # Plot the input image and the predicted results
    plt.figure(figsize=(15,15))
    plt.subplot(1,1+prediction_tensor.shape[0],1)
    plt.imshow(np.squeeze(input_sample_tensor), cmap="afmhot")
    plt.axis('off')
    plt.title("Input image to process")

    for pred_idx, pred in enumerate(prediction_tensor):
        plt.subplot(1,1+prediction_tensor.shape[0],pred_idx+2)
        plt.imshow(pred, cmap="afmhot")
        plt.axis('off')
        plt.title(f"Processed image {pred_idx+1}")
    plt.show()

print(f"Predicted images correctly saved on: {Results_folder}")

## **5. Fine-tune an existing model (only for segmentation)**

---

### **5.1. Load the training functions**

---

In [None]:
#@markdown ## Run the following cell to load the training functions

from marshmallow import missing
import torch
import numpy as np
import random
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MultiLabelSoftMarginLoss
from torch.nn.functional import one_hot
from torch.optim import Adam
from torch.utils.data import Dataset
from tqdm import tqdm
import time

import importlib
import sys

def load_pytorch_model(model):
    weight_spec = model.weights.pytorch_state_dict
    model_kwargs = weight_spec.architecture.kwargs
    joined_kwargs = {} if model_kwargs is missing else dict(model_kwargs)

    # Download the Python file with the model
    model_fullpath = str(download(model.weights.pytorch_state_dict.architecture.source).path)
    model_path, model_filename = os.path.split(model_fullpath)

    # Add it to the sys.path list
    sys.path.insert(0, str(model_path))

    # Get the name of the model's class
    module_name = str(model.weights.pytorch_state_dict.architecture.callable)

    # Import the Python file
    spec = importlib.util.spec_from_file_location(module_name, model_fullpath)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)

    # Add the module to sys.modules
    sys.modules[module_name] = module

    # Import the model
    model_class = getattr(importlib.import_module(module_name), module_name)
    # Initialzie the model
    model_instance = model_class(**joined_kwargs)

    print(f"Model {module_name} succesfully initialized!")

    _devices = ["cuda" if torch.cuda.is_available() else "cpu"]

    print(_devices)
    if len(_devices) > 1:
        warnings.warn("Multiple devices for single pytorch model not yet implemented")
    model_instance.to(_devices[0])

    weights= model.weights.pytorch_state_dict

    if weights is not None and weights.source:
        weights_fullpath = str(download(weights.source).path)
        state = torch.load(weights_fullpath, map_location=_devices[0])
        model_instance.load_state_dict(state)
    model_instance.eval()

    return model_instance


class SegmentationTrainDataset(Dataset):
    def __init__(self,  INPUT_IMAGE_WIDTH, INPUT_IMAGE_HEIGHT, imagePaths, maskPaths, classes):
        # store the image and mask filepaths, and augmentation
        # transforms
        self.imagePaths = imagePaths
        self.maskPaths = maskPaths
        self.classes = classes
    def __len__(self):
        # return the number of total samples contained in the dataset
        return len(self.imagePaths)

    def preprocess(self, im):
        norm_im = np.float32(im)
        return (norm_im - np.mean(norm_im)) / (np.std(norm_im) + 1.0e-6)

    def __getitem__(self, idx):
        # grab the image path from the current index
        image = imread(self.imagePaths[idx]) 
        mask = imread(self.maskPaths[idx])
        # we want to ensure that there are cells in the patch
        # without getting in an infinite loop
        # TODO: define a sampling function to remove the loop
        num_labels = 0
        k = 0
        while num_labels<(self.classes-1) and k<5:
          # Choose a random coordinate to crop a patch
          h = random.randint(1, image.shape[0]-INPUT_IMAGE_HEIGHT-1)
          w = random.randint(1, image.shape[1]-INPUT_IMAGE_WIDTH-1)
          mask_patch = mask[h:h+INPUT_IMAGE_HEIGHT, w:w+INPUT_IMAGE_WIDTH]
          num_labels = len(np.unique(mask_patch))
          # If the mask contains more than one label for semantic segmentation
          # we will trasnform into one-hot encoding
          mask_torch = torch.tensor(mask_patch).to(torch.int64)
          mask_hot = one_hot(mask_torch, self.classes)
          mask_hot = mask_hot[:,:,1:]
          if len(mask_hot.shape)==2:
            # add a dimension
            mask_hot = np.expand_dims(mask_hot, -1)
          # first axis goes to the channels
          mask_hot = np.transpose(mask_hot, [-1, 0, 1])
          k += 1
          # return a tuple of the image and its mask
        norm_image = self.preprocess(image)
        norm_image = np.expand_dims(norm_image[h:h+INPUT_IMAGE_HEIGHT, w:w+INPUT_IMAGE_WIDTH], 0)
        return (torch.tensor(norm_image).float(), torch.tensor(mask_hot).float())

class SegmentationTestDataset(Dataset):
    def __init__(self,  INPUT_IMAGE_WIDTH, INPUT_IMAGE_HEIGHT, imagePaths, maskPaths, classes):
        # store the image and mask filepaths, and augmentation
        # transforms
        self.imagePaths = imagePaths
        self.maskPaths = maskPaths
        self.classes = classes
    def __len__(self):
        # return the number of total samples contained in the dataset
        return len(self.imagePaths)

    def preprocess(self, im):
        norm_im = np.float32(im)
        return (norm_im - np.mean(norm_im)) / (np.std(norm_im) + 1.0e-6)

    def __getitem__(self, idx):
        # grab the image path from the current index
        image = imread(self.imagePaths[idx])
        mask = imread(self.maskPaths[idx])
        # no patches are cropped. Check for the memory
        mask_torch = torch.tensor(mask).to(torch.int64)
        mask_hot = one_hot(mask_torch, self.classes)
        mask_hot = mask_hot[:,:,1:]
        if len(mask_hot.shape)==2:
          # add a dimension
          mask_hot = np.expand_dims(mask_hot, -1)
        # first axis goes to the channels
        mask_hot = np.transpose(mask_hot, [-1, 0, 1])
        # return a tuple of the image and its mask
        norm_image = self.preprocess(image)
        norm_image = np.expand_dims(norm_image, 0)
        return (torch.tensor(norm_image).float(), torch.tensor(mask_hot).float())


def visualize_results(input_image, gt, prediction):
    if len(prediction.shape)>2:
        subplot_n = prediction.shape[0]
        plt.figure(figsize=(20,15))
        plt.subplot(subplot_n,3,1)
        plt.imshow(np.squeeze(input_image), cmap="gray")
        plt.axis('off')
        plt.title("Input test image")

        for i in range(subplot_n):
            plt.subplot(subplot_n,3,i*3+2)
            plt.imshow(np.squeeze(gt[i]), cmap="afmhot")
            plt.axis('off')
            plt.title("Ground truth image")

            plt.subplot(subplot_n,3,i*3+3)
            plt.imshow(np.squeeze(prediction[i]), cmap="afmhot")
            plt.axis('off')
            plt.title("Processed image")
        plt.show()
    else:
        plt.figure(figsize=(10,10))
        plt.subplot(1,3,1)
        plt.imshow(np.squeeze(input_image), cmap="afmhot")
        plt.axis('off')
        plt.title("Input test image")
        plt.subplot(1,3,2)
        plt.imshow(np.squeeze(gt), cmap="afmhot")
        plt.axis('off')
        plt.title("Ground truth image")

        plt.subplot(1,3,3)
        plt.imshow(np.squeeze(prediction), cmap="afmhot")
        plt.axis('off')
        plt.title("Processed image")
        plt.show()


def finetune_bioimageio_model(model, TRAIN_IM, TRAIN_MASK, TEST_IM, TEST_MASK,
                              BASE_OUTPUT, NUM_EPOCHS=100, INIT_LR=0.0001, BATCH_SIZE=10,
                              INPUT_IMAGE_WIDTH=512, INPUT_IMAGE_HEIGHT=512, CLASSES=3):

    model_instance = load_pytorch_model(model)
    # create the train and test datasets
    trainImages = sorted([os.path.join(TRAIN_IM, i) for i in os.listdir(TRAIN_IM) if i.endswith(".tif") ])
    trainMasks = sorted([os.path.join(TRAIN_MASK, i) for i in os.listdir(TRAIN_MASK) if i.endswith(".tif") ])
    testImages = sorted([os.path.join(TEST_IM, i) for i in os.listdir(TEST_IM) if i.endswith(".tif") ])
    testMasks = sorted([os.path.join(TEST_MASK, i) for i in os.listdir(TEST_MASK) if i.endswith(".tif") ])

    trainDS = SegmentationTrainDataset( INPUT_IMAGE_WIDTH, INPUT_IMAGE_HEIGHT, imagePaths=trainImages, maskPaths=trainMasks, classes = CLASSES)
    testDS = SegmentationTestDataset( INPUT_IMAGE_WIDTH, INPUT_IMAGE_HEIGHT, imagePaths=testImages, maskPaths=testMasks, classes = CLASSES)
    print(f"[INFO] found {len(trainDS)} examples in the training set...")
    print(f"[INFO] found {len(testDS)} examples in the test set...")

    # create the training and test data loaders
    from torch.utils.data import DataLoader

    # determine the device to be used for training and evaluation
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    # determine if we will be pinning memory during data loading
    PIN_MEMORY = True if DEVICE == "cuda" else False
    trainLoader = DataLoader(trainDS, shuffle=True, batch_size=BATCH_SIZE, pin_memory=PIN_MEMORY, num_workers=os.cpu_count())
    testLoader = DataLoader(testDS, shuffle=False, batch_size=1, pin_memory=PIN_MEMORY, num_workers=os.cpu_count())

    # initialize loss function and optimizer
    lossFunc = CrossEntropyLoss()
    opt = Adam(model_instance.parameters(), lr=INIT_LR)
    # calculate steps per epoch for training and test set
    trainSteps = len(trainDS) // BATCH_SIZE
    testSteps = len(testDS)
    # initialize a dictionary to store training history
    H = {"train_loss": [], "test_loss": []}

    # Save the initial prediction
    x, y = testDS.__getitem__(-1)
    x = x.unsqueeze(0)
    with torch.no_grad():
        # set the model in evaluation mode
        model_instance.eval()
        # send the input to the device
        (x, y) = (x.to(DEVICE), y.to(DEVICE))
        # make the predictions
        pred = model_instance(x)

    input_image = x.to("cpu")
    input_image = np.squeeze(input_image.numpy())

    gt = y.to("cpu")
    gt = np.squeeze(gt.numpy())

    prediction = pred.to("cpu")
    prediction = np.squeeze(prediction.numpy())
    print("Results of the prediction before finetuning")
    print("---------------------------------------------")
    visualize_results(input_image, gt, prediction)
    print("---------------------------------------------")
    del x, y, gt, prediction, input_image, pred

    # loop over epochs
    print("[INFO] training the network...")
    startTime = time.time()
    for e in tqdm(range(NUM_EPOCHS)):
        # set the model in training mode
        model_instance.train()
        # initialize the total training and validation loss
        totalTrainLoss = 0
        totalTestLoss = 0
        # loop over the training set

        for (i, (x, y)) in enumerate(trainLoader):
            # send the input to the device
            (x, y) = (x.to(DEVICE), y.to(DEVICE))
            # perform a forward pass and calculate the training loss
            pred = model_instance(x)
            loss = lossFunc(pred, y)
            # first, zero out any previously accumulated gradients, then
            # perform backpropagation, and then update model parameters
            opt.zero_grad()
            loss.backward()
            opt.step()
            # add the loss to the total training loss so far
            totalTrainLoss += loss
        # switch off autograd
        with torch.no_grad():
            # set the model in evaluation mode
            model_instance.eval()
            # loop over the validation set
            for (x, y) in testLoader:
                # send the input to the device
                (x, y) = (x.to(DEVICE), y.to(DEVICE))
                # make the predictions and calculate the validation loss
                pred = model_instance(x)
                totalTestLoss += lossFunc(pred, y)
        if DEVICE=="cuda":
            torch.save(model_instance.state_dict(),
                       os.path.join(BASE_OUTPUT, "finetuned_last.pth"))
        else:
            torch.save(model_instance.cpu().state_dict(),
                       os.path.join(BASE_OUTPUT, "finetuned_last.pth"))

        # calculate the average training and validation loss
        avgTrainLoss = totalTrainLoss / trainSteps
        avgTestLoss = totalTestLoss / testSteps
        # update our training history
        H["train_loss"].append(avgTrainLoss.cpu().detach().numpy())
        H["test_loss"].append(avgTestLoss.cpu().detach().numpy())
        # print the model training and validation information
        print("[INFO] EPOCH: {}/{}".format(e + 1, NUM_EPOCHS))
        print("Train loss: {:.6f}, Test loss: {:.4f}".format(
            avgTrainLoss, avgTestLoss))
    # display the total time needed to perform the training
    endTime = time.time()
    print("[INFO] total time taken to train the model: {:.2f}s".format(endTime - startTime))

    input_image = x.to("cpu")
    input_image = np.squeeze(input_image.numpy())

    gt = y.to("cpu")
    gt = np.squeeze(gt.numpy())

    prediction = pred.to("cpu")
    prediction = np.squeeze(prediction.numpy())

    visualize_results(input_image, gt, prediction)

    return model_instance, H

# get the python file defining the architecture.
# this is only required for models with pytorch_state_dict weights
def get_architecture_source(rdf):
    import bioimageio
    # here, we need the raw resource, which contains the information from the resource description
    # before evaluation, e.g. the file and name of the python file with the model architecture
    raw_resource = bioimageio.core.load_raw_resource_description(rdf)
    # the python file defining the architecture for the pytorch weihgts
    model_source = raw_resource.weights["pytorch_state_dict"].architecture
    # download the source file if necessary
    source_file = bioimageio.core.resource_io.utils.resolve_source(
        model_source.source_file
    )
    # if the source file path does not exist, try combining it with the root path of the model
    if not os.path.exists(source_file):
        source_file = os.path.join(raw_resource.root_path, os.path.split(source_file)[1])
    assert os.path.exists(source_file), source_file
    class_name = model_source.callable_name
    return f"{source_file}:{class_name}"

### **5.2. Start the fine-tuning**

---

##### Run the following cell to visualize the results of the pretrained model on the new images before running the fine-tuning

In [None]:
#@markdown ##Indicate the path to the image

# You might want to use the test images (like on section 4.1.)
use_test_image = False #@param {type:"boolean"}

#@markdown ### If you have an image in a folder to segment, copy the path to it here:
Image_path = ""  #@param {type:"string"}
#@markdown ### Indicate where to save the output of the model:
Results_folder = ""  #@param {type:"string"}

if use_test_image:
  # Download and take the input paths from the model description
  input_paths = {ipt.id: download(ipt.test_tensor).path for ipt in model.inputs}
else:
  # Load the paths to the input images
  input_paths = {"input0": Path(Image_path)}

# The prediction pipeline expects a Sample object from bioimageio.core
input_sample = create_sample_for_model(
    model=model, inputs=input_paths, sample_id="my_demo_sample"
)

# The prediction pipeline expects a Sample object from bioimageio.core
prediction = predict(model=model, inputs=input_sample)

# Convert both input and output sample tensors into a NumPy format to save and plot them
input_sample_tensor = input_sample.members["input0"].data
input_sample_tensor = np.squeeze(input_sample_tensor)
prediction_tensor = prediction.members["output0"].data
prediction_tensor = np.squeeze(prediction_tensor)

# Save the output results
filename = Image_path.split(os.sep)[-1]
name, extenstion = filename.split('.')
for i, f in enumerate(prediction_tensor):
  imsave(uri=os.path.join(Results_folder, f"{name}_{i}.{extenstion}"), im=f)

# Plot the input image and the predicted results
plt.figure(figsize=(15,15))
plt.subplot(1,1+prediction_tensor.shape[0],1)
plt.imshow(np.squeeze(input_sample_tensor), cmap="afmhot")
plt.axis('off')
plt.title("Input image to process")

for pred_idx, pred in enumerate(prediction_tensor):
    plt.subplot(1,1+prediction_tensor.shape[0],pred_idx+2)
    plt.imshow(pred, cmap="afmhot")
    plt.axis('off')
    plt.title(f"Processed image {pred_idx+1}")
plt.show()



##### Run the following cell to set up the parameters for the fine-tuning and run it.

The fine-tunning requires the following parameters:
- Initial learning rate `INIT_LR`, number of training epochs `NUM_EPOCHS`, and batch size `BATCH_SIZE`
- The input image dimensions `INPUT_IMAGE_WIDTH` and `INPUT_IMAGE_HEIGHT`, and number of label classes `NUM_CLASSES` (including background)
- Paired folders of Images and Masks, one for training and one for validation. 
- An output directory `BASE_OUTPUT`


In [None]:
# @markdown ## Set up training parameters

# @markdown Initialize learning rate, number of training epochs, and batch size
INIT_LR = 0.000001 #@param {type:"number"}
NUM_EPOCHS = 200 #@param {type:"integer"}
BATCH_SIZE = 20 #@param {type:"integer"}
# @markdown Shape of the input images and number of classes
INPUT_IMAGE_WIDTH = 256 #@param {type:"integer"}
INPUT_IMAGE_HEIGHT = 256 #@param {type:"integer"}
CLASSES = 3 #@param {type:"integer"}
# @markdown Path to the training data paired folders
TRAIN_IM = "" #@param {type:"string"}
TRAIN_MASK = "" #@param {type:"string"}
# @markdown Path to the test data paired folders
TEST_IM = "" #@param {type:"string"}
TEST_MASK = "" #@param {type:"string"}

#@markdown ### Path to the directory to save the new model
# Define the path to the base output directory
BASE_OUTPUT = "" #@param {type:"string"}

# Make the output directory
os.makedirs(BASE_OUTPUT, exist_ok=True)

# Define the path to the output serialized model, model training
# plot, and testing image paths
MODEL_PATH = os.path.join(BASE_OUTPUT, "finetuned_bioimageio.pth")
PLOT_PATH = os.path.sep.join([BASE_OUTPUT, "plot.png"])

## -----
## Release some memory
# del model_instance
# del finetuned_model
torch.cuda.empty_cache()
# gc.collect()
## -----

finetuned_model, H = finetune_bioimageio_model(model, TRAIN_IM, TRAIN_MASK, TEST_IM, TEST_MASK,
                                          BASE_OUTPUT, NUM_EPOCHS=NUM_EPOCHS, INIT_LR=INIT_LR, BATCH_SIZE=BATCH_SIZE,
                                          INPUT_IMAGE_WIDTH=INPUT_IMAGE_WIDTH, INPUT_IMAGE_HEIGHT=INPUT_IMAGE_HEIGHT,
                                               CLASSES=CLASSES)

## Save the model in two different formats

# Save the model as a pytorch statedict
MODEL_STATEDICT_PATH = os.path.join(BASE_OUTPUT, "finetuned_bioimageio_statedict_model.pth")
torch.save(finetuned_model.cpu().state_dict(),MODEL_STATEDICT_PATH)

# Convert the model to a torchscript format and save it as torchscript
MODEL_TORCHSCRIPT_PATH = os.path.join(BASE_OUTPUT, "finetuned_bioimageio_torchscript_model.pt")
with torch.no_grad():
    # load input and expected output data
    input_data = [np.load(download(inp.test_tensor.source).path).astype("float32") for inp in model.inputs]
    input_data = [torch.from_numpy(inp) for inp in input_data]
    scripted_model = torch.jit.trace(finetuned_model.cpu(), input_data)
    scripted_model.save(MODEL_TORCHSCRIPT_PATH)

# Plot the training loss and tr=est loss
plt.style.use("ggplot")
plt.figure(figsize=(20, 10))
plt.plot(H["train_loss"], label="train_loss")
plt.plot(H["test_loss"], label="test_loss")
plt.title("Training Loss on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Loss")
plt.legend(loc="lower left")
plt.savefig(PLOT_PATH)

## **6. Create a BioImage Model Zoo model**

---

Let's recreate a model based on parts of the loaded model description from above!

`bioimageio.core` also implements functionality to create a model package compatible with the [BioImnageIO Model Spec](https://bioimage.io/docs/#/bioimageio_model_spec) ready to be shared via the [Bioimage Model Zoo](https://bioimage.io/#/).
Here, we will use this functionality to create a new model with the finetuned weights.

For this we are using some information from the previouse model.

Run the following cell to export the model

In [None]:
# ------
# Information about the model

#@markdown ##Export the new model to the bioimage model zoo format

Trained_model_name = "My new model" #@param {type:"string"}
Trained_model_description = "Model finetuned" #@param {type:"string"}

#@markdown ### Choose a test image
Image_path = ""  #@param {type:"string"}
training_data_bioimageio_id = "zero/dataset_u-net_2d_multilabel_deepbacs"  #@param {type:"string"}

#@markdown ### Main directory where the new model checkpoint is saved
Model_folder = "" #@param {type:"string"}
MODEL_TORCHSCRIPT_PATH = os.path.join(Model_folder, "finetuned_bioimageio_torchscript_model.pt")
output_path = os.path.join(Model_folder, "finetuned_bioimageio_model.zip")

#####

author_list = []
for author in model.authors:
  author_list.append(Author(name=author.name,
                            affiliation=author.affiliation,
                            github_user=author.github_user))

Trained_model_license = model.license
readme_path = model.documentation
citation_list = []
for citation in model.cite:
  citation_list.append(CiteEntry(text=citation.text,
                                 doi=citation.doi,
                                 url=citation.url))


## Define the new input

# Load the input image, reshape it and save it as a numpy file
new_input_path = f"{Model_folder}/new_test_input.npy"
test_im = imread(Image_path)
test_im = test_im[-256:, :256]
test_im = np.expand_dims(test_im, axis=(0,1))
np.save(new_input_path, np.float32(test_im))

# Define the input axes (take the ones from previous model)
input_axes = []
for axis in model.inputs[0].axes:
  if axis.id == "batch":
    input_axes.append(BatchAxis(id=axis.id,
                                 description=axis.description,
                                 type=axis.type,
                                 size=axis.size))
  elif axis.id == "channel":
    input_axes.append(ChannelAxis(id=axis.id,
                                 description=axis.description,
                                 type=axis.type,
                                 channel_names=[Identifier("raw")]))
  else: # x, y or z
    input_axes.append(SpaceInputAxis(id=axis.id,
                                      description=axis.description,
                                      type=axis.type,
                                      unit=axis.unit,
                                      scale=axis.scale,
                                      concatenable=axis.concatenable,
                                      size=axis.size))

# Define the data description (which is the data type)
data_descr = IntervalOrRatioDataDescr(type="float32")

# Define the preprocessing functions, we take the ones that were already defined
preprocessing_list = model.inputs[0].preprocessing

# Create the input tensor description
new_input_descr = InputTensorDescr(id=TensorId("raw"),
                               axes=input_axes,
                               test_tensor=FileDescr(source=new_input_path),
                               data=data_descr,
                               preprocessing=preprocessing_list
)

## Define the new output, in this step we don't have the output yet, so we are
## creating a 'fake/auxiliar' output (which is created from the input) and then once
## the model is created we will be able to create the prediction and create the real
## output with it.

# Create a temporal output
aux_new_output_path = f"{Model_folder}/aux_new_test_output.npy"
# The output is expected to have 2 channels, then concatenate two input images on channel axis
np.save(aux_new_output_path, np.float32(np.concatenate((test_im,test_im), axis=1)))

# Define the output axes (take the ones from previous model)
output_axes = []
for axis in model.outputs[0].axes:
  if axis.id == "batch":
    output_axes.append(BatchAxis(id=axis.id,
                                 description=axis.description,
                                 type=axis.type,
                                 size=axis.size))
  elif axis.id == "channel":
    output_axes.append(ChannelAxis(id=axis.id,
                                 description=axis.description,
                                 type=axis.type,
                                 channel_names=axis.channel_names))
  else: # x, y or z
    output_axes.append(SpaceOutputAxis(id=axis.id,
                                      description=axis.description,
                                      type=axis.type,
                                      unit=axis.unit,
                                      scale=axis.scale,
                                      size=SizeReference(tensor_id=TensorId("raw"),
                                                         axis_id=axis.size.axis_id,
                                                         offset=axis.size.offset)))

# Define the posprocessing functions, we take the ones that were already defined
# postprocessing_list = model.outputs[0].postprocessing
postprocessing_list = [] # This model does not require postprocessing

# Create the output tensor description
new_output_descr = OutputTensorDescr(id=TensorId("prob"),
                                    axes=output_axes,
                                    test_tensor=FileDescr(source=aux_new_output_path),
                                    postprocessing=postprocessing_list)

# Define the training data
training_data = LinkedDataset(id=training_data_bioimageio_id)

# Define the PyTorch architecture with the one you previously loaded
# model_source = get_architecture_source("affable-shark")
pytorch_architecture = ArchitectureFromFileDescr(
        source=download(model.weights.pytorch_state_dict.architecture.source,
                        sha256=model.weights.pytorch_state_dict.architecture.sha256).path,
        sha256=model.weights.pytorch_state_dict.architecture.sha256,
        callable=model.weights.pytorch_state_dict.architecture.callable,
        kwargs=model.weights.pytorch_state_dict.architecture.kwargs
    )

# Get PyTorch version
try:
    import torch
except ImportError:
    pytorch_version = Version("1.15")
else:
    pytorch_version = Version(torch.__version__)

# Define the weights using provided info
weights = WeightsDescr(
            torchscript=TorchscriptWeightsDescr(
                source=MODEL_TORCHSCRIPT_PATH,
                sha256=None,
                pytorch_version=torch.__version__
            )
          )

# We create the model, process the input image and create the model again with the correct output.
for i in range(2):
  # The test input and output data are passed as list because we support multiple inputs / outputs per model
  my_model_descr = ModelDescr(
      name = Trained_model_name,
      description = Trained_model_description,
      authors = author_list,
      license = Trained_model_license,
      documentation = readme_path,
      weights= weights,
      inputs = [new_input_descr],
      outputs =  [new_output_descr],
      tags=["in-silico-labeling","pytorch", "cyclegan", "conditional-gan",
            "zerocostdl4mic", "deepimagej", "actin", "dapi", "cells", "nuclei",
            "fluorescence-light-microscopy", "2d"],  # the tags are used to make models more findable on the website
      cite = citation_list,
      training_data = training_data,
      # add_deepimagej_config=True,
      )

  if i == 0:

    # Define the new input sample (taken from the new model description)
    new_input_paths = {ipt.id: download(ipt.test_tensor).path for ipt in my_model_descr.inputs}

    # The prediction pipeline expects a Sample object from bioimageio.core
    input_sample = create_sample_for_model(
        model=my_model_descr, inputs=new_input_paths, sample_id="my_demo_sample"
    )

    # Create the new prediction
    prediction = predict(model=my_model_descr, inputs=input_sample)

    # Save the new prediction on a NumPy file
    new_output_path = f"{Model_folder}/new_test_output.npy"
    prediction_tensor = prediction.members["prob"].data
    np.save(os.path.join(new_output_path), prediction_tensor)

    # Define the posprocessing functions, we take the ones that were already defined
    # postprocessing_list = model.outputs[0].postprocessing
    postprocessing_list = [] # This model does not require postprocessing

    # Create the output tensor description
    new_output_descr = OutputTensorDescr(id=TensorId("prob"),
                          axes=output_axes,
                          test_tensor=FileDescr(source=new_output_path),
                          postprocessing=postprocessing_list)

# Check that the model works and display the result of the test
summary = test_model(my_model_descr)
summary.display()


if summary.status == "passed":
  # In case it has passed the test, save the bioimage.io model with the correct format
  save_bioimageio_package(my_model_descr, output_path=Path(output_path))
  print("The bioimage.io model was successfully exported to", output_path)
else:
  print("The bioimage.io model was exported to", output_path)
  print("Some tests of the model did not work!l.")
  print("You can still download and test the model, but it may not work as expected.")
