In [1]:
import torch
B = 256  # batch size (optional)
shape = (B, 64, 64)
high = torch.prod(torch.tensor(shape)).to(dtype=torch.long)
data = torch.arange(0, high).reshape(shape)

In [2]:
# index a single element
print(data[124, 5, 52])

tensor(508276)


In [3]:
# index all dimensions given the first is index 0 (the following are equivalent)
print(data[0])
print(data[0, :, :])
print(data[0, ...]) # pytorch only syntax

tensor([[   0,    1,    2,  ...,   61,   62,   63],
        [  64,   65,   66,  ...,  125,  126,  127],
        [ 128,  129,  130,  ...,  189,  190,  191],
        ...,
        [3904, 3905, 3906,  ..., 3965, 3966, 3967],
        [3968, 3969, 3970,  ..., 4029, 4030, 4031],
        [4032, 4033, 4034,  ..., 4093, 4094, 4095]])
tensor([[   0,    1,    2,  ...,   61,   62,   63],
        [  64,   65,   66,  ...,  125,  126,  127],
        [ 128,  129,  130,  ...,  189,  190,  191],
        ...,
        [3904, 3905, 3906,  ..., 3965, 3966, 3967],
        [3968, 3969, 3970,  ..., 4029, 4030, 4031],
        [4032, 4033, 4034,  ..., 4093, 4094, 4095]])
tensor([[   0,    1,    2,  ...,   61,   62,   63],
        [  64,   65,   66,  ...,  125,  126,  127],
        [ 128,  129,  130,  ...,  189,  190,  191],
        ...,
        [3904, 3905, 3906,  ..., 3965, 3966, 3967],
        [3968, 3969, 3970,  ..., 4029, 4030, 4031],
        [4032, 4033, 4034,  ..., 4093, 4094, 4095]])


In [4]:
# index all dimensions given the last is index 5 (the following are equivalent)
print(data[..., 5])
print(data[:, :, 5])

tensor([[      5,      69,     133,  ...,    3909,    3973,    4037],
        [   4101,    4165,    4229,  ...,    8005,    8069,    8133],
        [   8197,    8261,    8325,  ...,   12101,   12165,   12229],
        ...,
        [1036293, 1036357, 1036421,  ..., 1040197, 1040261, 1040325],
        [1040389, 1040453, 1040517,  ..., 1044293, 1044357, 1044421],
        [1044485, 1044549, 1044613,  ..., 1048389, 1048453, 1048517]])
tensor([[      5,      69,     133,  ...,    3909,    3973,    4037],
        [   4101,    4165,    4229,  ...,    8005,    8069,    8133],
        [   8197,    8261,    8325,  ...,   12101,   12165,   12229],
        ...,
        [1036293, 1036357, 1036421,  ..., 1040197, 1040261, 1040325],
        [1040389, 1040453, 1040517,  ..., 1044293, 1044357, 1044421],
        [1044485, 1044549, 1044613,  ..., 1048389, 1048453, 1048517]])


In [5]:
idx = [4, 8, 15, 16, 23, 42]

# index all dimensions given the first follows idx
print(data[idx].shape) # (len(idx), 64, 64)
print(data[idx, ...].shape)
print(data[idx, :, :].shape)

# index all dimensions given the second follows idx
print(data[:, idx].shape)
print(data[:, idx, :].shape)

torch.Size([6, 64, 64])
torch.Size([6, 64, 64])
torch.Size([6, 64, 64])
torch.Size([256, 6, 64])
torch.Size([256, 6, 64])


In [6]:
idx = [4, 8, 15, 16, 23, 42]
idx2 = [5, 2, 7, 1, 32, 4]

# index the last dimension when the first two are (4,5), (8,2), (15,7), (16,1), (23,32), and (42,4)
print(data[idx, idx2].shape)  # (len(idx), 64)

torch.Size([6, 64])


In [7]:
# indices of 5 entries
idx3 = [[0, 5, 3],
        [2, 7, 5],
        [100, 23, 45],
        [3, 6, 4],
        [4, 2, 1]]

In [8]:
print(data[idx3])

IndexError: too many indices for tensor of dimension 3

In [23]:
# easier to convert it to something that allows column indexing first
idx4 = torch.tensor(idx3)
print(data[idx4[:,0], idx4[:,1], idx4[:,2]]) # returns the 5 entries as desired
print(data[torch.unbind(idx4, -1)])  # can also use unbind

