# Predicting on images

## setup access

## Authenticate with NASA Earthdata portal

The new way

In [1]:
import earthaccess
auth = earthaccess.login(persist=True)

EARTHDATA_USERNAME and EARTHDATA_PASSWORD are not set in the current environment, try setting them or use a different strategy (netrc, interactive)
You're now authenticated with NASA Earthdata Login
Using token with expiration date: 10/01/2023


The old way

In [2]:
#from utils.s3_access import write_creds
#write_creds()

In [2]:
from utils.s3_access import get_temp_creds
temp_creds_req = get_temp_creds()

## Select the data we want

In [3]:
#imports
%load_ext autoreload
%autoreload 2
import s3fs
import xarray as xr
import rioxarray as riox
import hvplot.xarray
import holoviews as hv
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import json
from einops import rearrange

from xgboost import XGBRegressor

#our modules
from utils.emit_tools import emit_xarray, quality_mask, ortho_xr


In [4]:
from dask.distributed import Client, LocalCluster, progress

cluster = LocalCluster(n_workers=4)
client = Client(cluster)

In [5]:
# Pass Authentication to s3fs
fs_s3 = s3fs.S3FileSystem(anon=False, 
                          key=temp_creds_req['accessKeyId'], 
                          secret=temp_creds_req['secretAccessKey'], 
                          token=temp_creds_req['sessionToken'])

In [6]:
f_url = ['s3://lp-prod-protected/EMITL2ARFL.001/EMIT_L2A_RFL_001_20230119T114247_2301907_005/EMIT_L2A_RFL_001_20230119T114247_2301907_005.nc',
         's3://lp-prod-protected/EMITL2ARFL.001/EMIT_L2A_RFL_001_20230123T100615_2302306_006/EMIT_L2A_RFL_001_20230123T100615_2302306_006.nc']
f_mask_url = ['s3://lp-prod-protected/EMITL2ARFL.001/EMIT_L2A_RFL_001_20230119T114247_2301907_005/EMIT_L2A_MASK_001_20230119T114247_2301907_005.nc',
            's3://lp-prod-protected/EMITL2ARFL.001/EMIT_L2A_RFL_001_20230123T100615_2302306_006/EMIT_L2A_MASK_001_20230123T100615_2302306_006.nc']

## open datset

In [7]:
# Open s3 url
fp = fs_s3.open(f_url[0], mode='rb')
fp_mask = fs_s3.open(f_mask_url[0], mode='rb')

In [8]:
flags=[7]
mask = quality_mask(fp_mask,flags)

Flags used: ['Aggregate Flag']


In [9]:
ds = emit_xarray(fp, 
                 ortho=False,
                 chunk={'downtrack':100,'crosstrack':100,'wavelengths':-1})
ds = ds.where(ds.good_wavelengths.compute()==1,drop=True)
ds

Unnamed: 0,Array,Chunk
Bytes,1.45 GiB,9.31 MiB
Shape,"(1280, 1242, 244)","(100, 100, 244)"
Dask graph,169 chunks in 7 graph layers,169 chunks in 7 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.45 GiB 9.31 MiB Shape (1280, 1242, 244) (100, 100, 244) Dask graph 169 chunks in 7 graph layers Data type float32 numpy.ndarray",244  1242  1280,

Unnamed: 0,Array,Chunk
Bytes,1.45 GiB,9.31 MiB
Shape,"(1280, 1242, 244)","(100, 100, 244)"
Dask graph,169 chunks in 7 graph layers,169 chunks in 7 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [10]:
#quick plot
(ds
 .isel(downtrack=slice(800, 1000),crosstrack=slice(800, 1000))
 .sel(wavelengths=650, method='nearest')
 .hvplot.image(cmap='viridis', aspect = 'equal', frame_width=500, rasterize=True)
)


### prep for model

### apply model

In [12]:
#this function loads the model

#def get_model():
#    with open('models/rocketmodel.pkl', 'rb') as f:
#        model = pickle.load(f)
#    return model

def get_xgb_model():
    model = XGBRegressor()
    model.load_model('models/best_xgb_model.json')
    return model

# The client.submit() function is used to send a task to the Dask distributed scheduler. 
# When you do fmodel = client.submit(get_xgb_model), 
# you're asking the scheduler to run the get_xgb_model function in one of the worker processes. This function returns your XGBoost model, and client.submit wraps this into a Future object (fmodel), which is a promise to a result that the scheduler will compute in the future.
# This means that fmodel.result() in pred_chunk function will
# fetch the XGBoost model from the worker process where it was created, 
# ensuring that the model can be accessed across all worker processes
# when applying the model to each chunk of your data.
# Without using client.submit(), your worker processes wouldn't have access
# to the XGBoost model, because it wouldn't be in their local memory.


