In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# Import the folium library.
import folium

# Define a method for displaying Earth Engine image tiles to a folium map.
def add_ee_layer(self, ee_image_object, vis_params, name, show=True, opacity=1, min_zoom=0):
    map_id_dict = ee.Image(ee_image_object).getMapId(vis_params)
    folium.raster_layers.TileLayer(
        tiles=map_id_dict['tile_fetcher'].url_format,
        attr='Map Data &copy; <a href="https://earthengine.google.com/">Google Earth Engine</a>',
        name=name,
        show=show,
        opacity=opacity,
        min_zoom=min_zoom,
        overlay=True,
        control=True
        ).add_to(self)

# Add the Earth Engine layer method to folium.
folium.Map.add_ee_layer = add_ee_layer

# USE CASE
# map = folium.Map(location=[lat, lon], zoom_start=14)
# map.add_ee_layer(least_cloudy_image, {}, 'Normalized image')
# display(map)

# OR
# cropped_xarray.wx.rgb(bands=["SR_B4", "SR_B3", "SR_B2"], stretch=0.85, col_wrap=4)

In [None]:
import ee
import wxee

ee.Authenticate()

ee.Initialize()

## Install the mosaiks package

In [None]:
# Locally
# !pip install -e .. --upgrade

In [None]:
# From github
# 🚨 Make sure you update github token in the secrets file 🚨 
# import src.mosaiks.utils as utl
# mosaiks_package_link = utl.get_mosaiks_package_link
# !pip install {mosaiks_package_link} --upgrade

## Import packages

In [None]:
import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)

In [None]:
import sys
import os
import warnings

sys.path += ["../"]
warnings.filterwarnings("ignore")

In [None]:
import mosaiks.utils as utl
from mosaiks.featurize import *
from mosaiks.dask_run import *

# Setup Rasterio

In [None]:
rasterio_config = utl.load_yaml_config("rasterioc_config.yaml")
os.environ.update(rasterio_config)

# Setup Dask Cluster and Client

In [None]:
client = get_local_dask_client()

# Load params

In [None]:
featurization_config = utl.load_yaml_config("featurisation.yaml")
satellite_config = utl.load_yaml_config("satellite_config.yaml")
satellite_config = satellite_config[
    featurization_config["satellite_search_params"]["satellite_name"]
]

# Load point coords

In [None]:
request_points_gdf = utl.load_df_w_latlons_to_gdf(dataset_name=featurization_config["coord_set_name"])

In [None]:
points_gdf = request_points_gdf.sample(5, random_state=0) #.sample(200, random_state=0) # Select random 200 points (for testing)

## GEE Proof-of-Concept

In [None]:
# Set the location and buffer size
lat, lon  = points_gdf.iloc[1, :2]
buffer = 1200
search_start = '2013-04-01'
search_end = '2014-03-31'

point = ee.Geometry.Point(lon, lat)
crop = point.buffer(buffer).bounds()

collection = (
    ee.ImageCollection('LANDSAT/LC08/C02/T1_L2')
    .filterBounds(crop)
    .filterDate(search_start, search_end)
    .sort('CLOUD_COVER')
)
least_cloudy_image = collection.first()

xarray = least_cloudy_image.wx.to_xarray(region=crop, scale=30)
bands = satellite_config["bands"]
xarray = xarray[bands].to_array()
final_xarray = xarray.transpose("time", "variable", "y", "x").squeeze()

In [None]:
import torch
from mosaiks.featurize.stacs import minmax_normalize_image

image = final_xarray.values
torch_image = torch.from_numpy(image).float()
torch_image = minmax_normalize_image(torch_image)

In [None]:
from mosaiks.featurize.nn_forward_pass import featurize
model = RCF(
    featurization_config["model"]["num_features"],
    featurization_config["model"]["kernel_size"],
    len(satellite_config["bands"]),
)
features = featurize(torch_image, model, "cpu")
features[0]

## Setup Dask pipeline

In [None]:
points_dgdf = get_dask_gdf(points_gdf, featurization_config["dask"]["chunksize"])

In [None]:
if featurization_config["imagery_source"] == "GEE":
    partitions = points_dgdf.to_delayed()
elif featurization_config["imagery_source"] == "MPC":
    # need to fetch STACs for each point for MPC
    points_gdf_with_stac = fetch_image_refs(
        points_dgdf, featurization_config["satellite_search_params"]
    )
    partitions = points_gdf_with_stac.to_delayed()


In [None]:
model = RCF(
    featurization_config["model"]["num_features"],
    featurization_config["model"]["kernel_size"],
    len(satellite_config["bands"]),
)

# Run in parallel

## Trial run

8 simultaneous partitions seems to be about how many we can do in parallel on a local cluster. We may be able to do more on a Gateway Cluster once that is working.

TODO - CHANGE TO THIS SCHEME: There are also better schemes. For example, kick off another partitions whenever one finishes. That might be a better use of resources.

In [None]:
%%time
client.restart()
df = run_single_partition(
    partition=partitions[0], 
    satellite_config=satellite_config, 
    featurization_config=featurization_config, 
    model=model, 
    client=client
)

In [None]:
print("Average feature value:", df.mean().mean())
df.iloc[1].hist()
_ = client.restart()

Test multi-partition run

In [None]:
n_per_run = featurization_config["dask"]["n_per_run"]

failed_ids = run_partitions(
    partitions=partitions[:n_per_run],
    satellite_config=satellite_config,
    featurization_config=featurization_config,
    model=model,
    client=client,
    mosaiks_folder_path="test", #places files into "playground/test/"
    partition_ids=None,
)

In [None]:
failed_ids

In [None]:
utl.load_dataframe("./test/df_features_000.parquet.gzip")

## Full run

### Setup saving location

In [None]:
mosaiks_folder_path = utl.make_features_path_from_dict(featurization_config, featurization_config["coord_set_name"])

### Create features and save checkpoints to file

In [None]:
failed_partition_ids = run_partitions(
    partitions=partitions,
    satellite_config=satellite_config,
    featurization_config=featurization_config,
    model=model,
    client=client,
    mosaiks_folder_path=mosaiks_folder_path
)

## Re-run failed partitions

Use this to just run partitions that failed

In [None]:
%%time

# subset to partitions that failed
failed_partitions = [partitions[i] for i in failed_partition_ids]

failed_partition_ids_1 = run_partitions(
    partitions=failed_partitions,
    partition_ids=failed_partition_ids,
    satellite_config=satellite_config,
    featurization_config=featurization_config,
    model=model,
    client=client,
    mosaiks_folder_path=mosaiks_folder_path,
)

failed_partition_ids_1

# Load checkpoint files and combine

In [None]:
checkpoint_filenames = utl.get_filtered_filenames(
        folder_path=mosaiks_folder_path, prefix="df_"
    )

In [None]:
combined_df = utl.load_and_combine_dataframes(
    folder_path=mosaiks_folder_path, filenames=checkpoint_filenames
)
combined_df = combined_df.join(points_gdf[["Lat", "Lon", "shrid"]])
logging.info("Dataset size in memory (MB):", combined_df.memory_usage().sum() / 1000000)

In [None]:
%%time
combined_filename = "features.parquet.gzip"
combined_filepath = mosaiks_folder_path / combined_filename
utl.save_dataframe(
    df=combined_df, file_path=combined_filepath
)