In [118]:
import torch
import einops
import numpy as np


def retrieve_query_features(global_tokens_vol, query_pos, neighborhood_size=3):
    """
    Project reference points onto multi-view image to fetch appearence features
    Bilinear interpolation is used to fetch features.
    Average pooling is used to aggregate features from different views.
    """

    assert neighborhood_size % 2 == 1, "neighborhood size must be odd"

    # TODO: make this work with neighborhoods
    B, C, T, H, W = global_tokens_vol.shape
    B, N, _ = query_pos.shape
    query_tokens = []
    query_tokens = torch.empty(B, N, C).to(global_tokens_vol.device)
    global_tokens_vol = global_tokens_vol.permute(0, 2, 3, 4, 1)
    for b in range(B):
        for n in range(N):
            query_tokens[b, n, :] = global_tokens_vol[
                b, query_pos[b, n, 0], query_pos[b, n, 1], query_pos[b, n, 2], :
            ]
    return query_tokens

In [5]:
original_size = 2 * 96 * 24 * 24 * 24

In [6]:
original_size

2654208

## Checklist

- Tokenization: DONE
- 3D Encoding: DONE
- Feature retrieval: DONE

### Verify that tokenization preserves spatial information

In [94]:
bs = 4
f = 3
t = 24
h = 24
w = 24
orig = torch.ones(bs, f, t, h, w)

In [95]:
for i in range(24):
    for j in range(24):
        for k in range(24):
            orig[:, 0, i, j, k] = i
            orig[:, 1, i, j, k] = j
            orig[:, 2, i, j, k] = k

In [96]:
orig[0, :, 0, 2, 0]

tensor([0., 2., 0.])

In [110]:
dt, dh, dw = 2, 2, 2
patch_size = dt, dh, dw
patches_vol = einops.rearrange(
    orig,
    "b c (t dt) (h dh) (w dw) -> b t h w (dt dh dw c)",
    dh=patch_size[1],
    dw=patch_size[2],
    dt=patch_size[0],
)

# OK

In [111]:
patches_vol[0, 4, 1, 4, :]

# OK

tensor([8., 2., 8., 8., 2., 9., 8., 3., 8., 8., 3., 9., 9., 2., 8., 9., 2., 9.,
        9., 3., 8., 9., 3., 9.])

In [112]:
patches = einops.rearrange(patches_vol, "b t h w c -> b t (h w) c")
patches = einops.rearrange(patches, "b t n c -> b (t n) c")

In [117]:
print(patches.shape)

torch.Size([4, 1728, 24])


### Verify that retrieving tokens from flattened data works

In [113]:
global_tokens_volume = patches.view(bs, t // dt, h // dh, w // dw, -1)
global_tokens_volume = global_tokens_volume.permute(0, 4, 1, 2, 3)

In [116]:
global_tokens_volume.shape

torch.Size([4, 24, 12, 12, 12])

In [115]:
global_tokens_volume[0, :, 4, 1, 4] == patches_vol[0, 4, 1, 4, :]

tensor([True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True])

In [121]:
pos = torch.tensor([[[1, 3, 4], [10, 1, 11]]])

pos_feats = retrieve_query_features(global_tokens_volume, pos)

In [122]:
pos_feats

tensor([[[ 2.,  6.,  8.,  2.,  6.,  9.,  2.,  7.,  8.,  2.,  7.,  9.,  3.,  6.,
           8.,  3.,  6.,  9.,  3.,  7.,  8.,  3.,  7.,  9.],
         [20.,  2., 22., 20.,  2., 23., 20.,  3., 22., 20.,  3., 23., 21.,  2.,
          22., 21.,  2., 23., 21.,  3., 22., 21.,  3., 23.]]])

### Verify that positional encoding matches positional information from tokenization

In [138]:
patches[0, 239, :]

tensor([ 2., 14., 22.,  2., 14., 23.,  2., 15., 22.,  2., 15., 23.,  3., 14.,
        22.,  3., 14., 23.,  3., 15., 22.,  3., 15., 23.])

In [139]:
def create_pos_volume(depth, height, width):
    indices = np.indices((depth, height, width))
    arr = np.transpose(indices, (1, 2, 3, 0)).reshape(-1, 3)
    volume_coords = torch.from_numpy(arr).long()
    return volume_coords


pos_volume = create_pos_volume(12, 12, 12)

In [140]:
pos_volume_vol = pos_volume.unsqueeze(0).repeat(bs, 1, 1)
pos_volume_vol.shape

torch.Size([4, 1728, 3])

In [141]:
pos_volume_vol[0, 239, :]

tensor([ 1,  7, 11])

In [137]:
pos_volume_vol[0, 4, 1, 4, :]

tensor([4, 1, 4])