## LandCoverNet Data Preparation

<img src='https://radiant-assets.s3-us-west-2.amazonaws.com/PrimaryRadiantMLHubLogo.png' alt='Radiant MLHub Logo' width='300'/>

This tutorial delves into building a scalable model on the LandCoverNet dataset.

This portion of the tutorial is focused on developing a semantic segmentation model for LandCoverNet data
Here:

1. We will inspect the source imagery for the labels we have

2. We will process the source imagery in parallel using Dask

3. We will select the labels and filtered source images from Dask to be loaded 

4. We will save the images and associated labels data as a `pickle` file ('.pkl') on our directory to be loaded for model training

#### Store your MLHub API Developer Key

In [1]:
import getpass

MLHUB_API_KEY = getpass.getpass(prompt="MLHub API Key: ")
MLHUB_ROOT_URL = "https://api.radiant.earth/mlhub/v1"

MLHub API Key:  ································································


In [2]:
import pystac
import os
import itertools as it
import pystac_client
import requests
import shapely.geometry
from shapely.geometry import mapping, shape
import rioxarray
from pystac import Item
from typing import List, Tuple
from datetime import timedelta, datetime

import stackstac
import rasterio as rio
import rasterio.plot
import numpy as np
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import pandas as pd
import pickle
import dask
import dask_gateway
from pystac.item_collection import ItemCollection
from pystac.extensions.eo import EOExtension

import contextlib
from pystac.extensions.label import LabelExtension, LabelRelType
from urllib.parse import urljoin

import matplotlib.pyplot as plt

warnings.filterwarnings("ignore", "Creating an ndarray from ragged")

#### Instantiate an instance of the MLHub API Client

In [3]:
mlhub_client = pystac_client.Client.open(
    url=MLHUB_ROOT_URL,
    parameters={"key": MLHUB_API_KEY},
    ignore_conformance=True
)

In [4]:
tmp_dir = "/home/jovyan/PlanetaryComputerExamples"
if not os.path.isdir(f"{tmp_dir}/landcovnet/labels"):
    os.makedirs(f"{tmp_dir}/landcovnet/labels") #create folder for labels to be stored

#### Loading the source imagery

