# why flox?

```
/////
File: why_flox
Author: Thomas Moore
Description: This notebook explores examples showing the benefit of `flox` and will run on a laptop
Date: 1 May 2024
/////
```

In [1]:
Author_dict = {"name": "Thomas Moore", 
               "affiliation": "CSIRO", 
               "email": "thomas.moore@csiro.au",
               "orchid_ID":'https://orcid.org/0000-0003-3930-1946'}

# setup

In [2]:
import dask
import flox
import xarray as xr
import pandas as pd
import numpy as np
from dask.distributed import Client, LocalCluster

In [3]:
def print_chunks(data_array):
    chunks = data_array.chunks
    dim_names = data_array.dims
    readable_chunks = {dim: chunks[i] for i, dim in enumerate(dim_names)}
    for dim, sizes in readable_chunks.items():
        print(f"{dim} chunks: {sizes}")

In [4]:
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature

def plot_temperature_at_time(ds, time_index, depth_index=0):
    # Select the data for a specific time and depth
    temp_data = ds['temperature'].isel(time=time_index, depth=depth_index)
    
    # Create a plot
    fig = plt.figure(figsize=(15, 7))
    ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
    ax.set_global()
    
    # Add map features
    ax.add_feature(cfeature.LAND, zorder=100, edgecolor='k')
    ax.add_feature(cfeature.COASTLINE)
    ax.add_feature(cfeature.BORDERS, linestyle=':')

    # Plot the temperature data
    temp_plot = ax.pcolormesh(ds.longitude, ds.latitude, temp_data, transform=ccrs.PlateCarree(), cmap='viridis')

    # Add a color bar
    cbar = plt.colorbar(temp_plot, orientation='vertical', pad=0.02, aspect=50)
    cbar.set_label('Temperature (degrees Celsius)')

    # Add title
    plt.title(f'Ocean Temperature on {ds.time[time_index].values}')
    plt.show()


# start LocalCluster

In [5]:
cluster = LocalCluster(
    n_workers=4,          # Number of workers
    threads_per_worker=1 # Number of threads per each worker
)
client = Client(cluster)



# make example xarray object & write to `netcdf`

In [None]:
# Define dimensions
n_time = 33 * 365  # approximately, not accounting for leap years
n_leap = 8 # manually calculated for the period 1990 + 33 years
n_time = n_time + n_leap
n_lat = 60
n_lon = 144
n_depth = 20

In [None]:
# Create a time range from 1990-01-01 over 33 years
time = pd.date_range('1990-01-01', periods=n_time, freq='D')
time

In [None]:
# Create latitude and longitude arrays
latitude = np.linspace(-90, 90, n_lat)  # example range from -90 to 90
longitude = np.linspace(0, 360, n_lon, endpoint=False)  # example range from 0 to 360
depth = np.linspace(0, 5000, n_depth)  # example range from 0 to 5000 meters

In [None]:
# Initialize random temperature data
a = 1.0
b = 32.0
data = (a + (np.random.rand(n_time, n_lat, n_lon, n_depth) * (b - a))).astype(np.float32)

In [None]:
# Apply depth scaling: Temperature decreases with depth
# Assuming a linear decrease with depth, adjust scale factor as needed
depth_scale_factor = np.linspace(1, 0.2, n_depth).astype(np.float32)  # From 100% at the surface to 20% at the maximum depth
depth_scaled_data = data * depth_scale_factor[np.newaxis, np.newaxis, np.newaxis, :]

In [None]:
# Apply an exponential function to enhance the latitude effect
exponential_factor = np.exp(0.1 * np.abs(latitude)).astype(np.float32)  # Adjust the coefficient as needed to increase strength

# Inverse the exponential factor to decrease temperature towards the poles
latitude_scale_factor = 1 / exponential_factor.astype(np.float32)

# Reshape latitude_scale_factor for broadcasting
latitude_scale_factor = latitude_scale_factor[:, np.newaxis, np.newaxis].astype(np.float32)

# Apply latitude scaling
latitude_scaled_data = depth_scaled_data * (latitude_scale_factor + 1) / 2  # Adjust scale to range [0.5, 1]

In [None]:
latitude_scaled_data = latitude_scaled_data.astype(np.float32)

In [None]:
# Create the dataset
ds = xr.Dataset(
    {
        "temperature": (("time", "latitude", "longitude", "depth"),latitude_scaled_data)
    },
    coords={
        "time": time,
        "latitude": latitude,
        "longitude": longitude,
        "depth": depth
    }
)

# Assign attributes
ds['temperature'].attrs['units'] = 'degrees Celsius'
ds['temperature'].attrs['description'] = 'Simulated ocean temperatures'
#ds['time'].attrs['calendar'] = 'gregorian'


