In [1]:
from argparse import ArgumentParser, Namespace
import h5py
from itertools import permutations
from pathlib import Path
from typing import cast, Optional, List, Tuple, Dict, Type, TypeVar, Sequence
from tqdm import tqdm
import sys

import numpy as np
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset
from shapely.geometry import Point
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score, accuracy_score
from cropharvest.datasets import CropHarvest, CropHarvestLabels, Task
from cropharvest.columns import NullableColumns, RequiredColumns
from cropharvest.config import FEATURES_DIR
from cropharvest.engineer import Engineer
from cropharvest.utils import load_normalizing_dict
from cropharvest.bands import BANDS, DYNAMIC_BANDS, STATIC_BANDS, REMOVED_BANDS

sys.path.append("..")

from src.models import STR2MODEL, STR2BASE, train_model


S2_BANDS = ['B2','B3','B4','B5','B6','B7','B8','B8A','B9','B11','B12','NDVI']

## LSTM model

### Recycle old model and train on Nigeria train split

In [2]:
parser = ArgumentParser()

parser.add_argument("--max_epochs", type=int, default=100)
parser.add_argument("--patience", type=int, default=10)
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--wandb", default=False, action="store_true")

model_args = STR2MODEL["land_cover"].add_model_specific_args(parser).parse_args(args=[])
model = STR2MODEL["land_cover"](model_args)

Found normalizing dict geowiki_normalizing_dict.h5
Loading normalizing dict geowiki_normalizing_dict.h5
Creating Geowiki train split
Creating Geowiki val split
Number of instances in Geowiki training set: 19808
Number of instances in Nigeria training set: 913
Total number of files used for training: 20721
Number of model parameters: 25473


In [3]:
model_args

Namespace(add_geowiki=True, add_nigeria=True, add_togo=False, alpha=10, batch_size=64, data_folder='/home/gajo/code/togo-crop-mask/notebooks/../data', geowiki_subset='world', gpus=0, hidden_vector_size=64, learning_rate=0.001, rnn_dropout=0.2, max_epochs=100, model_base='lstm', multi_headed=False, num_classification_layers=2, num_rnn_layers=1, patience=10, probability_threshold=0.5, remove_b1_b10=True, wandb=False, weighted_loss_fn=False)

In [4]:
new_model_args_dict = vars(model_args)

In [17]:
# SET MODIFICATIONS TO DEFAULT MODEL ARGUMENTS:
new_model_args_dict['add_geowiki'] = True
new_model_args_dict['add_nigeria'] = True
#new_model_args_dict['multi_headed'] = False
# new_model_args_dict['num_classification_layers'] = 1
new_model_args_dict['max_epochs'] = 50 # Just for dev
# new_model_args_dict['weighted_loss_fn'] = True
# new_model_args_dict['hidden_vector_size'] = 64

In [18]:
# Initialize model with new arguments
new_model_args = Namespace(**new_model_args_dict)
model = STR2MODEL["land_cover"](new_model_args)
model.hparams

Found normalizing dict geowiki_normalizing_dict.h5
Loading normalizing dict geowiki_normalizing_dict.h5
Creating Geowiki train split
Creating Geowiki val split
Number of instances in Geowiki training set: 19808
Number of instances in Nigeria training set: 913
Total number of files used for training: 20721
Number of model parameters: 25473


"add_geowiki":               True
"add_nigeria":               True
"add_togo":                  False
"alpha":                     10
"batch_size":                64
"data_folder":               /home/gajo/code/togo-crop-mask/notebooks/../data
"geowiki_subset":            world
"gpus":                      0
"hidden_vector_size":        64
"learning_rate":             0.001
"rnn_dropout":              0.2
"max_epochs":                50
"model_base":                lstm
"multi_headed":              False
"num_classification_layers": 2
"num_rnn_layers":           1
"patience":                  10
"probability_threshold":     0.5
"remove_b1_b10":             True
"wandb":                     False
"weighted_loss_fn":          False

In [19]:
new_model_args

Namespace(add_geowiki=True, add_nigeria=True, add_togo=False, alpha=10, batch_size=64, data_folder='/home/gajo/code/togo-crop-mask/notebooks/../data', geowiki_subset='world', gpus=0, hidden_vector_size=64, learning_rate=0.001, rnn_dropout=0.2, max_epochs=50, model_base='lstm', multi_headed=False, num_classification_layers=2, num_rnn_layers=1, patience=10, probability_threshold=0.5, remove_b1_b10=True, wandb=False, weighted_loss_fn=False)

