In [1]:
import os
import time
import numpy as np
import pandas as pd

import rasterio
from affine import Affine

import dask_gateway
import dask.array as da
import dask.dataframe as daf

import planetary_computer as pc

import xarray as xr
import rioxarray as rioxr

from joblib import load

import raster_to_features as rm
import data_sampling_workflow.sample_rasters as sr

import matplotlib.pyplot as plt

In [2]:
n_workers = 16#36
workers_sqrt = 4 #6

In [3]:
# initialize DASK cluster
cluster = dask_gateway.GatewayCluster()
cluster.scale(n_workers)

client = cluster.get_client()
client

0,1
Connection method: Cluster object,Cluster type: dask_gateway.GatewayCluster
Dashboard: https://pccompute.westeurope.cloudapp.azure.com/compute/services/dask-gateway/clusters/prod.56f301a5a39d427198d2c7b36515d5d5/status,


In [4]:
#scene_ids = pd.read_csv(os.path.join(os.getcwd(),'temp','coastal_scenes_ids_2020.csv')).itemid#.loc[19:26]
itemid = 'ca_m_3412037_ne_10_060_20200607'

In [5]:
# ---------------------------------------
# open pre-trained random forest classifier
rfc = load('spectral_rfc.joblib') 
print('loaded model')

loaded model


In [6]:
# ---------------------------------------
# access NAIP scene
item = sr.get_item_from_id(itemid)   
href = pc.sign(item.assets["image"].href)

# save dimensions
y_n = item.properties['proj:shape'][0]
x_n = item.properties['proj:shape'][1]

# open raster as dask.DataArray
raster = rioxr.open_rasterio(href, chunks={"x": int(x_n/workers_sqrt), "y":int(y_n/workers_sqrt)})       
raster = raster.drop_vars(['spatial_ref','x','y'])

# ---------------------------------------
# make dask.DataFrame with pixels
pixels = daf.from_array(raster.stack(z=("y", "x")).drop_vars('z').T, 
                      chunksize = int(x_n*y_n/n_workers))

# convert into int16 to calculate ndvi and ndwi
pixels = pixels.astype('int16')
pixels['ndvi'] = (pixels[3] - pixels[0])/(pixels[3] + pixels[0])
pixels['ndwi'] = (pixels[1] - pixels[3])/(pixels[1] + pixels[3])

# add column with pixel number (a static index)
pixels['pix_n'] = daf.from_array(np.arange(0,x_n*y_n), chunksize = int(x_n*y_n/n_workers))

# ---------------------------------------
# remove water and low ndvi pixels
not_water = pixels[pixels.ndwi < 0.3]
is_veg = not_water[not_water.ndvi > 0.05]

# ---------------------------------------
# clean dataframe and add date features
# keep copy of pixel # of vegetation pixels
is_veg_index = is_veg.pix_n

date = item.datetime
kwargs = {'year' : date.year,
         'month' : date.month,
         'day_in_year' : sr.day_in_year(date.day, date.month, date.year)}
is_veg = is_veg.assign(**kwargs)

is_veg = is_veg.drop(['ndwi','pix_n'], axis=1)

#is_veg.persist()

In [8]:
# compute predictions
scene_preds = rfc.predict(is_veg.to_dask_array())
scene_preds = scene_preds.compute()

argument of type 'NoneType' is not iterable
Traceback (most recent call last):
  File "/srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/comm/tcp.py", line 223, in read
    frames_nbytes = await stream.read_bytes(fmt_size)
tornado.iostream.StreamClosedError: Stream is closed

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/utils.py", line 799, in wrapper
    return await func(*args, **kwargs)
  File "/srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/client.py", line 1427, in _handle_report
    msgs = await self.scheduler_comm.comm.read()
  File "/srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/comm/tcp.py", line 239, in read
    convert_stream_closed_error(self, e)
  File "/srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/comm/tcp.py", line 140, in convert_stream_closed_error
    if "UNKNOWN_CA

KeyboardInterrupt: 

  self.scheduler_comm.close_rpc()
2022-10-12 22:53:59,644 - distributed.client - ERROR - argument of type 'NoneType' is not iterable
Traceback (most recent call last):
  File "/srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/comm/tcp.py", line 223, in read
    frames_nbytes = await stream.read_bytes(fmt_size)
tornado.iostream.StreamClosedError: Stream is closed

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/client.py", line 1556, in _close
    await asyncio.wait_for(asyncio.shield(handle_report_task), 0.1)
  File "/srv/conda/envs/notebook/lib/python3.10/asyncio/tasks.py", line 445, in wait_for
    return fut.result()
  File "/srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/utils.py", line 799, in wrapper
    return await func(*args, **kwargs)
  File "/srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/client.

In [None]:
# ---------------------------------------
# retrieve pixel #s for vegetation pixels
veg_index = is_veg_index.compute()

In [None]:
# ---------------------------------------
# recover pixel indices for iceplant classifications
preds_df = pd.DataFrame(scene_preds, 
                         columns=['is_iceplant'], 
                         index = veg_index)
is_iceplant_index = preds_df[preds_df.is_iceplant == 1].index.to_numpy()
non_iceplant_index = preds_df[preds_df.is_iceplant == 0].index.to_numpy()

# ---------------------------------------
# reconstruct indices into image
indices = [non_iceplant_index,
           is_iceplant_index]
values = [0,    # values assigned to pixels from each index
          1]

reconstruct = rm.indices_to_image(y_n,x_n, indices, values, back_value=0)

# ---------------------------------------
# save raster 
filename = 'S_preds_' + itemid +'.tif'

with rasterio.open(
    os.path.join(os.getcwd(),'temp',filename),  # file path
    'w',           # w = write
    driver = 'GTiff', # format
    height = y_n, 
    width = x_n,
    count = 1,  # number of raster bands in the dataset
    dtype = rasterio.uint8,
    crs = raster.rio.crs,
#    transform = raster.rio.transform(),
    transform = Affine(*item.properties['proj:transform'][0:6]),
) as dst:
    dst.write(reconstruct.astype(rasterio.uint8), 1)
# ---------------------------------------
print('FINISHED: ', itemid , '\n')