# Crop mask and crop type inference 🌍
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nasaharvest/crop-mask/blob/master/notebooks/inference.ipynb)

**Author:** Ivan Zvonkov (izvonkov@umd.edu)

**Description:** This notebook provides all the code to create a crop mask or crop type map using NASA Harvest's Google Cloud architecture. Access to the Google Cloud project is required.

The notebook is in beta mode so issue reports and suggestions are welcome! 


In [None]:
!pip install rasterio>=1.2.6 geopandas==0.9.0 -q

In [None]:
!pip install cropharvest==0.3.0 -q

In [None]:
import ee
import google
import IPython
import ipywidgets as widgets
import os
import re
import requests

from cropharvest.eo import EarthEngineExporter
from cropharvest.countries import BBox
from collections import defaultdict
from datetime import date
from glob import glob
from google.colab import auth
from google.cloud import storage
from ipywidgets import Box
from pathlib import Path
from tqdm.notebook import tqdm

# Functions
Below are the functions that are used in the notebook, no changes are needed. Just run the cell!

In [None]:
#######################################################
# Status functions
#######################################################
def get_ee_task_amount(prefix=None):
    amount = 0
    task_list = ee.data.getTaskList()
    for t in tqdm(task_list):
      if t["state"] in ["READY", "RUNNING"]:
        if prefix and prefix in t["description"]:
          amount += 1
        else:
          amount += 1
    return amount

def get_gcs_file_dict_and_amount(bucket_name, prefix):
    blobs = client.list_blobs(bucket_name, prefix=prefix)
    files_dict = defaultdict(lambda: [])
    amount = 0
    for blob in tqdm(blobs, desc=f"From {bucket_name}"):
        p = Path(blob.name)
        files_dict[str(p.parent)].append(p.stem.replace("pred_", ""))
        amount += 1
    return files_dict, amount

def get_gcs_file_amount(bucket_name, prefix):
    return len(list(client.list_blobs(bucket_name, prefix=prefix)))

def get_status(model_name_version):
    print("------------------------------------------------------------------------------")
    print(model_name_version) 
    print("------------------------------------------------------------------------------")
    ee_task_amount = get_ee_task_amount(prefix=model_name_version.replace("/", "-"))
    tifs_amount = get_gcs_file_amount(tifs_bucket_name, prefix=model_name_version)
    predictions_amount = get_gcs_file_amount(preds_bucket_name, prefix=model_name_version)
    print(f"Earth Engine tasks: {ee_task_amount}")
    print(f"Data available: {tifs_amount}")
    print(f"Predictions: {predictions_amount}")
    return ee_task_amount, tifs_amount, predictions_amount

#######################################################
# Inference functions
#######################################################

def find_missing_predictions(model_name_version, verbose=False):
    print("Addressing missing files")
    tif_files, tif_amount = get_gcs_file_dict_and_amount(tifs_bucket_name, prefix=model_name_version)
    pred_files, pred_amount  = get_gcs_file_dict_and_amount(preds_bucket_name, prefix=model_name_version)
    missing = {}
    for full_k in tqdm(tif_files.keys(), desc="Missing files"):
        if full_k not in pred_files:
            diffs = tif_files[full_k]
        else:
            diffs = list(set(tif_files[full_k]) - set(pred_files[full_k]))
        if len(diffs) > 0:
            missing[full_k] = diffs

    batches_with_issues = len(missing.keys())
    if verbose:
        print("------------------------------------------------------------------------------")
        print(prefix) 
        print("------------------------------------------------------------------------------")
    if batches_with_issues > 0:
        print(f"\u2716 {batches_with_issues}/{len(tif_files.keys())} batches have a total {tif_amount - pred_amount} missing predictions")
        if verbose:
            for batch, files in missing.items():
                print("\t--------------------------------------------------")
                print(f"\t{Path(batch).stem}: {len(files)}")
                print("\t--------------------------------------------------")
                [print(f"\t{f}") for f in files]
    else:
        print(f"\u2714 all files in each batch match")
    return missing

