# Remap for Poisson Solver

## Remap in Real Space

Note: Try remap to shear coordinates, as Nathan suggested. After, test whole setup on Kida vortex.

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import seaborn as sns
import jax_cfd.sb as cfd

import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
jax.devices()

In [None]:
size = 4

Lx = 2 * jnp.pi
Ly = 2 * jnp.pi
dx = Lx/size
dy = Ly/size

print(dy)

In [None]:
shear_rate = 3/2
time = 2/3 * 3/4
time

In [226]:
x = jnp.arange(0, size) * dx - jnp.pi + 1/2 * dx # x coordinates of cell centers
y = jnp.arange(0, size) * dy - jnp.pi + 1/2 * dy # y coordinates of cell centers

In [None]:
arr = -shear_rate * x
background_shear = - shear_rate * x
data = jnp.repeat(background_shear[:, jnp.newaxis], repeats=size, axis=-1).transpose()#jnp.array([[1,2,3,4],[1,2,3,4],[1,2,3,4],[1,2,3,4]])#jnp.repeat(arr[:, jnp.newaxis], arr.size, axis=1)
data

In [228]:
shift = jnp.sign(shear_rate) * jnp.mod(jnp.abs(shear_rate) * time * x, Ly)
m = jnp.where(shift>=0, jnp.floor(shift/dy).astype(int), jnp.ceil(shift/dy).astype(int))
eps = jnp.abs(shift/dy - m)

def shift_column(col, shifti, mi, ei):
    col_rolled = jnp.roll(col, mi)
    direction = jnp.where(shifti>=0, 1, -1)                
    col_neighbour = jnp.roll(col_rolled, direction)
    return (1.0 - ei) * col_rolled + ei * col_neighbour
    
data_remapped = jax.vmap(shift_column, in_axes=(0, 0, 0, 0))(data, shift, m, eps)

In [None]:
print(shift)
print(m)
print(eps)
print(data_remapped)

In [230]:
shift = jnp.sign(-shear_rate) * jnp.mod(jnp.abs(-shear_rate) * time * x, Ly)
m = jnp.where(shift>=0, jnp.floor(shift/dy).astype(int), jnp.ceil(shift/dy).astype(int))
eps = jnp.abs(shift/dy - m)

data_remapped_remapped = jax.vmap(shift_column, in_axes=(0, 0, 0, 0))(data_remapped, shift, m, eps)

#print(shift)
#print(m)
#print(eps)
#print(data_remapped_remapped)

In [None]:
fig, ax = plt.subplots(1,3,figsize=(12,4))
data_plot = ax[0].imshow(jnp.abs(data).transpose(), origin='lower', cmap=sns.cm.icefire, vmax=3, vmin=-3)
remap_plot = ax[1].imshow(jnp.abs(data_remapped).transpose(), origin='lower', cmap=sns.cm.icefire, vmax=3, vmin=-3)
diff = ax[2].imshow(jnp.abs(data_remapped_remapped).transpose(), origin='lower', cmap=sns.cm.icefire, vmax=3, vmin=-3)
plt.colorbar(data_plot, ax=ax[0])
plt.colorbar(remap_plot, ax=ax[1])
plt.colorbar(diff, ax=ax[2])

## Remap in Fourier Space

In [None]:
data_fft_in_y = jnp.fft.fft(data, axis=1)
#delta_y = jnp.mod(shear_rate * time * Lx, Ly)
#shift = delta_y * x/Lx
#m = jnp.where(shift>=0, jnp.floor(shift/dy).astype(int), jnp.ceil(shift/dy).astype(int))
#eps = jnp.abs(shift/dy - m)
shift = jnp.sign(shear_rate) * jnp.mod(jnp.abs(shear_rate) * time * x, Ly)
m = jnp.where(shift>=0, jnp.floor(shift/dy).astype(int), jnp.ceil(shift/dy).astype(int))
eps = jnp.abs(shift/dy - m)
print(shift)
print(m)
print(eps)

