# Workshop on Domain-Specific Lanugages for Performance-Portable Weather and Climate Models

## Session 2A: Intro to Conditionals and Builtins

This notebook is provides an introduction to conditional statements in GT4PY

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from copy import deepcopy

import gt4py
from gt4py import gtscript
from tools import plot_two_ij_slices

In [None]:
# Setting up the domain
nx = 100
halo = 3
shape = (nx+2*halo, nx+2*halo, 1)

# and the storages
in_storage = gt4py.storage.zeros(
    shape=shape,
    default_origin=(0, 0, 0),
    dtype=float,
    backend="numpy",
)

out_storage = gt4py.storage.zeros(
    shape=shape,
    default_origin=(0, 0, 0),
    dtype=float,
    backend="numpy",
)

## Scalar Conditionals
- Conditionals can be specified with a scalar or a field.
- Scalar conditionals behave fairly intuitively and are applied across a field

In [None]:
@gtscript.stencil(backend="numpy")
def conditional_stencil(
    in_storage: gtscript.Field[float], out_storage: gtscript.Field[float], c: float
):

    with computation(PARALLEL), interval(...):
        if c > 0.:
            out_storage = (
                in_storage[0, 0, 0] + in_storage[1, 0, 0]
            )
        else:
            out_storage = 0

if we pass c=1 in our stencil call, we'll execute along one branch

