In [1]:
from pystac_client import Client
from utils import (
    read_url,
    UserSettings,
    coords_to_image,
    process_row,
    get_asset_by_common_name,
    convert_coordinates,
)
from pystac import Asset, read_file
from urllib.parse import urlparse
import geopandas as gpd
import pandas as pd
import boto3, botocore
import io
import os
from pystac.stac_io import DefaultStacIO, StacIO
import rasterio
from shapely.wkt import loads
from pystac import Item
import pyproj
import numpy as np
from loguru import logger

# os.environ["GDAL_DATA"] = "/opt/conda/envs/env_label/share/gdal"
# os.environ["PROJ_LIB"] = "/opt/conda/envs/env_label/share/proj"

In [2]:
stac_endpoint = "https://stac-api-dev.terradue.com/"

headers = []

cat = Client.open(stac_endpoint, headers=headers, ignore_conformance=True)
cat



In [3]:
collections = ["ai-extensions-svv-dataset-labels"]

query = cat.search(collections=collections)

In [4]:
[item.get_assets()["labels"] for item in query.item_collection()]

[<Asset href=s3://argo-wfs/svv-dataset/S2B_10SFH_20230613_0_L2A/label-S2B_10SFH_20230613_0_L2A.geojson>,
 <Asset href=s3://argo-wfs/svv-dataset/S2B_10SFG_20230613_0_L2A/label-S2B_10SFG_20230613_0_L2A.geojson>,
 <Asset href=s3://argo-wfs/svv-dataset/S2A_10SFH_20230618_0_L2A/label-S2A_10SFH_20230618_0_L2A.geojson>,
 <Asset href=s3://argo-wfs/svv-dataset/S2A_11SKB_20230618_0_L2A/label-S2A_11SKB_20230618_0_L2A.geojson>,
 <Asset href=s3://argo-wfs/svv-dataset/S2A_10SGG_20230618_0_L2A/label-S2A_10SGG_20230618_0_L2A.geojson>,
 <Asset href=s3://argo-wfs/svv-dataset/S2A_10SFG_20230618_0_L2A/label-S2A_10SFG_20230618_0_L2A.geojson>]

In [5]:
settings = UserSettings("usersettings.json")

settings.set_s3_environment(
    query.item_collection()[0].get_assets()["labels"].get_absolute_href()
)

print(os.environ["AWS_ACCESS_KEY_ID"])

SCWKDBM3QX8BXWW3RB6E


In [6]:
StacIO.set_default(DefaultStacIO)

In [7]:
label_item = query.item_collection()[0]
label_item

In [8]:
label_item.get_assets()

{'labels': <Asset href=s3://argo-wfs/svv-dataset/S2B_10SFH_20230613_0_L2A/label-S2B_10SFH_20230613_0_L2A.geojson>}

In [9]:
# create normalized difference function
def nd(a, b):
    return (a - b) / (a + b)


def read_geojson(
    label_item: Item, user_settings="usersettings.json", asset_key="labels"
):
    settings = UserSettings(user_settings)

    settings.set_s3_environment(label_item.get_assets()[asset_key].get_absolute_href())
    session = botocore.session.Session()

    s3_client = session.create_client(
        service_name="s3",
        region_name=os.environ.get("AWS_REGION"),
        use_ssl=True,
        endpoint_url=os.environ.get("AWS_S3_ENDPOINT"),
        aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"),
        aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"),
    )

    parsed = urlparse(label_item.get_assets()[asset_key].get_absolute_href())

    bucket = parsed.netloc
    key = parsed.path[1:]

    obj = s3_client.get_object(Bucket=bucket, Key=key)

    return gpd.read_file(io.BytesIO(obj["Body"].read()))


def sample_data(label_item, source_item=None, common_bands=["red", "nir"]):
    gdf = read_geojson(label_item)

    if source_item is None:
        source_item = read_file(
            [link.target for link in label_item.get_links() if link.rel in ["source"]][
                0
            ]
        )

    dataset = {}
    for common_band in common_bands:
        logger.info(f"Reading {common_band} band")
        dataset[common_band] = rasterio.open(
            get_asset_by_common_name(source_item, common_band).get_absolute_href()
        )

    def convert_row(row, target_crs):
        "EPSG:4326"
        longitude = row.geometry.x
        latitude = row.geometry.y

        src_crs = "EPSG:4326"

        row["utm_x"], row["utm_y"] = convert_coordinates(
            src_crs, target_crs, longitude, latitude
        )

        return pd.Series(row)

    crs_info = dataset[common_bands[0]].crs
    target_crs = f"EPSG:{crs_info.to_epsg()}"
    gdf = gdf.apply(convert_row, target_crs=target_crs, axis=1)

    points_utm = [(x, y) for x, y in zip(gdf["utm_x"], gdf["utm_y"])]

    for common_band in common_bands:
        logger.info(f"Sampling {common_band} band")
        gdf[common_band] = [
            val[0] / 10000 for val in dataset[common_band].sample(points_utm, 1)
        ]

    if "red" in common_bands and "nir" in common_bands:
        gdf["ndvi"] = nd(gdf["nir"], gdf["red"])
    if "green" in common_bands and "nir" in common_bands:
        gdf["ndwi1"] = nd(gdf["green"], gdf["nir"])
    if "nir" in common_bands and "swir16" in common_bands:
        gdf["ndwi2"] = nd(gdf["nir"], gdf["swir16"])

    return gdf

In [10]:
tmp_gdfs = []

for label_item in query.item_collection():

    sampled_data = sample_data(label_item=label_item, common_bands=["coastal", "red", "green", "blue", "nir", "nir08", "nir09", "swir16", "swir22"])
    
    tmp_gdfs.append(sampled_data)

gdf = pd.concat(tmp_gdfs)


[32m2023-08-04 05:55:51.347[0m | [1mINFO    [0m | [36m__main__[0m:[36msample_data[0m:[36m45[0m - [1mReading coastal band[0m
[32m2023-08-04 05:55:53.840[0m | [1mINFO    [0m | [36m__main__[0m:[36msample_data[0m:[36m45[0m - [1mReading red band[0m
[32m2023-08-04 05:55:55.560[0m | [1mINFO    [0m | [36m__main__[0m:[36msample_data[0m:[36m45[0m - [1mReading green band[0m
[32m2023-08-04 05:55:57.270[0m | [1mINFO    [0m | [36m__main__[0m:[36msample_data[0m:[36m45[0m - [1mReading blue band[0m
[32m2023-08-04 05:55:59.067[0m | [1mINFO    [0m | [36m__main__[0m:[36msample_data[0m:[36m45[0m - [1mReading nir band[0m
[32m2023-08-04 05:56:00.814[0m | [1mINFO    [0m | [36m__main__[0m:[36msample_data[0m:[36m45[0m - [1mReading nir08 band[0m
[32m2023-08-04 05:56:02.525[0m | [1mINFO    [0m | [36m__main__[0m:[36msample_data[0m:[36m45[0m - [1mReading nir09 band[0m
[32m2023-08-04 05:56:04.242[0m | [1mINFO    [0m | [36m__main

In [13]:
gdf.describe()

Unnamed: 0,utm_x,utm_y,coastal,red,green,blue,nir,nir08,nir09,swir16,swir22,ndvi,ndwi1,ndwi2
count,1534.0,1534.0,1534.0,1534.0,1534.0,1534.0,1534.0,1534.0,1534.0,1534.0,1534.0,1534.0,1534.0,1534.0
mean,606369.687313,4187560.0,0.064559,0.109282,0.099238,0.077235,0.194373,0.206784,0.208543,0.193951,0.139377,0.19784,-0.219207,0.010196
std,154536.548895,53664.58,0.050475,0.06683,0.055794,0.055445,0.114833,0.121875,0.117879,0.119755,0.086418,0.335595,0.369345,0.186601
min,200009.978139,4090250.0,0.0002,0.0036,0.0088,0.0001,0.0001,0.0001,0.0001,0.003,0.0031,-0.973333,-0.845459,-0.987097
25%,615859.964959,4148850.0,0.0362,0.062425,0.0664,0.04505,0.0804,0.082325,0.090025,0.079525,0.064825,-0.068093,-0.5052,-0.109481
50%,655189.998843,4186600.0,0.0527,0.09655,0.0892,0.067,0.2165,0.2348,0.24045,0.19295,0.12945,0.197909,-0.34964,-0.00643
75%,689600.010583,4223425.0,0.07445,0.146,0.11455,0.091375,0.2734,0.294075,0.2904,0.296575,0.20465,0.431809,0.108895,0.0836
max,753550.02967,4299910.0,0.708,0.7424,0.7328,0.7076,0.7296,0.7354,0.6875,0.6179,0.491,0.926432,0.991111,0.859485


In [12]:
gdf.to_pickle('sprint-0-STAC-labels-to-dataframe.pkl')