In [None]:
%%capture
%pip install zea

In [1]:
import os

os.environ["KERAS_BACKEND"] = "jax"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import matplotlib.pyplot as plt
from keras import ops
from PIL import Image
import numpy as np
import requests
from io import BytesIO

from zea import init_device
from zea.visualize import set_mpl_style
from zea.display import scan_convert_2d, inverse_scan_convert_2d
from zea.utils import translate

init_device(verbose=False)
set_mpl_style()

[1m[38;5;36mzea[0m[0m: Using backend 'jax'


E0000 00:00:1758438658.962206  905083 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1758438658.966328  905083 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1758438658.978618  905083 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1758438658.978634  905083 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1758438658.978636  905083 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1758438658.978637  905083 computation_placer.cc:177] computation placer already registered. Please check linka

In [2]:
# NOTE: this is a synthetic PLAX view image generated by a diffusion model.
url = "https://raw.githubusercontent.com/tue-bmd/zea/main/docs/source/notebooks/assets/plax.png"
response = requests.get(url)
img = Image.open(BytesIO(response.content)).convert("RGBA")

# Split channels
r, g, b, a = img.split()

# Composite onto a black background (RGB = 0,0,0)
black_bg = Image.new("RGBA", img.size, (0, 0, 0, 255))
img = Image.alpha_composite(black_bg, img)

# Convert to grayscale
img = img.convert("L")

# Convert to numpy
img_np = np.asarray(img).astype(np.float32)
# Convert to polar domain
img_polar_np = inverse_scan_convert_2d(img_np)

# plotting
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

# Plot input image
ax1.imshow(img_np, cmap="gray")
ax1.set_title("Cartesian", fontsize=15)
ax1.axis("off")

# Plot output with measurements
ax2.imshow(img_polar_np, cmap="gray")
ax2.set_title("Polar", fontsize=15)
ax2.axis("off")

plt.tight_layout()
plt.savefig("cartesian_polar.png")
plt.close()

[1m[38;5;36mzea[0m[0m: ❗️ It is recommended to use [34mnumpy[0m backend for `fit_scan_cone()`.


![Cartesian Polar input](./cartesian_polar.png)

### Define the downstream task function

In [3]:
from zea.models.echonetlvh import EchoNetLVH

# Load model from zeahub
model = EchoNetLVH.from_preset("echonetlvh")

def lvid_downstream_task(posterior_sample):
    posterior_sample_sc = scan_convert_2d(
        posterior_sample,
        rho_range=(0, ops.shape(img_polar_np)[0]),
        theta_range=np.deg2rad((-45, 45))
    )
    logits = model(posterior_sample_sc)
    key_points = model.extract_key_points_as_indices(logits)[0]
    lvid_bottom_coords, lvid_top_coords = key_points[1], key_points[2]
    lvid_length = ops.sqrt(
        ops.sum((lvid_top_coords - lvid_bottom_coords) ** 2)
    )
    return lvid_length

### Simulate a sparse acquisition

In [5]:
from zea.agent.selection import EquispacedLines

batch = ops.image.resize(ops.convert_to_tensor(img_polar_np[None, ..., None]), (256, 256))
batch_normalized = translate(batch, range_from=(0, 255), range_to=(-1, 1))

img_shape = (256, 256)
line_thickness = 1
factor = 32
agent = EquispacedLines(
    n_actions=img_shape[1] // line_thickness // factor,
    n_possible_actions=img_shape[1] // line_thickness,
    img_width=img_shape[1],
    img_height=img_shape[0],
)

_, mask = agent.sample()
mask = ops.expand_dims(mask, axis=-1)

measurements = ops.where(mask, batch, 0.0)

In [12]:
plt.imsave('measurements.png', measurements[0,...,0], cmap="gray")
plt.close()

![Measurements](./measurements.png)

Put the sparse measurements into a 3-frame measurement buffer, since we use a 3-frame diffusion model for perception

In [7]:
measurement_buffer = ops.concatenate((ops.zeros((1, *img_shape, 2)), measurements), axis=-1)

In [8]:
from zea.models.diffusion import DiffusionModel
diffusion_model = DiffusionModel.from_preset("diffusion-echonetlvh-3-frame")

config.json:   0%|          | 0.00/858 [00:00<?, ?B/s]

model.weights.h5:   0%|          | 0.00/31.7M [00:00<?, ?B/s]

### Perception step

In [None]:
mask.shape

(1, 256, 256, 1)

In [9]:
posterior_samples = diffusion_model.posterior_sample(
    measurements=measurement_buffer,
    mask=mask, 
    n_samples=1,
    n_steps=500,
    omega=10
)

KeyboardInterrupt: 