def make_new_predictions(missing):
    bucket = client.bucket(tifs_bucket_name)
    for batch, files in tqdm(missing.items(), desc="Going through batches"):
        for file in tqdm(files, desc="Renaming files", leave=False):
            blob_name = f"{batch}/{file}.tif"
            blob = bucket.blob(blob_name)
            if blob.exists():
                new_blob_name = f"{batch}/{file}-retry1.tif"
                bucket.rename_blob(blob, new_blob_name)
            else:
                print(f"Could not find: {blob_name}")  

#######################################################
# Map making functions
#######################################################
def gdal_cmd(cmd_type: str, in_file: str, out_file: str, msg = None, print_cmd=False):
    if cmd_type == "gdalbuildvrt":
        cmd = f"gdalbuildvrt {out_file} {in_file}"
    elif cmd_type == "gdal_translate":
        cmd = f"gdal_translate -a_srs EPSG:4326 -of GTiff {in_file} {out_file}"
    else:
        raise NotImplementedError(f"{cmd_type} not implemented.")
    if msg:
        print(msg)
    if print_cmd:
        print(cmd)
    os.system(cmd)

def build_vrt(prefix):
    # Build vrts for each batch of predictions
    print("Building vrt for each batch")
    for d in tqdm(glob(f"{prefix}_preds/*/*/")):
        if "batch" not in d:
            continue

        match = re.search("batch_(.*?)/", d)
        if match:
            i = int(match.group(1))
        else:
            raise ValueError(f"Cannot parse i from {d}")
        vrt_file = Path(f"{prefix}_vrts/{i}.vrt")
        if not vrt_file.exists():
            gdal_cmd(cmd_type="gdalbuildvrt", in_file=f"{d}*", out_file=str(vrt_file))

    gdal_cmd(
        cmd_type="gdalbuildvrt",
        in_file=f"{prefix}_vrts/*.vrt",
        out_file=f"{prefix}_final.vrt",
        msg="Building full vrt",
    )


# 1. Setup
**Prerequisite**: Access to bsos-geog-harvest Google Cloud project.

In [None]:
box_layout = widgets.Layout(flex_flow='column')
model_type_widget = widgets.RadioButtons(
    options=["crop-mask", "crop-type"],
    style= {'description_width': 'initial'},
    value="crop-mask",
    description='Model type:',
    disabled=False
)

Box(children=[model_type_widget], layout=box_layout)

In [None]:
model_type = model_type_widget.value

# References to Google Cloud resources
gcloud_project_id = "bsos-geog-harvest1"
tifs_bucket_name = f"{model_type}-earthengine" 
preds_bucket_name = f"{model_type}-preds" 
preds_merged_bucket_name = f"{model_type}-preds-merged"
models_url = f"https://{model_type}-management-api-grxg7bzh2a-uc.a.run.app/models"

In [None]:
print("Logging into Google Cloud")
auth.authenticate_user()
print("Logging into Earth Engine")
SCOPES = ['https://www.googleapis.com/auth/cloud-platform', 'https://www.googleapis.com/auth/earthengine']
CREDENTIALS, project_id = google.auth.default(default_scopes=SCOPES)
ee.Initialize(CREDENTIALS, project=gcloud_project_id)
client = storage.Client(project=gcloud_project_id)

In [None]:
response = requests.get(models_url)
assert response.status_code == 200, f"Got {response.status_code}. Either the url is incorrect or gcloud is not authenticated."
available_models = [item["modelName"] for item in response.json()["models"]]
available_models

# 2. Inference configuration



In [None]:
model_picker = widgets.Dropdown(options=available_models, description="Model to use")

start_date_select = widgets.DatePicker(description='Start date', value=date(2020, 2, 1))
end_date_select = widgets.DatePicker(description='End date', value=date(2021, 2, 1))

map_identifier = widgets.Text(description='Map identifier')

Box(children=[
              model_picker, 
              start_date_select, 
              end_date_select,
              map_identifier
              ], layout=box_layout)

In [None]:
##################################################################
# START: Configuration (edit below code)
##################################################################
# Coordinates for map
lat = 
lon = 

# Small margin for demos
# margin 0.01 -> 1 min
# margin 0.02 -> 3 mins
# margin 0.03 -> 9 mins
# margin 0.05 -> 10 mins

margin = 0.02 

bbox = BBox(
    min_lon=lon-margin, 
    max_lon=lon+margin, 
    min_lat=lat-margin, 
    max_lat=lat+margin
)