In [None]:
in_storage.data[halo + nx // 3 : halo +  2 * nx  // 3 : 2, halo + nx // 3 : halo + 2 * nx // 3, :] = 2

conditional_stencil(in_storage, out_storage, 1.0, origin=(halo-1, halo-1, 0), domain=(nx+1, nx+1, 1))

print("c > 0")
plot_two_ij_slices(in_storage, out_storage)

And if we instead pass -1 we get:

In [None]:
conditional_stencil(in_storage, out_storage, -1.0, origin=(halo-1, halo-1, 0), domain=(nx+1, nx+1, 1))

plot_two_ij_slices(in_storage, out_storage)

## Field conditionals
Fields can also be used in conditional statements
- This allows different points to be in different branches

In [None]:
@gtscript.stencil(backend="numpy")
def field_conditional_stencil(
    in_storage: gtscript.Field[float], out_storage: gtscript.Field[float], filter_storage: gtscript.Field[float]
):
    with computation(PARALLEL), interval(...):
        if filter_storage[-1, 0, 0] > 0.:
            out_storage = 0
        else:
            out_storage = in_storage[0, 0, 0]


# and the storages
filter_storage = gt4py.storage.zeros(
    shape=shape,
    default_origin=(0, 0, 0),
    dtype=float,
    backend="numpy",
)
filter_storage.data[halo + nx // 3 : halo +  2 * nx  // 3 : 2, halo + nx // 3 : halo + 2 * nx // 3, :] = 1
filter_storage.data[halo + nx // 3 + 1 : halo + 2 * nx // 3 :2, halo + nx // 3 : halo + 2 * nx // 3, :] = -1
filter_storage.data[halo + nx // 3 : halo + 2 * nx // 3, halo + nx // 3 : halo + 2 * nx // 3 : 2, :] *= -1

in_storage = gt4py.storage.zeros(
    shape=shape,
    default_origin=(0, 0, 0),
    dtype=float,
    backend="numpy",
)
for i in range(nx//3):
    in_storage.data[halo + nx // 3 + i, halo + nx // 3 : halo + 2 * nx // 3, :] = i+1



# running the computataion

print("Conditional Filter")
plt.imshow(filter_storage.data[:,:,0])
plt.show()

field_conditional_stencil(in_storage, out_storage, filter_storage, origin=(halo-1, halo-1, 0), domain=(nx+1, nx+1, 1))

plot_two_ij_slices(in_storage, out_storage)

## More on Field Conditionals
- Field conditionals use a mask to check branching
- This allows some fun referential code

In [None]:
@gtscript.stencil(backend="numpy")
def field_conditional_stencil(
    in_storage: gtscript.Field[float], out_storage: gtscript.Field[float]
):

    with computation(PARALLEL), interval(...):
        if in_storage[-1, 0, 0] > 0.:
            out_storage = 3
        else:
            out_storage = in_storage[0, 0, 0] + in_storage[1, 0, 0]


in_storage = gt4py.storage.zeros(
    shape=shape,
    default_origin=(0, 0, 0),
    dtype=float,
    backend="numpy",
)
in_storage.data[halo + nx // 3 : halo +  2 * nx  // 3 : 2, halo + nx // 3 : halo + 2 * nx // 3, :] = 1



print(in_storage.data[:,45,0])

field_conditional_stencil(in_storage, out_storage, origin=(halo-1, halo-1, 0), domain=(nx+1, nx+1, 1))
plot_two_ij_slices(in_storage, out_storage)

print(out_storage.data[:,45,0])

<div class="alert alert-block alert-info">
    <b> Now it's your turn: </b><br>
    (Hint: Make sure that when you modify code you retain the original code by commenting it out so that you undo any of the modifications you do.)
    <ol>
        <li style="margin-bottom: 10px">Use conditionals to add a simple flux limiter <code>x_flux = min(0, sign(x_flux, delta_x))</code> to the <code>diffusion</code> stencil below. Python does not have an equivalent to Fortran's <code>sign</code>, so you can either define your own as a separate function or write the equivalent code in the stencil.</li>
        <li style="margin-bottom: 10px">Open the <code>.gt_cache</code> directory and inspect the generated code.</li>
    </ol>
</div>

In [None]:
# Solution 1
@gtscript.stencil(backend="numpy")
def diffusion(
    in_storage: gtscript.Field[float], out_storage: gtscript.Field[float], alpha: float
):

    with computation(PARALLEL), interval(...):
        lap = (
            -4.0 * in_storage[0, 0, 0]
            + in_storage[1, 0, 0]
            + in_storage[-1, 0, 0]
            + in_storage[0, 1, 0]
            + in_storage[0, -1, 0]
        )
        
        x_flux = lap[1, 0, 0] - lap[0, 0, 0]
        y_flux = lap[0, 1, 0] - lap[0, 0, 0]
        
        out_storage = in_storage - alpha * (x_flux[0, 0, 0] - x_flux[-1, 0, 0] + y_flux[0, 0, 0] - y_flux[0, -1, 0])
        
alpha = 1./32.
in_storage.data[halo + nx // 3 : halo +  2 * nx  // 3, halo + nx // 3 : halo + 2 * nx // 3, :] = 1

for i in range(500):
    if i==0:
        diffusion(in_storage, out_storage, alpha, origin=(halo-1, halo-1, 0), domain=(nx+1, nx+1, 1))
    else:
        diffusion(out_storage, out_storage, alpha, origin=(halo-1, halo-1, 0), domain=(nx+1, nx+1, 1))
plot_two_ij_slices(in_storage, out_storage)

## Externals and Inlining Conditionals
Any stencil arguments are either fields or parameters and are read at runtime. Anything else, however, is an external and treated as a compile-time constant. Externals are and set when the stencil is decorated with `gtscript.stencil`, and the compiled code will have the value substituted. 
<div class="alert alert-block alert-info">
    <b> Now it's your turn: </b><br>
    <ol>
        Open the <code>.gt_cache</code> directory and inspect the <code>example_external_stencil</code> stencil generated by this:
    </ol>
</div>

In [None]:
EXT = 5.
@gtscript.stencil(backend="numpy")
def example_external_stencil(
    in_storage: gtscript.Field[float], out_storage: gtscript.Field[float]
):
    with computation(PARALLEL), interval(...):
        if EXT > 0.:
            out_storage = (
                in_storage[0, 0, 0] + in_storage[1, 0, 0]
            )
        else:
            out_storage = 0.0

in_storage = gt4py.storage.zeros(
    shape=shape,
    default_origin=(0, 0, 0),
    dtype=float,
    backend="numpy",
)
            
out_storage = gt4py.storage.zeros(
    shape=shape,
    default_origin=(0, 0, 0),
    dtype=float,
    backend="numpy",
)

example_external_stencil(in_storage, out_storage, origin=(halo-1, halo-1, 0), domain=(nx+1, nx+1, 1))

Notably, the compiled stencil contains `__condition_1 = 5.0 > 0.0` and `__condition_1` is then evaluated in subsequent if-statements. To avoid this, you can use `__INLINED` to have the compiler evaluate the conditional and only generate the relevant branch of code.
<div class="alert alert-block alert-info">
    <b> Now it's your turn: </b><br>
    <ol>
        Compare the last stencil in <code>.gt_cache</code> to the <code>inlined_conditional_stencil</code> stencil generated here:
    </ol>
</div>

In [None]:
EXT = 5.

@gtscript.stencil(backend="numpy")
def inlined_conditional_stencil(
    in_storage: gtscript.Field[float], out_storage: gtscript.Field[float]
):
    with computation(PARALLEL), interval(...):
        if __INLINED(EXT > 0.):
            out_storage = (
                in_storage[0, 0, 0] + in_storage[1, 0, 0]
            )
        else:
            out_storage = 0.0
            
in_storage.data[halo + nx // 3 : halo +  2 * nx  // 3 : 2, halo + nx // 3 : halo + 2 * nx // 3, :] = 2
out_storage = gt4py.storage.zeros(
    shape=shape,
    default_origin=(0, 0, 0),
    dtype=float,
    backend="numpy",
)

print("EXT = {0}".format(EXT))

inlined_conditional_stencil(in_storage, out_storage, origin=(halo-1, halo-1, 0), domain=(nx+1, nx+1, 1))

plot_two_ij_slices(in_storage, out_storage)

Now there is no conditional check in the cached stencil. 

But what happens if we change `EXT` and try to rerun?

In [None]:
#can we switch branches?
EXT = -1.0

print("EXT = {0}".format(EXT))

inlined_conditional_stencil(in_storage, out_storage, origin=(halo-1, halo-1, 0), domain=(nx+1, nx+1, 1))
plot_two_ij_slices(in_storage, out_storage)

We're still evaluating as though `EXT>0`. Externals are read at compile-time, and the value is frozen when the stencil decorator is invoked. Since `EXT` was set to 5 when we called `@gtscript.stencil`, that's what's in the generated code, and any subsequent calls to `inlined_conditional_stencil` will use that value. To change the behavior you would have to re-decorate `inlined_conditional_stencil` after changing `EXT`. This is easiest done if you explicitly invoke the decorator:

In [None]:
def inlined_conditional_function(
    in_storage: gtscript.Field[float], out_storage: gtscript.Field[float]
):
    with computation(PARALLEL), interval(...):
        if __INLINED(EXT > 0.):
            out_storage = (
                in_storage[0, 0, 0] + in_storage[1, 0, 0]
            )
        else:
            out_storage = 0.0
            
in_storage.data[halo + nx // 3 : halo +  2 * nx  // 3 : 2, halo + nx // 3 : halo + 2 * nx // 3, :] = 2
out_storage = gt4py.storage.zeros(
    shape=shape,
    default_origin=(0, 0, 0),
    dtype=float,
    backend="numpy",
)


EXT = -1.0
print("EXT = {0}".format(EXT))

stencil_call = gtscript.stencil(definition=inlined_conditional_function, backend="numpy")
stencil_call(in_storage, out_storage, origin=(halo-1, halo-1, 0), domain=(nx+1, nx+1, 1))

plot_two_ij_slices(in_storage, out_storage)

Now if we want to change `EXT` in the generated stencil we can redefine `stencil_call` after changing the value for `EXT`:

In [None]:
EXT = 1.0
print("EXT = {0}".format(EXT))

stencil_call = gtscript.stencil(definition=inlined_conditional_function, backend="numpy")
stencil_call(in_storage, out_storage, origin=(halo-1, halo-1, 0), domain=(nx+1, nx+1, 1))

plot_two_ij_slices(in_storage, out_storage)

You can also use a dictionary to pass externals through the decorator explicitly. To do this, set the `externals` argument in `gtscript.stencil()` to a dictionary mapping the external names to their values, and import those externals insode teh stencil function. 

This can be useful if you want to change names inside the stencil, for example:

In [None]:
def inlined_conditional_stencil(
    in_storage: gtscript.Field[float], out_storage: gtscript.Field[float]
):
    from __externals__ import EXT2
    with computation(PARALLEL), interval(...):
        if __INLINED(EXT2 > 0.):
            out_storage = (
                in_storage[0, 0, 0] + in_storage[1, 0, 0]
            )
        else:
            out_storage = 0.0
            
in_storage.data[halo + nx // 3 : halo +  2 * nx  // 3 : 2, halo + nx // 3 : halo + 2 * nx // 3, :] = 2

# running the computataion

my_value = 1.0
my_other_value = -1.0

stencil_call = gtscript.stencil(definition=inlined_conditional_stencil, backend="numpy", externals={"EXT2":my_value})
stencil_call(in_storage, out_storage, origin=(halo-1, halo-1, 0), domain=(nx+1, nx+1, 1))

plot_two_ij_slices(in_storage, out_storage)

In [None]:
stencil_call = gtscript.stencil(definition=inlined_conditional_stencil, backend="numpy", externals={"EXT2":my_other_value})
stencil_call(in_storage, out_storage, origin=(halo-1, halo-1, 0), domain=(nx+1, nx+1, 1))

plot_two_ij_slices(in_storage, out_storage)

### Why Inline?
it's faster!

In [None]:
C = 1.0

@gtscript.stencil(backend="numpy")
def inlined_conditional_stencil(
    in_storage: gtscript.Field[float], out_storage: gtscript.Field[float]
):
    with computation(PARALLEL), interval(...):
        if __INLINED(C > 0.):
            out_storage = (
                in_storage[0, 0, 0] + in_storage[1, 0, 0]
            )
        else:
            out_storage = 0.0

@gtscript.stencil(backend="numpy")
def conditional_stencil(
    in_storage: gtscript.Field[float], out_storage: gtscript.Field[float], c: float
):

    with computation(PARALLEL), interval(...):
        if c > 0.:
            out_storage = (
                in_storage[0, 0, 0] + in_storage[1, 0, 0]
            )
        else:
            out_storage = 0

# Setting up the domain
nx = 100
halo = 3
shape = (nx+2*halo, nx+2*halo, 1)

# and the storages
in_storage = gt4py.storage.zeros(
    shape=shape,
    default_origin=(0, 0, 0),
    dtype=float,
    backend="numpy",
)
in_storage.data[halo + nx // 3 : halo +  2 * nx  // 3 : 2, halo + nx // 3 : halo + 2 * nx // 3, :] = 2
out_storage = gt4py.storage.zeros(
    shape=shape,
    default_origin=(0, 0, 0),
    dtype=float,
    backend="numpy",
)
            

#Timing the regular stencil:
print("Runtime conditional:")
fields = {"in_storage": in_storage, "out_storage": out_storage}
scalars = {"c": 1.0}
exec_info = {}

%timeit conditional_stencil(**fields, **scalars, origin=(halo-1, halo-1, 0), domain=(nx+1, nx+1, 1), exec_info=exec_info)

call_time = exec_info['call_end_time']-exec_info['call_start_time']
run_time = exec_info['run_end_time']-exec_info['run_start_time']
print(f"  call_time = {call_time * 1000.} ms")
print(f"  run_time = {run_time * 1000.} ms")
print(f"  overhead = {(call_time - run_time)*1000.} ms")
print("")

#And the inlined version:
print("Inlined conditional:")
exec_info = {}

%timeit inlined_conditional_stencil(**fields, origin=(halo-1, halo-1, 0), domain=(nx+1, nx+1, 1), exec_info=exec_info)

call_time = exec_info['call_end_time']-exec_info['call_start_time']
run_time = exec_info['run_end_time']-exec_info['run_start_time']
print(f"  call_time = {call_time * 1000.} ms")
print(f"  run_time = {run_time * 1000.} ms")
print(f"  overhead = {(call_time - run_time)*1000.} ms")
print("")

You can skip the control flow at run-time because only one branch of the conditional is being generated by the compiler

You can also use externals for code re-use:

In [None]:
def avg(a: gtscript.Field[float], b: gtscript.Field[float]):
    from __externals__ import OFFSET
    with computation(PARALLEL), interval(...):
        b = 0.5 * (a + a[OFFSET, 0, 0])

in_storage.data[halo + nx // 3 : halo + 2 * nx // 3, halo + nx // 3 : halo + 2 * nx // 3, :] = 3
in_storage.data[halo + nx // 3 : halo + 2 * nx // 3 :2, halo + nx // 3 : halo + 2 * nx // 3, :] = -1

stencil_call_1 = gtscript.stencil(
    definition = avg,
    externals = {"OFFSET": 1},
    backend = "numpy"
)


stencil_call_1(in_storage, out_storage, origin=(halo-1, halo-1, 0), domain=(nx+1, nx+1, 1))

plot_two_ij_slices(in_storage, out_storage)

<div class="alert alert-block alert-info">
    <b> Now it's your turn: </b><br>
    <ol>
        <li style="margin-bottom: 10px">Write a <code>fill_horizontal</code> stencil function to fill sparse, small negative values in the input fields using neighboring values along the i-direction. Does your solution work if the first or last value is negative?</li>
        <li style="margin-bottom: 10px">Starting from <code>fill_horizontal</code>, write a <code>fill_horizontal_direction</code> stencil function that can fill along the i or j directions based on a run-time conditional.
        <li style="margin-bottom: 10px">Adapt the <code>fill_horizontal_direction</code> stencil function into <code>fill_horizontal_inlined</code> so that the direction is specified at compiletime
        <li style="margin-bottom: 10px">Open the <code>.gt_cache</code> directory and inspect the generated code.</li>
    </ol>
</div>

In [None]:
# Solution 1
@gtscript.stencil(backend="numpy")
def fill_horizontal(
    in_storage: gtscript.Field[float]
):
    with computation(PARALLEL), interval(...):
        in_storage = in_storage


in_storage_1 = gt4py.storage.zeros(
    shape=shape,
    default_origin=(0, 0, 0),
    dtype=float,
    backend="numpy",
)
in_storage_1.data[halo + nx // 3 : halo + 2 * nx // 3, halo + nx // 3 : halo + 2 * nx // 3, :] = 3
in_storage_1.data[halo + nx // 3 : halo + 2 * nx // 3 :2, halo + nx // 3 : halo + 2 * nx // 3, :] = -1

in_copy = deepcopy(in_storage_1)

print(in_storage_1.data[:,45,0], np.mean(in_storage_1.data))

fill_horizontal(in_copy, origin=(halo, halo, 0), domain=(nx, nx, 1))
    
print(in_copy.data[:,45,0], np.mean(in_copy.data))

plot_two_ij_slices(in_storage_1, in_copy)

In [None]:
# Solution 2            
@gtscript.stencil(backend="numpy")
def fill_horizontal_direction(
    in_storage: gtscript.Field[float], fill_direction: int
):
    with computation(PARALLEL), interval(...):
        in_storage = in_storage

in_storage_2 = gt4py.storage.zeros(
    shape=shape,
    default_origin=(0, 0, 0),
    dtype=float,
    backend="numpy",
)

in_storage_2.data[halo + nx // 3 : halo + 2 * nx // 3, halo + nx // 3 : halo + 2 * nx // 3, :] = 3
in_storage_2.data[halo + nx // 3 : halo + 2 * nx // 3, halo + nx // 3 : halo + 2 * nx // 3 : 2, :] = -1

in_copy = deepcopy(in_storage_2)

print(in_storage_2.data[45,:,0], np.mean(in_storage_2.data))

fill_horizontal_direction(in_copy, -1, origin=(halo, halo, 0), domain=(nx, nx, 1))
    
print(in_copy.data[45,:,0], np.mean(in_copy.data))

plot_two_ij_slices(in_storage_2, in_copy)

In [None]:
# Solution 3
DIR = -1
@gtscript.stencil(backend="numpy")
def fill_horizontal_inlined(
    in_storage: gtscript.Field[float]
):
    with computation(PARALLEL), interval(...):
        in_storage = in_storage


in_copy = deepcopy(in_storage_2)

print(in_storage_2.data[45,:,0], np.mean(in_storage_2.data))

fill_horizontal_inlined(in_copy, origin=(halo, halo, 0), domain=(nx, nx, 1))
    
print(in_copy.data[45,:,0], np.mean(in_copy.data))

plot_two_ij_slices(in_storage_2, in_copy)

In [None]:
def avg(a: gtscript.Field[float], b: gtscript.Field[float]):
    from __externals__ import OFFSET
    with computation(PARALLEL), interval(...):
        a = 0.5 * (b + b[OFFSET, 0, 0])

stencil_call_1 = gtscript.stencil(
    definition = avg,
    externals = {"OFFSET": 2},
    backend = "numpy"
)

in_storage_1 = gt4py.storage.zeros(
    shape=shape,
    default_origin=(0, 0, 0),
    dtype=float,
    backend="numpy",
)
in_storage_1.data[halo + nx // 3 : halo + 2 * nx // 3, halo + nx // 3 : halo + 2 * nx // 3, :] = 3
in_storage_1.data[halo + nx // 3 : halo + 2 * nx // 3 :2, halo + nx // 3 : halo + 2 * nx // 3, :] = -1

out_storage = gt4py.storage.zeros(
    shape=shape,
    default_origin=(0, 0, 0),
    dtype=float,
    backend="numpy",
)

stencil_call_1(out_storage, in_storage_1, origin=(halo-1, halo-1, 0), domain=(nx+1, nx+1, 1))

plot_two_ij_slices(in_storage_1, out_storage)