# Create grid and use batch to request data

In [1]:
from collections import defaultdict

import geopandas as gpd
from oauthlib.oauth2 import BackendApplicationClient
from requests_oauthlib import OAuth2Session

from eogrow.utils.fs import LocalFile
from eolearn.core.utils.fs import get_filesystem
from sentinelhub import CRS, BBox, DataCollection, MimeType, SentinelHubRequest, SHConfig, UtmGridSplitter

In [2]:
import os
os.environ["AWS_PROFILE"] = "batch"

In [4]:
# in meters
BBOX_SIZE = 5000
RESOLUTION = 10.0

BUCKET_PATH = "s3://batch-slo-24/input-data/"
GRID_FILENAME = "grid_slovenia_border_UTM33N_500x500.gpkg"

In [None]:
aoi = gpd.read_file('cache/area_slovenia_border_UTM33N_BatchAreaManager_0.2_0.004.gpkg')
aoi.boundary.plot(color="red", linewidth=2)

In [38]:
splitter = UtmGridSplitter([aoi.to_crs("EPSG:4326").geometry.values[0]], CRS.WGS84, BBOX_SIZE)
bboxes = splitter.get_bbox_list() # buffer=(BBOX_OVERLAP / (BBOX_SIZE)))

In [39]:
splitter = UtmGridSplitter([aoi.geometry.values[0]], aoi.crs, BBOX_SIZE)
bboxes = splitter.get_bbox_list() # buffer=(BBOX_OVERLAP / (BBOX_SIZE)))

In [None]:
bboxes_dict = defaultdict(list)

for bbox in bboxes:
    bboxes_dict[bbox.crs.epsg].append(bbox)

for k, v in bboxes_dict.items():
    print(f"EPSG:{k} - {len(v)}")

In [44]:
grid_list = []
n_total = 0

for crs, bboxes in bboxes_dict.items():
    utm_grid = gpd.GeoDataFrame(geometry=[bbox.geometry for bbox in bboxes], crs=crs)
    utm_grid["id"] = n_total + utm_grid.index
    utm_grid["identifier"] = utm_grid['id'].apply(lambda x: f"tile_{x}")
    utm_grid["width"] = BBOX_SIZE // RESOLUTION
    utm_grid["height"] = utm_grid["width"]
    utm_grid["resolution"] = RESOLUTION

    grid_list.append(utm_grid)
    n_total += len(utm_grid)

In [None]:
for idx, grid in enumerate(grid_list):
    grid.to_file(GRID_FILENAME, layer=idx, driver="GPKG")

!aws s3 cp {GRID_FILENAME} {BUCKET_PATH}

In [5]:
# Load grid file
grid = gpd.read_file(GRID_FILENAME, driver="GPKG")

## Batch V2 request

In [6]:
config = SHConfig()
#config.instance_id = ''
#config.sh_client_id = ''
#config.sh_client_secret = ''

In [31]:
client = BackendApplicationClient(client_id=config.sh_client_id)
oauth = OAuth2Session(client=client)

# Get token for the session
token = oauth.fetch_token(
    token_url="https://services.sentinel-hub.com/auth/realms/main/protocol/openid-connect/token",
    client_secret=config.sh_client_secret,
    include_client_id=True,
)

In [8]:
#config.save("travniki")
config = SHConfig("travniki")

In [9]:
url_batch_v2 = "https://services.sentinel-hub.com/api/v2/batch/process"
data_collection = DataCollection.SENTINEL2_L2A

In [10]:
with open("config/signals/sentinel2_l2a/evalscript.js") as fp:
    evalscript = fp.read()

In [11]:
import calendar

def get_month_range(year, month):
  """
  Returns a tuple containing the start and end dates for the given month as strings.

  Args:
    year: The year (e.g., 2024).
    month: The month (1-12).

  Returns:
    A tuple of strings: (start_date, end_date), where the dates are in 'YYYY-MM-DD' format.
  """

  # Get the number of days in the month
  _, num_days = calendar.monthrange(year, month)

  start_date = f"{year}-{month:02d}-01"
  end_date = f"{year}-{month:02d}-{num_days:02d}"

  return (start_date, end_date)

In [21]:
month = 12

In [22]:
# data_bands = ["blue", "green", "red", "nir", "alpha"]
data_bands = [band.name for band in data_collection.bands]
#qa_bands = ["clear", "snow", "shadow", "haze_light", "haze_heavy", "cloud", "confidence", "udm", "dataMask"]
qa_bands = ["dataMask", "CLP", "CLM"] # how to get OUT_PROBA?

