In [None]:
import zmq
import numpy as np
import matplotlib.pyplot as plt

Server setup and image retrieval

In [None]:
PORT = "7878"
DOMAIN = "*"
SOCKET_ADDR = f"tcp://{DOMAIN}:{PORT}"

In [None]:
HELLO = "Hello"
ACK = "Acknowledge"
DENIED = "Denied"

In [None]:
context = zmq.Context()
socket = context.socket(zmq.PAIR)
socket.copy_threshold = 0
b = socket.bind(SOCKET_ADDR)
b

In [None]:
# handshake
socket.send_string(HELLO)
print("Sent hello, waiting for acknowledgement...")
ack = socket.recv_string()
if ack == ACK:
    print('Received connection ack:', ack)
else:
    print("Received unkown message", ack)

In [None]:
header = socket.recv_json()
header

In [None]:
# acknowledge receipt of header, ask for image data
socket.send_string(ACK)

In [None]:
im_bytes = socket.recv(copy=False)
im_bytes

In [None]:
buf = memoryview(im_bytes)
im = np.frombuffer(buf, dtype=header['descr'])
im = (im * 255).astype(np.uint8)
im.shape = header['shape']
im

In [None]:
# acknowledge receipt of image data
socket.send_string(ACK)

In [None]:
plt.imshow(im, cmap='gray')

BIOIMAGE download and inspect the model

In [None]:
from pprint import pprint
from typing_extensions import assert_never


from bioimageio.spec.pretty_validation_errors import enable_pretty_validation_errors_in_ipynb
from bioimageio.spec import InvalidDescr, load_description
from bioimageio.spec.model.v0_5 import ModelDescr

from bioimageio.spec.model.v0_5 import ArchitectureFromLibraryDescr, ArchitectureFromFileDescr

In [None]:
enable_pretty_validation_errors_in_ipynb()

In [None]:
# https://bioimage.io/#/?tags=affable-shark&id=10.5281%2Fzenodo.5764892
MODEL_ID = "affable-shark"
MODEL_DOI = "10.5281/zenodo.11092561"

In [None]:
source = MODEL_ID

loaded_description = load_description(source)

In [None]:
loaded_description

In [None]:
loaded_description.validation_summary.display()

In [None]:
# let's make sure we have a valid model...
if isinstance(loaded_description, InvalidDescr):
    raise ValueError(f"Failed to load {source}")
elif not isinstance(loaded_description, ModelDescr):
    raise ValueError("This notebook expects a model 0.5 description")

model = loaded_description
example_model_id = model.id
assert example_model_id is not None

In [None]:
print(f"The model is named '{model.name}'")
print(f"Description:\n{model.description}")
print(f"License: {model.license}")

print("\nThe authors of the model are:")
pprint(model.authors)
print(f"\nIn addition to the authors it is maintained by:")
pprint(model.maintainers)

print("\nIf you use this model, you are expected to cite:")
pprint(model.cite)

print(f"\nFurther documentation can be found here: {model.documentation}")

if model.git_repo is None:
    print("\nThere is no associated GitHub repository.")
else:
    print(f"\nThere is an associated GitHub repository: {model.git_repo}.")

In [None]:
for w in [(weights := model.weights).onnx, weights.keras_hdf5, weights.tensorflow_js, weights.tensorflow_saved_model_bundle, weights.torchscript,weights.pytorch_state_dict]:
    if w is  None:
        continue

    print(w.weights_format_name)
    print(f"weights are available at {w.source.absolute()}")
    print(f"and have a SHA-256 value of {w.sha256}")
    details = {k: v for k, v in w.model_dump(mode="json", exclude_none=True).items() if k not in ("source", "sha256")}
    if details:
        print(f"additonal metadata for {w.weights_format_name}:")
        pprint(details)

    print()

In [None]:
print(f"Model '{model.name}' requires {len(model.inputs)} input(s) with the following features:")
for ipt in model.inputs:
    print(f"\ninput '{ipt.id}' with axes:")
    pprint(ipt.axes)
    print(f"Data description: {ipt.data}")
    print(f"Test tensor available at:  {ipt.test_tensor.source.absolute()}")
    if len(ipt.preprocessing) > 1:
        print("This input is preprocessed with: ")
        for p in ipt.preprocessing:
            print(p)

print("\n-------------------------------------------------------------------------------")
# # and what the model outputs are
print(f"Model '{model.name}' requires {len(model.outputs)} output(s) with the following features:")
for out in model.outputs:
    print(f"\noutput '{out.id}' with axes:")
    pprint(out.axes)
    print(f"Data description: {out.data}")
    print(f"Test tensor available at:  {out.test_tensor.source.absolute()}")
    if len(out.postprocessing) > 1:
        print("This output is postprocessed with: ")
        for p in out.postprocessing:
            print(p)

In [None]:
assert isinstance(model, ModelDescr)
if (w:=model.weights.pytorch_state_dict) is not None:
    arch = w.architecture
    print(f"callable: {arch.callable}")
    if isinstance(arch, ArchitectureFromFileDescr):
        print(f"import from file: {arch.source.absolute()}")
        if arch.sha256 is not None:
            print(f"SHA-256: {arch.sha256}")
    elif isinstance(arch, ArchitectureFromLibraryDescr):
        print(f"import from module: {arch.import_from}")
    else:
        assert_never(arch)

BIOIMAGE - run prediction

In [None]:
import pathlib
import bioimageio.core.io as bio
import bioimageio.core.prediction as bi_pred

In [None]:
im = (im / np.iinfo(im.dtype).max).astype(np.float32)

In [None]:
pad_y = (64 - im.shape[0] % 64) % 64
pad_x = (64 - im.shape[1] % 64) % 64
padded_image = np.pad(im, ((0, pad_y), (0, pad_x)), mode='constant', constant_values=0)

In [None]:
input_image = padded_image.reshape([1,1,padded_image.shape[0],padded_image.shape[1]])
del padded_image

In [None]:
input_image.shape, input_image.dtype

In [None]:
out = bi_pred.predict(model=model, inputs={'input0': input_image}, skip_postprocessing=True, skip_preprocessing=True)
del input_image

In [None]:
res = np.array(out.members['output0'].data[0])

In [None]:
res = res[:, :im.shape[0], :im.shape[1]]

In [None]:
plt.imshow(res[0,:,:])

In [None]:
plt.imshow(res[1,:,:])

In [None]:
plt.imshow(np.stack([
    np.zeros_like(res[0]),
    res[0],
    res[1]
]).transpose(1, 2, 0))

In [None]:
from skimage.filters import threshold_otsu
from skimage.segmentation import clear_border
from skimage.measure import label, regionprops
from skimage.morphology import closing, square
from skimage.color import label2rgb

In [None]:
res_im = res[0]

thresh = threshold_otsu(res_im)
bw = closing(res_im > thresh, square(3))
# cleared = clear_border(bw)
# label_image = label(cleared)
label_image = label(bw)

In [None]:
plt.imshow(bw)

In [None]:
plt.imshow(label_image)

In [None]:
plt.imshow(label2rgb(label_image, image=im, bg_label=0))

Send back the results

In [None]:
return_header = np.lib.format.header_data_from_array_1_0(label_image)
return_header

In [None]:
socket.send_json(return_header)

In [None]:
ack = socket.recv_string()
if ack == ACK:
    print('Received return header ack:', ack)
else:
    print("Received unkown message", ack)

In [None]:
socket.send(label_image, copy=False)

In [None]:
# socket.send_string("Cancel")