In [1]:
from argparse import ArgumentParser, Namespace
import h5py
from itertools import permutations
from pathlib import Path
from typing import cast, Optional, List, Tuple, Dict, Type, TypeVar, Sequence
from tqdm import tqdm
import sys

import numpy as np
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset
from shapely.geometry import Point
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score, accuracy_score
from cropharvest.datasets import CropHarvest, CropHarvestLabels, Task
from cropharvest.columns import NullableColumns, RequiredColumns
from cropharvest.config import FEATURES_DIR
from cropharvest.engineer import Engineer
from cropharvest.utils import load_normalizing_dict
from cropharvest.bands import BANDS, DYNAMIC_BANDS, STATIC_BANDS, REMOVED_BANDS

sys.path.append("..")

from src.models import STR2MODEL, STR2BASE, train_model


S2_BANDS = ['B2','B3','B4','B5','B6','B7','B8','B8A','B9','B11','B12','NDVI']

## LSTM model

### Recycle old model and train on Nigeria train split

In [2]:
parser = ArgumentParser()

parser.add_argument("--max_epochs", type=int, default=100)
parser.add_argument("--patience", type=int, default=10)
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--wandb", default=False, action="store_true")

model_args = STR2MODEL["land_cover"].add_model_specific_args(parser).parse_args(args=[])
model = STR2MODEL["land_cover"](model_args)

Found normalizing dict geowiki_normalizing_dict.h5
Loading normalizing dict geowiki_normalizing_dict.h5
Creating Geowiki train split
Creating Geowiki val split
Number of instances in Geowiki training set: 19808
Number of instances in Nigeria training set: 913
Total number of files used for training: 20721
Number of model parameters: 25473


In [3]:
model_args

Namespace(add_geowiki=True, add_nigeria=True, add_togo=False, alpha=10, batch_size=64, data_folder='/home/gajo/code/togo-crop-mask/notebooks/../data', geowiki_subset='world', gpus=0, hidden_vector_size=64, learning_rate=0.001, lstm_dropout=0.2, max_epochs=100, model_base='lstm', multi_headed=False, num_classification_layers=2, num_lstm_layers=1, patience=10, probability_threshold=0.5, remove_b1_b10=True, wandb=False, weighted_loss_fn=False)

In [4]:
new_model_args_dict = vars(model_args)

In [17]:
# SET MODIFICATIONS TO DEFAULT MODEL ARGUMENTS:
new_model_args_dict['add_geowiki'] = True
new_model_args_dict['add_nigeria'] = True
#new_model_args_dict['multi_headed'] = False
# new_model_args_dict['num_classification_layers'] = 1
new_model_args_dict['max_epochs'] = 50 # Just for dev
# new_model_args_dict['weighted_loss_fn'] = True
# new_model_args_dict['hidden_vector_size'] = 64

In [18]:
# Initialize model with new arguments
new_model_args = Namespace(**new_model_args_dict)
model = STR2MODEL["land_cover"](new_model_args)
model.hparams

Found normalizing dict geowiki_normalizing_dict.h5
Loading normalizing dict geowiki_normalizing_dict.h5
Creating Geowiki train split
Creating Geowiki val split
Number of instances in Geowiki training set: 19808
Number of instances in Nigeria training set: 913
Total number of files used for training: 20721
Number of model parameters: 25473


"add_geowiki":               True
"add_nigeria":               True
"add_togo":                  False
"alpha":                     10
"batch_size":                64
"data_folder":               /home/gajo/code/togo-crop-mask/notebooks/../data
"geowiki_subset":            world
"gpus":                      0
"hidden_vector_size":        64
"learning_rate":             0.001
"lstm_dropout":              0.2
"max_epochs":                50
"model_base":                lstm
"multi_headed":              False
"num_classification_layers": 2
"num_lstm_layers":           1
"patience":                  10
"probability_threshold":     0.5
"remove_b1_b10":             True
"wandb":                     False
"weighted_loss_fn":          False

In [19]:
new_model_args

Namespace(add_geowiki=True, add_nigeria=True, add_togo=False, alpha=10, batch_size=64, data_folder='/home/gajo/code/togo-crop-mask/notebooks/../data', geowiki_subset='world', gpus=0, hidden_vector_size=64, learning_rate=0.001, lstm_dropout=0.2, max_epochs=50, model_base='lstm', multi_headed=False, num_classification_layers=2, num_lstm_layers=1, patience=10, probability_threshold=0.5, remove_b1_b10=True, wandb=False, weighted_loss_fn=False)

In [20]:
trainer = train_model(model, new_model_args) 

GPU available: True, used: False
TPU available: False, using: 0 TPU cores

  | Name              | Type       | Params
-------------------------------------------------
0 | base              | LSTM       | 21 K  
1 | global_classifier | Sequential | 4 K   


Number of instances in Geowiki validation set: 4953
Number of instances in Nigeria validation set: 454
Total number of files used for validation: 5407


Validation sanity check: 0it [00:00, ?it/s]

confusion matrix: [[52  0]
 [76  0]]
Number of instances in Geowiki training set: 19808
Number of instances in Nigeria training set: 913
Total number of files used for training: 20721
Number of instances in Geowiki validation set: 4953
Number of instances in Nigeria validation set: 454
Total number of files used for validation: 5407


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

confusion matrix: [[1713  664]
 [ 964 2066]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1666  711]
 [ 850 2180]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1704  673]
 [ 838 2192]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1593  784]
 [ 711 2319]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1721  656]
 [ 820 2210]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1761  616]
 [ 830 2200]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1794  583]
 [ 896 2134]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1699  678]
 [ 770 2260]]


In [None]:
model.normalizing_dict

{'mean': array([-1.11145875e+01, -1.83695470e+01,  1.52960034e+03,  1.52106053e+03,
         1.57221665e+03,  1.79048354e+03,  2.54191155e+03,  2.96519233e+03,
         2.80609321e+03,  3.26681681e+03,  7.02900024e+02,  2.86986717e+03,
         1.90742782e+03,  2.99788003e+02,  3.93802458e-03,  3.24957571e+02,
         3.50859630e+00,  3.30125508e-01]),
 'std': array([3.88630623e+00, 4.88522927e+00, 7.66991639e+02, 7.60186317e+02,
        9.44870974e+02, 8.73766432e+02, 8.06759587e+02, 8.77939336e+02,
        8.19407925e+02, 9.05889813e+02, 5.56644420e+02, 1.11925709e+03,
        1.03326184e+03, 2.78615972e+00, 3.94202669e-03, 2.19158225e+02,
        4.01912168e+00, 1.28905131e-01])}

In [None]:
trainer.test()

Number of instances in Nigeria testing set: 455


Testing: 0it [00:00, ?it/s]

confusion matrix: [[237  35]
 [ 46 137]]
--------------------------------------------------------------------------------
TEST RESULTS
{'test_accuracy': 0.8219780219780219,
 'test_f1_score': 0.771830985915493,
 'test_loss': 0.39456379413604736,
 'test_precision_score': 0.7965116279069767,
 'test_recall_score': 0.7486338797814208,
 'test_roc_auc_score': 0.9066618450658953}
--------------------------------------------------------------------------------


{'test_loss': 0.39456379413604736,
 'test_roc_auc_score': 0.9066618450658953,
 'test_precision_score': 0.7965116279069767,
 'test_recall_score': 0.7486338797814208,
 'test_f1_score': 0.771830985915493,
 'test_accuracy': 0.8219780219780219}