In [1]:
import os
from glob import glob
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
import torch
from torchvision.io import decode_jpeg, read_file
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [45]:
def generate_3d_grid(shape=(32, 32, 32), device='cpu'):
    """Generate a 3D grid of shape (D, H, W) with values from -1 to 1 in each axis."""
    D, H, W = shape
    z = torch.linspace(-1, 1, steps=D, device=device)
    y = torch.linspace(-1, 1, steps=H, device=device)
    x = torch.linspace(-1, 1, steps=W, device=device)
    zz, yy, xx = torch.meshgrid(z, y, x, indexing='ij')  # Shape: (D, H, W)
    grid = torch.stack([zz, yy, xx], dim=0)  # Shape: (3, D, H, W)
    return grid, xx

In [76]:
volume, svolume = generate_3d_grid()

In [79]:
svolume = svolume.unsqueeze(dim=0)

In [112]:
# Top-k keypoints instead of dynamic torch.where
flat_nms = volume.contiguous().view(volume.shape[0], -1)  # (B, D*H*W)

In [113]:
conf_vals, indices = torch.topk(flat_nms, k=10, dim=1)  # (B, K)

# Convert flat indices to 3D (z, y, x)
zyx = torch.stack(torch.unravel_index(indices, volume.shape[-3:]), dim=-1)  # (3, B, K)
# zyx = zyx.permute(1, 2, 0)  # (B, K, 3)

In [120]:
conf_vals.view(-1).shape

torch.Size([30])

In [110]:
zyx.shape

torch.Size([1, 10, 3])

In [137]:
batch_idx = torch.arange(zyx.shape[0]).unsqueeze(1).expand(zyx.shape[0], 10)

In [136]:
zyx.shape[0]

3

In [139]:
batch_idx

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2]])

In [88]:
batch_idx = batch_idx.reshape(-1, 1)
zyx_flat = zyx.reshape(-1, 3)

In [90]:
batch_idx.shape

torch.Size([10, 1])

In [93]:
bzyx = torch.cat([batch_idx, zyx_flat], dim = 1)

In [108]:
zeros = torch.zeros((2, 2, 5, 5, 5))

In [109]:
zeros.squeeze(dim=1).shape

torch.Size([2, 2, 5, 5, 5])