In [1]:
from copy import copy
from pathlib import Path
from functools import partial
import multiprocessing
import json

import pandas as pd
import numpy as np
from scipy.special import expit as sigmoid
import matplotlib.pyplot as plt

from tqdm import tqdm

import pystac_client
import planetary_computer
from geopy import distance
from odc import stac
import rioxarray

In [2]:
def create_date_range(date: str, days_before: int, days_after: int):
    date_format = '%Y-%m-%d'
    date_before = (
        pd.to_datetime(date) - pd.Timedelta(f'{days_before}D')
    )
    date_after = (
        pd.to_datetime(date) + pd.Timedelta(f'{days_after}D')
    )
    return f"{date_before.strftime(date_format)}/{date_after.strftime(date_format)}"

In [3]:
def get_bounding_box(center: tuple, dist: int = 2560):
    distance_search = distance.distance(meters=dist // 2)

    min_lat = distance_search.destination(center, bearing=180)[0]
    min_long = distance_search.destination(center, bearing=270)[1]
    max_lat = distance_search.destination(center, bearing=0)[0]
    max_long = distance_search.destination(center, bearing=90)[1]

    return [min_long, min_lat, max_long, max_lat]

In [4]:
def get_elevation_href(catalog, search_bbox, *args, **kwargs):
    iterator = catalog.search(
        collections=["cop-dem-glo-30"],
        bbox=search_bbox
    ).items()
    return next(iterator).assets["data"].href

In [5]:
def float_timedelta(max_date, datetime):
    # max_date = date_range.split('/')[-1]
    timedelta = pd.to_datetime(max_date, utc=True) - pd.to_datetime(datetime)
    return timedelta / pd.to_timedelta(1, unit='D')

def scoring_fun(ccov, ftd, hour, alpha=1.2, beta=1.5, gamma=2.0, sigma=50):
    return (
        np.expm1(ccov ** alpha) / 1.75
    ) * np.clip(
        ftd, beta, None
    ) + gamma * sigmoid(
        sigma * (
            np.cos(np.pi * (hour - 1.5) / 12) - 0.6
        )
    )

def select_best_item(items, date_range):
    if not items:
        raise ValueError("Not enough items!")

    best_item = items[0]
    best_score = np.inf

    max_date = date_range.split('/')[-1]
    for item in items:
        hour = pd.to_datetime(item.datetime).hour
        ftd = float_timedelta(max_date, item.datetime)
        ccov = item.properties["eo:cloud_cover"] / 100
        score = scoring_fun(ccov, ftd, hour)
        if score < best_score:
            best_score = score
            best_item = item

    return best_item

In [6]:
def get_landsat_data(catalog, search_bbox, date_range, *args, **kwargs):
    items = catalog.search(
        collections=["landsat-c2-l2"],
        bbox=search_bbox,
        datetime=date_range,
        filter={
                "op": "gte",
                "args": [{"property": "platform"}, 'landsat-8']
        },
    ).get_all_items()

    selected_item = select_best_item(items, date_range)

    bands_of_interest_landsat = ['red', 'green', 'blue', 'nir08', 'qa_aerosol', 'swir16']
    data = stac.stac_load(
        [selected_item],
        bands=bands_of_interest_landsat,
        bbox=search_bbox
    ).isel(time=0)

    img = data[['red', 'green', 'blue']].to_array().to_numpy().astype(float).transpose(1, 2, 0) / np.iinfo(np.uint16).max
    qaa = data[['qa_aerosol']].to_array().to_numpy()

    clouds = qaa[0] & ((1 << 7) | (1 << 6))

    red = data["red"].astype("float")
    blue = data["blue"].astype("float")
    nir = data["nir08"].astype("float")
    swir = data["swir16"].astype("float")

    ndvi = (nir - red) / (nir + red)
    ndwi = (nir - swir) / (nir + swir)
    evi = (nir - red) / (nir + 6 * red - 7.5 * blue + 1)

    platform = selected_item.properties['platform']
    cloud_cover = selected_item.properties['eo:cloud_cover']
    sun_azimuth = selected_item.properties['view:sun_azimuth']
    sun_elevation = selected_item.properties['view:sun_elevation']
    cap_ts = pd.to_datetime(selected_item.datetime)
    cap_date = cap_ts.strftime('%Y-%m-%d')
    cap_time = cap_ts.strftime('%H:%M')

    # 7 channels: R, G, B, C, V, W, E
    capture = np.dstack([img, clouds, ndvi, ndwi, evi])
    metadata = {
        'l_platform': platform,
        'l_cloud_cover': cloud_cover,
        'l_sun_azimuth': sun_azimuth,
        'l_sun_elevation': sun_elevation,
        'l_cap_date': cap_date,
        'l_cap_time': cap_time,
    }

    return capture, metadata

In [7]:
def get_sentinel_data(catalog, search_bbox, date_range, *args, **kwargs):
    items = catalog.search(
        collections=["sentinel-2-l2a"],
        bbox=search_bbox,
        datetime=date_range,
    ).get_all_items()

    selected_item = select_best_item(items, date_range)

    data = stac.stac_load(
        [selected_item],
        bands=['B02', 'B03', 'B04', 'B08', 'SCL', 'B11'],
        bbox=search_bbox
    ).isel(time=0)

    img = data[['B04', 'B03', 'B02']].to_array().to_numpy().astype(float).transpose(1, 2, 0) / np.iinfo(np.uint16).max

    scl = data['SCL']

    red = data["B04"].astype("float")
    blue = data["B02"].astype("float")
    nir = data["B08"].astype("float")
    swir = data["B11"].astype("float")

    ndvi = (nir - red) / (nir + red)
    ndwi = (nir - swir) / (nir + swir)
    evi = (nir - red) / (nir + 6 * red - 7.5 * blue + 1)


    platform = selected_item.properties['platform']
    cloud_cover = selected_item.properties['eo:cloud_cover']
    sun_azimuth = selected_item.properties['s2:mean_solar_azimuth']
    sun_elevation = 90 - selected_item.properties['s2:mean_solar_zenith']
    cap_ts = pd.to_datetime(selected_item.datetime)
    cap_date = cap_ts.strftime('%Y-%m-%d')
    cap_time = cap_ts.strftime('%H:%M')

    capture = np.dstack([img, scl, ndvi, ndwi, evi])
    metadata = {
        's_platform': platform,
        's_cloud_cover': cloud_cover,
        's_sun_azimuth': sun_azimuth,
        's_sun_elevation': sun_elevation,
        's_cap_date': cap_date,
        's_cap_time': cap_time,
    }

    return capture, metadata

In [8]:
data_path = Path("../data")

df = pd.read_csv(data_path / "metadata.csv")
catalog = pystac_client.Client.open(
    "https://planetarycomputer.microsoft.com/api/stac/v1",
    modifier=planetary_computer.sign_inplace,
)

In [9]:
df["date_range_-14+0"] = df["date"].apply(partial(create_date_range, days_before=14, days_after=0))
df["bbox"] = df[["latitude", "longitude"]].apply(get_bounding_box, axis=1)

In [10]:
get_l = partial(get_landsat_data, catalog=catalog)
get_s = partial(get_sentinel_data, catalog=catalog)

save_path = data_path / "more_arrays"
save_path.mkdir(parents=True, exist_ok=True)

save_path_meta = data_path / "more_metadata"
save_path_meta.mkdir(parents=True, exist_ok=True)

def download_single_row(row):
    meta_l_fail = {
        'l_platform': None,
        'l_cloud_cover': None,
        'l_sun_azimuth': None,
        'l_sun_elevation': None,
        'l_cap_date': None,
        'l_cap_time': None,
    }
    meta_s_fail = {
        's_platform': None,
        's_cloud_cover': None,
        's_sun_azimuth': None,
        's_sun_elevation': None,
        's_cap_date': None,
        's_cap_time': None,
    }

    date_range = row['date_range_-14+0']
    search_bbox = row['bbox']
    try:
        landsat_name = save_path / (row['uid'] + '_landsat.npz')
        if landsat_name.exists():
            raise FileExistsError()
        cap_l, meta_l = get_l(search_bbox=search_bbox, date_range=date_range)
        np.savez_compressed(landsat_name, caption=cap_l)
    except:
        meta_l = copy(meta_l_fail)
    try:
        sentinel_name = save_path / (row['uid'] + '_sentinel.npz')
        if sentinel_name.exists():
            raise FileExistsError()
        cap_s, meta_s = get_s(search_bbox=search_bbox, date_range=date_range)
        np.savez_compressed(sentinel_name, caption=cap_s)
    except:
        meta_s = copy(meta_s_fail)

    metadata_fname = save_path_meta / (row['uid'] + '_metadata.json')

    if not metadata_fname.exists():
        metadata = {'uid': row['uid'], **meta_l, **meta_s}
        json.dump(metadata, open(metadata_fname, 'w'))
        return metadata

In [None]:
with multiprocessing.Pool(6) as pool:
    results = list(
        tqdm(
            pool.imap(
                download_single_row,
                map(
                    lambda x: x[1],
                    df.iterrows()
                )
            ),
            total=len(df)
        )
    )

  1%|▍                                                                            | 122/23570 [01:36<4:52:21,  1.34it/s]