In [None]:
!pip -q install "copick[all]" copick-utils zarr matplotlib torch

In [None]:
# Option A: open a locally-synced competition mirror
import copick
root = copick.from_file("/path/to/copick_config.json")
# Option B: open directly from the CZ CryoET Data Portal (dataset id 10440)
# This mirrors what the overview uses when exploring portal-hosted data.
root = copick.from_czcdp_datasets([10440], overlay_root="/tmp/overlay")

In [None]:
# Introspection cells from the overview
print("Pickable objects (name → label id):")
for o in root.pickable_objects:
    print(f" {o.name:>22s} → {o.label}")

print("\nFirst few runs in this project:")
for r in root.runs[:5]:
    print(" ", r.name)

In [None]:
import numpy as np, zarr

run = root.get_run(root.runs[0].name) # e.g., "TS_5_4"
vs = run.get_voxel_spacing(10.0) # 10 Å (≈ competition scale)
tomo = vs.get_tomogram("denoised") # commonly "wbp" in raw form

# Zarr stores multiple scales; "0" is unbinned, "1" is bin-2, etc.
tomo_vol = np.array(zarr.open(tomo.zarr())["0"]) # shape: (Z, Y, X)
print("Tomogram volume:", tomo_vol.shape)

In [None]:
# Read a dense segmentation (semantic labels) into a NumPy array
# If multiple users/algorithms exist, choose the one you want to inspect.

seg = run.get_segmentations()[0]
# same (Z, Y, X) layout as tomo
seg_vol = np.array(zarr.open(seg.zarr())["0"])
print("Segmentation volume:", seg_vol.shape, seg_vol.dtype)

In [None]:
import matplotlib.pyplot as plt

z = tomo_vol.shape[0] // 2 # middle slice
fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(10, 5))
ax0.imshow(tomo_vol[z], cmap="gray")
ax0.set_title("Tomogram Slice")
ax0.axis("off")
ax1.imshow(seg_vol[z], interpolation="nearest")
ax1.set_title("Segmentation Mask")
ax1.axis("off")
plt.tight_layout()

In [None]:
# Build a simple 2D sampling field by summing across z
# (For 3D patching you keep it volumetric; here we visualize the footprint.)
density_2d = (seg_vol > 0).sum(axis=0).astype(np.float32)
density_2d = density_2d / (density_2d.max() + 1e-6) # normalize to [0,1]

# Sample patch centers with probability density
rng = np.random.default_rng(0)
H, W = density_2d.shape
ps = 96 # patch size (example)
n = 30 # how many rectangles to draw

flat_indices = rng.choice(H * W, size=n, p=flat_probs)
ys, xs = np.unravel_index(flat_indices, (H, W))

# Visualize the sampling plan
fig, ax = plt.subplots(figsize=(5, 5))
ax.imshow(density_2d, cmap="jet")
for y, x in zip(ys, xs):
    r0, c0 = max(0, y-ps//2), max(0, x-ps//2)
    rect = plt.Rectangle((c0, r0), ps, ps, linewidth=1.5,
        edgecolor="white", facecolor="none", alpha=0.7)
    ax.add_patch(rect)
ax.set_title("Patches")
ax.axis("off")
plt.tight_layout()

In [None]:
import lightning.pytorch as pl
from monai.networks.nets import UNet
from monai.losses import TverskyLoss
from monai.metrics import DiceMetric

class Model(pl.LightningModule):
    def __init__(self, spatial_dims=3, in_channels=1, out_channels=7):
        super().__init__()
        # 3D U-Net with residual units
        self.model = UNet(
            spatial_dims=spatial_dims,
            in_channels=in_channels,
            out_channels=out_channels,
            channels=(48, 64, 80, 80),
            strides=(2, 2, 1),
            num_res_units=1
        )
        # Tversky loss (alpha,beta tuned via include_background)
        self.loss_fn = TverskyLoss(
            include_background=True,
            to_onehot_y=True,
            softmax=True
        )
        # Dice score metric for validation
        self.metric_fn = DiceMetric(
            include_background=False,
            reduction="mean",
            ignore_empty=True
        )

    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch['image'], batch['label'] # 3D patch & label mask
        y_pred = self(x) # forward pass
        loss = self.loss_fn(y_pred, y) # compute Tversky loss
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch['image'], batch['label']
        y_pred = self(x)
        # Convert predictions & labels to one-hot for Dice metric
        pred_onehot = [
            AsDiscrete(argmax=True, to_onehot=7)(p)
            for p in decollate_batch(y_pred)
        ]
        label_onehot = [
            AsDiscrete(to_onehot=7)(t)
            for t in decollate_batch(y)
        ]
        self.metric_fn(
            y_pred=pred_onehot, y=label_onehot) # accumulate Dice

In [None]:
# Pseudocode for model inference on a tomogram
volume = load_volume(experiment_id) # load 3D numpy array for the tomogram
    patches, coords = extract_3d_patches(
    volume, patch_size=96
) # tile the volume

mask_preds = np.zeros_like(volume, dtype=np.int8) # initialize empty mask
for patch, coord in zip(patches, coords):
    patch_tensor = torch.tensor(
        patch[None, None, ...]
    ).to(device) # shape (1,1,96,96,96)
    
    output = model(patch_tensor) # output shape (1,7,96,96,96)
    probs = torch.softmax(
    output[0], dim=0) # 7-channel probability map for the patch
    # convert to hard labels:
    # pick class with prob > 0.5 (else 0 = background)

    # thresholding yields a 7-channel boolean mask
    mask_patch = (probs > 0.5).int()

    # take the argmax across classes for each voxel
    _, pred_class = mask_patch.max(dim=0)

    # place patch prediction into full volume mask
    mask_preds[coord] = pred_class.cpu().numpy()

# Now perform connected components on each class in the full mask:
detected_particles = []

for class_label in [1,2,3,4,5,6]: # skip 0 (background)
    binary_mask = (mask_preds == class_label)
    cc = cc3d.connected_components(binary_mask) # label connected regions
    stats = cc3d.statistics(cc)

    for label_val in range(1, stats['num_objects']+1):
        # (z, y, x) in voxel indices
        centroid_vox = stats['centroids'][label_val]
        size = stats['voxel_counts'][label_val]

        if size < BLOB_MIN_SIZE:
            continue # skip tiny noise

        # Convert voxel coordinates to physical (x,y,z) and
        # record detection
        centroid_xyz = voxel_to_world(
            centroid_vox) # e.g. multiply by voxel size
        detected_particles.append(
            (experiment_id, class_label, *centroid_xyz)
        )