# Working with Large Datasets

Sometimes in climate, we have very large datasets that don't all fit in memory or even if they do will take a long time to calculate. We can use `dask` to handle the data in `chunks` that will fit in memory and/or be computed in parallel. 

### Dask Arrays

[Dask](https://dask.org/)

[Dask and Xarray](http://xarray.pydata.org/en/stable/dask.html)

A `dask` array looks and feels a lot like a `numpy` array. However, a `dask` array doesn’t directly hold any data. Instead, it symbolically represents the computations needed to generate the data. Nothing is actually computed until the actual numerical values are needed. This mode of operation is called “lazy”; it allows one to build up complex, large calculations symbolically before turning them over the scheduler for execution.

From 
https://earth-env-data-science.github.io/lectures/dask/dask_arrays.html

In [None]:
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs

import dask.array as da

### Let's make a big array...
...as a `numpy` object using the `numpy` method `ones`, which creates an array filled with the number 1.0.

In [None]:
shape=(1000,4000)
ones_np=np.ones(shape)
ones_np

In [None]:
print(f"{ones_np.nbytes/1e6} million bytes")

The size of this array containing 4 million floating point numbers is 32 million bytes.* 
That means 8 bytes per number.
This is "double precision", and is the default for floating point numbers in Python.
"Single precision" would be 4 bytes per number.

\*Note that a "megabyte" (MB) is not exactly equal to one million bytes, as a "kilobyte" (KB) is not exactly 1000 bytes. 
Rather, 1 KB = 1024 bytes, and 1 MB = 1,048,576 bytes.

### Dask objects

We can also create a big array like this as a `dask` object using the corresponding `dask` method:

In [None]:
ones=da.ones(shape)
ones

What we see looks just like what have seen for DataArrays in `xarray`. 
This is because `xarray` uses `dask` for its data management.

### Chunks

Notice that we have two columns in our table: 
1. The first is called "Array", and its meaning should be obvious.
2. The second is called "Chunk". It describes how the data are grouped in the computer's memory.

In this case, there is only one "chunk" containing all the data in the array.
However, we can specify that this large array should be broken into smaller pieces:

In [None]:
chunk_shape=(1000,1000)
ones=da.ones(shape,chunks=chunk_shape)
ones

Our `dask` depiction of the data is now different. 
We now see that there are 4 chunks, each with the square shape we specified (1000x1000).
Each has a size in bytes that is 1/4 the size of the entire array.
Also, the pictoral representation shows this division into chunks with vertical lines.

Try some other values for chunk shapes... see what happens.

What happens if you choose a chunk dimension that does not divide evenly into the array's dimensions?

### Dask methods

`dask` is _lazy_. That is, `dask` will not initiate a calculation when it encounters a line of code _unless_ its results are to be:
1. displayed (as either text output or a graphical plot)
2. written to disk
3. shared with another non-dask function that needs the result for its operation

This is done to conserve memory, and speed the execution of the code. Unlike the example above, most of the time
`dask` is employed when reading data from files stored on disk, or in the cloud. 
Often these are large files, and/or many files, that are opened at once. 
Performing calculations with large datasets on disk is an operation that is **IO bound**, meaning that the
speed limitations of input and output (hard disk speeds, networking, etc) are the main factor that
slows computations. `dask` is designed to avoid unnecessary input/output (IO) as well as conserve 
computer memory.

The combination of the _laziness_ of `dask` and its chunking feature is often used in distributed computing environments like
clusters. It can allow a computation to spread across multiple CPUs using parallel computing, much like MPI or Open MP in 
computing languages like FORTRAN.

One of the methods to spur `dask` calculations into action is thorugh one of several special `dask` commands or methods.
One is `compute`:

In [None]:
ones.compute?

In [None]:
test = ones.compute()
test

Our array called `ones` did not actually exist until the previous cell was executed. 
What we saw above were descriptions of what `ones` _would look like_, how much memory it _would require_ and 
how the chunks _would be distributed_ once the array was in memory. 
But memory space was not occupied until we executed the command:
`test = ones.compute()`

In [None]:
!jupyter --version

--------------------
--------------------

### Visualizing what Dask does

There is a handy method called `visualize` that shows how `dask` defines chunks and manages them as it performs calculations.

In [None]:
ones.visualize()

In [None]:
sum_of_ones=ones.sum()

In [None]:
sum_of_ones.compute()

In [None]:
sum_of_ones.visualize()

In [None]:
fancy_calculation=(ones*ones[::-1,::-1]).mean()
fancy_calculation.visualize()

In [None]:
fancy_calculation.compute()

------------------
------------------
### A Big Calculation

The examples above were toy examples (32Mb).  This data is not big enough to warrant the use of `dask`.  Let's try a much bigger example

In [None]:
bigshape=(200000,4000)
big_ones=da.ones(bigshape,chunks=chunk_shape)
big_ones

In [None]:
print(big_ones.nbytes/1e9,"GB")

DO NOT VISUALIZE THIS!

### Dask has some other tools to help us understand what is happening 

In [None]:
from dask.diagnostics import ProgressBar

big_calc = (big_ones * big_ones[::-1, ::-1]).mean()

with ProgressBar():
    result = big_calc.compute()
result

#### All the usual `numpy` (and `xarray`) methods work on `dask` arrays.

In [None]:
big_ones_reduce=(np.cos(big_ones)**2).mean(axis=0)
big_ones_reduce

---------------------------
## Xarray uses dask by default when you use open_mfdataset 
`dask` can be invoked by specifying `chunks` when you open and read your data.

Example: ERA5 daily atmospheric data (multiple presure levels)
File for every day from 1979-2020 


In [None]:
path = '/home/pdirmeye/ERA5_z/'
fname = 'ea_global_an_daily_'


In [None]:
# Create a list of many files to open and read
fnames = path+fname+'*.nc4'
fnames

In [None]:
# Note - you can use a wildcard string in the file name to open multiple files
ds=xr.open_mfdataset(fnames,combine='nested',concat_dim='time')
ds

### Reduced Gaussian grids are a type of irregular grid
* [How they work](https://confluence.ecmwf.int/display/FCST/Gaussian+grids)
* [The N320 (grid used by ERA5) table by latitude rows](https://confluence.ecmwf.int/display/EMOS/N320)

We need to map each of the grid cells in the reduced Gaussian grid onto their corresponding longitudes and latitudes. 
There is more than one way to do this... 

In [None]:
# One way to reindex the "rgrid" reduced Gaussian grid vector into latitudes and longitudes
rgg_file = path+'N320_reduced_grid.nc4'
rgg = xr.open_dataset(rgg_file)
# rgg is an rgrid-length pair of vectors for the corresponding latitudes and longitudes from the rectangular grid
rgg

In [None]:
# Another way to reindex the "rgrid" reduced Gaussian grid vector into latitudes and longitudes
latlon_file = path+'N320_index.nc' 
latlon = xr.open_dataset(latlon_file)
# latlon is 640x1280 lat-lon grid containing the representative "rgrid" value to map to each point on the regular grid. 
latlon

### (Re)projecting data

* How would the two different reindexing datasets be applied? 
* What be the result of each?
* Could you produce the same grids of data, and the same maps, from them?

In [None]:
gravity = 9.8 # m/s**2
fig = plt.figure(figsize=(13,8))
ax=plt.axes(projection=ccrs.PlateCarree())

plt.scatter(rgg['lon'],rgg['lat'],c=ds["z"][0,0]/gravity,s=0.1,marker='s',transform=ccrs.PlateCarree(),cmap="GnBu_r")
ax.coastlines()

plt.title('300hPa Geopotential Heights',fontsize=20)
plt.colorbar(shrink=0.7,aspect=30,orientation='horizontal',label='meters') ;

### A function to reconstitute reduced to full Gaussian Grids
`era5_remap` is a Python function that remaps reduced grid data onto the full rectangular (1280x640) Gaussian grid.
* Functions in a `.py` script file can be imported like any other Python pachage or library 

In [None]:
!cp /home/pdirmeye/classes/clim680_2022/era5_remap.py .
from era5_remap import era5_remap
help(era5_remap)

In [None]:
reg_grid = era5_remap(ds["z"][0,0]/gravity,'rgg',latlon)

fig = plt.figure(figsize=(13,8))
ax = plt.axes(projection=ccrs.PlateCarree())
plt.pcolormesh(reg_grid.lon,reg_grid.lat,reg_grid,cmap='GnBu_r',transform=ccrs.PlateCarree(),shading='nearest')
ax.coastlines()

plt.title('300hPa Geopotential Heights',fontsize=20)
plt.colorbar(shrink=0.7,aspect=30,orientation='horizontal',label='meters') ;

## How long does it take?

We can use timers to see how long it takes for blocks of code to run. This is a great way to find inefficiencies and understand code performance.

In [None]:
from time import perf_counter

In [None]:
start = perf_counter()
reg_grid = era5_remap(ds["z"][0,0],'rgg',latlon) # One level on one day
end_1d = perf_counter()
print(f"Horizontal field required {end_1d-start:.3g} seconds")

reg_grid = era5_remap(ds["z"][0],'rgg',latlon) # All levels on one day
end_2d = perf_counter()
print(f"Horizontal and vertical field required {end_2d-end_1d:.3g} seconds")

reg_grid = era5_remap(ds["z"][:32,0],'rgg',latlon) # All levels for all days in the first month
end_mo = perf_counter()
print(f"Horizontal and time (one month) field required {end_mo-end_2d:.3g} seconds")


### Note that there is a signficant amount of _scaling_ here. 
"Scaling" means that the amount of time it takes to do a larger task does not grow as quickly as the task grows.
* Processing 5 levels did not take 5x longer than processing 1 level. Less than 2x.
* Processing 31 days took around 4 times longer than processing one day.

The function `era5_remap` uses a method called [_list comprehension_](https://en.wikipedia.org/wiki/Comparison_of_programming_languages_(list_comprehension)#Python) to regrid the data.
* _List comprehension_ is much faster than loops.

---------------------

### Let's convert these awkward timestamps into `dtype=datetime64`

In [None]:
import pandas as pd
dtobj = pd.to_datetime(list(ds['time'].values), format='%Y%m%d.%f')
ds['time'] = dtobj
ds['time']

### Now we can use our beloved `.groupby()` method to parse through the time dimension

In [None]:
t0 = perf_counter()
z_climo_jan = ds['z'].groupby('time.month')[1].mean(dim='time')
t1 = perf_counter()
print(f"{t1-t0:.3g} seconds")

z_climo_jan.load()
t2 = perf_counter()
print(f"{t2-t1:.3g} seconds")


Now we can clearly see that our big calculation didn't actually happen at the line where we assigned `z_climo_jan`.

`dask` does not perform actual calculations or even load the data from a file into memory _until it absolutely must_ (e.g., to plot a result).
The `.load()` method forces `dask` to read the data from the file on disk into memory, much as `.compute()` forces `dask` to perform a calculation.

#### When would you want to force `dask` to be _eager_ rather than _lazy_?

* When you want to use computations over and over
    * Example: If you calculate anomalies for a really large dataset and then you want to use the anomalies for the rest of the program without asking `dask` to recompute them each time.
<br><br>

* When you have a performance issue
    * If it is taking a long time to do the calculations, you can tell `dask` to go ahead and `load` the data ahead of time, if you have enough memory, or go ahead and `compute` the computations up to this point.  


See how long it is taking with the progress bar...

In [None]:
z_decade_climo_jul = ds["z"].sel(plev=50000.,time=slice("2001-01-01", "2010-12-31")).groupby('time.month')[12].mean(dim='time')

fig = plt.figure(figsize=(13,8))
ax = plt.axes(projection=ccrs.PlateCarree())

with ProgressBar():
    era5_remap(z_decade_climo_jul,'rgg',latlon).plot(cmap='GnBu_r',transform=ccrs.PlateCarree()) 
    
ax.coastlines()
plt.title('300hPa Heights - Dec (2001-2010)') ;

In [None]:
# It's not the remapping or the plotting that takes all the time - it is crunching through the calculation across all that data on disk.
with ProgressBar():
    ds_load=z_decade_climo_jul.load()

`ProgressBar` only monitors `dask` actions. Below we are not invoking `dask`, so the function does nothing.

In [None]:
fig = plt.figure(figsize=(13,8))
ax = plt.axes(projection=ccrs.PlateCarree())
with ProgressBar():
    era5_remap(ds_load,'rgg',latlon).plot(cmap='GnBu_r')
ax.coastlines()
plt.title('300hPa Heights - Dec (2001-2010)') ;