In [2]:
import os
import tempfile

import azure.storage.blob
import adlfs
import dask.base
import dask.array
from dask.distributed import Client, wait
from dask_gateway import Gateway
import rioxarray
import rasterio.vrt
import numpy as np
import xarray as xr
import hvplot.xarray
import panel
import pandas as pd
from bokeh.models.tools import BoxZoomTool
import stac_vrt

import utils

# Land Use / Land Cover

This notebook demonstrates applying a land use / land cover classification model to [NAIP](https://www.fsa.usda.gov/programs-and-services/aerial-photography/imagery-programs/naip-imagery/index) imagery stored in Azure Blob storage. We'll use various components of Microsoft's Planetary Computer to facilitate the workload. We'll use a Dask Cluster to distribute the prediction.

Some background: The US Department of Agriculture [publishes a dataset](https://www.fsa.usda.gov/programs-and-services/aerial-photography/imagery-programs/naip-imagery/index) of high-resolution aerial photography. The high resolution makes it great for "land use / land cover" tasks. We maintain a cloud optimized, analysis-ready version of that dataset in Azure.

A Land use / land cover model takes an image and classifies each pixel into a category (e.g. "water", "tree canopy", "road", "structure", etc.). We're using a neural network trained by our teammates in the [AI for Good](https://www.microsoft.com/en-us/ai/ai-for-good) program. We'll use the model to analyze how land use changed over a portion of Maryland from 2013 to 2017.

## Step 1: Cluster Setup

This is a somewhat large computation, and we'll handle the scale in two ways:

1. We'll use a cloud-native workflow, reading data directly from Blob Storage into memory on VMs running in Azure, skipping a slow local download step.
2. We'll use a cluster of machines to perform the computation in parallel.

All of the infrastructure around distributed computing can be challenging. We'll use the Planetary Computer's **managed scalable compute** to start a [Dask Cluster](https://docs.dask.org/en/latest/) for us.

In [3]:
# N_WORKERS = 8
# g = Gateway()
# options = g.cluster_options()
# options['gpu'] = True
# options['worker_memory'] = 64
# options["worker_cores"] = 5
# options['environment'] = {
#     "AZURE_STORAGE_CONNECTION_STRING": os.environ["AZURE_STORAGE_CONNECTION_STRING"],
#     "DASK_DISTRIBUTED__WORKERS__RESOURCES__GPU": 1,
# }

# cluster = g.new_cluster(options)
# client = cluster.get_client()
# cluster.scale(N_WORKERS)

# Create a Local cluster on a single node
from dask_cuda import LocalCUDACluster
from distributed import Client

N_WORKERS=2
cluster = LocalCUDACluster(threads_per_worker=5, resources={"GPU": 1})
client = Client(cluster)

## Step 2: Load the Model

We've stored the model in Blob Storage. We'll load it here.

In [None]:
import azure.storage.blob
from pathlib import Path
import segmentation_models_pytorch
import torch

p = Path("unet_both_lc.pt")
if not p.exists():
    blob_client = azure.storage.blob.BlobClient(
        account_url="https://gtclandcoverdemo.blob.core.windows.net/",
        container_name="models",
        blob_name="unet_both_lc.pt"
    )

    with o.open("wb") as f:
        f.write(blob_client.download_blob().readall())

device = torch.device("cuda")

model = segmentation_models_pytorch.Unet(
    encoder_name='resnet18', encoder_depth=3,
    encoder_weights=None,
    decoder_channels=(128, 64, 64), in_channels=4,
    classes=13
)
model.load_state_dict(torch.load("unet_both_lc.pt",
                                 map_location='cuda:0'))
model = model.to(device)

Each worker needs a copy of the model, so we'll send it to them ahead of time.

In [4]:
remote_model = client.scatter(model, broadcast=True)
del model

## Step 2: Data Discovery

Just *finding* the images you're interested in can be a real challenge. The full dataset consists of millions of individual images, but we only care about a few hundred of them. How do we find the ones we need?

With the Planetary Computer's **metadata query API**, that's straightforward.

In [7]:
from satsearch import Search
import json

with open("aoi.geojson") as f:
    area_of_interest = json.load(f)

fields = [
    "properties.proj:epsg",
    "properties.proj:shape",
    "properties.proj:transform",
    "properties.proj:bbox",
]

def search_for_year(year):
    time_range = f'{year-1}-12-31T00:00:00Z/{year+1}-01-01T00:00:00Z'
    return Search(
        url="https://pct-pqe-staging.westeurope.cloudapp.azure.com/stac/v1",
        intersects=area_of_interest,
        datetime=time_range,
        fields={"include": fields}
    )

search_2013 = search_for_year(2013)
search_2017 = search_for_year(2017)

print('2013: %s items' % search_2013.found())
print('2017: %s items' % search_2017.found())

2013: 425 items
2017: 425 items


That returns the URLs to GEOTiff files stored in Azure Blob Storage.

## Step 3: Aligning Images

We have URLs to many files in blob storage. We want to treat all those as one big, logical dataset, so we'll use some open-source libraries to stitch them all together.

In [8]:
items = search_2013.items()
data_2013 = [x._data for x in items._items]

items = search_2017.items()
data_2017 = [x._data for x in items._items]

naip_2013 = stac_vrt.build_vrt(
    data_2013, block_width=512, block_height=512, data_type="Byte"
)
mosaic_2017 = stac_vrt.build_vrt(
    data_2017, block_width=512, block_height=512, data_type="Byte"
)

In [9]:
a = rasterio.open(naip_2013)
naip_2017 = rasterio.vrt.WarpedVRT(
    rasterio.open(mosaic_2017),
    transform=a.transform,
    height=a.height, width=a.width
)

xarray provides a convenient data structure for working with large, n-dimensional, labeled datasets like this.

In [10]:
ds1 = rioxarray.open_rasterio(
    naip_2013, chunks=(4, 8192, 8192), lock=False
)
ds2 = rioxarray.open_rasterio(
    naip_2017, chunks=(4, 8192, 8192), lock=False
)

ds = xr.concat([ds1, ds2], dim=pd.Index([2013, 2017], name="time"))
ds

Unnamed: 0,Array,Chunk
Bytes,168.33 GB,268.44 MB
Shape,"(2, 4, 149498, 140744)","(1, 4, 8192, 8192)"
Count,2054 Tasks,684 Chunks
Type,uint8,numpy.ndarray
"Array Chunk Bytes 168.33 GB 268.44 MB Shape (2, 4, 149498, 140744) (1, 4, 8192, 8192) Count 2054 Tasks 684 Chunks Type uint8 numpy.ndarray",2  1  140744  149498  4,

Unnamed: 0,Array,Chunk
Bytes,168.33 GB,268.44 MB
Shape,"(2, 4, 149498, 140744)","(1, 4, 8192, 8192)"
Count,2054 Tasks,684 Chunks
Type,uint8,numpy.ndarray


## Step 4: Pre-processing for Neural Network

Now we have a big dataset, that's been pixel-aligned on a grid for the two time periods.
The model requires a bit of pre-processing upfront.

In [11]:
bands = xr.DataArray([1, 2, 3, 4], name="band",
                     dims=["band"], coords={'band': [1, 2, 3, 4]})
NAIP_2013_MEANS = xr.DataArray(
    np.array([117.00, 130.75, 122.50, 159.30], dtype="float32"),
    name="mean",
    coords=[bands]
)
NAIP_2013_STDS = xr.DataArray(
    np.array([38.16, 36.68, 24.30, 66.22], dtype="float32"),
    name="mean",
    coords=[bands],
)
NAIP_2017_MEANS = xr.DataArray(
    np.array([72.84,  86.83, 76.78, 130.82], dtype="float32"),
    name="std",
    coords=[bands],
)
NAIP_2017_STDS = xr.DataArray(
    np.array([41.78, 34.66, 28.76, 58.95], dtype="float32"),
    name="mean",
    coords=[bands],
)

mean = xr.concat([NAIP_2013_MEANS, NAIP_2017_MEANS], dim="time")
std = xr.concat([NAIP_2013_STDS, NAIP_2017_STDS], dim="time")

In [12]:
# Normalize by per-year mean, std
normalized = (ds - mean) / std

# fix up partial chunks
slices = {}
for coord in ["y", "x"]:
    remainder = len(ds.coords[coord]) % 32
    slice_ = slice(-remainder) if remainder else slice(None)
    slices[coord] = slice_

normalized = normalized.isel(**slices)
normalized

Unnamed: 0,Array,Chunk
Bytes,673.15 GB,1.07 GB
Shape,"(2, 4, 149472, 140736)","(1, 4, 8192, 8192)"
Count,4116 Tasks,684 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 673.15 GB 1.07 GB Shape (2, 4, 149472, 140736) (1, 4, 8192, 8192) Count 4116 Tasks 684 Chunks Type float32 numpy.ndarray",2  1  140736  149472  4,

Unnamed: 0,Array,Chunk
Bytes,673.15 GB,1.07 GB
Shape,"(2, 4, 149472, 140736)","(1, 4, 8192, 8192)"
Count,4116 Tasks,684 Chunks
Type,float32,numpy.ndarray


## Step 5: Predicting Land Use for Each Pixel

At this point, we're ready to make predictions. We'll apply the model to the entire dataset, taking care to not over-saturate the GPUs.

In [13]:
def predict_chip(data: torch.Tensor, model) -> torch.Tensor:
    # Input is GPU, output is GPU.
    with torch.no_grad():
        result = model(data).argmax(dim=1).to(torch.uint8)
    return result.to("cpu")


def copy_and_predict_chunked(tile, model, token=None):
    has_time = tile.ndim == 4
    if has_time:
        assert tile.shape[0] == 1
        tile = tile[0]

    slices = dask.array.core.slices_from_chunks(dask.array.empty(tile.shape).chunks)
    out = np.empty(shape=tile.shape[1:], dtype="uint8")
    device = torch.device("cuda")

    for slice_ in slices:    
        gpu_chip = torch.as_tensor(tile[slice_][np.newaxis, ...]).to(device)
        out[slice_[1:]] = predict_chip(gpu_chip, model).cpu().numpy()[0]
    if has_time:
        out = out[np.newaxis, ...]
    return out

In [14]:
meta = np.array([[]], dtype="uint8")[:0]

predictions_array = normalized.data.map_blocks(
    copy_and_predict_chunked,
    meta=meta,
    drop_axis=1,
    model=remote_model,
    name="predict",
)

predictions = xr.DataArray(
    predictions_array,
    coords=normalized.drop_vars("band").coords,
    dims=("time", "y", "x"),
)
predictions

Unnamed: 0,Array,Chunk
Bytes,42.07 GB,67.11 MB
Shape,"(2, 149472, 140736)","(1, 8192, 8192)"
Count,4800 Tasks,684 Chunks
Type,uint8,numpy.ndarray
"Array Chunk Bytes 42.07 GB 67.11 MB Shape (2, 149472, 140736) (1, 8192, 8192) Count 4800 Tasks 684 Chunks Type uint8 numpy.ndarray",140736  149472  2,

Unnamed: 0,Array,Chunk
Bytes,42.07 GB,67.11 MB
Shape,"(2, 149472, 140736)","(1, 8192, 8192)"
Count,4800 Tasks,684 Chunks
Type,uint8,numpy.ndarray


In [15]:
predictions[:, :200, :200].persist()

Unnamed: 0,Array,Chunk
Bytes,80.00 kB,40.00 kB
Shape,"(2, 200, 200)","(1, 200, 200)"
Count,2 Tasks,2 Chunks
Type,uint8,numpy.ndarray
"Array Chunk Bytes 80.00 kB 40.00 kB Shape (2, 200, 200) (1, 200, 200) Count 2 Tasks 2 Chunks Type uint8 numpy.ndarray",200  200  2,

Unnamed: 0,Array,Chunk
Bytes,80.00 kB,40.00 kB
Shape,"(2, 200, 200)","(1, 200, 200)"
Count,2 Tasks,2 Chunks
Type,uint8,numpy.ndarray


Finally, we can compute the result we're interested in: Which pixels (spots on the earth) changed land cover / land use over the four years.

In [16]:
change = predictions.sel(time=2013) != predictions.sel(time=2017)
change

Unnamed: 0,Array,Chunk
Bytes,21.04 GB,67.11 MB
Shape,"(149472, 140736)","(8192, 8192)"
Count,5826 Tasks,342 Chunks
Type,bool,numpy.ndarray
"Array Chunk Bytes 21.04 GB 67.11 MB Shape (149472, 140736) (8192, 8192) Count 5826 Tasks 342 Chunks Type bool numpy.ndarray",140736  149472,

Unnamed: 0,Array,Chunk
Bytes,21.04 GB,67.11 MB
Shape,"(149472, 140736)","(8192, 8192)"
Count,5826 Tasks,342 Chunks
Type,bool,numpy.ndarray


## Step 6: Inspect Model Results

Upon inspection, the model is a bit too sensitive. So we'll smooth the output by flagging a pixel as having changed classification only when all of it's neighbors have *also* changed.

In [17]:
import utils

In [18]:
n_classes = 13

other = (
    n_classes * predictions.sel(time=2013).data + 
    predictions.sel(time=2017).data
)
change2 = np.where(~change.data, 0, other)

smoothed_arr = change2.map_overlap(
    utils.smooth, (3, 3), meta=change2._meta
) != 0
smoothed = xr.DataArray(
    smoothed_arr, coords=change.coords,
    dims=change.dims, attrs=change.attrs
)

changed_predictions = (
    predictions.where(smoothed, other=0)
)
changed_predictions

Unnamed: 0,Array,Chunk
Bytes,42.07 GB,67.11 MB
Shape,"(2, 149472, 140736)","(1, 8192, 8192)"
Count,14422 Tasks,684 Chunks
Type,uint8,numpy.ndarray
"Array Chunk Bytes 42.07 GB 67.11 MB Shape (2, 149472, 140736) (1, 8192, 8192) Count 14422 Tasks 684 Chunks Type uint8 numpy.ndarray",140736  149472  2,

Unnamed: 0,Array,Chunk
Bytes,42.07 GB,67.11 MB
Shape,"(2, 149472, 140736)","(1, 8192, 8192)"
Count,14422 Tasks,684 Chunks
Type,uint8,numpy.ndarray


In [19]:
changed_predictions[:, :200, :200].persist()

Unnamed: 0,Array,Chunk
Bytes,80.00 kB,40.00 kB
Shape,"(2, 200, 200)","(1, 200, 200)"
Count,2 Tasks,2 Chunks
Type,uint8,numpy.ndarray
"Array Chunk Bytes 80.00 kB 40.00 kB Shape (2, 200, 200) (1, 200, 200) Count 2 Tasks 2 Chunks Type uint8 numpy.ndarray",200  200  2,

Unnamed: 0,Array,Chunk
Bytes,80.00 kB,40.00 kB
Shape,"(2, 200, 200)","(1, 200, 200)"
Count,2 Tasks,2 Chunks
Type,uint8,numpy.ndarray


We'll do some visual spot checking of our model. This does require processing the full-resolution images, so we need to limit things to something that fits in memory now.

In [20]:
middle = ds.shape[2] // 2, ds.shape[3] // 2
slice_y = slice(middle[0], middle[0] + 10_000)
slice_x = slice(middle[1], middle[1] + 10_000)

parts = [
    x.isel(y=slice_y, x=slice_x)
    for x in [ds, predictions, changed_predictions]
]
# parts = dask.optimize(*parts)
# keys = dask.core.flatten([
#     list(x.__dask_graph__())
#     for x in parts
# ])
# resources = {
#     str(k): {"GPU": 1} for k in keys
#     if isinstance(k, tuple)
#     and k[0].startswith("predict")
# }

In [21]:
(
    ds_local,
    predictions_local,
    changed_predictions_local
) = dask.compute(*parts)

In [22]:
def logo(plot, element):
    plot.state.toolbar.logo = None

zoom = BoxZoomTool(match_aspect=True)
style_kwargs = dict(
    width=450, height=400, xaxis=False, yaxis=False,
)
kwargs = dict(x="x", y="y", cmap=utils.lc_cmap,
              rasterize=True, aggregator="mode",
              colorbar=False,
              tools=["pan", zoom, "wheel_zoom", "reset"],
              clim=(0, utils.lc_cmap.N - 1))
panel.Column(
    panel.Row(
        ds_local.sel(time=2013).hvplot.rgb(
            bands="band", rasterize=True,
            hover=False, title="NAIP 2013",
            **style_kwargs
        ).opts(default_tools=[], hooks=[logo]),
        changed_predictions_local.sel(time=2013).hvplot.image(
            title="Classification 2013", **kwargs, **style_kwargs
        ).opts(default_tools=[]),
    ),
    panel.Row(
        ds_local.sel(time=2017).hvplot.rgb(
            bands="band", rasterize=True,
            hover=False, title="NAIP 2017",
            **style_kwargs,
        ).opts(default_tools=[], hooks=[logo]),
        changed_predictions_local.sel(time=2017).hvplot.image(
            title="Classification 2017", **kwargs, **style_kwargs
        ).opts(default_tools=[]),
    ),
)

We took this workload and applied it to a larger area: https://pcstoraccount.z6.web.core.windows.net/.