# Tutorial for LEAP-REU data processing and visualization

**Authors: Yu Huang, Sungduk Yu**

GitHub repository links: [tutorials for REU dataset](https://github.com/sungdukyu/LEAP_REU_Dataset_Notebook); [LEAP-REU23 Bootcamp](https://github.com/leap-stc/LEAP-bootcamps).

It is an introductory tutorial for a demo dataset from the climate model [E3SM-MMF](https://www.exascaleproject.org/research-project/e3sm-mmf/). See [E3SM-MMF_baseline](https://github.com/sungdukyu/E3SM-MMF_baseline/tree/main) and [E3SM](https://e3sm.org/wp-content/uploads/2021/11/E3SM_Brochure-2021.pdf) for more information.

This notebook includes preprocessing of the unstructured data, which cannot be managed by xarray directly, and visualization of climate variables. 

The goal is to practice the skills covered during Week1 of the bootcamp to carry out climate analysis, and to get more familiar with the climate concepts using the REU dataset.


## Git Authentification

Use terminal or the left side bar to push your files if you want to use Git to keep track of your file.

Run the below code to give LEAP-Pangeo access to your Github account.

In [None]:
# import gh_scoped_creds
# %ghscopedcreds

## Data preprocessing
### Install and load python packages

We use [mamba](https://mamba.readthedocs.io/en/latest/installation.html) (instead of conda) to install packages on Hub. Please click the "+" button on the leftup corner to launch a terminal, copy the below commands after "!" and run them on the terminal if you cannot directly import them

In [None]:
# ! mamba install -y pynco pynio pyngl

In [None]:
import gcsfs
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import xarray as xr
import cftime
from nco import Nco
from tqdm import tqdm
import Ngl
import xesmf as xe

### Load dataset from Google Cloud

#### Open Google Cloud Storage File System


In [None]:
fs = gcsfs.GCSFileSystem()

#### List files in the bucket where the E3SM-MMF dataset is stored


In [None]:
fs.ls("gs://leap-persistent-ro/sungdukyu")

#### Open the file you want using xarray


In [None]:
mapper = fs.get_mapper('leap-persistent-ro/sungdukyu/E3SM-MMF_ne4.train.output.zarr')
ds = xr.open_dataset(mapper, engine='zarr')


#### Check which variables are included & their dimensions/shapes 

Use [E3SM-MMF Dataset Variable List](https://docs.google.com/spreadsheets/d/1ljRfHq6QB36u0TuoxQXcV4_DSQUR0X4UimZ4QHR8f9M/edit#gid=0) to check the physical meaning of each variable.

Check the original data coordinates first. Instead of using time, latitude, longitude as the coordinates, the raw data uses **sample**(time step) and **ncol**(column index).


In [None]:
ds

### Reorganize the temporal dimension/coordinate

#### Add the *time* dimension  
Originally the time information is coded in the variables **ymd** and **tod**. The **sample** index represents the time step count. 

**ymd** includes date information: the first digit indicates the index of year, the next two digits indicate the month and the last three digits indicates the calendar day in the year.

**tod** represents time in the day counted in seconds.

In [None]:
ds.ymd.values[0], ds.tod.values[0]

In [None]:
# loop over all sample points
year=ds['ymd']//10000
month=ds['ymd']%10000//100
day=ds['ymd']%10000%100
hour=ds['tod']//3600
minute=ds['tod']%3600//60

k=0
t = []
for k in range(len(ds['ymd'])):
    t.append(cftime.DatetimeNoLeap(year[k],month[k],day[k],hour[k],minute[k]))

# add the time array to the 'sample' dimension; then, rename
ds['sample'] = t
ds = ds.rename({'sample':'time'})

# now 'time' dimension replaced 'sample' dimension.
ds = ds.drop(['tod','ymd'])

# Check the current **time** dimension, read the timestep
ds.time.values[0:5]

#### Reduce the data size/coarse the time resolution 

The **time** dimension is large. For avoid memory issues, currently we'll keep only one sample each day. 

You can take the daily/monthly mean/max/min data or the data at a specific time each day. 

There are many ways to do that. You can uncomment the codes to try.

In [None]:
# keep a copy of the original data at 20-min time step and gridded by column index
ds_origin = ds.copy()

# 1. tried to resample the whole dataset but then the memory blows up
# ds = ds.resample(time='1D').mean('time')
# ds

# 2. select data at noon per day for the whole dataset 
# itime = np.arange(36,len(ds.time),24*3)
# ds = ds.isel(time = itime)
# ds

I can continue taking the monthly mean using *resample* function, which will further reduce the data size. It requires more memory so I only process two 2D variables here *total precipitation* and *net surface shortwave radiation*. For the sake of convenience, I will loosely refer to these two variables as 'precipitation' and 'radiation' in the later analysis.

In [None]:
# resample the data to be monthly instead of daily
# first do the transformation so we have: total precipitation = snow + rainfall

ds['cam_out_PRECT'] = ds['cam_out_PRECC'] + ds['cam_out_PRECSC']
ds = ds[['cam_out_PRECT', 'cam_out_NETSW']].resample(time='1M').mean(dim='time')
ds

### Remap and interpolate the data so that they have structured lat, lon coordinates

#### Open a file that stores grid information, and check the original lat, lon information

In [None]:
# grid info
mapper = fs.get_mapper("gs://leap-persistent-ro/sungdukyu/E3SM-MMF_ne4.grid-info.zarr")
ds_grid = xr.open_dataset(mapper, engine='zarr')
ds_grid

In [None]:
fig, ax = plt.subplots()
ds_grid['lat'].plot(label='lat', ax=ax)
ds_grid['lon'].plot(label='lon', ax=ax)
ax.set_ylabel('deg')
ax.legend()

In [None]:
print('column number is ',len(np.unique(ds_grid.ncol.values)))
print('if each lat occured in the data can be paired with a lon, the (lat,lon) grid number is', len(np.unique(ds_grid.lat.values))*len(np.unique(ds_grid.lon.values)))
np.unique(ds_grid.lon.values.round(2))

#### Change the geo-coordinate from column index to multi-index (lat, lon) and see what the data looks like

In [None]:
# original lat and lon info
lat = ds_grid.lat.values.round(2) 
lon = ds_grid.lon.values.round(2) 

# merge the original grid info with the dataset containing atmos variables
ds_multiindex = ds.copy()
ds_multiindex['lat'] = (('ncol'),lat.T) # (('sample', 'ncol'),lat.T)
ds_multiindex['lon'] = (('ncol'),lon.T)

# set multi-index for the original dataset using lat and lon
ds_multiindex = ds_multiindex.set_index(index_id=["lat", "lon"])
index_id = ds_multiindex.index_id
ds_multiindex = ds_multiindex.drop('index_id')
ds_multiindex = ds_multiindex.rename({'ncol':'index_id'})
ds_multiindex = ds_multiindex.assign_coords(index_id = index_id)
ds_multiindex

In [None]:
# create a dataset with stacked (lat, lon) grids in the original dataset, all values are NaNs
time = ds.time.values

data_np = np.empty(shape=(len(time), len(np.unique(lat)),len(np.unique(lon))))
data_np[:,:] = np.nan

ds_latlon = xr.Dataset(
     data_vars={
         # v: (("time","index_id"), np.zeros([len(time), len(np.unique(lat))*len(np.unique(lon))]))
         v: (("time","lat","lon"), data_np)
         for v in ['cam_out_NETSW','cam_out_PRECT']
     },
     coords={
         "time": ds.time,
         # "index_id": pd.MultiIndex.from_product(
         #    [np.unique(lat), np.unique(lon)], names=["lat", "lon"],),
         "lat": np.unique(lat),
         "lon": np.unique(lon),
         # "lev": ds.lev,
    },
)

# use multi-index so that we can assign the column data to the (lat,lon) data
ds_latlon = ds_latlon.stack(index_id=['lat','lon'])
ds_latlon 

# print(len(ds_multiindex.index_id.values))
# ds_latlon.sel(index_id=(-32.59, 320.27))

In [None]:
# use tqdm to visualize the progress of the below loop
# this cell takes about 20min to run, please patiently wait when it is run
ds_precc = ds_latlon.cam_out_PRECT.copy()
ds_netsw = ds_latlon.cam_out_NETSW.copy()

for i in tqdm(ds_multiindex.index_id.values):
    # ds_latlon.loc[{"index_id": i}] = ds_multiindex[[' cam_out_NETSW','cam_out_PRECC']].sel(index_id = i) 
    #### wrong, will lead to all vars have the same values
    ds_precc.loc[{"index_id": i}] = ds_multiindex['cam_out_PRECT'].sel(index_id = i)
    ds_netsw.loc[{"index_id": i}] = ds_multiindex['cam_out_NETSW'].sel(index_id = i)

ds_latlon['cam_out_PRECT'] = ds_precc.copy()
ds_latlon['cam_out_NETSW'] = ds_netsw.copy()

ds_unstack = ds_latlon.unstack('index_id')
ds_unstack

## if we directly visualize the 2D maps, then there are many missing values 
fig, ax = plt.subplots(ncols=2, figsize=(12,4))
ds_unstack. cam_out_NETSW.mean('time').plot(cmap='RdBu_r',ax=ax[0])
ds_unstack.cam_out_PRECT.mean('time').plot(cmap='RdBu_r',ax=ax[1])

<span style="color:blue"> **Look, the raw dataset is visualized weirdly directly using xarray because it is on an unstructured grid and xarray cannot handle these coordinate points properly. See [page 9](https://www.osti.gov/servlets/purl/1807356) to know more about the raw grid setup.**</span>

#### What does the raw grid look like?

Use [pynco](https://pynco.readthedocs.io/en/stable/) to remap data on the unstructured grid to a structured grid. Do not bother yourself to try to understand this kind of grid right now.

In [None]:
infile = '../E3SM-MMF_ne4_train_output_monthly_raw.nc'
outfile = '../E3SM-MMF_ne4_train_output_monthly_remap.nc'
#Mapfile to convert unstructured data to gridded data
mapfile = '../map_ne4pg2_to_180x360_aave.20220722.nc'

ds.to_netcdf('../E3SM-MMF_ne4_train_output_monthly_raw.nc')

nco = Nco()
nco.ncks(input=infile, output=outfile, map=mapfile)

ds_remap = xr.open_dataset(outfile)

fig, ax = plt.subplots(ncols=2, figsize=(12,4))
ds_remap['cam_out_NETSW'].mean('time').plot(cmap='RdBu_r',ax=ax[0])
ds_remap['cam_out_PRECT'].mean('time').plot(cmap='RdBu_r',ax=ax[1])

In [None]:
ds_remap

#### Interpolate the unstructured data at 2-deg resolution

For simplicity and memory consideration, we use a tool [PyNGL](https://www.pyngl.ucar.edu/Functions/Ngl.natgrid.shtml) to interpolate the **RAW** data to make it structured on the grid we want, so that we can make some climate analysis using the skills you learned earlier this week. 


In [None]:
# original lat and lon info
ncol = ds.ncol.values

# new lat and lon grids that we finally want
nlat = np.arange(-90, 90.5, 2)
nlon = np.arange(0, 360, 2)

# create a nan-value xr dataset to contain new remapped & interpolated data 
data_var = np.empty([len(time), len(nlat), len(nlon)])
data_var[:,:,:] = np.nan

ds_new = xr.Dataset(
     data_vars={
         v: (("time","lat","lon"), data_var)
         for v in ['cam_out_NETSW','cam_out_PRECT']
     },
     coords={
         "time": ds.time,
         "lat": nlat,
         "lon": nlon,
        # "lev": ds.lev,
    },
)

ds_new

In [None]:
data_prect = data_var.copy()
data_netsw = data_var.copy()

# interpolate the data using pyngl
# optional: use multiprocessing to save the running time

for it, tt in enumerate(ds.time):
    data = ds.sel(time=tt).cam_out_NETSW.values
    iarr = Ngl.natgrid(lat, lon, data, nlat, nlon) #.squeeze()
    data_netsw[it,:] = iarr
    
    data = ds.sel(time=tt).cam_out_PRECT.values
    iarr = Ngl.natgrid(lat, lon, data, nlat, nlon) #.squeeze()
    data_prect[it,:] = iarr

ds_new['cam_out_NETSW'].values = data_netsw
ds_new['cam_out_PRECT'].values = data_prect

In [None]:
# plot the maps of these two variables at a specific time step

fig, ax = plt.subplots(ncols=2, figsize=(12,4))
ds_new['cam_out_NETSW'].isel(time=10).plot(cmap='RdBu_r',ax=ax[0])
ds_new['cam_out_PRECT'].isel(time=10).plot(cmap='RdBu_r',ax=ax[1])
ax[0].set_title('NET SW Surface')
ax[1].set_title('Total Precip (m/s)')

The maps looks smoother and closer to the realistic condition.


## Now, analysis and visualization can be done using xarray

<span style="color:blue"> Work in a group of 2, each group should finish two sections of analysis. Please use the plotting skills you learned to make the figures look well-annotated, nice and clear. Use Google or ChatGPT if you have questions regarding the climate concepts. </span>
    
### [Analysis 1] Time series and trend

We can use the original dataset to calculate the *global mean* (spatial mean) time series. The unit for precipitation is m/s. 

Note it should be the weighted average mean based on the area of each atmos grid/column. Here we provide you the grid area data so you do not need to calculate it.

In [None]:
# here we use the data whose dimension size was not reduced
ds_origin = xr.merge([ds_grid, ds_origin])

# total precipitation = rainfall + snowfall
PRECT = ds_origin['cam_out_PRECC'] + ds_origin['cam_out_PRECSC']

# area-weighted global mean PRECT
# required concept: avg weights, broadcast, resampling
PRECT_mean = (PRECT * (ds_origin['area']/ds_origin['area'].sum())).sum('ncol')
PRECT_mean_daily = PRECT_mean.resample(time='1D').mean('time')
PRECT_mean_monthly = PRECT_mean.resample(time='1M').mean('time')

# visualization
fig, ax = plt.subplots()
PRECT_mean.plot(label='instantaneous', ax=ax)
PRECT_mean_daily.plot(label='daily mean', ax=ax)
PRECT_mean_monthly.plot(label='monthly mean', ax=ax)
ax.legend()

Or, we can use use the processed dataset on structured lat/lon coordinates. In this case, we need to calculate the area of each 2x2 degree grid. You'll only be able to show the time series of monthly mean data or annual data, because we've resampled it at the monthly frequency.

Please show the time series of global weighted mean cam_out_NETSW using the processed dataset *ds_new*. Refer your Assignment #1.

In [None]:
##### to be implemented...
weights  = 

Do you see a trend of the precip & radiation data? What else do you find with the time series?

#### Spatial variability of monthly precipitation 

For the same month in some regions, like United States, precipitation can be very unevenly distributed - some locations can be extremely dry and some can be extremely wet. 

Use ds_new in the previous analysis and show the variability of monthly precipitation in United States by plot the time series of the median values and the mean $\pm 1$ std values across all grids in  [120W, 70W] [24N, 50.5N]. 

Hint: your x axis should be time (the time step is one month), and y axis should be the precipitation values.

In [None]:
##### to be implemented...
# If you want to take weighted mean, you also need to apply weights to the standard deviation
US_precip = 


### [Analysis 2] Histograms/distributions

Please use PRECT (the variable you created in Analysis 1) and ds_new. 

Convert the unit of precipitation from m/s to mm/hr or mm/day, and plot:

1. a histogram for monthly precipitation from all grids across the globe; 
2. a histogram for daily precipitation from all grids 

to show the statistical distributions of the variable values. Use log scale for y axis. You can plot them on the same graph or do it separately. 

See [xarray.plot.hist](https://docs.xarray.dev/en/stable/generated/xarray.plot.hist.html). 

In [None]:
##### to be implemented...
# m/s = ? mm/hr


Simply describe what you see, especially the difference.

### [Analysis 3] Climatology, Anomalies and Normalization

#### 2D horizontal maps

**<span style="color:blue"> We'll use *ds_new* for the rest analysis from here unless specified.</span>**

Review the lecture on Day 2, read **[Climatology vs weather](https://drought.unl.edu/Education/DroughtIn-depth/WhatisClimatology.aspx)** and **[Current Climate](https://climateknowledgeportal.worldbank.org/country/united-states/climate-data-historical)** to understand what is climatology - naively it can be interpreted as the average for weather in an area over decades. In other words, climatology is what would a typical year / the average condition be like for climate variables, such as precipitation or temperature, in a region or over the globe. It can be the seasonal cycles or just annual mean condition, depending on your goal.

Choose a proper projection (use cartopy) and visualize the 2D maps of climatology status for total precipitation and surface net SW using subplot. Choose a nice colormap and annotate the information properly.

Hint: it can be achieved by simply taking temporal mean. **You are not yet required to deal with seasonal cycles (namely, don't group data by month)** in this analysis.

In [None]:
##### to be implemented...
import cartopy.crs as ccrs
import cartopy


#### Zonal mean climatology
Other than plotting the 2D maps, we can also plot curves of the zonal mean radiation/precipitation. 

Hint: Plot the curve along latitude, with x-axis as lat and y-axis as the zonnal mean variables. You can plot radiation and precipitation in two subplots or plot them together using **[twin axes](https://matplotlib.org/stable/gallery/subplots_axes_and_figures/two_scales.html#sphx-glr-gallery-subplots-axes-and-figures-two-scales-py)**.


In [None]:
##### to be implemented...


Simply describe the patterns.

#### Anomalies and standard anomalies

What is anomaly in climate sciences? Read **[Temperature Anomalies](https://www.ncei.noaa.gov/access/monitoring/dyk/anomalies-vs-temperature)** to learn more. Anomalies can be visualized as time series plots or maps. 

When you want to compare the variability of two variables whose original magnitudes are different, it often helps to express the data in terms of **[Standardized Anomalies](https://iridl.ldeo.columbia.edu/dochelp/StatTutorial/Climatologies/index.html)**. 

Visualize the time series of standardized anomalies for *global mean* precipitation & radiation on the same plot. Take a rolling mean with time window as 3-month to make the curves smoother. Show the legend properly.

In [None]:
##### to be implemented...


Did you find any pattern? Which variable has stronger variability?

### [Analysis 4] Seasonality

We've seen that the data exhibits strong seasonality/periods in the previous analysis. We can further visualize the seasonal cycles from the perspective of climatology, using *Groupby* function.

Groupby the *global mean* precipitation from ds_new in by month and take the mean values. Plot precipitation respect to month. 

Compare with plots in [paper for net SW surface](https://www.researchgate.net/figure/Climatological-mean-annual-cycle-of-a-net-shortwave-radiation-b-net-longwave-radiation_fig4_314121791) and [paper for precip](https://www.researchgate.net/figure/Mean-seasonal-cycle-of-a-temperatures-over-the-globe-the-tropical-ocean-and-the_fig1_234072312).

Hint: the x-axis should be Jan, Feb, … Dec; and y-axis should be the groupby mean of *global mean* precipitation for each month.



In [None]:
##### to be implemented...


We can also groupby the data by seasons, sush as JJA (Jun, Jul, Aug) or DJF (Dec, Jan, Feb).

Plot a map of radiation difference in summer and winter over the whole globe (e.g. JJA vs DJF).


In [None]:
##### to be implemented...

season_data = 

What did you find?

### [Analysis 5] Vertical profiles of the 3D variables

#### Interpolation of data

Danger: <span style="color:red">restart your server and select memory = 64G or 128G option for this analysis. Only do this when you process this task.</span>

Till now, we only visualize and process a few 2D variables, we can also create an interpolated dataset/dataarray for the 3D variable **state_t**, air temperature, with *lev* dimension following the data processing steps to generate ds_new. 

In order to avoid requesting huge memory, we'll only keep the 3D data at every 4th (try 6th when the current memory size is still not sufficient) lev coordinate points and resample it to be monthly data.

In [None]:
##### to be implemented...

lev = ds_origin.lev
lev = lev[0::4]
ds_3d = ds_origin.sel(lev=lev)

# try to release some memory by deleting useless datasets if needed
# del ds_origin, ds, ds_grid 

# memory failure: ds_3d.resample(time='1M').mean(dim='time')
ds_3d = ds_3d.state_t.

In [None]:
##### to be implemented...
# create a nan-value xr dataset to contain new remapped & interpolated data 
data_var = np.empty([len(time), len(lev), len(nlat), len(nlon)])
data_var[:,:,:,:] = np.nan

ds_new3d = 

#### Vertical structure of state_t from any grid 

Select any one grid (or domain) as you want and take the temporal mean (you also need to take the horizontal mean if you select a domain instead one grid before plotting). 

Plot the state_t values along the altitude/height level. Your x-axsis should be the values of state_t, and the y-axis should be lev. 

Compare your plots with [this website](https://crisp.nus.edu.sg/~research/tutorial/atmos.htm#:~:text=The%20vertical%20profile%20of%20the,%2C%20mesopause%20and%20thermopause%2C%20respectively.).

In [None]:
##### to be implemented...


Which level (lev=0 or lev=9), do you think, is closer to the earth surface?

#### Seasonality or distribution across the latitudes

The vertical profiles of state_t from a domain/grid can show seasonal variation and can also exhibit geographic characteristics respect to the latitude. 

Make ONE plot to show Either the seasonal change or geographic characteristics of the air temperature.

Hint: one way to visualize the seasonal change is to take the JJA and DJF mean of domain mean state_t, and plot these two vertical curves on the same figure (for the figure, x-axis: stat_t, y_axis: lev). 

Hint: one way to visualize the geographic difference is to take the temporal and zonal mean of the state_t, so for each lat and lev, you have one state_t value (for the figure, x-axis: lat, y_axis: lev, color: state_t). 

You do not need to make the figure exactly following these steps. Please unleash your creativity and it will be fine as long as you plot something that makes sense and helps understand the concepts.



In [None]:
##### to be implemented...


### [Analysis 6] Simple Machine Learning
#### Linear regression

Change the (weighted or unweighted) *global mean* data from xarray.dataset to numpy.array format and fit a linear regression between precip and radiation (just use the monthly data from ds_new). Namely, fit y = precipitation as a function of x = radiation using linear regression. There are many [Python packages](https://towardsdatascience.com/five-regression-python-modules-that-every-data-scientist-must-know-a4e03a886853) that you can use.

Visualize your predicted data along with the raw data.

You are not asked to tune the regression model for accuracy. 


In [None]:
##### to be implemented...
x = 
y = 

Replace the precip and radiation data with their standard anomalies. Is there any difference?

In [None]:
##### to be implemented...


Are you satisfied with your model and prediction? How to understand the relationship (think of physical mechanisms)?