# Model training 🏋
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nasaharvest/openmapflow/blob/main/crop-mask-example/notebooks/train.ipynb)

<img src="https://storage.googleapis.com/harvest-public-assets/openmapflow/train_model.png" width=80%/>

# 1. Setup

If you don't already have one, obtain a Github Personal Access Token using the steps [here](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token). Save this token somewhere private.

In [None]:
try:
    from google.colab import auth
    IN_COLAB = True
except:
    IN_COLAB = False
    
if IN_COLAB:
    email = input("Github email: ")
    username = input("Github username: ")

    !git config --global user.email $username
    !git config --global user.name $email

    from getpass import getpass
    token = getpass('Github Personal Access Token:')

    # TODO: Generate below two lines from config
    !git clone https://$username:$token@github.com/nasaharvest/openmapflow.git
    !cd openmapflow && pip install -r requirements.txt -q
    %cd openmapflow/crop-mask-example
else:
    print("Running notebook outside Google Colab. Assuming in local repository.")
    !cd ../.. && pip install -r requirements.txt -q
    !pip install earthengine-api google-auth -q
    %cd ..

In [None]:
!pip install cmocean torch wandb tsai -q

In [None]:
from cropharvest.bands import DYNAMIC_BANDS
from cropharvest.eo import EarthEngineExporter
from cropharvest.inference import Inference
from cropharvest.countries import BBox
from google.cloud import storage
from datetime import date
from pathlib import Path
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader
from sklearn.metrics import (
    roc_auc_score,
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    confusion_matrix, 
    ConfusionMatrixDisplay
)

import cmocean
import numpy as np
import matplotlib.pyplot as plt
import rasterio as rio
import tempfile
import torch
import wandb
import warnings
import yaml
import sys
sys.path.append("..")

from openmapflow.config import RELATIVE_PATHS, FULL_PATHS, PROJECT_ROOT
from openmapflow.pytorch_dataset import PyTorchDataset
from openmapflow.config import SUBSET

from datasets import datasets



warnings.simplefilter("ignore", UserWarning) # TorchScript throws excessive warnings

# 2. Download latest data

In [None]:
for path_key in tqdm(["models", "processed", "compressed_features"]):
    !dvc pull {RELATIVE_PATHS[path_key]} -q

!tar -xzf {RELATIVE_PATHS["compressed_features"]} -C data

In [None]:
# Currently available models
sorted([p.stem for p in FULL_PATHS["models"].glob('*.pt')])

In [None]:
# Available datasets for training and evaluation
!cat data/datasets.txt

# 3. Train model

