In [1]:
import os
from pprint import pprint
from PIL import Image
import numpy as np
import rasterio
from rio_tiler.io import STACReader
from rio_tiler.io.rasterio import Reader
from rasterio.plot import reshape_as_image
from pystac_client import Client
from pystac import Item
from dataclasses import dataclass

In [219]:
ELEMENT84_BASE_URL = 'https://earth-search.aws.element84.com/v1/'
E84_COLLECTION = 'sentinel-2-c1-l2a'
BBOX = {
    'test': [-118.6106,33.6712,-116.6877,34.9933],
    'uganda': [27.5968,-4.3307,36.8378,3.2935],
    'sadat': [30.317347,30.326325,30.575375,30.473935],
    'tocopilla': [-70.29301,-22.158244,-70.068124,-22.001797],
    'topeka': [-95.794928,39.008084,-95.614374,39.132846],
    'charleston, wv': [-83.2226,37.677,-80.6133,39.1619],
    'AE_C001_0002': [54.988546510801605, 24.907741125789755, 55.00809926263708, 24.927582519228114],
}

In [220]:
def update_assets_prefer_s3(item: Item):
    for asset in item.assets.values():
        match asset.extra_fields:
            case {'alternate': {'s3': {'href': uri}}}:
                asset.href = uri

def get_stac_first_item(base_url: str, collection: str, bbox: list[float]):
    client = Client.open(base_url)
    results = client.search(
        method='GET',
        bbox=bbox,
        datetime='2024-10-01T00:00:00Z/2024-10-07T00:00:00Z',
        collections=[collection],
        limit=5,
    )
    item = next(results.items())
    update_assets_prefer_s3(item)
    return item

def get_rescale_range(stats, keys: list[str]):
    rstats = stats[keys[0]]
    gstats = stats[keys[1]]
    bstats = stats[keys[2]]
    return (
        (rstats.min, rstats.max),
        (gstats.min, gstats.max),
        (bstats.min, bstats.max),
    )

In [221]:
item = get_stac_first_item(ELEMENT84_BASE_URL, E84_COLLECTION, BBOX['test'])
stac = STACReader(None, item=item.to_dict())
cloud_cover = stac.item.properties.get('eo:cloud_cover', 0)

center = [
    (stac.bounds[0] + stac.bounds[2]) / 2,
    (stac.bounds[1] + stac.bounds[3]) / 2,
]
dims = [
    abs(stac.bounds[2] - stac.bounds[0]),
    abs(stac.bounds[3] - stac.bounds[1]),
]
shrink_pc = 0.5 # shrink percent
bounds = [
    center[0] - shrink_pc * dims[0]/2,
    center[1] - shrink_pc * dims[1]/2,
    center[0] + shrink_pc * dims[0]/2,
    center[1] + shrink_pc * dims[1]/2,
]

stats = stac.merged_statistics(assets=['red', 'green', 'blue', 'visual'])

# RGB concat

In [222]:
rgb_im = stac.part(bounds, assets=['red', 'green', 'blue'])

## plain concat

In [223]:
rgb_im_rescaled = rgb_im.from_array(rgb_im.data)

In [224]:
rgb_im_stats = rgb_im.statistics()

In [225]:
rgb_im_rescaled.rescale(get_rescale_range(stats, ['red_b1', 'green_b1', 'blue_b1']))
Image.fromarray(rgb_im_rescaled.data_as_image()).save('./_rgbconcat.png')

## corrected

In [226]:
reflectance = rgb_im.data / 10_000
img = reflectance * 1.1

def brighten(band, fact=1):
    return band * fact

def gamma(band):
    gamma = 0.8
    return np.power(band, 1/gamma)

minreflect = 0
maxreflect = 0.4
rayscatt = [0.013, 0.024, 0.041] # pulled from a sentinelhub example
brightfact = 1

# rasters store in [band,x,y] format
for bi in range(3):
    band = img[bi,:,:]
    band = gamma(brighten(band, brightfact))
    band -= rayscatt[bi]
    img[bi,:,:] = band

minval, maxval = 0, maxreflect*brightfact
img = ((img-minval) / (maxval-minval)).clip(0, 1) * 255

rgb_corr_im = rgb_im.from_array(img.astype(np.uint8))

In [227]:
#rgb_corr_im.rescale(get_rescale_range(rgb_corr_im.statistics(), ['b1', 'b2', 'b3']))
Image.fromarray(rgb_corr_im.data_as_image()).save('./_rgbgamma.png')

# Visual
Only works for the S2 data. L8 does not supply true color assets.

In [None]:
visual_im = stac.part(bounds, assets=['visual'])

In [None]:
visual_im.data.shape

In [None]:
Image.fromarray(visual_im.data_as_image()).save('./_visual.png')