In [1]:
%load_ext autoreload
%autoreload 2
import tiled_convolution as tc

In [2]:
import numpy as np
import torch

In [3]:
size_input = (21, 24)
size_kernel = (3, 4)
stride = (1, 1)
dilation = (1, 1)

# padding = (0, 6)
padding = 'same'

In [4]:
image = torch.rand(size=size_input, dtype=torch.float32)
kernel = torch.rand(size=size_kernel, dtype=torch.float32)

In [5]:
padding_val = tc.compute_padding_amount(
    size_input=size_input,
    size_kernel=size_kernel,
    stride=stride,
    dilation=dilation,
    padding=padding,
)

padding_val

(1, 1, 1, 2)

In [6]:
size_output = tc.conv2d_output_size(
    size_input=size_input,
    size_kernel=size_kernel,
    padding=padding_val,
    stride=stride,
    dilation=dilation,
)
size_output

(21, 24)

In [7]:
idx_out_top, idx_out_bottom, idx_out_left, idx_out_right = 0, size_output[0] - 1, 0, size_output[1] - 1
idx_out_top, idx_out_bottom, idx_out_left, idx_out_right

(0, 20, 0, 23)

In [8]:
idx_top, _, idx_left, _ = tc.get_receptive_field_indices(
    indices=(idx_out_top, idx_out_left),
    padding=padding_val,
    size_kernel=size_kernel,
    stride=stride,
    dilation=dilation,
)
idx_top, idx_left

(-1, -1)

In [9]:
_, idx_bottom, _, idx_right = tc.get_receptive_field_indices(
    indices=(idx_out_bottom, idx_out_right),
    padding=padding_val,
    size_kernel=size_kernel,
    stride=stride,
    dilation=dilation,
)
idx_bottom, idx_right

(21, 25)

In [10]:
padding_for_tile = tc.indices_to_padding(
    indices=(idx_top, idx_bottom, idx_left, idx_right),
    size_input_full=size_input,
)
padding_for_tile

(1, 1, 1, 2)

In [11]:
tile_padded = tc.pad_tile(
    tile=image,
    padding=padding_for_tile,
)
tile_padded.shape

torch.Size([23, 27])

In [12]:
out_custom = torch.nn.functional.conv2d(
    input=tile_padded[None, None, :, :],
    weight=kernel[None, None, :, :],
    stride=stride,
    padding='valid',
    dilation=dilation,
)
out_custom.shape

torch.Size([1, 1, 21, 24])

