## MOSAIKS feature extraction

This tutorial demonstrates the **MOSAIKS** method for extracting _feature vectors_ from satellite imagery patches for use in downstream modeling tasks. It will show:
- How to extract 1km$^2$ patches of Sentinel 2 or Landsat multispectral imagery for a list of latitude, longitude points
- How to extract summary features from each of these imagery patches
- How to use the summary features in a linear model of the population density at each point

### Background

Consider the case where you have a dataset of latitude and longitude points assosciated with some dependent variable (for example: population density, weather, housing prices, biodiversity) and, potentially, other independent variables. You would like to model the dependent variable as a function of the independent variables, but instead of including latitude and longitude directly in this model, you would like to include some high dimensional representation of what the Earth looks like at that point (that hopefully explains some of the variance in the dependent variable!). From the computer vision literature, there are various [representation learning techniques](https://en.wikipedia.org/wiki/Feature_learning) that can be used to do this, i.e. extract _features vectors_ from imagery. This notebook gives an implementation of the technique described in [Rolf et al. 2021](https://www.nature.com/articles/s41467-021-24638-z), "A generalizable and accessible approach to machine learning with global satellite imagery" called Multi-task Observation using Satellite Imagery & Kitchen Sinks (**MOSAIKS**). For more information about **MOSAIKS** see the [project's webpage](http://www.globalpolicy.science/mosaiks).

### Environment setup
This notebook works with or without an API key, but you will be given more permissive access to the data with an API key.
- If you're running this on the [Planetary Computer Hub](http://planetarycomputer.microsoft.com/compute), make sure to choose the **GPU - PyTorch** profile when presented with the form to choose your environment.
- The Planetary Computer Hub is pre-configured to use your API key.
- To use your API key locally, set the environment variable `PC_SDK_SUBSCRIPTION_KEY` or use `pc.settings.set_subscription_key(<YOUR API Key>)`.
    
**Notes**:
- This example uses either
    - [sentinel-2-l2a data](https://planetarycomputer.microsoft.com/dataset/sentinel-2-l2a)
    - [landsat-c2-l2 data](https://planetarycomputer.microsoft.com/dataset/landsat-c2-l2)
- The techniques used here apply equally well to other remote-sensing datasets.

In [20]:
!pip install -q git+https://github.com/geopandas/dask-geopandas

In [20]:
import warnings
import time
import os
import gc
import calendar
import re

RASTERIO_BEST_PRACTICES = dict(  # See https://github.com/pangeo-data/cog-best-practices
    CURL_CA_BUNDLE="/etc/ssl/certs/ca-certificates.crt",
    GDAL_DISABLE_READDIR_ON_OPEN="EMPTY_DIR",
    AWS_NO_SIGN_REQUEST="YES",
    GDAL_MAX_RAW_BLOCK_CACHE_SIZE="200000000",
    GDAL_SWATH_SIZE="200000000",
    VSI_CURL_CACHE_SIZE="200000000",
)
os.environ.update(RASTERIO_BEST_PRACTICES)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import rasterio
import rasterio.warp
import rasterio.mask
import shapely.geometry
import geopandas
import dask_geopandas
from dask.distributed import Client
import dask.dataframe as dd
import dask_gateway

from pystac import Item
import stackstac
import pyproj

warnings.filterwarnings(action="ignore", category=UserWarning, module="torch")
warnings.filterwarnings(action="ignore", category=FutureWarning)
warnings.filterwarnings(action="ignore", category=RuntimeWarning)
warnings.filterwarnings(action="ignore", category=UserWarning)

import pystac_client
import planetary_computer as pc


# Disabling the benchmarking feature with torch.backends.cudnn.benchmark = False 
# causes cuDNN to deterministically select an algorithm, possibly at the cost of reduced performance.
# https://pytorch.org/docs/stable/notes/randomness.html
torch.backends.cudnn.benchmark = False

np.random.seed(42)
torch.manual_seed(42)

import random
random.seed(42)

## Set Parameters

In [21]:
num_features = 1000
country_code = 'ZMB'
use_file = False #changed to false 
# use_file = False

In [22]:
satellite = "landsat-c2-l2"
#bands = [
    #"red",
    #"green", 
    #"blue",
    #"nir08",
    #"swir16",
    #"swir22"
#]
bands = [    # Landsat-8 Bands
    # "SR_B1", # Coastal/Aerosol Band (B1)
    "SR_B2", # Blue Band (B2)
    "SR_B3", # Green Band (B3)
    "SR_B4", # Red Band (B4)
    "SR_B5", # Near Infrared Band 0.8 (B5)
    "SR_B6", # Short-wave Infrared Band 1.6 (B6)
    "SR_B7" # Short-wave Infrared Band 2.2 (B7)
]

In [6]:
# satellite = "sentinel-2-l2a"
# bands = [  # Sentinel-2 Bands
#     "B02", # B02 (blue) 10 meter
#     "B03", # B03 (green) 10 meter
#     "B04", # B04 (red) 10 meter
#     "B05", # B05(Veg Red Edge 1) 20 meter
#     "B06", # B06(Veg Red Edge 2) 20 meter
#     "B07", # B07(Veg Red Edge 3) 20 meter
#     "B08", # B08 (NIR) 10 meter
#     "B11", # B11 (SWIR (1.6)) 20 meter
#     "B12", # B12 (SWIR (2.2)) 20 meter
# ]

In [23]:
if satellite == "landsat-c2-l2":
    resolution = 30
    min_image_edge = 6
else:
    resolution = 10
    min_image_edge = 20

In [24]:
channels = len(bands)
dat_re = re.compile(r'\d+') 
l = [str(int(dat_re.search(x).group())) for x in bands if dat_re.search(x)]
bands_short = '-'.join(l)

In [5]:
#channels = len(bands)
#bands_short = "r-g-b-nir-swir16-swir22"

## Create grid and sample points to featurize

In [25]:
if use_file:
    gdf = pd.read_feather('data/keep/ZMB_crop_weights_20k-points.feather')
    gdf = (
        geopandas
        .GeoDataFrame(
            gdf, 
            geometry = geopandas.points_from_xy(x = gdf.lon, y = gdf.lat), 
            crs='EPSG:4326')
    )
else:
    cell_size = 0.01  # Roughly 1 km
    ### get country shape
    country_file_name = "~/PlanetaryComputerExamples/Capstone/Featurization/geoBoundaries-ZMB-ADM0-all(1)/geoBoundaries-ZMB-ADM0_simplified.geojson"
    zambia = geopandas.read_file(country_file_name)
    gdf_sea = geopandas.read_file('~/PlanetaryComputerExamples/Featurization/SEAs/features.geojson', crs = 'EPSG:4326')

    ### Create grid of points
    cell_size = .01  # Very roughly 1 km
    xmin, ymin, xmax, ymax = gdf_sea.total_bounds
    xs = list(np.arange(xmin, xmax + cell_size, cell_size))
    ys = list(np.arange(ymin, ymax + cell_size, cell_size))
    def make_cell(x, y, cell_size):
        ring = [
            (x, y),
            (x + cell_size, y),
            (x + cell_size, y + cell_size),
            (x, y + cell_size)
        ]
        cell = shapely.geometry.Polygon(ring).centroid
        return cell


In [14]:
cluster = dask_gateway.GatewayCluster()
client = cluster.get_client()
cluster.adapt(minimum=2, maximum=50)
print(cluster.dashboard_link)

https://pccompute.westeurope.cloudapp.azure.com/compute/services/dask-gateway/clusters/prod.ffe3eaebe4b046559284045b27c49095/status


In [8]:
center_points = []
for x in xs:
    for y in ys:
        cell = make_cell(x, y, cell_size)
        center_points.append(cell)
            
    ### Put grid into a GeDataFrame for cropping to country shape
gdf = geopandas.GeoDataFrame({'geometry': center_points}, crs = 'EPSG:4326')
gdf['lon'], gdf['lat'] = gdf.geometry.x, gdf.geometry.y
#gdf = gdf.sample(frac=0.1, random_state=43, ignore_index=False) 

    # convert GeoDataFrame to Dask GeoDataFrame
dask_gdf = dask_geopandas.from_geopandas(gdf, npartitions=47)

# use Dask to apply the within() method in parallel
dask_gdf_within = dask_gdf.map_partitions(
    lambda partition: partition[partition.geometry.within(gdf_sea.unary_union)],
    meta=gdf.head(0)
)

# convert Dask GeoDataFrame back to GeoDataFrame
gdf = dask_gdf_within.compute()

points = gdf[["lon", "lat"]].to_numpy()
pt_len = gdf.shape[0]
gdf.shape

(14452, 3)

In [13]:
#gdf.to_file('sea.geojson', driver = 'GeoJSON')

In [26]:
gdf = geopandas.read_file('~/PlanetaryComputerExamples/Capstone/Featurization/sea.geojson')
points = gdf[["lon", "lat"]].to_numpy()
pt_len = gdf.shape[0]
gdf.shape

(14452, 3)

In [9]:
cluster.close()

First we define the pytorch model that we will use to extract the features and a helper method. The **MOSAIKS** methodology describes several ways to do this and we use the simplest.

In [27]:
class RCF(nn.Module):
    """A model for extracting Random Convolution Features (RCF) from input imagery."""
    def __init__(self, num_features=16, kernel_size=3, num_input_channels=channels):
        super(RCF, self).__init__()
        # We create `num_features / 2` filters so require `num_features` to be divisible by 2
        assert num_features % 2 == 0, "Please enter an even number of features."
        # Applies a 2D convolution over an input image composed of several input planes.
        self.conv1 = nn.Conv2d(
            num_input_channels,
            num_features // 2,
            kernel_size=kernel_size,
            stride=1,
            padding=0,
            dilation=1,
            bias=True,
        )
        # Fills the input Tensor 'conv1.weight' with values drawn from the normal distribution
        nn.init.normal_(self.conv1.weight, mean=0.0, std=1.0)
        # Fills the input Tensor 'conv1.bias' with the value 'val = -1'.
        nn.init.constant_(self.conv1.bias, -1.0)
    def forward(self, x):
        # The rectified linear activation function or ReLU for short is a piecewise linear function 
        # that will output the input directly if it is positive, otherwise, it will output zero.
        x1a = F.relu(self.conv1(x), inplace=True)
        # The below step is where we take the inverse which is appended later
        x1b = F.relu(-self.conv1(x), inplace=True)
        # Applies a 2D adaptive average pooling over an input signal composed of several input planes.
        x1a = F.adaptive_avg_pool2d(x1a, (1, 1)).squeeze()
        x1b = F.adaptive_avg_pool2d(x1b, (1, 1)).squeeze()
        if len(x1a.shape) == 1:  # case where we passed a single input
            return torch.cat((x1a, x1b), dim=0)
        elif len(x1a.shape) == 2:  # case where we passed a batch of > 1 inputs
            return torch.cat((x1a, x1b), dim=1)

Next, we initialize the model and pytorch components

In [28]:
device = torch.device("cuda")
model = RCF(num_features).eval().to(device)

### Extract features from the imagery around each point

We need to find a suitable Sentinel 2 scene for each point. As usual, we'll use `pystac-client` to search for items matching some conditions, but we don't just want do make a `.search()` call for each of the 67,968 remaining points. Each HTTP request is relatively slow. Instead, we will *batch* or points and search *in parallel*.

We need to be a bit careful with how we batch up our points though. Since a single Sentinel 2 scene will cover many points, we want to make sure that points which are spatially close together end up in the same batch. In short, we need to spatially partition the dataset. This is implemented in `dask-geopandas`.

So the overall workflow will be

1. Find an appropriate STAC item for each point (in parallel, using the spatially partitioned dataset)
2. Feed the points and STAC items to a custom Dataset that can read imagery given a point and the URL of a overlapping S2 scene
3. Use a custom Dataloader, which uses our Dataset, to feed our model imagery and save the corresponding features

In [29]:
NPARTITIONS = 250

ddf = dask_geopandas.from_geopandas(gdf, npartitions=1)
hd = ddf.hilbert_distance().compute()
gdf["hd"] = hd
gdf = gdf.sort_values("hd")

dgdf = dask_geopandas.from_geopandas(gdf, npartitions=NPARTITIONS, sort=False)

In [30]:
del ddf
del hd
del gdf
gc.collect()

3040014

In [31]:
%%time

start_month = 7
year_start = 2015
year_end = 2021

buffer_size = 0.005 #do we need a buffer for SEAs?
cloud_limit = 20

batch_size = 30
workers = os.cpu_count()

print("Using:  \n", 
      f"  Satellite: {satellite}  \n",
      f"  Pixel Resolution: {resolution}  \n",
      f"  Grid Resolution: {buffer_size * 2} degree squared (WGS84) \n",
      f"  Cloud Limit: less than {cloud_limit}%  \n",
      f"  Bands: {bands} \n",
      f"  Points: {pt_len} \n",
      f"  Number Features: {num_features} features \n",
      f"  Year Range: {year_start} to {year_end} \n")

for yr in range(year_start, year_end+1):
    features = pd.DataFrame()
    ft = []
    
    if (yr == year_start):
        month_range = range(start_month, 13)
    else:
        month_range = range(1, 13) 
        
    for mn in month_range:

        if mn < 10:
            month = "0"+str(mn)
        else:
            month = mn

        def query(points):
            """
            Find a STAC item for points in the `points` DataFrame

            Parameters
            ----------
            points : geopandas.GeoDataFrame
                A GeoDataFrame

            Returns
            -------
            geopandas.GeoDataFrame
                A new geopandas.GeoDataFrame with a `stac_item` column containing the STAC
                item that covers each point.
            """
            intersects = shapely.geometry.mapping(points.unary_union.convex_hull)

            catalog = pystac_client.Client.open(
                "https://planetarycomputer.microsoft.com/api/stac/v1"
            )
            # Define search date range for query
            ending_day = calendar.monthrange(yr, int(mn))[1]
            search_start = f"{yr}-{month}-1" 
            search_end = f"{yr}-{month}-{ending_day}" 
            
            # The time frame in which we search for non-cloudy imagery
            search = catalog.search(
                collections=[satellite],  
                intersects=intersects,
                datetime=[search_start, search_end],
                query={"eo:cloud_cover": {"lt": cloud_limit},
                      "platform": {"in": ["landsat-8"]}},
                limit=500,
            )
            ic = search.get_all_items_as_dict()
            features = ic["features"]
            features_d = {item["id"]: item for item in features}
            data = {
                "eo:cloud_cover": [],
                "geometry": [],
            }
            index = []
            for item in features:
                data["eo:cloud_cover"].append(item["properties"]["eo:cloud_cover"])
                data["geometry"].append(shapely.geometry.shape(item["geometry"]))
                index.append(item["id"])
            items = geopandas.GeoDataFrame(data, index=index, geometry="geometry").sort_values(
                "eo:cloud_cover"
            )
            point_list = points.geometry.tolist()
            point_items = []
            for point in point_list:
                covered_by = items[items.covers(point)]
                if len(covered_by):
                    point_items.append(features_d[covered_by.index[0]])
                else:
                    # There weren't any scenes matching our conditions for this point (too cloudy)
                    point_items.append(None)
            return points.assign(stac_item=point_items)

        tic = time.time()
        print("Matching images to points for: ", mn, "-", yr, sep = "")

        with Client(n_workers=16) as client:
            meta = dgdf._meta.assign(stac_item=[])
            df2 = dgdf.map_partitions(query, meta=meta).compute()
        df3 = df2.dropna(subset=["stac_item"]).reset_index(drop = True)

        matching_items = []
        for item in df3.stac_item.tolist():
            signed_item = pc.sign(Item.from_dict(item))
            matching_items.append(signed_item)


        points = df3[["lon", "lat"]].to_numpy()
        
        print("Found acceptable images for ", 
              points.shape[0], "/", pt_len,
              " points in ", 
              f"{time.time()-tic:0.2f} seconds", 
              sep = "")


        class CustomDataset(Dataset):
            def __init__(self, points, items, buffer=buffer_size):
                self.points = points
                self.items = items
                self.buffer = buffer

            def __len__(self):
                return self.points.shape[0]

            def __getitem__(self, idx):

                lon, lat = self.points[idx]
                fn = self.items[idx]

                if fn is None:
                    return None
                else:
                    try:
                        stack = stackstac.stack(
                            fn,
                            assets=bands,
                            resolution=resolution,
                        )
                        x_min, y_min = pyproj.Proj(stack.crs)(lon-self.buffer, lat-self.buffer)
                        x_max, y_max = pyproj.Proj(stack.crs)(lon+self.buffer, lat+self.buffer)
                        aoi = stack.loc[..., y_max:y_min, x_min:x_max]
                        data = aoi.compute(
                            scheduler="single-threaded"
                            )
                        out_image = data.data 
                        out_image = ((out_image - out_image.min()) ) / (out_image.max() - out_image.min())
                    except ValueError:
                        pass
                    out_image = torch.from_numpy(out_image).float()
                    return out_image

        dataset = CustomDataset(points, matching_items)

        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=workers,
            collate_fn=lambda x: x,
            pin_memory=False,
            # persistent_workers=True,
        )

        x_all = np.zeros((points.shape[0], num_features), dtype=float)
        tic = time.time()
        toc = time.time()
        i = 0
        print("Featurizing: ", month, "-", yr, sep = "")
        for images in dataloader:
            for image in images:

                if i % 1000 == 0:
                    print(
                        f"{i}/{points.shape[0]} -- {i / points.shape[0] * 100:0.2f}%"
                        + f" -- {time.time()-tic:0.2f} seconds"
                    )
                    tic = time.time()

                    # LS 8 scene size is 185 km x 180 km

                if image is not None:
                    # each image should have dim (time, bands, height, width) so len(image.shape) == 4
                    # with only 1 timestamp (image.shape[0] == 1)
                    # Ideally an image.shape will be (1, 7, 33, 34)
                    assert len(image.shape) == 4, image.shape[0] == 1
                    # A full image should be ~33x34 pixels (i.e. ~1km^2 at a 30m/px spatial
                    # resolution), however we can receive smaller images if an input point
                    # happens to be at the edge of a Landsat scene (a literal edge case). To deal
                    # with these (edge) cases we crudely drop all images where the spatial
                    # dimensions aren't both greater than 20 pixels.

                    # if type(image) == torch.Tensor: 
                    try:
                        if image.shape[2] >= min_image_edge and image.shape[3] >= min_image_edge:
                            image = image.to(device)
                            with torch.no_grad():
                                feats = model(image).cpu().numpy()
                            x_all[i] = feats
                        else:
                            # this happens if the point is close to the edge 
                            # of a scene (one or both of the spatial dimensions
                            # of the image are very small)
                            pass
                    except ValueError: 
                        pass 
                else:
                    pass  # this happens if we do not find a S2 scene for some point
                i += 1
                
                
                torch.cuda.empty_cache()
                
            # torch.cuda.empty_cache()
                
                
                
                
                
        features_monthly = pd.DataFrame(x_all)
        features_monthly[["lon", "lat"]] = points.tolist()
        features_monthly['year'] = yr
        features_monthly['month'] = mn
        
        # ft.append(features_monthly)

        features_monthly.columns = features_monthly.columns.astype(str)
        
        # Save the features to a feather file
        file_name = (f'data/{satellite}_bands-{bands_short}_{country_code}_{pt_len/1000:.0f}'+
                    f'k-points_{num_features}-features_{yr}_{mn}.feather')
        
        print("Saving file as:", file_name)
        features_monthly.to_feather(file_name)
        
        # Free memory before loop iterates
        print("Freeing RAM")
        del meta
        del query
        del df2
        del df3
        del points
        del dataset
        del dataloader
        del x_all
        del features_monthly
        del CustomDataset
        gc.collect()
        print(f"Done in {(time.time()-toc)/60:0.2f} minutes")
        print('')
    # features = pd.concat(ft).reset_index(drop = True)
    
    # features.columns = features.columns.astype(str)
    
    # # Save the features to a feather file
    # file_name = (f'data/{satellite}_bands-{bands_short}_{country_code}_{pt_len/1000:.0f}'+
    #              f'k-points_{num_features}-features_{yr}.feather')
    
    # print("Saving file as:", file_name)
    # features.to_feather(file_name)
    
    # display(FileLink(file_name))
    
    print("Save finished!")
    # Free memory before loop iterates
    print("Freeing RAM")
    # del features
    # del ft
    # gc.collect()
    print('')

Using:  
   Satellite: landsat-c2-l2  
   Pixel Resolution: 30  
   Grid Resolution: 0.01 degree squared (WGS84) 
   Cloud Limit: less than 20%  
   Bands: ['SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7'] 
   Points: 14452 
   Number Features: 1000 features 
   Year Range: 2015 to 2021 

Matching images to points for: 7-2015
Found acceptable images for 14452/14452 points in 58.43 seconds
Featurizing: 07-2015


TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/srv/conda/envs/notebook/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
    data = fetcher.fetch(index)
  File "/srv/conda/envs/notebook/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/srv/conda/envs/notebook/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "<timed exec>", line 138, in __getitem__
  File "/srv/conda/envs/notebook/lib/python3.9/site-packages/stackstac/stack.py", line 287, in stack
    asset_table, spec, asset_ids, plain_items = prepare_items(
  File "/srv/conda/envs/notebook/lib/python3.9/site-packages/stackstac/prepare.py", line 333, in prepare_items
    out_bounds = geom_utils.snapped_bounds(out_bounds, out_resolutions_xy)
  File "/srv/conda/envs/notebook/lib/python3.9/site-packages/stackstac/geom_utils.py", line 72, in snapped_bounds
    minx, miny, maxx, maxy = bounds
TypeError: cannot unpack non-iterable NoneType object
