# Search STAC API
This notebook uses the **PySTAC Client** (```pystac-client```, https://pystac-client.readthedocs.io/en/latest/) Python package to find suitable Sentinel-2 MSIL2A data by using a query with a number of defined paramteres. 

The STAC API for AWS is available on the endpoint (https://stac-api-dev.terradue.com/).

# Set-up

In [1]:
import pystac
from pystac_client import Client

from shapely.geometry import box, mapping, Polygon
import shapely.wkt

import os
import numpy as np
from datetime import datetime

import matplotlib.pyplot as plt
import matplotlib.colors as colors

import xarray as xr
import stackstac

from helpers import *

In [2]:
# Access to Catalog 
URL = 'https://stac-api-dev.terradue.com/'

headers = []

cat = Client.open(URL, headers=headers)
# cat.set_conforms_to('https://api.stacspec.org/v1.0.0-rc.1/item-search')
cat

## Define params


In [56]:
# Collection 
collections=["ai-extensions-svv-dataset-labels"]

# Start and End dates of the item
start_date_str_it = '2023-08-01'; start_date_it = datetime.fromisoformat(start_date_str_it)
stop_date_str_it = '2023-08-15'; stop_date_it = datetime.fromisoformat(stop_date_str_it)

# Start and End dates of the EO data
start_date_str_eo = '2023-06-01'; start_date_eo = datetime.fromisoformat(start_date_str_eo)
stop_date_str_eo = '2023-06-15'; stop_date_eo = datetime.fromisoformat(stop_date_str_eo)

# BBOX or AOI
# bbox = [-74.9211627, 17.818496, -72.2759771, 19.654895]
aoi='POLYGON((-120.90 37.55,-120.32 37.55,-120.32 37.10,-120.90 37.10,-120.90 37.55))'; bbox=shapely.wkt.loads(aoi).bounds 

# Other metadata
cloud_cover = 3

# Define EPSG code
epsg = 'EPSG:4326'

## Query the Catalog
Query on the catalog is for general params AOI and TOI (depending which time is recorded on the `item.properties`). For other params such as cloud cover, bands etc, will be used for filtering out the items once they are extracted from the STAC API endpoint.

In [4]:
# Query by AOI, start/end date of the item
query = cat.search(
    collections=collections, 
    datetime = (start_date_it, stop_date_it), # this is likely the datetime of the item creation, not the labels or EO data within them 
    sortby="properties.datetime",
    bbox=bbox, 
)

# items_stack = query.get_all_items()
items = query.item_collection()

print(f'There are {len(items)} items in the collection')
display(items)

There are 4 items in the collection


## Show map with folium

In [5]:
mymap = showMap_BBOX(items)
mymap

## Now filter items with other params (eg cloud_cover)

In [53]:
display(items)

In [61]:
cloud_cover = 1.9

In [62]:
# Get href of the selected S2 scenes if cc lower than threshold
s2_hrefs = []
index_to_remove = []

for index, item in enumerate(items):
    item_id = item.id
    s2_href = [l.href for l in item.links if l.rel == 'source'][0]
    
    # Read STAC Item to check metadata
    s2item = pystac.read_file(s2_href)
    print(s2item.properties['eo:cloud_cover'])
    if s2item.properties['eo:cloud_cover'] <= cloud_cover: 
        s2_hrefs.append(s2_href)
    else:         
        # Save item_id to remove 
        index_to_remove.append(index)

print(index_to_remove)
display(items)
print(s2_hrefs)

0.03861
1.799635
3.357495
1.974181
[2, 3]


['https://earth-search.aws.element84.com/v1/collections/sentinel-2-l2a/items/S2A_10SFG_20230618_0_L2A', 'https://earth-search.aws.element84.com/v1/collections/sentinel-2-l2a/items/S2A_10SGG_20230618_0_L2A']


pystac.item_collection.ItemCollection

In [63]:
# Remove the item if found
if len(index_to_remove) > 0:
    items.remove(index_to_remove)
display(items)

AttributeError: 'ItemCollection' object has no attribute 'remove'

In [42]:
# Read STAC Item
s2item = pystac.read_file(s2_hrefs[0])
display(s2item.properties)

{'created': '2023-06-19T05:09:52.909Z',
 'platform': 'sentinel-2a',
 'constellation': 'sentinel-2',
 'instruments': ['msi'],
 'eo:cloud_cover': 0.03861,
 'proj:epsg': 32610,
 'mgrs:utm_zone': 10,
 'mgrs:latitude_band': 'S',
 'mgrs:grid_square': 'FG',
 'grid:code': 'MGRS-10SFG',
 'view:sun_azimuth': 132.17348750068,
 'view:sun_elevation': 70.6778731316069,
 's2:degraded_msi_data_percentage': 0.0239,
 's2:nodata_pixel_percentage': 0,
 's2:saturated_defective_pixel_percentage': 0,
 's2:dark_features_percentage': 0.024459,
 's2:cloud_shadow_percentage': 0.018825,
 's2:vegetation_percentage': 37.547791,
 's2:not_vegetated_percentage': 60.54256,
 's2:water_percentage': 1.067375,
 's2:unclassified_percentage': 0.760376,
 's2:medium_proba_clouds_percentage': 0.031788,
 's2:high_proba_clouds_percentage': 0.004439,
 's2:thin_cirrus_percentage': 0.002382,
 's2:snow_ice_percentage': 0,
 's2:product_type': 'S2MSI2A',
 's2:processing_baseline': '05.09',
 's2:product_uri': 'S2A_MSIL2A_20230618T184921

In [16]:
for item in items:
    print(items['features'])

TypeError: list indices must be integers or slices, not str

In [11]:
eo_data = [item.links['href'] for item in items if item.links['rel'] == 'source']
eo_data

TypeError: list indices must be integers or slices, not str

In [None]:
# Query by AOI, start/end date of the item
query = cat.search(
    collections=collections, 
    datetime = (start_date_it, stop_date_it), # this is likely the datetime of the item creation, not the labels or EO data within them 
    sortby="properties.datetime",
    bbox=bbox, 
    # query = {
    #     'eo:cloud_cover': {'lt': 5}, # cloud_cover less than 15 percent by default
    #     # 'sentinel:product_id': {'eq': product_id}
    # }
)

In [None]:
# Display properties of one item 
items[0]
#items[0].properties

In [None]:
# print id of the first item
index = 0
item_id = items[index].id 
print(f'id: {item_id}')
date = items[index].datetime
print(f'date: {date}')

# print available bands
print(f'bands: {list(items[index].assets.keys())}')

In [None]:
import os
from osgeo import gdal, ogr, osr
import pyproj
import numpy as np
import pandas as pd
import json

gdal.UseExceptions()

In [None]:
os.environ['PROJ_DATA']='/workspace/.conda/envs/env_labels/share/proj' 

In [None]:
import pystac
from pystac import Link, Asset
from pystac.extensions.label import LabelExtension
from pystac.extensions.label import LabelType
from pystac.extensions.label import LabelClasses
from pystac.extensions.label import LabelStatistics
from pystac.extensions.version import ItemVersionExtension

In [None]:
# Set dataframe to None 
df = None

In [None]:
# Read GeoJSON file and extract point coordinates
def read_geojson_coordinates(geojson_file):
    with open(geojson_file, 'r') as file:
        geojson_data = json.load(file)
    #for f in geojson_data['features'][:10]: print(f)
    
    points = []
    luc = []
    for feature in geojson_data['features']:
        if feature['geometry']['type'] == 'Point':
            # Add lon and lat
            lon, lat, _ = feature['geometry']['coordinates']
            points.append((lon, lat))
            
            # Add classification 
            luc.append(feature['properties']['class'])
    return points, luc

In [None]:
# Function to transform unprojected coordinates to projected coordinates
def transform_coordinates(coordinates, epsg_s, epsg_t):
    source_crs = pyproj.CRS(f'EPSG:{epsg_s}') 
    target_crs = pyproj.CRS(f'EPSG:{epsg_t}')  
    transformer = pyproj.Transformer.from_crs(source_crs, target_crs, always_xy=True)
    transformed_coords = [transformer.transform(lon, lat) for lon, lat in coordinates]
    
    transformed_coords_int = [[int(tc[0]), int(tc[1])] for tc in transformed_coords]
    return transformed_coords_int

In [None]:
# Function to extract pixel values from a GeoTIFF at given coordinates
def extract_pixel_values(b_g, transformed_coords):
    gt = b_g.GetGeoTransform()
    b_rst = b_g.GetRasterBand(1)
    
    values = []

    for lon, lat in transformed_coords:
        px = int((lon - gt[0]) / gt[1])  # Convert longitude to pixel x
        py = int((lat - gt[3]) / gt[5])  # Convert latitude to pixel y

        value = b_rst.ReadAsArray(px, py, 1, 1)[0][0]
        values.append(value)
    
    # Empty raster 
    b_rst = None
    
    return values

## Read Label STAC Item

In [None]:
# Define name of the ML item
item_label_fname = "item-label-train.json"

In [None]:
item = pystac.read_file(item_label_fname)
display(item.properties)

## Load S2 scene

In [None]:
s2_href = [l.target for l in item.links if l.rel == 'source'][0]
print('href of the S2 scene:', s2_href)

In [None]:
# Read STAC Item
s2item = pystac.read_file(s2_href)
display(s2item.properties)

In [None]:
epsg_t = s2item.properties['proj:epsg']
print(f'- Target EPSG:{epsg_t}')
print(f'- Available bands: {list(s2item.assets.keys())}')

## Load geojson points

In [None]:
geojson_href = item.assets['labels'].href
print('href of the geojson file:', geojson_href)

In [None]:
# Open asset (geojson format) and read all the coordinates within 
coordinates, luc = read_geojson_coordinates(geojson_href)
coordinates[:5], luc[:5]

In [None]:
# Transofrm coordinates
epsg_s = '4326'
transformed_coords = transform_coordinates(coordinates, epsg_s, epsg_t)
transformed_coords[:5]

## Extract values of selected band(s) for each point in the geojson file


In [None]:
# Define dictionary of other Common Band Names, for those that are missing in the metadata
other_cbn = {
    'B05': 'rededge70', 
    'B06': 'rededge74', 
    'B07': 'rededge78', 
    'B8A': 'nir08', 
    'B09': 'nir09'
}

In [None]:
for band in list(s2item.assets.keys()): 
    b_metadata = s2item.assets[band].to_dict()
    if 'eo:bands' in b_metadata.keys() and len(b_metadata['eo:bands']) == 1:
        if 'common_name' in b_metadata['eo:bands'][0].keys(): 
            cbn = b_metadata['eo:bands'][0]['common_name']
        else: 
            # cbn does not exist in metadata - use dictionary of other_cbn
            cbn = other_cbn[b_metadata['eo:bands'][0]['name']]
    else: 
        print(f'{band} is not eo band, skipping.')
        continue

    if (df is not None) and (cbn in df.columns): 
        print(f'Band {cbn} exists already in the dataframe, skipping.')
        continue
    
    print('Band:', band)
    print(f'- Common Band Name: {cbn}')
    print(f'- Res: {b_metadata["gsd"]}m')
    print(f'- Center Wavelenght: {b_metadata["eo:bands"][0]["center_wavelength"]}')

    # Extract band
    b_href = s2item.assets[band].href
    print('- href:', b_href)

    # Get gdal object
    b_g = gdal.Open(b_href)
    
    # Extract pixel values
    pixel_values = extract_pixel_values(b_g, transformed_coords)
    # for (lon0, lat0), (lon1, lat1), value, lu in zip(coordinates[:5], transformed_coords[:5], pixel_values[:5], luc[:5]):
    #     print(f"Coords (Unprj): {np.round(lon0,3)}, {np.round(lat0,3)} - Coords (Prj): {lon1}, {lat1} - Pixel Value: {value} - LUC: {lu}")
    
    # Empty b_g
    b_g = None
    
    # Make or Append to Pandas dataframe
    data = {
        'long': [x[0] for x in transformed_coords], 
        'lat': [x[1] for x in transformed_coords], 
        'LUC': luc,
        cbn: pixel_values,
    }

    if df is None: 
        print('Creating Dataframe')
        # Create a DataFrame from the dictionary
        df = pd.DataFrame(data)
        df.index.name = 'Index'

    else: 
        print('Adding to existing Dataframe')

        # Create temp dataframe
        df2 = pd.DataFrame(data)
        df2.index.name = 'Index'

        # Assert the two dataframes have the same long and lat values
        assert df['long'].isin(df2['long']).value_counts().values[0] == len(pixel_values)
        assert df['lat'].isin(df2['lat']).value_counts().values[0] == len(pixel_values)

        # Merge temp dataframe with original dataframe, based on matching columns
        df = pd.merge(df, df2, on=['Index', 'long', 'lat', 'LUC'])   
        # Empty memory
        df2 = None

    display(df)
    print()

print('\n--- Complete Dataframe with all Sentinel-2 bands ---')
display(df)

## Add NDVI and NDWI bands

In [None]:
# Add NDVI
assert 'nir' in df.columns and 'red' in df.columns
df['ndvi'] = ((df['nir'].astype(int) - df['red'].astype(int)) / (df['nir'].astype(int) + df['red'].astype(int)) * 10000).astype(int)
df

In [None]:
# Add NDWI
# Formula for S2: NDWI = (NIR - MIR) / (NIR + MIR) using Sentinel-2 Band 8 (NIR) and Band 12 (MIR=SWIR for Sentinel2)
assert 'nir' in df.columns and 'swir16' in df.columns and 'swir22' in df.columns
df['ndwi1'] = ((df['nir'].astype(int) - df['swir16'].astype(int)) / (df['nir'].astype(int) + df['swir16'].astype(int)) * 10000).astype(int)
df['ndwi2'] = ((df['nir'].astype(int) - df['swir22'].astype(int)) / (df['nir'].astype(int) + df['swir22'].astype(int)) * 10000).astype(int)
df

## Add Water Label

In [None]:
# Add Water label from the LUC value (LUC=6 is water)
df.loc[df['LUC'] == 6, 'water'] = 1
df.loc[df['LUC'] != 6, 'water'] = 0
df

## Show some statistics

In [None]:
df.describe().T

## Export Dataframe to CSV

In [None]:
# Export dataframe 
df.to_csv('df_extractedpixels.csv')

In [None]:
print('END')

**Note**: If the ```epsg``` and ```resolution``` are not defined in all Items/Assets, they must be explicitly defined in the ```stackstac.stack()``` call. The ```resolution``` refers to the output resolution and must be set in the same unit as the ```epsg``` field ([stackstac documentation](https://stackstac.readthedocs.io/en/latest/api/main/stackstac.stack.html)).