In [None]:
ds.nbytes/1e9

In [None]:
plot_temperature_at_time(ds, 0, 0)

In [None]:
ds

In [None]:
# Define chunking strategy
chunks = {
    'time': 100,    # Chunk per day
    'latitude': 60,  # Chunk latitude if necessary, depends on available memory
    'longitude': 144,  # Chunk longitude if necessary, can adjust based on use case
    'depth': 20    # Chunk by all depth levels if not too large
}

# Write the dataset to NetCDF
ds.chunk(chunks).to_netcdf('/Users/moo270/data/climatology-demo/dummy_ocean_temperature.nc', mode='w', format='NETCDF4', engine='netcdf4',
                           encoding={'temperature': {'chunksizes': (100, 60, 144, 20)}})

# RELOAD dataset from `netcdf`

In [None]:
loaded_ds = xr.open_dataset('/Users/moo270/data/climatology-demo/dummy_ocean_temperature.nc',chunks={})

In [None]:
loaded_ds

In [None]:
ds_chunk_time_100 = loaded_ds

In [None]:
ds_chunk_time_100

In [None]:
# Retrieve Xarray options
options = xr.get_options()

# Convert the options dictionary to a Pandas DataFrame for a nicer table display
options_df = pd.DataFrame(list(options.items()), columns=['Option', 'Value'])

# Print the DataFrame
print(options_df)

In [None]:
%%time
clim_flox = ds_chunk_time_100.groupby('time.month').mean('time').compute()

In [None]:
clim_flox.nbytes/1e9

In [None]:
%%time
with xr.set_options(use_flox=False):
    # Retrieve Xarray options
    options = xr.get_options()

    # Convert the options dictionary to a Pandas DataFrame for a nicer table display
    options_df = pd.DataFrame(list(options.items()), columns=['Option', 'Value'])

    # Print the DataFrame
    print(options_df)
    #
    clim_noflox = ds_chunk_time_100.groupby('time.month').mean('time').compute()

In [None]:
import gc
#del clim_noflox
#del clim_flox
gc.collect()
client.restart()

# flox docs examples

In [13]:
import dask.array as da
# Generate a DataArray with random numbers
oisst = xr.DataArray(
    da.random.random((14532, 720, 144), chunks=(20, 720, 144)),  # Generate random values
    dims=("time", "lat", "lon"),
    coords={"time": pd.date_range("1981-09-01 12:00", "2021-06-14 12:00", freq="D")},
    name="sst"
)

In [14]:
oisst.nbytes/1e9

12.05342208

In [15]:
oisst.groupby('time.month').mean('time')

Unnamed: 0,Array,Chunk
Bytes,9.49 MiB,810.00 kiB
Shape,"(12, 720, 144)","(1, 720, 144)"
Dask graph,12 chunks in 79 graph layers,12 chunks in 79 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 9.49 MiB 810.00 kiB Shape (12, 720, 144) (1, 720, 144) Dask graph 12 chunks in 79 graph layers Data type float64 numpy.ndarray",144  720  12,

Unnamed: 0,Array,Chunk
Bytes,9.49 MiB,810.00 kiB
Shape,"(12, 720, 144)","(1, 720, 144)"
Dask graph,12 chunks in 79 graph layers,12 chunks in 79 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [16]:
%%time
oisst.groupby('time.month').mean('time').compute()

CPU times: user 2.33 s, sys: 360 ms, total: 2.69 s
Wall time: 9.76 s


In [17]:
%%time
with xr.set_options(use_flox=False):
    # Retrieve Xarray options
    options = xr.get_options()

    # Convert the options dictionary to a Pandas DataFrame for a nicer table display
    options_df = pd.DataFrame(list(options.items()), columns=['Option', 'Value'])

    # Print the DataFrame
    print(options_df)
    #
    oisst.groupby('time.month').mean('time').compute()

                      Option    Value
0       arithmetic_broadcast     True
1            arithmetic_join    inner
2             cmap_divergent   RdBu_r
3            cmap_sequential  viridis
4           display_max_rows       12
5   display_values_threshold      200
6              display_style     html
7              display_width       80
8       display_expand_attrs  default
9      display_expand_coords  default
10  display_expand_data_vars  default
11       display_expand_data  default
12     display_expand_groups  default
13    display_expand_indexes  default
14   display_default_indexes    False
15        enable_cftimeindex     True
16        file_cache_maxsize      128
17                keep_attrs  default
18   warn_for_unclosed_files    False
19            use_bottleneck     True
20                  use_flox    False
21               use_numbagg     True
22            use_opt_einsum     True
CPU times: user 1.59 s, sys: 238 ms, total: 1.83 s
Wall time: 4.21 s
