For MPC authentication (if not on MPC jupyterhub), get your subscription key from [here](https://planetarycomputer.developer.azure-api.net/profile) and follow the instructions on [this link](https://planetarycomputer.microsoft.com/docs/concepts/sas/#:~:text=data%20catalog.-,planetary%2Dcomputer%20Python%20package,-The%20planetary%2Dcomputer). Then use the python `planetary_computer` package to sign each image request.

If you don't have `dask-geopandas` installed, run `pip install dask-geopandas`.

#### To Do:
- Figure out sizing of images
- Can we go older than 2013?

In [1]:
import warnings
import time
import os
import gc

from tqdm.notebook import tqdm

import torch
from torch.utils.data import DataLoader

import numpy as np
import pandas as pd
import geopandas as gpd
import dask_geopandas as dask_gpd
from dask.distributed import Client

import matplotlib.pyplot as plt

In [2]:
from custom.mpc_imagery import (
    sort_by_hilbert_distance, 
    filter_points_with_buffer,
    fetch_least_cloudy_stac_items, 
    CustomDataset
)
from custom.models import featurize, RCF

In [3]:
warnings.filterwarnings(action="ignore", category=RuntimeWarning)
warnings.filterwarnings(action="ignore", category=UserWarning)

In [4]:
RASTERIO_BEST_PRACTICES = dict(  # See https://github.com/pangeo-data/cog-best-practices
    CURL_CA_BUNDLE="/etc/ssl/certs/ca-certificates.crt",
    GDAL_DISABLE_READDIR_ON_OPEN="EMPTY_DIR",
    AWS_NO_SIGN_REQUEST="YES",
    GDAL_MAX_RAW_BLOCK_CACHE_SIZE="200000000",
    GDAL_SWATH_SIZE="200000000",
    VSI_CURL_CACHE_SIZE="200000000",
)
os.environ.update(RASTERIO_BEST_PRACTICES)

## Load point coordinates to fetch images for

In [5]:
data_dir = "/home/jovyan/ds_nudge_up/data/01_preprocessed/mosaiks_request_points"
points_df = pd.read_csv(f"{data_dir}/INDIA_SHRUG_request_points.csv")

points_gdf = gpd.GeoDataFrame(
    points_df, 
    geometry=gpd.points_from_xy(
        points_df['lon'], 
        points_df['lat']
    )
)

del points_df

In [6]:
# # Filer points at the edge of the shape to within a buffer
# states = gpd.read_file("/home/jovyan/ds_nudge_up/data/00_raw/SHRUG/geometries_shrug-v1.5.samosa-open-polygons-shp/state.shp")
# states['zero_column'] = 0
# country = states.dissolve(by='zero_column')
# BUFFER_DISTANCE = 0.005
# points_gdf = filter_points_with_buffer(points_gdf, country, BUFFER_DISTANCE)

In [7]:
points_gdf = sort_by_hilbert_distance(points_gdf)

In [8]:
points_gdf.shape[0]

96167

In [9]:
points_gdf = points_gdf.sample(1000, random_state=0)

In [10]:
# points_gdf.plot()

## Get the imagery around each point

Convert to dask geodataframe

In [11]:
NPARTITIONS = 250
points_dgdf = dask_gpd.from_geopandas(points_gdf, npartitions=NPARTITIONS, sort=False)

del points_gdf
gc.collect()

110

Get stac_item references to the least cloudy image that corresponds to each point

In [12]:
with Client(n_workers=16) as client:
    print(client.dashboard_link)

    # `meta` is the expected output format:
    # an empty df with correct column types
    meta = points_dgdf._meta
    meta = meta.assign(stac_item=pd.Series([], dtype="object"))

    points_gdf_with_stac = points_dgdf.map_partitions(
        fetch_least_cloudy_stac_items, 
        satellite="landsat-8-c2-l2",
        search_start="2013-01-01",
        search_end="2013-12-31",
        meta=meta)

    points_gdf_with_stac = points_gdf_with_stac.compute()

/user/amirali1376@gmail.com/proxy/8787/status


Key:       ('fetch_least_cloudy_stac_items-928d3ce5ba4910e6d8ea4ff3e931cda0', 193)
Function:  subgraph_callable-27f95ddb-62ce-4197-8949-0d173598
args:      (             lat        lon                   geometry          hd
92170  31.954878  78.234486  POINT (78.23449 31.95488)  1740509227
39907  22.254878  74.334486  POINT (74.33449 22.25488)  1172045905
54103  24.204878  85.384486  POINT (85.38449 24.20488)  2188416215
34309  21.354878  83.434486  POINT (83.43449 21.35488)  2151148514)
kwargs:    {}
Exception: "APIError('<html>\\r\\n<head><title>502 Bad Gateway</title></head>\\r\\n<body>\\r\\n<center><h1>502 Bad Gateway</h1></center>\\r\\n<hr><center>nginx</center>\\r\\n</body>\\r\\n</html>\\r\\n')"



KeyboardInterrupt: 

Filter out points with no image found

In [None]:
points_gdf_with_stac_clean = points_gdf_with_stac.dropna(subset=["stac_item"])
matched_stac_items = points_gdf_with_stac_clean.stac_item.tolist()
matched_points_list = points_gdf_with_stac_clean[["lon", "lat"]].to_numpy()

In [13]:
NUM_POINTS = len(matched_points_list)

Setup Dataset object

In [14]:
BUFFER_DISTANCE = 0.005
bands = [
    # "SR_B1", # Coastal/Aerosol Band (B1)
    "SR_B2",  # Blue Band (B2)
    "SR_B3",  # Green Band (B3)
    "SR_B4",  # Red Band (B4)
    "SR_B5",  # Near Infrared Band 0.8 (B5)
    "SR_B6",  # Short-wave Infrared Band 1.6 (B6)
    "SR_B7",  # Short-wave Infrared Band 2.2 (B7)
]
resolution = 30

dataset = CustomDataset(
    matched_points_list, 
    matched_stac_items, 
    buffer=BUFFER_DISTANCE,
    bands=bands,
    resolution=resolution
)

Setup PyTorch DataLoader

In [15]:
batch_size = 1

dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=os.cpu_count(),
    collate_fn=lambda x: x,
    pin_memory=False,
    persistent_workers=True,
)