#If you pass get_xgb_model directly to xr.apply_ufunc, it will be evaluated once for every chunk. This means you'd be loading your model from disk each time, which is very inefficient, especially if your model is large or you have many chunks.
#By first doing fmodel = client.submit(get_xgb_model), you ensure the model is loaded only once per worker and kept in memory. Then pred_chunk can quickly access it for each chunk.
#This is a typical pattern when using a model or large data structure with Dask: load it once per worker, then apply it many times. This avoids the overhead of repeatedly loading it.

#fmodel = client.submit(get_model)
fmodel = client.submit(get_xgb_model)

In [None]:
#this function apply a transformaton to each chunk
def pred_chunk(arr,fmodel):
    #fill nas
    arr = arr[:,:,:-1]
    xs, ys, zs = arr.shape
    arr = rearrange(arr,'x y z -> (x y) z')
    arr=np.nan_to_num(arr)
    #predict
    ypred = fmodel.result().predict(arr)
    #clip to 0-100
    ypred = np.clip(ypred,0,100)
    ypred = rearrange(ypred,'(x y) z -> x y z', x=xs,y=ys)
    return ypred

In [20]:
#here we define how to apply the func
res = xr.apply_ufunc(pred_chunk, #the function
                           ds, #the data
                           input_core_dims=[['wavelengths']], #the dims we will lose in the result
                           exclude_dims=set(('wavelengths',)), #the dims we will lose in the result
                           output_core_dims=[["class"]], #the dims we will gain in the result
                           dask="parallelized", #use dask
                           output_dtypes=[np.uint8], #dtype of result
                           output_sizes={"class": 4}, #length of new dim,
                           keep_attrs='override',
                           kwargs={'fmodel':fmodel}) #addiotnal args to func

  res = xr.apply_ufunc(pred_chunk, #the function


In [21]:
qmask = mask[:,:,np.newaxis]
res = res.where(qmask != 1,-9999)

In [22]:
#apply the function
#.persist the perform the computation in the background 
# and keep result as chunked array
res = res.persist()

In [23]:
#a progress bar if das dashboard is not working
progress(res)

VBox()

In [22]:
#mask nas
#res = res.where(newmask==0,np.nan)

#unstack and return to x y
#res=res.unstack()

In [24]:
for coord in res.coords:
    res.coords[coord] = res.coords[coord].load()

In [25]:
res

Unnamed: 0,Array,Chunk
Bytes,12.13 MiB,78.12 kiB
Shape,"(1280, 1242, 4)","(100, 100, 4)"
Dask graph,169 chunks in 1 graph layer,169 chunks in 1 graph layer
Data type,int16 numpy.ndarray,int16 numpy.ndarray
"Array Chunk Bytes 12.13 MiB 78.12 kiB Shape (1280, 1242, 4) (100, 100, 4) Dask graph 169 chunks in 1 graph layer Data type int16 numpy.ndarray",4  1242  1280,

Unnamed: 0,Array,Chunk
Bytes,12.13 MiB,78.12 kiB
Shape,"(1280, 1242, 4)","(100, 100, 4)"
Dask graph,169 chunks in 1 graph layer,169 chunks in 1 graph layer
Data type,int16 numpy.ndarray,int16 numpy.ndarray


In [26]:
res = res.load()

### quick plot

In [28]:
ores = ortho_xr(res, GLT_NODATA_VALUE=0, fill_value = -9999)
ores


1.23.5


In [29]:
ores.isel({'class':3}).hvplot.image(cmap='viridis', clim=(0,100),aspect = 'equal', frame_width=500, rasterize=True)

In [38]:
ores

### Save GeoTIFF

In [47]:
#add crs info
res_ras = ores.rio.write_crs('epsg:4326')
#read in class names
classes = json.load(open("data/classes.json"))
classes = list(classes.keys())
res_ras.coords["class"] = classes
#convert to dataset with one var per class
res_ras = res_ras["reflectance"].to_dataset(dim="class") 

#our values are 0-100 so lets make then int8
res_ras = res_ras.fillna(255)
res_ras = res_ras.astype("int8")

In [48]:
res_ras

In [49]:
#write tif
#get infilename
filename_with_ext = os.path.basename(f_url[0])
res_ras.rio.to_raster(f'data/unmixed/unmixed_{os.path.splitext(filename_with_ext)[0]}.tiff',dtype='int8',)