### 3.1 Import model
Any PyTorch based model that can take sequence data as input will work here.
Example uses a PyTorch model from [tsai](https://github.com/timeseriesAI/tsai)

In [None]:
from tsai.models.TransformerModel import TransformerModel

### 3.2 Setup training parameters

In [None]:
# ------------ Dataloaders -------------------------------------
batch_size = 64
df = datasets[0].load_labels()
split_dfs = {
    "training": df[df[SUBSET] == "training"],
    "validation": df[df[SUBSET] == "validation"],
    "testing": df[df[SUBSET] == "testing"]
}
data = {split: PyTorchDataset(df=df, start_month="February", subset=split) for split, df in split_dfs.items()}
data_loaders = {}
batch_amount = {}
for k,d in data.items():
  data_loaders[k] = DataLoader(d, batch_size=batch_size, shuffle=(k=="training")) 
  batch_amount[k] = 1 + len(d) // batch_size

num_timesteps, num_bands = data["training"][0][0].shape

# ------------ Model -----------------------------------------
class Model(torch.nn.Module):
  def __init__(self):
    super().__init__()
    self.model = TransformerModel(c_in=num_bands, c_out=1)

  def forward(self, x):
    x = self.model(x.transpose(2,1)).squeeze(dim=1)
    x = torch.sigmoid(x)
    return x

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = Model().to(device)

# ------------ Optimizer -------------------------------------
lr = 0.0001
params_to_update = model.parameters()
optimizer = torch.optim.SGD(params_to_update, lr=lr, momentum=0.9)
criterion = torch.nn.BCELoss()

### 3.3 Training loop
Inspired by [PyTorch tutorial](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html)

In [None]:
# Train
#%%wandb
model_name = input("Model name: ")
num_epochs = 100
config={
  "model_name": model_name,
  "model": model.__class__,
  "batch_size": batch_size,
  "num_epochs": num_epochs,
  "lr": lr,
  "optimizer": optimizer.__class__.__name__,
  "loss": criterion.__class__.__name__,
}
run = wandb.init(project=PROJECT_ROOT.name, config=config)

lowest_validation_loss = None

for epoch in tqdm(range(num_epochs), total=num_epochs):  

    # ------------------------ Training ----------------------------------------
    total_train_loss = 0.0
    model.train()
    for x in tqdm(data_loaders["training"], total=batch_amount["training"], desc="Train", leave=False):
      inputs, labels = x[0].to(device), x[1].to(device)

      # zero the parameter gradients
      optimizer.zero_grad()

      # Get model outputs and calculate loss
      outputs = model(inputs)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()

      total_train_loss += (loss.item() * len(inputs))

    # ------------------------ Validation --------------------------------------
    total_val_loss = 0.0
    y_true = []
    y_score = []
    y_pred = []
    model.eval() 
    with torch.no_grad():
      for x in tqdm(data_loaders["validation"], total=batch_amount["validation"], desc="Validate", leave=False):
        inputs, labels = x[0].to(device), x[1].to(device)

        # Get model outputs and calculate loss
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        total_val_loss += (loss.item() * len(inputs))

        y_true += labels.tolist()
        y_score += outputs.tolist()
        y_pred += (outputs > 0.5).long().tolist()
    

    # ------------------------ Metrics + Logging -------------------------------
    train_loss = total_train_loss / len(data["training"])
    val_loss = total_val_loss / len(data["validation"])
    cm = confusion_matrix(y_true, y_pred)
    ConfusionMatrixDisplay(cm, display_labels=["Negative", "Positive"]).plot()
    to_log = {
      "train_loss": train_loss, 
      "val_loss":   val_loss, 
      "epoch":      epoch,
      "accuracy":   accuracy_score(y_true, y_pred),
      "f1":         f1_score(y_true, y_pred),
      "precision":  precision_score(y_true, y_pred),
      "recall":     recall_score(y_true, y_pred),   
      "roc_auc":    roc_auc_score(y_true, y_score),
      "confusion_matrix": wandb.Image(plt)
    }
    wandb.log(to_log)
    plt.close("all")

    # ------------------------ Model saving --------------------------
    if lowest_validation_loss is None or val_loss < lowest_validation_loss:
      lowest_validation_loss = val_loss
      sm = torch.jit.script(model)
      model_path = FULL_PATHS["models"] / f"{model_name}.pt"
      if model_path.exists():
          model_path.unlink()
      sm.save(str(model_path))

run.finish()

In [None]:
# Newly available models
sorted([p.stem for p in FULL_PATHS["models"].glob('*.pt')])

### 3.4 Record test metrics

In [None]:
model_pt = torch.jit.load(str(FULL_PATHS["models"] / f"{model_name}.pt"))
model_pt.eval()

y_true = []
y_score = []
y_pred = []
model.eval() 
with torch.no_grad():
  for x in tqdm(data_loaders["testing"], total=batch_amount["testing"], desc="Testing", leave=False):
    inputs, labels = x[0].to(device), x[1].to(device)

    # Get model outputs and calculate loss
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    total_val_loss += (loss.item() * len(inputs))

    y_true += labels.tolist()
    y_score += outputs.tolist()
    y_pred += (outputs > 0.5).long().tolist()

metrics = {
  "accuracy":   accuracy_score(y_true, y_pred),
  "f1":         f1_score(y_true, y_pred),
  "precision":  precision_score(y_true, y_pred),
  "recall":     recall_score(y_true, y_pred),   
  "roc_auc":    roc_auc_score(y_true, y_score),
}
metrics = {k: round(float(v), 4) for k,v in metrics.items()}

all_metrics = {}
if FULL_PATHS["metrics"].exists():
  with FULL_PATHS["metrics"].open() as f:
      all_metrics = yaml.safe_load(f)

all_metrics[model_name] = {"params": run.url, "test_metrics": metrics}

with open(FULL_PATHS["metrics"], 'w') as f:
    yaml.dump(all_metrics, f)

# 4. Pushing the model to the repository

In [None]:
!dvc commit {RELATIVE_PATHS["models"]} -q -f
!dvc push -q

In [None]:
# Push changes to github
!git checkout -b'{model_name}'
!git add .
!git commit -m 'Trained new: {model_name}'
!git push --set-upstream origin "{model_name}"

Create a Pull Request so the model can be merged into the main branch.

# 5. [OPTIONAL] Create small map with model

### 5.1 Setup

In [None]:
bbox_name = "Togo_2019_demo"
bbox = BBox(min_lat=6.31, max_lat=6.34, min_lon=1.70, max_lon=1.74)
dest_bucket = "crop-mask-tifs2"
start_date= date(2019, 2, 1)
end_date= date(2020,2,1)
prefix = f"{bbox_name}_{start_date}_{end_date}"
print(bbox.url)

temp_dir = tempfile.gettempdir()
auth.authenticate_user()

### 5.2 Download earth observation data for entire region (bbox)

In [None]:
client = storage.Client()
cloud_tif_list_iterator = client.list_blobs(dest_bucket, prefix=prefix)
cloud_tif_list = [
    blob.name
    for blob in tqdm(cloud_tif_list_iterator, desc="Loading tifs already on Google Cloud")
]

if len(cloud_tif_list) == 0:
  !earthengine authenticate
  EarthEngineExporter(check_ee=False, check_gcp=False, dest_bucket=dest_bucket).export_for_bbox(    
    bbox=bbox,
    bbox_name=bbox_name,
    start_date=date(2019, 2, 1),
    end_date=date(2020,2,1),
    metres_per_polygon=50000,
    file_dimensions=256
  )
  print("Earth observation data is being exported, progress: https://code.earthengine.google.com/tasks")
else:
  bucket = storage.Client().bucket(dest_bucket)
  local_tif_paths = []
  for gs_path in tqdm(cloud_tif_list, desc="Downloading tifs"):
    local_path = Path(f"{temp_dir}/{gs_path.replace('/', '_')}")
    if not local_path.exists():
      bucket.blob(gs_path).download_to_filename(local_path)
    local_tif_paths.append(local_path)

### 5.3 Make predictions for each pixel in the earth observation data

<img src="https://storage.googleapis.com/harvest-public-assets/openmapflow/basic_inference.png" width="80%"/>

In [None]:
inference = Inference(model=model, normalizing_dict=None, device=device, batch_size=batch_size)
local_pred_paths = []
for local_tif_path in tqdm(local_tif_paths, desc="Making predictions"):
  local_pred_path = Path(f"{temp_dir}/pred_{local_tif_path.stem}.nc")
  inference.run(
      local_path=local_tif_path, 
      start_date=start_date, 
      dest_path=local_pred_path
  )
  local_pred_paths.append(local_pred_path)

### 5.4 Merge pixel predictions into single map

<img src="https://storage.googleapis.com/harvest-public-assets/openmapflow/merging_predictions.png" width="60%"/>

In [None]:
def merge_tifs(full_prefix):
  vrt_in_file = f"{full_prefix}*"
  vrt_out_file = f"{full_prefix}.vrt"
  merged_file = f"{full_prefix}.tif"
  !gdalbuildvrt {vrt_out_file} {vrt_in_file}
  !gdal_translate -a_srs EPSG:4326 -of GTiff {vrt_out_file} {merged_file}
  return merged_file

merged_eo_file = merge_tifs(full_prefix=f"{temp_dir}/{prefix}")
merged_pred_file = merge_tifs(full_prefix=f"{temp_dir}/pred_{prefix}")

### 5.5 Visualize earth observation data and predictions map

In [None]:
def normalize(array):
    array_min, array_max = array.min(), array.max()*0.6
    return ((array - array_min)/(array_max - array_min))

month = 2
rgb_indexes = [DYNAMIC_BANDS.index(b) for b in ["B4", "B3", "B2"]]
colors = [eo_map.read(i + month*len(DYNAMIC_BANDS)) for i in rgb_indexes]
normalized_colors = [normalize(c) for c in colors]
rgb = np.dstack(normalized_colors)
plt.title("Earth Observation data for one month")
plt.axis('off')
plt.imshow(rgb);

In [None]:
predictions_map = rio.open(merged_pred_file)
plt.title("Model predicted map")
plt.axis('off')
rio.plot.show(predictions_map, cmap=cmocean.cm.speed);