tensor([   323,   8645, 411117,  12676,  16513])
tensor([   323,   8645, 411117,  12676,  16513])


In [10]:
from multidim_indexing import torch_view as view

# simple wrapper with bounds checking
data_multi = view.TorchMultidimView(data)
# another view into the data, treating it as a batch of 2 dimensional grid data with X in [-5, 5] and Y in [0, 10]
# can specify value to assign a query if it's out of bounds (defaults to -1)
# note that the invalid value needs to be of the same type as the source, so we can't for example use float('inf') here
data_batch = view.TorchMultidimView(data, value_ranges=[[0, B], [-5, 5], [0, 10]], invalid_value=-1)
# another view into the data, treating it as a 3D grid data with X in [-2.5, 5], Y in [0, 4], and Z in [0, 10]
data_3d = view.TorchMultidimView(data, value_ranges=[[-2.5, 5], [0, 4], [0, 10]])

In [11]:
# convert index to the corresponding type (pytorch vs numpy)
key = torch.tensor(idx3, dtype=torch.long)
print(data_multi[key]) # returns the 5 entries as desired

tensor([   323,   8645, 411117,  12676,  16513])


In [22]:
# query the other views using grid values
# first, let's try keying the same 2D values across all batches
value_key_per_batch = torch.tensor([[-3.5, 0.2],
                                    [-4, 0.1],
                                    [-7, 0.5],  # this is out of bounds
                                    [3, 2]])
# number of entries to query
N = value_key_per_batch.shape[0]
print(torch.arange(B, dtype=value_key_per_batch.dtype).reshape(B, 1, 1).repeat(1, N, 1).shape)
# make the indices for all batches
value_key_batch = torch.cat(
    (torch.arange(B, dtype=value_key_per_batch.dtype).reshape(B, 1, 1).repeat(1, N, 1),
     value_key_per_batch.repeat(B, 1, 1)), dim=-1)
# keys can have an additional batch indices at the front
print(value_key_batch.shape)  # (B, N, 3)
# these 2 should be the same apart from the first batch index
print(value_key_batch[0:N])
print(value_key_batch[12*N:13*N])

# should see some -1 to indicate invalid value
print(data_batch[value_key_batch]) 

# also there is a shorthand for directly using the per batch indices
print(data_batch[value_key_per_batch.repeat(B,1,1)]) # should be the same as above

torch.Size([256, 4, 1])
torch.Size([256, 4, 3])
tensor([[[ 0.0000, -3.5000,  0.2000],
         [ 0.0000, -4.0000,  0.1000],
         [ 0.0000, -7.0000,  0.5000],
         [ 0.0000,  3.0000,  2.0000]],

        [[ 1.0000, -3.5000,  0.2000],
         [ 1.0000, -4.0000,  0.1000],
         [ 1.0000, -7.0000,  0.5000],
         [ 1.0000,  3.0000,  2.0000]],

        [[ 2.0000, -3.5000,  0.2000],
         [ 2.0000, -4.0000,  0.1000],
         [ 2.0000, -7.0000,  0.5000],
         [ 2.0000,  3.0000,  2.0000]],

        [[ 3.0000, -3.5000,  0.2000],
         [ 3.0000, -4.0000,  0.1000],
         [ 3.0000, -7.0000,  0.5000],
         [ 3.0000,  3.0000,  2.0000]]])
tensor([[[48.0000, -3.5000,  0.2000],
         [48.0000, -4.0000,  0.1000],
         [48.0000, -7.0000,  0.5000],
         [48.0000,  3.0000,  2.0000]],

        [[49.0000, -3.5000,  0.2000],
         [49.0000, -4.0000,  0.1000],
         [49.0000, -7.0000,  0.5000],
         [49.0000,  3.0000,  2.0000]],

        [[50.0000, -3.5000, 

In [18]:
value_key_3d = torch.tensor([[-2.5, 0., 0.],  # right on the boundary of validity
                             [-2.51, 0.5, 0], # out of bounds
                             [5, 4, 10] # right on the boundary
                            ]  
                        )
print(data_3d[value_key_3d]) # (0, -1 for invalid, high - 1)
print(torch.prod(torch.tensor(data.shape)) - 1)
print(high - 1)

tensor([      0,      -1, 1048575])
tensor(1048575)
tensor(1048576)
