# Model training 🏋
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nasaharvest/openmapflow/blob/main/openmapflow/notebooks/train.ipynb)

Example notebook for training models.

# 1. Load project

If you don't already have one, obtain a Github Personal Access Token using the steps [here](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token). Save this token somewhere private.

In [None]:
try:
    from google.colab import auth
    IN_COLAB = True
except:
    IN_COLAB = False
    
if IN_COLAB:
    github_url = input("Github HTTPS URL: ")
    email = input("Github email: ")
    username = input("Github username: ")

    !git config --global user.email $username
    !git config --global user.name $email

    from getpass import getpass
    token = getpass('Github Personal Access Token:')

    !git clone {github_url.replace("https://", f"https://{username}:{token}@")}

    # Temporarily install from Github
    !pip install git+https://ivanzvonkov:$token@github.com/nasaharvest/openmapflow.git -q
    !pip install pyyaml==5.4.1 -q
else:
    print("Running notebook outside Google Colab. Assuming in local repository.")
!pip install cmocean torch wandb tsai earthengine-api google-auth -q

In [None]:
from pathlib import Path
openmapflow_yaml_path = input("Path to openmapflow.yaml: ")
%cd {Path(openmapflow_yaml_path).parent}

In [None]:
!dvc pull -q
!tar -xzf $(openmapflow datapath COMPRESSED_FEATURES) -C data

In [None]:
MODEL_NAME = input("Model name: ")

In [None]:
!python train.py --model_name $MODEL_NAME

In [None]:
!python evaluate.py --model_name $MODEL_NAME

In [None]:
!dvc commit -q && dvc push -q

In [None]:
# Push changes to github
!git checkout -b"$MODEL_NAME"
!git add .
!git commit -m "$MODEL_NAME"
!git push --set-upstream origin "$MODEL_NAME"

Create a Pull Request so the model can be merged into the main branch.

# 5. [OPTIONAL] Create small map with model

<img src="https://storage.googleapis.com/harvest-public-assets/openmapflow/basic_inference.png" width="80%"/>

### 5.1 Setup

In [None]:
bbox_name = "Togo_2019_demo"
bbox = BBox(min_lat=6.31, max_lat=6.34, min_lon=1.70, max_lon=1.74)
start_date= date(2019, 2, 1)
end_date= date(2020,2,1)
prefix = f"{bbox_name}_{start_date}_{end_date}"
print(bbox.url)

temp_dir = tempfile.gettempdir()

colab_gee_gcloud_login(GCLOUD_PROJECT_ID, google)

### 5.2 Download earth observation data for entire region (bbox)

In [None]:
client = storage.Client()
cloud_tif_list_iterator = client.list_blobs(BucketNames.LABELED_TIFS, prefix=prefix)
cloud_tif_list = [
    blob.name
    for blob in tqdm(cloud_tif_list_iterator, desc="Loading tifs already on Google Cloud")
]

if len(cloud_tif_list) == 0:
  EarthEngineExporter(check_ee=False, check_gcp=False, dest_bucket=BucketNames.LABELED_TIFS).export_for_bbox(    
    bbox=bbox,
    bbox_name=bbox_name,
    start_date=date(2019, 2, 1),
    end_date=date(2020,2,1),
    metres_per_polygon=50000,
    file_dimensions=256
  )
  print("Earth observation data is being exported, progress: https://code.earthengine.google.com/tasks")
else:
  bucket = storage.Client().bucket(BucketNames.LABELED_TIFS)
  local_tif_paths = []
  for gs_path in tqdm(cloud_tif_list, desc="Downloading tifs"):
    local_path = Path(f"{temp_dir}/{gs_path.replace('/', '_')}")
    if not local_path.exists():
      bucket.blob(gs_path).download_to_filename(local_path)
    local_tif_paths.append(local_path)

### 5.3 Make predictions for each pixel in the earth observation data

In [None]:
inference = Inference(model=model, normalizing_dict=None, device=device, batch_size=batch_size)
local_pred_paths = []
for local_tif_path in tqdm(local_tif_paths, desc="Making predictions"):
  local_pred_path = Path(f"{temp_dir}/pred_{local_tif_path.stem}.nc")
  inference.run(
      local_path=local_tif_path, 
      start_date=start_date, 
      dest_path=local_pred_path
  )
  local_pred_paths.append(local_pred_path)

### 5.4 Merge pixel predictions into single map

<img src="https://storage.googleapis.com/harvest-public-assets/openmapflow/merging_predictions.png" width="60%"/>

In [None]:
def merge_tifs(full_prefix):
  vrt_in_file = f"{full_prefix}*"
  vrt_out_file = f"{full_prefix}.vrt"
  merged_file = f"{full_prefix}.tif"
  !gdalbuildvrt {vrt_out_file} {vrt_in_file}
  !gdal_translate -a_srs EPSG:4326 -of GTiff {vrt_out_file} {merged_file}
  return merged_file

merged_eo_file = merge_tifs(full_prefix=f"{temp_dir}/{prefix}")
merged_pred_file = merge_tifs(full_prefix=f"{temp_dir}/pred_{prefix}")

### 5.5 Visualize earth observation data and predictions map

In [None]:
def normalize(array):
    array_min, array_max = array.min(), array.max()*0.6
    return ((array - array_min)/(array_max - array_min))

month = 2
rgb_indexes = [DYNAMIC_BANDS.index(b) for b in ["B4", "B3", "B2"]]
eo_data = rio.open(merged_eo_file)
colors = [eo_data.read(i + month*len(DYNAMIC_BANDS)) for i in rgb_indexes]
normalized_colors = [normalize(c) for c in colors]
rgb = np.dstack(normalized_colors)
plt.figure(figsize=(10,10))
plt.title("Earth Observation data for one month")
plt.axis('off')
plt.imshow(rgb);

In [None]:
predictions_map = rio.open(merged_pred_file)
if "maize" in PROJECT:
  cmap = cmocean.cm.solar
elif "crop" in PROJECT:
  cmap = cmocean.cm.speed
else:
  cmap = cmocean.cm.thermal

plt.figure(figsize=(10,10))
plt.imshow(predictions_map.read(1), cmap=cmap)
plt.title(f"Map Preview: {PROJECT}")
plt.colorbar(fraction=0.03, pad=0.04)
plt.axis("off");