##################################################################
# END: Configuration
##################################################################

start_date = start_date_select.value
end_date = end_date_select.value
model_name = model_picker.value
version = map_identifier.value
model_name_version = f"{model_name}/{version}"

# Verify configuration
assert version is not "", "Map identifier not set."

print(f"Preparing to do inference for this region: {bbox.url}")

if str(start_date.year) not in model_name:
  print(("-")*100)
  print(f"WARNING: Start year: {start_date.year} not in model name {model_name}, verify start and end date.")
  print(("-")*100)

# 3. Run inference

![inference](https://github.com/nasaharvest/crop-mask/blob/master/assets/inference.png?raw=true)



In [None]:
# Inference can take time so you may need to rerun this cell multiple times
ee_task_amount, tifs_amount, predictions_amount = get_status(model_name_version)
if ee_task_amount == 0:
    if tifs_amount == 0:
      print("Starting earth engine exports...")
      EarthEngineExporter(check_ee=False, check_gcp=False, dest_bucket=tifs_bucket_name).export_for_bbox(    
        bbox=bbox,
        bbox_name=model_name_version,
        start_date=start_date,
        end_date=end_date,
        metres_per_polygon=50000,
        file_dimensions=256
      )
      print("Waiting for some data to become available, wait a couple seconds and rerun this cell.")
    elif tifs_amount > predictions_amount:
        missing = find_missing_predictions(model_name_version)
        make_new_predictions(missing)
        print("Wait 5 seconds then rerun this cell.")
    else:
        print("Inference complete! Time to merge predictions into a map.")

if ee_task_amount > 0:
    print(f"Please wait for all {ee_task_amount} Earth Engine tasks to complete and rerun this cell."
    "\nView progress here: https://code.earthengine.google.com/tasks.")
    

# 4. Merge predictions into a map

<img src="https://github.com/nasaharvest/crop-mask/blob/master/assets/merging-predictions.png?raw=true" alt="merging-predictions" width="500"/>

In [None]:
if ee_task_amount > 0:
    print(f"Please wait for all {ee_task_amount} Earth Engine tasks to complete and rerun the above cell before moving on.")
else:
  prefix = f"{model_name}_{version}"
  Path(f"{prefix}_preds").mkdir(exist_ok=True)
  Path(f"{prefix}_vrts").mkdir(exist_ok=True)
  Path(f"{prefix}_tifs").mkdir(exist_ok=True)

In [None]:
print("Download predictions as nc files (may take several minutes)")
!gsutil -m cp -n -r gs://{preds_bucket_name}/{model_name_version}* {prefix}_preds

In [None]:
build_vrt(prefix)

In [None]:
# Translate vrt for all predictions into a tif file
!gdal_translate -a_srs EPSG:4326 -of GTiff {prefix}_final.vrt {prefix}_final.tif

# 5. Upload map to Earth Engine

In [None]:
dest = f"gs://{preds_merged_bucket_name}/{model_name_version}_{start_date}_{end_date}"

In [None]:
!gsutil cp {prefix}_final.tif {dest}

In [None]:
earthengine_user = input("Enter your earthengine username: ")
request_id = ee.data.newTaskId()[0]
params = {
    "name": f"projects/earthengine-legacy/assets/users/{earthengine_user}/{prefix}",
    'tilesets': [{'sources': [{'uris': [dest]}]}], 
    'start_time': f"{start_date}T00:00:00Z", 
    'end_time': f"{end_date}T00:00:00Z"
}
ee.data.startIngestion(request_id=request_id, params=params, allow_overwrite=True)
print("See map upload here: https://code.earthengine.google.com/tasks")

# 6. Visualize on GEE

Click **View asset** on the image just created here: https://code.earthengine.google.com/tasks


Then click **Import** and add the following to the script to view the map
```
var palettes = require('users/gena/packages:palettes');
var palette = palettes.cmocean.Speed[7]

Map.setCenter(lon, lat, 11); 
Map.addLayer(image.gt(0.5), {min: 0, max: 1.0, palette: palette.slice(0,-2)}, 'mask');
Map.addLayer(image, {min: 0, max: 1.0, palette: palette}, 'My crop map');
```