# Workshop on Domain-Specific Lanugages for Performance-Portable Weather and Climate Models

## Session 2A: Intro to Conditionals and Builtins

This notebook is provides an introduction to conditional statements in GT4PY

In [None]:
import matplotlib.pyplot as plt
import numpy as np

import gt4py
from gt4py import gtscript

## Scalar Conditionals
- Conditionals can be specified with a scalar or a field.
- Scalar conditionals behave fairly intuitively and are applied across a field

In [None]:
@gtscript.stencil(backend="numpy")
def conditional_stencil(
    in_storage: gtscript.Field[float], out_storage: gtscript.Field[float], c: float
):

    with computation(PARALLEL), interval(...):
        if c > 0.:
            out_storage = (
                in_storage[0, 0, 0] + in_storage[1, 0, 0]
            )
        else:
            out_storage = 0

In [None]:
# Setting up the domain
nx = 100
halo = 3
shape = (nx+2*halo, nx+2*halo, 1)

# and the storages
in_storage = gt4py.storage.zeros(
    shape=shape,
    default_origin=(0, 0, 0),
    dtype=float,
    backend="numpy",
)
in_storage.data[halo + nx // 3 : halo +  2 * nx  // 3 : 2, halo + nx // 3 : halo + 2 * nx // 3, :] = 2
# in_storage.data[halo + nx // 3 + 1 : halo + 2 * nx // 3 :2, halo + nx // 3 : halo + 2 * nx // 3, :] = -1
out_storage = gt4py.storage.zeros(
    shape=shape,
    default_origin=(0, 0, 0),
    dtype=float,
    backend="numpy",
)

# running the computataion
print("Initial Condition")
plt.imshow(in_storage.data[:,:,0])
plt.show()

conditional_stencil(in_storage, out_storage, 1.0, origin=(halo-1, halo-1, 0), domain=(nx+1, nx+1, 1))

print("Output with c > 0")
plt.imshow(out_storage.data[:,:,0])
plt.show()

conditional_stencil(in_storage, out_storage, -1.0, origin=(halo-1, halo-1, 0), domain=(nx+1, nx+1, 1))

print("Output with c <= 0")
plt.imshow(out_storage.data[:,:,0])
plt.show()

## Field conditionals
Fields can also be used in conditional statements
- This allows different points to be in different branches

In [None]:
@gtscript.stencil(backend="numpy")
def field_conditional_stencil(
    in_storage: gtscript.Field[float], out_storage: gtscript.Field[float], filter_storage: gtscript.Field[float]
):
    with computation(PARALLEL), interval(...):
        if filter_storage[-1, 0, 0] > 0.:
            out_storage = 0
        else:
            out_storage = in_storage[0, 0, 0]

# Setting up the domain
nx = 100
halo = 3
shape = (nx+2*halo, nx+2*halo, 1)

# and the storages
filter_storage = gt4py.storage.zeros(
    shape=shape,
    default_origin=(0, 0, 0),
    dtype=float,
    backend="numpy",
)
filter_storage.data[halo + nx // 3 : halo +  2 * nx  // 3 : 2, halo + nx // 3 : halo + 2 * nx // 3, :] = 1
filter_storage.data[halo + nx // 3 + 1 : halo + 2 * nx // 3 :2, halo + nx // 3 : halo + 2 * nx // 3, :] = -1
filter_storage.data[halo + nx // 3 : halo + 2 * nx // 3, halo + nx // 3 : halo + 2 * nx // 3 : 2, :] *= -1

in_storage = gt4py.storage.zeros(
    shape=shape,
    default_origin=(0, 0, 0),
    dtype=float,
    backend="numpy",
)
for i in range(nx//3):
    in_storage.data[halo + nx // 3 + i, halo + nx // 3 : halo + 2 * nx // 3, :] = i+1

out_storage = gt4py.storage.zeros(
    shape=shape,
    default_origin=(0, 0, 0),
    dtype=float,
    backend="numpy",
)

# running the computataion
print("Initial Condition")
plt.imshow(in_storage.data[:,:,0])
plt.show()
print(in_storage.data[:,45,0])


print("Conditional Filter")
plt.imshow(filter_storage.data[:,:,0])
plt.show()

field_conditional_stencil(in_storage, out_storage, filter_storage, origin=(halo-1, halo-1, 0), domain=(nx+1, nx+1, 1))

print("Output")
plt.imshow(out_storage.data[:,:,0])
plt.show()
print(out_storage.data[:,45,0])

## More on Field Conditionals
- Field conditionals use a mask to check branching
- This allows some fun referential tech

In [None]:
@gtscript.stencil(backend="numpy")
def field_conditional_stencil(
    in_storage: gtscript.Field[float], out_storage: gtscript.Field[float]
):

    with computation(PARALLEL), interval(...):
        if in_storage[-1, 0, 0] > 0.:
            out_storage = 10
        else:
            out_storage = in_storage[0, 0, 0] + in_storage[1, 0, 0]

# Setting up the domain
nx = 100
halo = 3
shape = (nx+2*halo, nx+2*halo, 1)

# and the storages
in_storage = gt4py.storage.zeros(
    shape=shape,
    default_origin=(0, 0, 0),
    dtype=float,
    backend="numpy",
)
in_storage.data[halo + nx // 3 : halo +  2 * nx  // 3 : 2, halo + nx // 3 : halo + 2 * nx // 3, :] = 1
out_storage = gt4py.storage.zeros(
    shape=shape,
    default_origin=(0, 0, 0),
    dtype=float,
    backend="numpy",
)

# running the computataion
print("Initial Condition")
plt.imshow(in_storage.data[:,:,0])
plt.show()
print(in_storage.data[:,45,0])

field_conditional_stencil(in_storage, out_storage, origin=(halo-1, halo-1, 0), domain=(nx+1, nx+1, 1))

print("Output")
plt.imshow(out_storage.data[:,:,0])
plt.show()
print(out_storage.data[:,45,0])

<div class="alert alert-block alert-info">
    <b> Now it's your turn: </b><br>
    (Hint: Make sure that when you modify code you retain the original code by commenting it out so that you undo any of the modifications you do.)
    <ol>
        <li style="margin-bottom: 10px">Use conditionals to add a simple flux limiter <code>x_flux = min(0, sign(x_flux, delta_x))</code> to the <code>diffusion</code> stencil below. Python does not have an equivalent to Fortran's <code>sign</code>, so you can either define your own as a separate function or write the equivalent code in the stencil.</li>
        <li style="margin-bottom: 10px">Open the <code>.gt_cache</code> directory and inspect the generated code.</li>
    </ol>
</div>

In [None]:
# Solution 1
@gtscript.stencil(backend="numpy")
def diffusion(
    in_storage: gtscript.Field[float], out_storage: gtscript.Field[float], alpha: float
):

    with computation(PARALLEL), interval(...):
        lap = (
            -4.0 * in_storage[0, 0, 0]
            + in_storage[1, 0, 0]
            + in_storage[-1, 0, 0]
            + in_storage[0, 1, 0]
            + in_storage[0, -1, 0]
        )
        
        x_flux = lap[1, 0, 0] - lap[0, 0, 0]
        y_flux = lap[0, 1, 0] - lap[0, 0, 0]
        
        out_storage = in_storage - alpha * (x_flux[0, 0, 0] - x_flux[-1, 0, 0] + y_flux[0, 0, 0] - y_flux[0, -1, 0])

## Externals and Inlining Conditionals
Any stencil arguments are either fields or parameters and are read at runtime. Anything else, however, is an external and treated as a compile-time constant. Externals are and set when the stencil is decorated with `gtscript.stencil`, and the compiled code will have the value substituted. 
<div class="alert alert-block alert-info">
    <b> Now it's your turn: </b><br>
    <ol>
        Open the <code>.gt_cache</code> directory and inspect the <code>example_external_stencil</code> stencil generated by this:
    </ol>
</div>

In [None]:
EXT = 5.
@gtscript.stencil(backend="numpy")
def example_external_stencil(
    in_storage: gtscript.Field[float], out_storage: gtscript.Field[float]
):
    with computation(PARALLEL), interval(...):
        if EXT > 0.:
            out_storage = (
                in_storage[0, 0, 0] + in_storage[1, 0, 0]
            )
        else:
            out_storage = 0.0

in_storage = gt4py.storage.zeros(
    shape=shape,
    default_origin=(0, 0, 0),
    dtype=float,
    backend="numpy",
)
            
out_storage = gt4py.storage.zeros(
    shape=shape,
    default_origin=(0, 0, 0),
    dtype=float,
    backend="numpy",
)

example_external_stencil(in_storage, out_storage, origin=(halo-1, halo-1, 0), domain=(nx+1, nx+1, 1))

Notably, the compiled stencil contains `__condition_1 = 5.0 > 0.0` and `__condition_1` is then evaluated in subsequent if-statements. To avoid this, you can use `__INLINED` to have the compiler evaluate the conditional and only generate the relevant branch of code.
<div class="alert alert-block alert-info">
    <b> Now it's your turn: </b><br>
    <ol>
        Compare the last stencil in <code>.gt_cache</code> to the <code>inlined_conditional_stencil</code> stencil generated here:
    </ol>
</div>

In [None]:
EXT = 5.

@gtscript.stencil(backend="numpy")
def inlined_conditional_stencil(
    in_storage: gtscript.Field[float], out_storage: gtscript.Field[float]
):
    with computation(PARALLEL), interval(...):
        if __INLINED(EXT > 0.):
            out_storage = (
                in_storage[0, 0, 0] + in_storage[1, 0, 0]
            )
        else:
            out_storage = 0.0
            
in_storage.data[halo + nx // 3 : halo +  2 * nx  // 3 : 2, halo + nx // 3 : halo + 2 * nx // 3, :] = 2
out_storage = gt4py.storage.zeros(
    shape=shape,
    default_origin=(0, 0, 0),
    dtype=float,
    backend="numpy",
)

# running the computataion
print("Initial Condition")
plt.imshow(in_storage.data[:,:,0])
plt.show()

inlined_conditional_stencil(in_storage, out_storage, origin=(halo-1, halo-1, 0), domain=(nx+1, nx+1, 1))

print("Output 1")
plt.imshow(out_storage.data[:,:,0])
plt.show()

#can we switch branches?
EXT = -1.0

inlined_conditional_stencil(in_storage, out_storage, origin=(halo-1, halo-1, 0), domain=(nx+1, nx+1, 1))

print("Output trying to change EXT")
plt.imshow(out_storage.data[:,:,0])
plt.show()

Now there is no conditional check in the cached stencil. However, we can't change EXT to switch the branch of the conditional. Since EXT was set to 1 when the stencil was decorated, the stencil was compiled with `EXT=1`, and changing EXT won't change the behavior.

Also, when coded this way, EXT had to be set before the stencil definition, which is not very pythonic. Instead we could decorate the stencil later, when it is called:

In [None]:
def inlined_conditional_stencil(
    in_storage: gtscript.Field[float], out_storage: gtscript.Field[float]
):
    with computation(PARALLEL), interval(...):
        if __INLINED(EXT > 0.):
            out_storage = (
                in_storage[0, 0, 0] + in_storage[1, 0, 0]
            )
        else:
            out_storage = 0.0
            
in_storage.data[halo + nx // 3 : halo +  2 * nx  // 3 : 2, halo + nx // 3 : halo + 2 * nx // 3, :] = 2
out_storage = gt4py.storage.zeros(
    shape=shape,
    default_origin=(0, 0, 0),
    dtype=float,
    backend="numpy",
)

# running the computataion
print("Initial Condition")
plt.imshow(in_storage.data[:,:,0])
plt.show()

EXT = -1.0

stencil_call = gtscript.stencil(definition=inlined_conditional_stencil, backend="numpy")

print("Output")
plt.imshow(out_storage.data[:,:,0])
plt.show()

EXT = 1.0

stencil_call = gtscript.stencil(definition=inlined_conditional_stencil, backend="numpy")

print("Output trying to change EXT again")
plt.imshow(out_storage.data[:,:,0])
plt.show()

EXT is still frozen at the first decoration of our stencil, though. If we wanted to invoke our stencil with a different EXT we would have to pass it as an external:

In [None]:
def inlined_conditional_stencil(
    in_storage: gtscript.Field[float], out_storage: gtscript.Field[float]
):
    from __externals__ import EXT
    with computation(PARALLEL), interval(...):
        if __INLINED(EXT > 0.):
            out_storage = (
                in_storage[0, 0, 0] + in_storage[1, 0, 0]
            )
        else:
            out_storage = 0.0
            
in_storage.data[halo + nx // 3 : halo +  2 * nx  // 3 : 2, halo + nx // 3 : halo + 2 * nx // 3, :] = 2
out_storage = gt4py.storage.zeros(
    shape=shape,
    default_origin=(0, 0, 0),
    dtype=float,
    backend="numpy",
)

# running the computataion
print("Initial Condition")
plt.imshow(in_storage.data[:,:,0])
plt.show()

EXT=1.0

stencil_call = gtscript.stencil(definition=inlined_conditional_stencil, backend="numpy", externals={"EXT":EXT})
stencil_call(in_storage, out_storage, origin=(halo-1, halo-1, 0), domain=(nx+1, nx+1, 1))

print("Output")
plt.imshow(out_storage.data[:,:,0])
plt.show()

EXT=-1.0

stencil_call = gtscript.stencil(definition=inlined_conditional_stencil, backend="numpy", externals={"EXT":EXT})
stencil_call(in_storage, out_storage, origin=(halo-1, halo-1, 0), domain=(nx+1, nx+1, 1))

print("Output successfully changing EXT")
plt.imshow(out_storage.data[:,:,0])
plt.show()

This is a bit clunky, since we have to redefine the stencil to generate the new code, but you can do it.

### Why Inline?
it's faster!

In [None]:
C = 1.0

@gtscript.stencil(backend="numpy")
def inlined_conditional_stencil(
    in_storage: gtscript.Field[float], out_storage: gtscript.Field[float]
):
    with computation(PARALLEL), interval(...):
        if __INLINED(C > 0.):
            out_storage = (
                in_storage[0, 0, 0] + in_storage[1, 0, 0]
            )
        else:
            out_storage = 0.0

@gtscript.stencil(backend="numpy")
def conditional_stencil(
    in_storage: gtscript.Field[float], out_storage: gtscript.Field[float], c: float
):

    with computation(PARALLEL), interval(...):
        if c > 0.:
            out_storage = (
                in_storage[0, 0, 0] + in_storage[1, 0, 0]
            )
        else:
            out_storage = 0

# Setting up the domain
nx = 100
halo = 3
shape = (nx+2*halo, nx+2*halo, 1)

# and the storages
in_storage = gt4py.storage.zeros(
    shape=shape,
    default_origin=(0, 0, 0),
    dtype=float,
    backend="numpy",
)
in_storage.data[halo + nx // 3 : halo +  2 * nx  // 3 : 2, halo + nx // 3 : halo + 2 * nx // 3, :] = 2
out_storage = gt4py.storage.zeros(
    shape=shape,
    default_origin=(0, 0, 0),
    dtype=float,
    backend="numpy",
)
            

#Timing the regular stencil:
print("Runtime conditional:")
fields = {"in_storage": in_storage, "out_storage": out_storage}
scalars = {"c": 1.0}
exec_info = {}

%timeit conditional_stencil(**fields, **scalars, origin=(halo-1, halo-1, 0), domain=(nx+1, nx+1, 1), exec_info=exec_info)

call_time = exec_info['call_end_time']-exec_info['call_start_time']
run_time = exec_info['run_end_time']-exec_info['run_start_time']
print(f"  call_time = {call_time * 1000.} ms")
print(f"  run_time = {run_time * 1000.} ms")
print(f"  overhead = {(call_time - run_time)*1000.} ms")
print("")

#And the inlined version:
print("Inlined conditional:")
exec_info = {}

%timeit inlined_conditional_stencil(**fields, origin=(halo-1, halo-1, 0), domain=(nx+1, nx+1, 1), exec_info=exec_info)

call_time = exec_info['call_end_time']-exec_info['call_start_time']
run_time = exec_info['run_end_time']-exec_info['run_start_time']
print(f"  call_time = {call_time * 1000.} ms")
print(f"  run_time = {run_time * 1000.} ms")
print(f"  overhead = {(call_time - run_time)*1000.} ms")
print("")

<div class="alert alert-block alert-info">
    <b> Now it's your turn: </b><br>
    <ol>
        <li style="margin-bottom: 10px">Write a <code>fill_horizontal</code> stencil function to fill sparse, small negative values in the input fields using neighboring values along the i-direction. Does your solution work if the first or last value is negative?</li>
        <li style="margin-bottom: 10px">Starting from <code>fill_horizontal</code>, write a <code>fill_horizontal_direction</code> stencil function that can fill along the i or j directions based on a run-time conditional.
        <li style="margin-bottom: 10px">Adapt the <code>fill_horizontal_direction</code> stencil function into <code>fill_horizontal_inlined</code> so that the direction is specified at compiletime
        <li style="margin-bottom: 10px">Open the <code>.gt_cache</code> directory and inspect the generated code.</li>
    </ol>
</div>

In [None]:
# remove solution
# Solution 1
@gtscript.stencil(backend="numpy")
def fill_horizontal(
    in_storage: gtscript.Field[float], out_storage: gtscript.Field[float]
):
    with computation(PARALLEL), interval(...):
        out_storage = in_storage

# Solution 2            
@gtscript.stencil(backend="numpy")
def fill_horizontal_direction(
    in_storage: gtscript.Field[float], out_storage: gtscript.Field[float], fill_direction: int
):
    with computation(PARALLEL), interval(...):
        out_storage = in_storage

                
# Solution 3
@gtscript.stencil(backend="numpy")
def fill_horizontal_inlined(
    in_storage: gtscript.Field[float], out_storage: gtscript.Field[float]
):
    with computation(PARALLEL), interval(...):
        out_storage = in_storage
                
                
nx = 101
halo = 3
alpha = 1.0 / 32.0
shape = (nx + 2 * halo, nx + 2 * halo, 1)
in_storage_1 = gt4py.storage.zeros(
    shape=shape,
    default_origin=(0, 0, 0),
    dtype=float,
    backend="numpy",
)

in_storage_1.data[halo + nx // 3 : halo + 2 * nx // 3, halo + nx // 3 : halo + 2 * nx // 3, :] = 3
in_storage_1.data[halo + nx // 3 : halo + 2 * nx // 3 :2, halo + nx // 3 : halo + 2 * nx // 3, :] = -1

in_storage_2 = gt4py.storage.zeros(
    shape=shape,
    default_origin=(0, 0, 0),
    dtype=float,
    backend="numpy",
)

in_storage_2.data[halo + nx // 3 : halo + 2 * nx // 3, halo + nx // 3 : halo + 2 * nx // 3, :] = 3
in_storage_2.data[halo + nx // 3 : halo + 2 * nx // 3, halo + nx // 3 : halo + 2 * nx // 3 : 2, :] = -1

out_storage = gt4py.storage.zeros(
    shape=shape,
    default_origin=(0, 0, 0),
    dtype=float,
    backend="numpy",
)

print("Initial Condition 1")
plt.imshow(in_storage_1.data[:, :, 0])
plt.show()
print(in_storage_1.data[:,45,0], np.mean(in_storage.data))

fill_horizontal(in_storage_1, out_storage, origin=(halo, halo, 0), domain=(nx, nx, 1))
    
print("Output 1")
plt.imshow(out_storage.data[:, :, 0])
plt.show()
print(out_storage.data[:,45,0], np.mean(out_storage.data))



print("Initial Condition 2")
plt.imshow(in_storage_2.data[:, :, 0])
plt.show()
print(in_storage_2.data[:,45,0], np.mean(in_storage.data))

fill_horizontal_direction(in_storage_2, out_storage, 1, origin=(halo, halo, 0), domain=(nx, nx, 1))
    
print("Output 2")
plt.imshow(out_storage.data[:, :, 0])
plt.show()
print(out_storage.data[:,45,0], np.mean(out_storage.data))



fill_horizontal_inlined(in_storage_2, out_storage, origin=(halo, halo, 0), domain=(nx, nx, 1))
    
print("Output 3")
plt.imshow(out_storage.data[:, :, 0])
plt.show()
print(out_storage.data[:,45,0], np.mean(out_storage.data))