# Test for Boundary Conditions

In [1]:
import jax.numpy as jnp

In [None]:
size = 4

Lx = 2 * jnp.pi
Ly = 2 * jnp.pi
dy = Ly/size

print(dy)

In [None]:
shear_rate = 1.0
time =11/8 
dL = jnp.mod(shear_rate * Lx * time,Ly) # shear displacement in y direction

print(shear_rate * Lx * time, dL)

In [None]:
m = jnp.floor(dL/dy).astype(int) # number of cells to shift
eps = dL/dy - m # proportional shift within a cell

print(m, eps)

In [None]:
x = jnp.array([[1,2,3,4],[2,3,4,5],[3,4,5,6],[4,5,6,7]], dtype=jnp.float32)
x

In [6]:
width = 1
if width > 0:
    full_padding = ((0, width),(0, 0))
elif width < 0:
    full_padding = ((-width, 0),(0, 0))


In [None]:
sign = -jnp.sign(width) # 1 for lower bound, -1 for upper bound
data = jnp.pad(x, full_padding, mode='wrap')

print(data)

In [None]:
data = jnp.roll(data, sign * m, axis=1)

print(data)

In [None]:
neighbour = jnp.roll(data, sign, axis=1)

print(neighbour)

In [None]:
data = (1.0 - eps) * data + eps * neighbour

print(data)

In [None]:
if width < 0:
    data = data.at[sign:,:].set(x)
elif width > 0:
    data = data.at[:sign,:].set(x)

print(data)

In [None]:
correction = data + sign * shear_rate * Lx
if width < 0:
    data = data[-width:]
    data = correction.at[-width:].set(data)
elif width > 0:
    data = data[:-width]
    data = correction.at[:-width].set(data)

print(data)