responses = [SentinelHubRequest.output_response(band, MimeType.TIFF) for band in data_bands + qa_bands]
responses.append(SentinelHubRequest.output_response("userdata", MimeType.JSON))

sh_req = SentinelHubRequest(
    evalscript=evalscript,
    input_data=[
        SentinelHubRequest.input_data(
            data_collection=data_collection,
            time_interval=(get_month_range(2024, month)), #f"2024-{month:02d}-01", f"2024-{month:02d}-05"), 
            #upsampling="BICUBIC",
            #downsampling="BICUBIC",
        )
    ],
    responses=responses,
    bbox=BBox(grid.unary_union.bounds, crs=grid.crs),
    #resolution=(RESOLUTION, RESOLUTION),
    #bbox=BBox(grid.iloc[0].geometry.bounds, crs=grid.crs),  # test
    #size=(200, 200),  # test
    resolution=(10, 10),
    config=config,
)

In [23]:
request_dict = sh_req.download_list[0].post_values

# resx/y are taken from the geopackage and should not be part of the processRequest
del request_dict["output"]["resx"]
del request_dict["output"]["resy"]
# processRequest.output must not specify width, height, nor resolution.
#del request_dict["output"]["width"]
#del request_dict["output"]["height"]

In [24]:
batch_payload = {
    "processRequest": request_dict,
    "input": {
        "type": "geopackage",
        "features": {
            "s3": {
                "url": f"s3://batch-slo-24/input-data/{GRID_FILENAME}",
                "iamRoleARN": "arn:aws:iam::621520595318:role/batch-role",
                "region": "eu-west-1"
            }
        },
    },
    "output": {
        "type": "raster",
        "delivery": {
            "s3": {
                "url": f"s3://batch-slo-24/tiffs/2024-{month:02d}/<tileName>/<outputId>.<format>",
                "iamRoleARN": "arn:aws:iam::621520595318:role/batch-role",
                "region": "eu-west-1"
            }
        },
    },
    "description": f"Travniki-2024-{month:02d} Basemaps",
}

In [25]:
# create the batch request
headers = {"Content-Type": "application/json"}
response = oauth.request("POST", url_batch_v2, headers=headers, json=batch_payload)
response.raise_for_status()

In [None]:
batch_request_id = response.json()["id"]
batch_request_id

In [None]:
# start the batch process
response = oauth.request("POST", f"{url_batch_v2}/{batch_request_id}/start")
response.status_code

In [None]:
# check the status of the batch request
response = oauth.request("GET", f"{url_batch_v2}/{batch_request_id}")
response.json()["status"]

In [None]:
response.text

## Check results

In [32]:
# month = 6

batch_requests = {
    1: 'c3a860c4-ffbb-49f0-8948-ab768629ffd2',
    5:'c0c03557-b220-42f3-ac05-deb21a36ad20', # may
    6:'a5c1928f-6fb4-4a59-a7dc-da9eee5d94fa', #b8b99012-fd61-433b-ad9b-a4c7bd34b035', # june 355cb8b5-8d4b-4c45-8766-91235af2959f
    #6: "24564414-63f8-425d-ac75-badb224fe4b5",  # june
    7: "7795443e-8f0b-4b24-834d-90afee24c54b",  # july
}

batch_request_id = batch_requests[month]

In [None]:
fs_shai = get_filesystem("s3://batch-slo-24/")
tiles_down = [tn for tn in fs_shai.listdir(f"tiffs/2024-{month:02d}") if "tile_" in tn]
len(tiles_down)

In [None]:
utm_zones = list(bboxes_dict.keys())

feat_manifest = []
with LocalFile(f"tiffs/2024-{month:02d}/featureManifest-{batch_request_id}.gpkg", mode="r", filesystem=fs_shai) as f:
    for utmz in utm_zones:
        feat_manifest.append(gpd.read_file(f.path, layer=f"feature_{utmz}"))

In [None]:
# sum all unique identifiers from all feature manifest layers
sum([len(set(fm.identifier.drop_duplicates().tolist())) for fm in feat_manifest])

In [45]:
# check if all identifiers across all layers are unique or if any are missing
assert len({idn for fm in feat_manifest for idn in fm.identifier.drop_duplicates().tolist()}) == n_total