In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
from dl.utils.helpers import load_master_csv
from dl.dataset.datamodes.npz.rgbd import RGBDNPZ

# dataset_path = '/data2/jupiter/datasets/Jupiter_train_v6_2/'
# anno_path = 'master_annotations_20231019_clean.csv'
dataset_path = '/mnt/datasets/halo_rgb_stereo_train_v6_1/'
anno_path = 'master_annotations_1k.csv'
df = load_master_csv(dataset_path + anno_path)
rgbdnpz = RGBDNPZ(dataset_path)
artifacts = rgbdnpz.get_artifacts(df.iloc[0])
plt.imshow(artifacts['depth'])

In [None]:
for i, (_, row) in enumerate(df.iterrows()):
    artifacts = rgbdnpz.get_artifacts(row)
    if np.sum(artifacts['label'] == 11) > 1:
        print(i)
        break


In [None]:
from dl.dataset.datamodes.npz.rgbd import RGBDNPZ

rgbdnpz = RGBDNPZ(dataset_path)
def viz_some_images(rows):
    fig, ax = plt.subplots(len(rows), 2,  figsize=(5,16), squeeze=False)
    for i, (_, row) in enumerate(rows.iterrows()):
        artifacts = rgbdnpz.get_artifacts(row)
        ax[i][0].imshow(artifacts['image'])
        ax[i][1].imshow(artifacts['label'] == 11)
viz_some_images(df[24:25])

In [None]:
im = artifacts['image']
plt.imshow(artifacts['label'])

In [None]:
import torch
from torch.nn.functional import grid_sample

In [None]:
xvals = torch.FloatTensor(list(range(im.shape[0]))) / (.5 * im.shape[0]) - 1
yvals =  torch.FloatTensor(list(range(im.shape[1]))) / (.5 * im.shape[1]) - 1
im_tensor = torch.Tensor(im)[None, ].permute([0, 3, 1, 2])

In [None]:
grid_row, grid_col = torch.meshgrid(xvals, yvals)
grid = torch.stack([grid_col, grid_row], axis=0)[None, :]

In [None]:
out_tensor = grid_sample(im_tensor, grid.permute([0,  2, 3, 1]))
out_np = np.array(out_tensor.permute([0, 2, 3, 1])[0])
out_np.shape

In [None]:
plt.imshow(out_np)

In [None]:
import math
kernel_size = 10
sigma = 8

# Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
x_cord = torch.arange(kernel_size)
x_grid = x_cord.repeat(kernel_size).view(kernel_size, kernel_size)
y_grid = x_grid.t()
xy_grid = torch.stack([x_grid, y_grid], dim=-1)

mean = (kernel_size - 1)/2.
variance = sigma**2.

gaussian_kernel = (1./(2.*math.pi*variance)) *\
                  torch.exp(
                      -torch.sum((xy_grid - mean)**2., dim=-1) /\
                      (2*variance)
                  )
# Make sure sum of values in gaussian kernel equals 1.
gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)

# Reshape to 2d depthwise convolutional weight
gaussian_kernel = gaussian_kernel.view(1, kernel_size, kernel_size)
gaussian_kernel = torch.nn.Parameter(gaussian_kernel.repeat(1, 1, 1, 1))
gaussian_kernel.requires_grad = False

In [None]:
saliency = artifacts['label']

In [None]:
saliency_torch = torch.Tensor(saliency)[None, None, :, :]

