# Working with multidimensional spatial data in Python
# lesson 2.2 - Transforming imaging spectroscopy data at scale
_Glenn Moncrieff_  
[github.com/GMoncrieff](github.com/GMoncrieff)  
[@glennwithtwons](https://twitter.com/glennwithtwons)

We have now trained a model and evaluated it's predictions. Let's pretend that we are happy with it's performance (I said pretend). Now what we want to do is use the model to transfom the reflectance spectra over an entire scene into a map of the biophysical property our model predicts. That means running `model.predict()` over the entire image. Ideally, we would like to not have to do this on the entire datacube at once, as that would be mean having to load all the data into memory, and potentially limit our ability to parallelise this costly computation.  
  
Fortunately we can perform this task on a chunked xarray. This means we only have to load a few chunks at a time, we can stream the data chunk by chunk from s3, and we can use the dask scheduler to help parallelise the work.

## 1. Setup access

### Authenticate with NASA Earthdata portal
Earthaccess is a python library to search, download or stream NASA Earth science data. You will also need an account on NASA's Earthdata data portal
https://search.earthdata.nasa.gov/

In [None]:
import earthaccess
auth = earthaccess.login(persist=True)

In [None]:
#The old way
#from utils.s3_access import write_creds
#write_creds()

## 2. Load libraries

In [None]:
import s3fs
import xarray as xr
import rioxarray as riox
import hvplot.xarray
import holoviews as hv
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import json
from einops import rearrange

from xgboost import XGBRegressor

#our modules
from utils.emit_tools import emit_xarray, quality_mask, ortho_xr, gamma_adjust

## 3. Setup the dask scheduler

In [None]:
from dask.distributed import Client, LocalCluster, progress

cluster = LocalCluster(n_workers=4)
client = Client(cluster)

## 4. Open the datset

### If you are in us-west-2: Authenticate with s3 and stream data

In [None]:
#load s3 credentials
from utils.s3_access import get_temp_creds
temp_creds_req = get_temp_creds()

In [None]:
# Pass Authentication to s3fs
fs_s3 = s3fs.S3FileSystem(anon=False, 
                          key=temp_creds_req['accessKeyId'], 
                          secret=temp_creds_req['secretAccessKey'], 
                          token=temp_creds_req['sessionToken'])

#### Link to the s3 file and have a quick look with xarray

In [None]:
#these are our files - the s3 links
f_url = ['s3://lp-prod-protected/EMITL2ARFL.001/EMIT_L2A_RFL_001_20230119T114247_2301907_005/EMIT_L2A_RFL_001_20230119T114247_2301907_005.nc',
         's3://lp-prod-protected/EMITL2ARFL.001/EMIT_L2A_RFL_001_20230123T100615_2302306_006/EMIT_L2A_RFL_001_20230123T100615_2302306_006.nc']
f_mask_url = ['s3://lp-prod-protected/EMITL2ARFL.001/EMIT_L2A_RFL_001_20230119T114247_2301907_005/EMIT_L2A_MASK_001_20230119T114247_2301907_005.nc',
            's3://lp-prod-protected/EMITL2ARFL.001/EMIT_L2A_RFL_001_20230123T100615_2302306_006/EMIT_L2A_MASK_001_20230123T100615_2302306_006.nc']

In [None]:
# Open s3 url
fp = fs_s3.open(f_url[0], mode='rb')
fp_mask = fs_s3.open(f_mask_url[0], mode='rb')

### If you are not in us-west-2: 
you should already have the data

In [None]:
# paths to data
fp = 'data/downloads/EMIT_L2A_RFL_001_20230119T114247_2301907_005.nc'
fp_mask = 'data/downloads/EMIT_L2A_MASK_001_20230119T114247_2301907_005.nc'

### Load the mask
same as before

In [None]:
flags=[7]
mask = quality_mask(fp_mask,flags)

### Load the data
This is different to before. Now we will use the unorthorectified data, and later orthorectify the biophysical product we produce. We will mask the data, as to do this we don't need to orthorectify because the mask is on the same grid as the image. Because we don't need to orthorectfy we can chunk and load into a `dask.array` backed `xr.Dataset`, saving us memory.

> Why do we not orthorectify?   
We may have to resample the data from the original grid of the focal plane array when we orthorectify - this can mean altering the original measured spectra. To maintain the spectroscopic fidelity of the measurements, it is better to perform all calculation and modelling using the untransformed spectra, and resample/orthorectify the downstream biophysical attributes.
' 


In [None]:
ds = emit_xarray(fp, 
                 ortho=False,
                 chunk={'downtrack':100,'crosstrack':100,'wavelengths':-1})
#mask bad bands
ds = ds.where(ds.good_wavelengths.compute()==1,drop=True)
ds

## 5. Modelling

### get model
Before we can make predictions, we need a function that will load the model from where we saved it onto the dask workers, where it can be applied to the chunks

In [None]:
#this function loads the model

#for the xgboost model
def get_xgb_model():
    model = XGBRegressor()
    model.load_model('models/best_xgb_model.json')
    return model

#for the ROCKET model
#def get_model():
#    with open('models/rocketmodel.pkl', 'rb') as f:
#        model = pickle.load(f)
#    return model

#### The short explanation
`Client.submit()` sends tasks to the Dask scheduler. Using `fmodel = client.submit(get_xgb_model)`, the `get_xgb_model()` function runs in a worker process, returning your XGBoost model. Without `client.submit()`, worker processes won't access the model.

#### The long explanation

The `client.submit()` function is used to send a task to the Dask distributed scheduler. When you do `fmodel = client.submit(get_xgb_model)`, you're asking the scheduler to run the `get_xgb_model()` function in one of the worker processes. This function returns your XGBoost model, and `client.submit` wraps this into a Future object (fmodel), which is a promise to a result that the scheduler will compute in the future. This means that `fmodel.result()` in pred_chunk function will fetch the XGBoost model from the worker process where it was created,  ensuring that the model can be accessed across all worker processes when applying the model to each chunk of your data. Without using `client.submit()`, your worker processes wouldn't have access to the XGBoost model, because it wouldn't be in their local memory.

If you pass `get_xgb_model()` directly to `xr.apply_ufunc`, it will be evaluated once for every chunk. This means you'd be loading your model from disk each time, which is very inefficient, especially if your model is large or you have many chunks. By first doing `fmodel = client.submit(get_xgb_model)`, you ensure the model is loaded only once per worker and kept in memory. Then pred_chunk can quickly access it for each chunk. This is a typical pattern when using a model or large data structure with Dask: load it once per worker, then apply it many times. This avoids the overhead of repeatedly  loading it.

In [None]:
#XGBoost
fmodel = client.submit(get_xgb_model)
#ROCKET
#fmodel = client.submit(get_model)

### Apply model to each chunk
The function below will run on each chunk. The first argument `arr` is the numpy array (the chunk)

In [None]:
#this function apply a transformaton to each chunk
def pred_chunk(arr,fmodel):
    #prep data
    arr = arr[:,:,:-1] #dont ask, a band got dropped somewhere
    xs, ys, zs = arr.shape #get teh shape of the chunk
    arr = rearrange(arr,'x y z -> (x y) z') #turn stack x and y so each pixel is an obs
    arr=np.nan_to_num(arr) #fill nas
    
    #predict
    ypred = fmodel.result().predict(arr)
    
    #prep result
    ypred = np.clip(ypred,0,100) #clip to 0-100
    ypred = rearrange(ypred,'(x y) z -> x y z', x=xs,y=ys) #return to original shape
    return ypred

### Define how to apply the function to the xarray

In [None]:
res = xr.apply_ufunc(pred_chunk, #the function
                           ds, #the data
                           input_core_dims=[['wavelengths']], #the dims we will lose in the result
                           exclude_dims=set(('wavelengths',)), #the dims we will lose in the result
                           output_core_dims=[["class"]], #the dims we will gain in the result
                           dask="parallelized", #use dask
                           output_dtypes=[np.uint8], #dtype of result
                           output_sizes={"class": 4}, #length of new dim,
                           keep_attrs='override',
                           kwargs={'fmodel':fmodel}) #addiotnal args to func

Now we have the result (it has not been calculated though), next step is to mask

In [None]:
qmask = mask[:,:,np.newaxis]
res = res.where(qmask != 1,-9999)

Tell dask to actually perform the computation

In [None]:
#.persist - perform the computation in the background 
# and keep result as chunked array
res = res.persist()
#a progress bar if dask dashboard is not working
progress(res)

Finally, because the result only has a few bands (one for each endmember class), rather than the 250ish in the image data cube, it can fit comfortably in memory. We need the xarray to be backed by `np.array` and not `dask.array` (i.e. not chunked) in order to be able to orthorectify.  
  
`.load()` turns `dask.array`s into `np.array's

In [None]:
res = res.load()

## 6. Orthorectify

In [None]:
res = ortho_xr(res, GLT_NODATA_VALUE=0, fill_value = -9999)
res

### quick plot

In [None]:
res.isel({'class':3}).hvplot.image(cmap='viridis', clim=(0,100),aspect = 'equal', frame_width=500, rasterize=True)

## 7. Save GeoTIFF

We need to add some info to save as a geoTIFF using `rioxarray`

First lets add names for the class codes, we saved these when we encoded the strings to int

In [None]:
#read in class names
classes = json.load(open("data/classes.json"))
classes = list(classes.keys())
res.coords["class"] = classes
#convert to dataset with one var per class

add the formationg that rioxarray needs for writing

In [None]:
#res = res.rio.write_crs('epsg:4326')
res = res["reflectance"].to_dataset(dim="class") 

#our values are 0-100 so lets make then int8
res = res.fillna(255)
res = res.astype("int8")

In [None]:
#get infilename
filename_with_ext = os.path.basename(fp.path)
#write tif
res.rio.to_raster(f'data/unmixed/unmixed_{os.path.splitext(filename_with_ext)[0]}.tiff',dtype='int8',)

## 8. Automate this

I have provided a script that takes the code from this notebook into an executable python file that takes the input file and mask as arguments - `unmix_image.py`. This means we can iterate through a list of files and unmix them in a shell script. The script `run_unmix.sh` will go though all files listed in `data/infiles.csv` and unmix them. To run this file, run this in the terminal (only do this if you are on a machine in us-west-2):
```
chmod +x run_unmix.sh
bash run_unmix.sh


## credits:

This lesson has borrowed from:    

[the EMIT-Data-Resources repository by LPDAAC/ the EMIT team](https://github.com/nasa/EMIT-Data-Resources) 

[Okunjeni et al 2013 RSE for the original methodology](https://www.sciencedirect.com/science/article/abs/pii/S0034425713002009)
