In [None]:
import numpy as np
from fsspec.implementations.http import HTTPFileSystem
from dea_tools.spatial import xr_vectorize
from dea_tools.spatial import xr_rasterize
from samgeo import SamGeo
import leafmap

from emit_tools import emit_xarray
from utils import get_rgb_dataset, get_earthdata_token, gamma_adjust

In [None]:
# See README.md for instructions on how to get an Earthdata token
token = get_earthdata_token()

In [None]:
%%time
# Loading data can take around 3-4 minutes on a 100 Mbps connection

# Refer to the README.md for instructions on how to find granule IDs
granule = "EMIT_L2A_RFL_001_20230316T045211_2307503_006" # Canberra

s3_url = "s3://lp-prod-protected/EMITL2ARFL.001/" + granule + "/" + granule + ".nc"
http_url = s3_url.replace("s3://", "https://data.lpdaac.earthdatacloud.nasa.gov/")

fs = HTTPFileSystem(headers={
    "Authorization": f"bearer {token}"
})
ds = emit_xarray(fs.open(http_url))
ds

In [None]:
import xarray as xr
import rioxarray as rxa

bands = {
    "red": 650,
    "green": 560,
    "blue": 470,
}

dataset = xr.Dataset()

for band, wavelength in bands.items():
    data_array = xr.DataArray(
        gamma_adjust(ds, wavelength, 0.4, replace_nans=False) * 255,
        dims=('latitude', 'longitude'),
        coords={'longitude': ds.longitude, 'latitude': ds.latitude}
    ).astype(np.uint8)
    dataset[band] = data_array

# Rename coordinates
dataset = dataset.rename({'longitude': 'x', 'latitude': 'y'})
# see the structure
print(dataset)

In [None]:
image = "scene_rgb.tif"
dataset.rio.to_raster("scene_rgb.tif")

In [None]:
# Show the saved image on a map
m = leafmap.Map(center=[-42.9, 147.3], zoom=15)
m.add_basemap("SATELLITE")
m.add_raster(image, layer_name="Image")
m

In [None]:
# When we're happy, delete the raw data to save memory
del ds

In [None]:
# Set up our model
sam_kwargs = {
    "points_per_side": 32,
    "pred_iou_thresh": 0.86,
    "stability_score_thresh": 0.92,
    "crop_n_layers": 1,
    "crop_n_points_downscale_factor": 2,
    "min_mask_region_area": 100,
}

sam = SamGeo(
    model_type="vit_h",
    checkpoint="sam_vit_h_4b8939.pth",
    sam_kwargs=sam_kwargs,
)

In [None]:
sam.generate(image, "scene_rgb_mask.tif")

In [None]:
sam.show_masks()