Inspect images

In [16]:
# for images in dataloader:
#     for image in images:
#         # print(image)
#         array = np.array(image[0][:3])
#         reshaped_array = np.swapaxes(array, 0, 2)
#         plt.imshow(reshaped_array)
#         plt.show()

Clear memory

In [17]:
del points_gdf_with_stac
del points_gdf_with_stac_clean
del matched_stac_items
# del matched_points_list

## Define featurization model and apply to images

In [18]:
NUM_FEATURES = 1024

In [19]:
DEVICE = torch.device("cuda")
MODEL = RCF(NUM_FEATURES).eval().to(DEVICE)

### Apply featurization to images

In [20]:
min_image_edge = 6

In [15]:
x_all = np.zeros((NUM_POINTS, NUM_FEATURES), dtype=float)

i = 0
for images in tqdm(dataloader):
    for image in images:

        if image is not None:
            # A full image should be 36x36(?) pixels (i.e. ~1km^2 at a 30m/px spatial
            # resolution), however we can receive smaller images if an input point
            # happens to be at the edge of a scene. To deal with these we crudely 
            # drop all images where the spatial dimensions aren't both greater than 6 pixels.
            if image.shape[2] >= min_image_edge and image.shape[3] >= min_image_edge:
                mosaiks_features = featurize(image, MODEL, DEVICE)
                x_all[i] = mosaiks_features
            else:
                pass
        else:
            pass  # this happens if we do have not found an image for some point
        i += 1

In [None]:
x_all.shape

### Save to file

In [20]:
x_all_df = pd.DataFrame(x_all, columns=range(4))

In [21]:
x_all_df.insert(0, "lat", matched_points_list[:,0])
x_all_df.insert(1, "lon", matched_points_list[:,1])

Unnamed: 0,0,1,2,3
0,1,2,3,4
1,1,2,3,4


In [None]:
x_all_df.to_csv("/home/jovyan/ds_nudge_up/data/01_preprocessed/mosaiks_features/df_x_all_2013.csv")

In [None]:
# np.savetxt("/home/jovyan/ds_nudge_up/data/01_preprocessed/mosaiks_features/x_all_2013_latlons.csv", matched_points_list, delimiter=",")
# np.savetxt("/home/jovyan/ds_nudge_up/data/01_preprocessed/mosaiks_features/x_all_2013.csv", x_all, delimiter=",")