In [None]:
xvals = torch.FloatTensor(list(range(im.shape[0] // 2))) / (.25 * im.shape[0]) - 1
yvals =  torch.FloatTensor(list(range(im.shape[1] // 2))) / (.25 * im.shape[1]) - 1
grid_row, grid_col = torch.meshgrid(xvals, yvals)

In [None]:
conv = torch.nn.Conv2d(1, 1, kernel_size, bias=False, padding='same', padding_mode='replicate')
conv.weight = gaussian_kernel
new_grid_num_x = conv.forward(saliency_torch * grid_col)
new_grid_num_y = conv.forward(saliency_torch * grid_row)
new_grid_denom = conv.forward(saliency_torch)

new_grid = torch.concat([new_grid_num_x, new_grid_num_y], dim=1) / new_grid_denom

In [None]:
out_tensor = grid_sample(im_tensor, new_grid.permute([0, 2, 3, 1]))
out_np = out_tensor.permute([0, 2, 3, 1])[0].detach().numpy()
plt.imshow(out_np)
# Looks good but I don't thinks we invert this quickly on onnx

# OK let's try a simple, fixed transform. If we fix the transform globally then we should be able to invert it quickly

In [None]:
class NonUniformSamplingLayer(torch.nn.Module):
    def __init__(
            self,
            downsample_rows: int = 32,
            downsample_cols: int = 32,
            n_row_divisions: int = 16,
            n_col_divisions: int = 16,
        ):
        super().__init__()
        assert downsample_rows % n_row_divisions == 0
        assert downsample_cols % n_col_divisions == 0
        self.row_expand_copies = downsample_rows // n_row_divisions
        self.col_expand_copies = downsample_cols // n_col_divisions

    def performn_nonuniform_s(self, image: torch.Tensor, saliency: torch.Tensor):
        """
        image shape: [B, C, H, W]
        saliency shape: floats with value >0 corresponding to the desired zoom for a tile of the image.
        Shape [B, 1, n_row_divisions, n_col_divisions]
        """
        # For simpler computation, we sum the saliency of each row and determine the total size of a row
        # proportional to that
        x_weights = torch.repeat_interleave(torch.mean(saliency, dim=3), repeats=self.row_expand_copies, dim=2)
        y_weights = torch.repeat_interleave(torch.mean(saliency, dim=2), repeats=self.col_expand_copies, dim=2)

        # The ouput image coordinate at [r,c] will take the value of the image at [xvals[r], yvals[c]], normalized
        # to lie within [-1, 1]
        xvals = 2 * torch.cumsum(x_weights, 0) / torch.sum(x_weights) - 1
        yvals = 2 * torch.cumsum(y_weights, 0) / torch.sum(y_weights) - 1


In [None]:
x_weights = torch.repeat_interleave(torch.mean(saliency_torch_resize, dim=3), repeats=2, dim=2)
y_weights = torch.repeat_interleave(torch.mean(saliency_torch_resize, dim=2), repeats=2, dim=2)

In [None]:
x_weights.shape

In [None]:
downsample_rows = 32
downsample_cols = 32
n_row_divisions = 16
n_col_divisions = 16
x_weight_base_tensor = torch.arange(.1, 1.1, 1 / n_row_divisions)[None, :]
y_weight_base_tensor = torch.arange(.1, 1.1, 1 / n_col_divisions)[None, :]
print(len(x_weight_base_tensor))
x_weights = torch.nn.Parameter(x_weight_base_tensor)
y_weights = torch.nn.Parameter(x_weight_base_tensor)

In [None]:
xvals = 2 * torch.cumsum(x_weights, 0) / torch.sum(x_weights) - 1
yvals = 2 * torch.cumsum(y_weights, 0) / torch.sum(y_weights) - 1

In [None]:
grid_row, grid_col = torch.meshgrid(xvals[1:-1], yvals[:-1])
grid = torch.stack([grid_col, grid_row], axis=0)[None, :]
out_tensor = grid_sample(im_tensor, grid.permute([0,  2, 3, 1]))
out_np = np.array(out_tensor.permute([0, 2, 3, 1])[0].detach())
plt.imshow(out_np)

In [None]:
transformed_x = torch.arange(-1, 1.0001, 2 / downsample_rows)
transformed_y = torch.arange(-1, 1.0001, 2 / downsample_cols)
target_orig_x = torch.arange(-1, 1, 2 / im.shape[0])
target_orig_y = torch.arange(-1, 1, 2 / im.shape[1])

In [None]:
xvals_pad = torch.cat((-torch.ones(1), xvals))
yvals_pad = torch.cat((-torch.ones(1), xvals))

In [None]:
# I think searchsorted is not available with onnx which is why using global constants will be good
x_r_pix = torch.searchsorted(xvals_pad, target_orig_x, side='right')
y_r_pix = torch.searchsorted(yvals_pad, target_orig_y, side='right')

In [None]:
x_l_vals = torch.index_select(xvals_pad, 0, x_r_pix - 1)
x_r_vals = torch.index_select(xvals_pad, 0, x_r_pix)
y_l_vals = torch.index_select(yvals_pad, 0, y_r_pix - 1)
y_r_vals = torch.index_select(yvals_pad, 0, y_r_pix)

In [None]:
x_l_inds = transformed_x[x_r_pix - 1]
x_r_inds = transformed_x[x_r_pix]
y_l_inds = transformed_y[y_r_pix - 1]
y_r_inds = transformed_y[y_r_pix]

In [None]:
x_inds = torch.lerp(x_l_inds, x_r_inds, (target_orig_x - x_l_vals) / (x_r_vals - x_l_vals))
y_inds = torch.lerp(y_l_inds, y_r_inds, (target_orig_y - y_l_vals) / (y_r_vals - y_l_vals))

In [None]:
grid_row_inv, grid_col_inv = torch.meshgrid(x_inds, y_inds)
grid = torch.stack([grid_col_inv, grid_row_inv], axis=0)[None, :]
out_tensor_inv = grid_sample(out_tensor, grid.permute([0,  2, 3, 1]))
out_inv_np = out_tensor_inv.permute([0, 2, 3, 1])[0].detach().numpy()
plt.imshow(out_inv_np)

In [None]:
plt.imshow(im)