In [233]:
def shift_column(col, shifti, mi, ei):
        col_rolled = jnp.roll(col, mi)
        direction = jnp.where(shifti>=0, 1, -1)                
        col_neighbour = jnp.roll(col_rolled, direction)
        return (1.0 - ei) * col_rolled + ei * col_neighbour
    
data_fft_in_y_remap = jax.vmap(shift_column, in_axes=(0, 0, 0, 0))(data_fft_in_y, shift, m, eps)

In [234]:
data_fft_full = jnp.fft.fft(data_fft_in_y_remap, axis=0)

In [235]:
data_ifft_full = jnp.fft.ifft(data_fft_full, axis=0)

In [None]:
delta_y = -jnp.mod(shear_rate * time * Lx, Ly)
shift = delta_y * x/Lx
m = jnp.where(shift>=0, jnp.floor(shift/dy).astype(int), jnp.ceil(shift/dy).astype(int))
eps = jnp.abs(shift/dy - m)
print(shift)
print(m)
print(eps)
data_ifft_in_y_remap = jax.vmap(shift_column, in_axes=(0, 0, 0, 0))(data_ifft_full, shift, m, eps)

In [None]:
data

In [None]:
data_ifft = jnp.fft.ifft(data_ifft_in_y_remap, axis=1).real#.astype(jnp.float32)
data_ifft

In [None]:
fig, ax = plt.subplots(1,3,figsize=(12,4))
data_plot = ax[0].imshow(data.transpose(), origin='lower', cmap=sns.cm.icefire)
shifted_data_ft_plot = ax[1].imshow(data_fft_in_y_remap.real.transpose(), origin='lower', cmap=sns.cm.icefire)
shifted_data_plot = ax[2].imshow(data_ifft.transpose(), origin='lower', cmap=sns.cm.icefire)
plt.colorbar(data_plot, ax=ax[0])
plt.colorbar(shifted_data_ft_plot, ax=ax[1])
plt.colorbar(shifted_data_plot, ax=ax[2])
ax[0].set_title('Original Data')
ax[1].set_title('Shifted Data FT')
ax[2].set_title('Remap back to Original Data')

## Remap in Fourier Space using Exponential

In [None]:
data_fft = jnp.fft.fft(data, axis=0)
kx = jnp.fft.fftfreq(size, dx)
ky = jnp.fft.fftfreq(size, dy)
kx, ky = jnp.meshgrid(kx, ky, indexing='ij')
ky

In [None]:
data_fft

In [None]:
map_data_fft = data_fft * jnp.exp(-1j * shear_rate * ky * jnp.mod(time, Ly/(shear_rate*Lx)))
map_data_fft

In [None]:
map_data = jnp.fft.ifft(map_data_fft, axis=0)#.real
map_data

In [244]:
remap_data_fft = map_data_fft * jnp.exp(1j * shear_rate * ky * jnp.mod(time, Ly/(shear_rate*Lx)))

In [245]:
remap_data = jnp.fft.ifft(remap_data_fft, axis=0).real

In [None]:
fig, ax = plt.subplots(1,3,figsize=(12,4))
data_plot = ax[0].imshow(data.transpose(), origin='lower', cmap=sns.cm.icefire)
map_data_plot = ax[1].imshow(jnp.abs(map_data).transpose(), origin='lower', cmap=sns.cm.icefire)
remap_data_plot = ax[2].imshow(remap_data.transpose(), origin='lower', cmap=sns.cm.icefire)
plt.colorbar(data_plot, ax=ax[0])
plt.colorbar(map_data_plot, ax=ax[1])
plt.colorbar(remap_data_plot, ax=ax[2])
ax[0].set_title('Before Mapping')
ax[1].set_title('Mapped Data')
ax[2].set_title('Remapped Data')

## Check Remap Function

In [247]:
domain = ((-jnp.pi, jnp.pi), (-jnp.pi, jnp.pi)) 
grid = cfd.grids.Grid((size, size), domain=domain)

In [248]:
x = grid.mesh(grid.cell_center)[0]