The esip-summer-2021-geospatial-ml tutorial was helpful in this task, which can be found [here](https://github.com/TomAugspurger/esip-summer-2021-geospatial-ml/blob/main/segmentation-model.ipynb)

It was particularly useful for loading the STAC items and Sentinel-2 scenes

In [5]:
#check for data in collection file
catalog = pystac.read_file(
    tmp_dir+"/landcovnet/labels/ref_landcovernet_v1_labels/collection.json"
)

In [6]:
from dask.distributed import Client

client = Client()
client.run(lambda: warnings.filterwarnings("ignore", "Creating an ndarray from ragged"))

{'tcp://127.0.0.1:32979': None,
 'tcp://127.0.0.1:36787': None,
 'tcp://127.0.0.1:37301': None,
 'tcp://127.0.0.1:45593': None}

In [7]:
#client.shutdown()

In [8]:
gateway = dask_gateway.Gateway()
options = gateway.cluster_options()
options["worker_cores"] = 6

In [9]:
#client.shutdown() #use this if you want to shutdown dask client

Here, we will select `n` labels with their respective source images for processing

In [10]:
links = catalog.get_item_links() #links from the catalog
label_items = [link.resolve_stac_object().target for link in links]

In [11]:
len(label_items)

1980

In [12]:
def calculate_cloud_cover(img_arr: np.ndarray) -> int:
        
    """Takes a chip cloud cover band and returns the integer score by dividing the sum of normalized values by the chip area (HxW)
    
    Args:
    img_arr: np.ndarray - 2d array of cloud cover mask
    
    Returns:
    arr_cc: int - integer value of cloud cover score
    
    """
    CHIP_AREA = 256 * 256
    arr_filled = np.nan_to_num(img_arr)
    arr_norm = arr_filled / 100
    arr_sum = arr_norm.sum()
    arr_cc = arr_sum / CHIP_AREA * 100
    return int(arr_cc)

In [13]:
def get_median_date(id_arr: np.ndarray) -> int:
    
    """Takes a 2d array of source Item IDs for a quarter, and returns median date 
    
    Args: id_arr: np.ndarray - 2d array of string values for source Item IDs
    
    Returns:
    median_date: int - the calculated median date value for input array
    
    """
    
    dates = [int(s[-8:]) for s in id_arr]
    dates.sort()

    n = len(dates)
    
    # case in which multiple items returned
    if n > 1:
        if n % 2 == 0:
            mid = int(n / 2)
        else:
            mid = int((n + 1) / 2)
        median_date = dates[mid]
    # base case there is only one source item
    elif n == 1:
        median_date = dates[0]
    # base case there are no source items
    else:
        median_date = 0
        #print('No dates returned from search criteria')
        
    return median_date

In [14]:
def filter_quarter_items(cc_df: pd.DataFrame) -> pd.DataFrame:
    
    """Takes a dataframe of source Items with metadata and filters on ranked cloudcover by quarter/season
    
    Args:
    cc_df: pd.DataFrame - unfiltered dataframe
    
    Returns:
    filtered_df: pd.Dataframe - filtered dataframe
    
    """
    
    # assigns quarter and rank by quarter
    cc_df['date_time'] = pd.to_datetime(cc_df['date_time'])
    cc_df['quarter'] = cc_df['date_time'].dt.quarter
    cc_df['rank'] = cc_df.groupby("quarter")["cloud_cover"].rank(method="min", ascending=True)
    # print(cc_df.groupby('quarter').size())

    id_prefix = cc_df.iloc[0]['id'][:-8]
    median_dates = []

    # filters DataFrame on rank
    min_cc_df = cc_df[cc_df['rank']==1]
    # print(min_cc_df.groupby('quarter').size())

    # for each quarter in year, get the median date of source items
    for i in range(1, 5):
        quarter_df = min_cc_df[min_cc_df['quarter']==i]
        quarter_median_date = get_median_date(quarter_df['id'].values)
        quarter_median_id = id_prefix + str(quarter_median_date)
        print(f'The median quarter date is: {quarter_median_id}')
        median_dates.append(quarter_median_id)

    # filter the ranked DataFrame by median date
    filtered_df = min_cc_df[min_cc_df['id'].isin(median_dates)]
    return filtered_df

In [15]:
def get_season_min_cloud_cover(item_list: List[Item]) -> ItemCollection:
    
    """Takes a list of source Items and returns a single chip per season
    ranked by the minimum cloud cover from eo:cloud_cover property
    
    Args:
    item_list: List[Item] - iterable of source Items returned from search
    
    Returns:
    ItemCollection - STAC Iterable containing Items filtered by cloud cover
    """
    
    # constructs a DataFrame of each source item properties
    df_list = []
    for ui in item_list:
        if 'eo:cloud_cover' in ui.properties:
            # print('eo:cloud_cover in item properties')
            cloud_cover = ui.properties['eo:cloud_cover']
        else:
            # print('eo:cloud_cover not in item properties, manually calculating')
            cloud_cover = calculate_cloud_cover(rio.open(ui.get_assets()['CLD'].href).read())
        uid = {
            'item': ui,
            'id': ui.id,
            'cloud_cover': cloud_cover,
            'date_time': ui.datetime
        }
        df_list.append(uid)
        
    cc_df = pd.DataFrame(df_list)
    
    # filters source items by cloud cover rank and returns ItemCollection
    if not cc_df.empty:
        filtered_df = filter_quarter_items(cc_df)
        
        return ItemCollection(filtered_df['item'].tolist())
    
    return None

In [16]:
def get_label_item_collection(label_item: Item) -> ItemCollection:
    
    """Takes a label Item from the LandCoverNet Collection and searches
    for source imagery for chips that match spatial and temporal criteria
    
    Args:
    label_item: Item - item of current iteration in the get_item() Dask parallelization
    
    Returns:
    ItemCollection - STAC Iterable containing Items that match search criteria
    """
    
    n = 0
    cc_thresh = 10
    year_collection = ItemCollection([])
    
    # iterate over each start and end date per quarter
    for start, end in quarter_ranges:
    
        while n == 0:

            # performs a temporal and spatial search for each label item
            search = mlhub_client.search(
                collections=['ref_landcovernet_v1_source'],
                intersects=mapping(shape(label_item.geometry)),
                datetime=[start, end],
                query={"eo:cloud_cover": {"lt": cc_thresh}},
            )

            # converts search results to ItemCollection
            item_results = search.get_all_items()
            # print(f'Search resulted in {len(item_results)} items between {start} and {end}')
            
            if not item_results:
                # print(f'Search criteria for {label_item.id} using cloud cover threshold of {cc_thresh} between {start} and {end} did not return any source items')
                cc_thresh += 5
            else:
                n = len(item_results)
                
        year_collection += item_results # concatenate ItemCollections for each quarter
        n = 0 # reset the length criteria for search results
        
    filtered_items = get_season_min_cloud_cover(year_collection.items)
    
    return filtered_items

In [17]:
#code reference from https://github.com/TomAugspurger/esip-summer-2021-geospatial-ml/blob/main/segmentation-model.ipynb

#This function will load source imagery and label into xarray for further processing
def get_item(label_item: Item, assets: Tuple[str]) -> (np.ndarray, np.ndarray):
    
    """Takes label Item and asset bands to construct n-darrays for model training
    
    Args:
    label_item: Item - item of current iteration in the get_item() Dask parallelization
    assets: Tuple[str] - a set of strings corresponding to the Asset band names
    
    Returns:
    data: np.ndarray, labels: np.ndarray - X and y n-darrays for model training
    """
    
    assets = list(assets)
    labels = rioxarray.open_rasterio(
        tmp_dir+"/landcovnet/labels/ref_landcovernet_v1_labels/"+label_item.id+"/labels.tif",
    ).squeeze()
    
    source_item_collection = get_label_item_collection(label_item)
    
    if len(source_item_collection) > 0:
    
        bounds = tuple(round(x, 0) for x in labels.rio.bounds())
        
        data = (
                stackstac.stack(
                    items=source_item_collection,
                    assets=assets,
                    dtype="float32",
                    resolution=10,
                    bounds=bounds,
                    epsg=labels.rio.crs.to_epsg(),
                )
            )
            #assert data.shape[1:] == labels.shape
        data = data.assign_coords(x=labels.x.data, y=labels.y.data)
        data /= 4000
        data = np.clip(data, 0,1)
        
        return data, labels.astype("int64")

In [18]:
def get_quarter_ranges() -> List[List[str]]:
    
    """Builds a list of start and end date ranges for each quarter in the year
    
    Args: None
    Returns:
    quarter_ranges: List[List[str]] - a list of pairs of strings representing the start and end dates
    
    """
    
    td = timedelta(days=1)
    
    quarter_ends = [pd.to_datetime(d) for d in pd.date_range(temporal_start, temporal_end, freq='Q').values]
    quarter_ranges = []

    for ix, quarter_end in enumerate(quarter_ends):
        if ix == 0:
            quarter_ranges.append([datetime.strptime(temporal_start, '%Y-%m-%d').strftime('%Y-%m-%d'), quarter_end.strftime('%Y-%m-%d')])
            # print(datetime.strptime(temporal_start, '%Y-%m-%d'), quarter_end)
        else:
            quarter_ranges.append([(quarter_ends[ix-1] + td).strftime('%Y-%m-%d'), quarter_end.strftime('%Y-%m-%d')])
            # print(quarter_ends[ix-1] + td, quarter_end)
    return quarter_ranges

#### Extracting Source Imagery

Let's find the the source imagery associated with those labels by examining the Item links (source imagery links will have a `"rel"` type of `"source"`.

#### Loading the source imagery from their respective links

In [19]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [20]:
#obtain bounding boxes for each label
temporal_start = catalog.extent.temporal.intervals[0][0].strftime("%Y-%m-%d") # global starting datetime for label Collection
temporal_end = catalog.extent.temporal.intervals[0][1].strftime("%Y-%m-%d") # global ending datetime for label Collection
quarter_ranges = get_quarter_ranges()
assets = ("B04", "B03", "B02") # we will make use of the RGB bands

In [None]:
%%time
    
Xys_list=[]

chunk_size = 20
for i in range(0, len(label_items), chunk_size):
    label_chunk=label_items[i:i+chunk_size]

    Xys=[]
    get_item_ = dask.delayed(get_item, nout=5)

    Xys.append([get_item_(label, assets) for label in label_chunk])
    Xys = dask.persist(*Xys)
    Xys = dask.compute(*Xys)
    Xys_list.append(Xys[0])

In [23]:
client.shutdown()

In [21]:
flat_list = [item for sublist in Xys_list for item in sublist]
#pickle.dump((flat_list), open(f'{tmp_dir}/landcovnet/items' + '.pkl', 'ab'))

distributed.client - ERROR - Failed to reconnect to scheduler after 30.00 seconds, closing client
_GatheringFuture exception was never retrieved
future: <_GatheringFuture finished exception=CancelledError()>
asyncio.exceptions.CancelledError


In [22]:
pickle.dump((flat_list), open(f'{tmp_dir}/landcovnet/items' + '.pkl', 'ab'))