In [20]:
trainer = train_model(model, new_model_args) 

GPU available: True, used: False
TPU available: False, using: 0 TPU cores

  | Name              | Type       | Params
-------------------------------------------------
0 | base              | LSTM       | 21 K  
1 | global_classifier | Sequential | 4 K   


Number of instances in Geowiki validation set: 4953
Number of instances in Nigeria validation set: 454
Total number of files used for validation: 5407


Validation sanity check: 0it [00:00, ?it/s]

confusion matrix: [[52  0]
 [76  0]]
Number of instances in Geowiki training set: 19808
Number of instances in Nigeria training set: 913
Total number of files used for training: 20721
Number of instances in Geowiki validation set: 4953
Number of instances in Nigeria validation set: 454
Total number of files used for validation: 5407


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

confusion matrix: [[1713  664]
 [ 964 2066]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1666  711]
 [ 850 2180]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1704  673]
 [ 838 2192]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1593  784]
 [ 711 2319]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1721  656]
 [ 820 2210]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1761  616]
 [ 830 2200]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1794  583]
 [ 896 2134]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1699  678]
 [ 770 2260]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1686  691]
 [ 702 2328]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1679  698]
 [ 703 2327]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1878  499]
 [ 927 2103]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1864  513]
 [ 925 2105]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1756  621]
 [ 754 2276]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1790  587]
 [ 803 2227]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1720  657]
 [ 725 2305]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1798  579]
 [ 836 2194]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1734  643]
 [ 727 2303]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1872  505]
 [ 896 2134]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1838  539]
 [ 839 2191]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1773  604]
 [ 765 2265]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1790  587]
 [ 789 2241]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1782  595]
 [ 754 2276]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1890  487]
 [ 863 2167]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1776  601]
 [ 761 2269]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1916  461]
 [ 949 2081]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1789  588]
 [ 757 2273]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1883  494]
 [ 898 2132]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1910  467]
 [ 923 2107]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1862  515]
 [ 860 2170]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1790  587]
 [ 755 2275]]


Validating: 0it [00:00, ?it/s]

confusion matrix: [[1739  638]
 [ 721 2309]]


Validating: 0it [00:00, ?it/s]

Epoch 00032: early stopping triggered.


confusion matrix: [[1832  545]
 [ 818 2212]]


In [21]:
model.normalizing_dict

{'mean': array([-1.18093416e+01, -1.86999297e+01,  2.00553716e+03,  1.88579588e+03,
         1.95091810e+03,  2.17022502e+03,  2.81705663e+03,  3.14242626e+03,
         3.01727620e+03,  3.36851746e+03,  1.16118482e+03,  2.44198249e+03,
         1.68557889e+03,  2.89972494e+02,  2.79428964e-03,  6.09329951e+02,
         6.14648560e+00,  2.94451736e-01]),
 'std': array([4.51362275e+00, 5.33090248e+00, 1.79541564e+03, 1.67771422e+03,
        1.87926254e+03, 1.82168139e+03, 1.68772691e+03, 1.69442362e+03,
        1.61445807e+03, 1.66997012e+03, 1.08937314e+03, 1.19782549e+03,
        1.02885346e+03, 1.17090673e+01, 3.41730081e-03, 7.39463455e+02,
        7.78857065e+00, 1.74758663e-01])}

In [22]:
trainer.test()

Number of instances in Nigeria testing set: 455


Testing: 0it [00:00, ?it/s]

confusion matrix: [[211  61]
 [ 38 145]]
--------------------------------------------------------------------------------
TEST RESULTS
{'test_accuracy': 0.7824175824175824,
 'test_f1_score': 0.7455012853470437,
 'test_loss': 0.5160723924636841,
 'test_precision_score': 0.7038834951456311,
 'test_recall_score': 0.7923497267759563,
 'test_roc_auc_score': 0.8682397139183543}
--------------------------------------------------------------------------------


{'test_loss': 0.5160723924636841,
 'test_roc_auc_score': 0.8682397139183543,
 'test_precision_score': 0.7038834951456311,
 'test_recall_score': 0.7923497267759563,
 'test_f1_score': 0.7455012853470437,
 'test_accuracy': 0.7824175824175824}