In [13]:
out_torch = torch.nn.functional.conv2d(
    input=image[None, None, :, :],
    weight=kernel[None, None, :, :],
    stride=stride,
    padding=padding,
    dilation=dilation,
)
out_torch.shape

  out_torch = torch.nn.functional.conv2d(


torch.Size([1, 1, 21, 24])

In [14]:
out_custom

tensor([[[[1.1680, 1.4680, 2.0761, 1.9334, 2.7274, 2.5891, 1.9201, 1.5034,
           1.1677, 1.3519, 2.1082, 2.1855, 2.0022, 2.4742, 2.0512, 2.4671,
           3.1116, 2.4050, 2.2750, 2.4208, 2.2453, 2.8572, 1.9628, 1.5523],
          [2.1124, 2.5579, 2.9993, 2.7260, 3.6293, 3.5731, 3.2433, 3.5011,
           2.2640, 1.9230, 2.6361, 3.2479, 3.7682, 3.1345, 2.7856, 3.0936,
           3.5349, 2.8015, 3.3131, 2.7763, 2.5641, 3.4692, 2.3083, 2.2438],
          [2.5488, 3.1607, 3.0132, 3.0381, 3.4410, 3.6903, 3.9633, 3.5094,
           2.8501, 2.6991, 2.6489, 3.5937, 3.1688, 3.0379, 2.6715, 2.9625,
           2.9373, 2.6916, 2.8819, 2.3220, 2.5514, 2.6305, 1.8973, 1.8253],
          [3.1550, 3.1235, 3.7999, 2.9618, 3.2402, 3.1960, 3.3345, 3.8692,
           2.9735, 3.3827, 2.7628, 2.8919, 2.8341, 3.3131, 2.7631, 2.2389,
           2.7483, 2.8846, 2.4299, 2.5873, 2.8139, 2.3843, 1.8999, 1.8021],
          [2.6745, 3.4760, 3.7582, 3.0131, 2.3905, 3.1122, 3.4268, 3.9708,
           3.5823, 3.

In [15]:
out_torch

tensor([[[[1.1680, 1.4680, 2.0761, 1.9334, 2.7274, 2.5891, 1.9201, 1.5034,
           1.1677, 1.3519, 2.1082, 2.1855, 2.0022, 2.4742, 2.0512, 2.4671,
           3.1116, 2.4050, 2.2750, 2.4208, 2.2453, 2.8572, 1.9628, 1.5523],
          [2.1124, 2.5579, 2.9993, 2.7260, 3.6293, 3.5731, 3.2433, 3.5011,
           2.2640, 1.9230, 2.6361, 3.2479, 3.7682, 3.1345, 2.7856, 3.0936,
           3.5349, 2.8015, 3.3131, 2.7763, 2.5641, 3.4692, 2.3083, 2.2438],
          [2.5488, 3.1607, 3.0132, 3.0381, 3.4410, 3.6903, 3.9633, 3.5094,
           2.8501, 2.6991, 2.6489, 3.5937, 3.1688, 3.0379, 2.6715, 2.9625,
           2.9373, 2.6916, 2.8819, 2.3220, 2.5514, 2.6305, 1.8973, 1.8253],
          [3.1550, 3.1235, 3.7999, 2.9618, 3.2402, 3.1960, 3.3345, 3.8692,
           2.9735, 3.3827, 2.7628, 2.8919, 2.8341, 3.3131, 2.7631, 2.2389,
           2.7483, 2.8846, 2.4299, 2.5873, 2.8139, 2.3843, 1.8999, 1.8021],
          [2.6745, 3.4760, 3.7582, 3.0131, 2.3905, 3.1122, 3.4268, 3.9708,
           3.5823, 3.

In [16]:
torch.allclose(
    out_custom,
    out_torch,
)

True

In [17]:
size_input = (41, 34)
size_kernel = (3, 4)
stride = (1, 1)
dilation = (1, 1)

# padding = (0, 6)
padding = 'same'

In [18]:
image = torch.rand(size=size_input, dtype=torch.float32)
kernel = torch.rand(size=size_kernel, dtype=torch.float32)

In [19]:
## try tiling with a small tile size

size_tile = (9, 15)

In [20]:
print(f'kernel shape: {kernel.shape}')
print(f'padding: {padding}')

padding_val = tc.compute_padding_amount(
    size_input=size_input,
    size_kernel=size_kernel,
    stride=stride,
    dilation=dilation,
    padding=padding,
)
print(f'padding_val: {padding_val}')

shape_out = tc.conv2d_output_size(
    size_input=size_input,
    size_kernel=size_kernel,
    padding=padding_val,
    stride=stride,
    dilation=dilation,
)
print(f'shape_out: {shape_out}')

out = torch.empty(
    size=shape_out,
    dtype=torch.float32,
)

idx_tiles = [(ii, min(ii+size_tile[0], shape_out[0])-1, jj, min(jj+size_tile[1], shape_out[1])-1) for ii in range(0, shape_out[0], size_tile[0]) for jj in range(0, shape_out[1], size_tile[1])]

# loop over the tiles
for i_tile, (idx_out_top, idx_out_bottom, idx_out_left, idx_out_right) in enumerate(idx_tiles):
    print(f"Tile: {i_tile} / {len(idx_tiles)}")
    print(f"idx_out: {idx_out_top, idx_out_bottom, idx_out_left, idx_out_right}")
    # get the input indices for the tile
    idx_in_top, _, idx_in_left, _ = tc.get_receptive_field_indices(
        indices=(idx_out_top, idx_out_left),
        padding=padding_val,
        size_kernel=size_kernel,
        stride=stride,
        dilation=dilation,
    )
    _, idx_in_bottom, _, idx_in_right = tc.get_receptive_field_indices(
        indices=(idx_out_bottom, idx_out_right),
        padding=padding_val,
        size_kernel=size_kernel,
        stride=stride,
        dilation=dilation,
    )
    print(f"idx_in: {idx_in_top, idx_in_bottom, idx_in_left, idx_in_right}")

    idx_in_top_clip, idx_in_bottom_clip = max(0, idx_in_top), min(size_input[0] - 1, idx_in_bottom)
    idx_in_left_clip, idx_in_right_clip = max(0, idx_in_left), min(size_input[1] - 1, idx_in_right)
    print(f"idx_in_clip: {idx_in_top_clip, idx_in_bottom_clip, idx_in_left_clip, idx_in_right_clip}")

    # get the tile
    tile = image[idx_in_top_clip:idx_in_bottom_clip + 1, idx_in_left_clip:idx_in_right_clip + 1]
    print(f'tile in shape: {tile.shape}')
    
    # get the padding for the tile
    padding_for_tile = tc.indices_to_padding(
        indices=(idx_in_top, idx_in_bottom, idx_in_left, idx_in_right),
        size_input_full=size_input,
    )
    print(f'padding for tile: {padding_for_tile}')

    # pad the tile
    tile_padded = tc.pad_tile(
        tile=tile,
        padding=padding_for_tile,
    )
    print(f'tile padded shape: {tile_padded.shape}')

    # compute the output for the tile
    out_custom = torch.nn.functional.conv2d(
        input=tile_padded[None, None, :, :],
        weight=kernel[None, None, :, :],
        stride=stride,
        padding='valid',
        dilation=dilation,
    )
    print(f'out_custom shape: {out_custom.shape}')
    print(f'target indices: {slice(idx_out_top, idx_out_bottom + 1)}, {slice(idx_out_left, idx_out_right + 1)}')
    print(f'target shape: {out[slice(idx_out_top, idx_out_bottom + 1), slice(idx_out_left, idx_out_right + 1)].shape}')

    # assign the output to the correct location
    out[slice(idx_out_top, idx_out_bottom + 1), slice(idx_out_left, idx_out_right + 1)] = out_custom[0, 0, :, :]

kernel shape: torch.Size([3, 4])
padding: same
padding_val: (1, 1, 1, 2)
shape_out: (41, 34)
Tile: 0 / 15
idx_out: (0, 8, 0, 14)
idx_in: (-1, 9, -1, 16)
idx_in_clip: (0, 9, 0, 16)
tile in shape: torch.Size([10, 17])
padding for tile: (1, 0, 1, 0)
tile padded shape: torch.Size([11, 18])
out_custom shape: torch.Size([1, 1, 9, 15])
target indices: slice(0, 9, None), slice(0, 15, None)
target shape: torch.Size([9, 15])
Tile: 1 / 15
idx_out: (0, 8, 15, 29)
idx_in: (-1, 9, 14, 31)
idx_in_clip: (0, 9, 14, 31)
tile in shape: torch.Size([10, 18])
padding for tile: (1, 0, 0, 0)
tile padded shape: torch.Size([11, 18])
out_custom shape: torch.Size([1, 1, 9, 15])
target indices: slice(0, 9, None), slice(15, 30, None)
target shape: torch.Size([9, 15])
Tile: 2 / 15
idx_out: (0, 8, 30, 33)
idx_in: (-1, 9, 29, 35)
idx_in_clip: (0, 9, 29, 33)
tile in shape: torch.Size([10, 5])
padding for tile: (1, 0, 0, 2)
tile padded shape: torch.Size([11, 7])
out_custom shape: torch.Size([1, 1, 9, 4])
target indices:

In [21]:
torch.allclose(
    out,
    torch.nn.functional.conv2d(
        input=image[None, None, :, :],
        weight=kernel[None, None, :, :],
        stride=stride,
        padding=padding,
        dilation=dilation,
    )[0, 0, :, :],
)

True

In [22]:
size_input = (50000, 100000)
size_kernel = (11, 11)
stride = (1, 1)
dilation = (1, 1)

# padding = (0, 6)
padding = 'same'

In [23]:
image = torch.rand(size=size_input, dtype=torch.float32)
kernel = torch.rand(size=size_kernel, dtype=torch.float32)

In [24]:
## try tiling with a small tile size

size_tile = (1600, 1600)

In [25]:
out = tc.conv2d_tiled(
    arr=image,
    kernel=kernel,
    size_tile=size_tile,
    stride=stride,
    padding=padding,
    dilation=dilation,
    device_compute='cuda:0',
    device_return='cpu',
    dtype_compute=torch.float32,
)

In [None]:
conv2d_tiled_compiled = torch.jit.script(tc.conv2d_tiled)

In [35]:
out = conv2d_tiled_compiled(
    arr=image,
    kernel=kernel,
    size_tile=size_tile,
    stride=stride,
    padding=padding,
    dilation=dilation,
    device_compute='cuda:0',
    device_return='cpu',
    dtype_compute=torch.float32,
)

In [29]:
torch.allclose(
    out,
    torch.nn.functional.conv2d(
        input=image[None, None, :, :],
        weight=kernel[None, None, :, :],
        stride=stride,
        padding=padding,
        dilation=dilation,
    )[0, 0, :, :],
)

True

In [3]:
import zarr

In [36]:
size_input = (50000, 100000)
size_kernel = (11, 11)
stride = (1, 1)
dilation = (1, 1)

# padding = (0, 6)
padding = 'same'

In [37]:
image = torch.rand(size=size_input, dtype=torch.float32)
kernel = torch.rand(size=size_kernel, dtype=torch.float32)

In [38]:
## try tiling with a small tile size

size_tile = (2000, 2000)

In [39]:
zarr_path_input = '/media/rich/bigSSD/tmp/image.zarr'
store = zarr.storage.LocalStore(zarr_path_input)

zarr_input = zarr.create_array(
    shape=size_input,
    chunks=size_tile,
    dtype=tc.torch_dtype_to_numpy_dtype(image.dtype),
    store=store,
    overwrite=True,
    order='C',
)
zarr_input[:] = image.numpy()

In [40]:
shape_output = tc.conv2d_output_size(
        size_input=size_input,
        size_kernel=size_kernel,
        padding=tc.compute_padding_amount(
            size_input=size_input,
            size_kernel=size_kernel,
            stride=stride,
            dilation=dilation,
            padding=padding,
        ),
        stride=stride,
        dilation=dilation,
    )

zarr_path_input = '/media/rich/bigSSD/tmp/output.zarr'
store = zarr.storage.LocalStore(zarr_path_input)

zarr_output = zarr.create_array(
    store=store,
    shape=shape_output,
    # shards=(1000, 1000),
    chunks=size_tile,
    dtype=np.float32,
    overwrite=True,
    order='C',
)

In [None]:
out = tc.conv2d_tiled(
    input=zarr_input,
    kernel=kernel,
    size_tile=size_tile,
    stride=stride,
    padding=padding,
    dilation=dilation,
    output=zarr_output,
    device_compute='cpu',
    device_return='cpu',
    dtype_compute=torch.float32,
    dtype_return=None,
    # kind_return='zarr',
    verbose=True,
)

Processing tiles:   0%|          | 0/1250 [00:00<?, ?it/s]

In [None]:
zarr_array = zarr.create_array(
    store=store,
    shape=(5000, 10000),
    # shards=(1000, 1000),
    chunks=(1000, 1000),
    dtype=np.float32,
)

In [None]:
zarr_array[:] = image.numpy()

In [103]:
store = zarr.storage.LocalStore(zarr_path)

zarr_array = zarr.create_array(
    store=store,
    shape=out.shape,
    chunks=(1000, 1000),
    dtype=np.float32,
    overwrite=True,
)

In [106]:
zarr_array[:] = out.numpy()