In [None]:
"""Various intepolation schemes for PyTorch"""
import os
import sys
import time

import numpy as np
import torch

from matplotlib import pyplot as plt
from skimage.io import imread

from EquivariantBasis.EquivariantBasis import EquivariantBasis
from interpolator.core import Interpolator
from interpolator.kernels import BilinearKernel, GaussianKernel

In [None]:
dtype = torch.FloatTensor
dtype_long = torch.LongTensor

def view(image):
    if image.shape[0] == 3:
        image = image.transpose(1,2,0)
    else:
        image = image[0,...]
    plt.figure(1, figsize=(12,8))
    plt.imshow(image)
    plt.show()

def rot(theta, image_shape):
    ## Define a grid of sample locations
    I, J = np.meshgrid(np.linspace(1, image_shape[0], num=image_shape[0]),
                       np.linspace(1, image_shape[1], num=image_shape[1]),
                       indexing='ij')
    samples = np.stack((I,J), 0)
    R = np.reshape([np.cos(theta), -np.sin(theta), np.sin(theta), np.cos(theta)], (2,2))
    shift = np.asarray(image_shape)[:,np.newaxis,np.newaxis]/2
    sample_shape = samples.shape

    samples -= shift
    samples = np.reshape(samples, (2,-1))
    samples = (R @ samples).reshape(*sample_shape)
    samples += shift
    return samples

# Source: https://www.flickr.com/photos/48509939@N07/5927758528/in/photostream/
image = imread('./owl.jpg').transpose(2,0,1)[np.newaxis,...]/255.

In [None]:
# Let's rotate the owl
samples = rot(0.1, image.shape[2:])

In [None]:
# Convert to pytorch tensors
image = torch.FloatTensor(image).type(dtype)
samples = torch.FloatTensor(samples).type(dtype)

# get grid
grid = EquivariantBasis.get_affine_grid(shape=image.shape, rot_thetas=[0.1])
aux_grid = np.transpose(grid[0], (2,0,1))
#flow to image indices
aux_grid[0] = aux_grid[0]*(image.shape[-1]/2) + (image.shape[-1]/2)
aux_grid[1] = aux_grid[1]*(image.shape[-2]/2) + (image.shape[-2]/2)
aux_grid = aux_grid[[1,0]]