In [2]:
import sys
from argparse import ArgumentParser, Namespace

sys.path.append("..")

from src.models import STR2MODEL, train_model

## 1. Initialize model 

In [3]:
parser = ArgumentParser()
parser.add_argument("--max_epochs", type=int, default=1000)
parser.add_argument("--patience", type=int, default=10)

_StoreAction(option_strings=['--patience'], dest='patience', nargs=None, const=None, default=10, type=<class 'int'>, choices=None, help=None, metavar=None)

**Slight modification below:** adding args=[] makes it work for Jupyter notebook https://stackoverflow.com/questions/30656777/how-to-call-module-written-with-argparse-in-ipython-notebook

In [4]:
model_args = STR2MODEL["land_cover"].add_model_specific_args(parser).parse_args(args=[])

In [5]:
model = STR2MODEL["land_cover"](model_args)
model.hparams

Number of geowiki instances in training set: 27947


Namespace(add_geowiki=True, add_togo=True, alpha=10, batch_size=64, data_folder='/home/gajo/code/togo-crop-mask/notebooks/../data', hidden_vector_size=64, learning_rate=0.001, lstm_dropout=0.2, max_epochs=1000, model_base='lstm', multi_headed=True, num_classification_layers=2, num_lstm_layers=1, patience=10, probability_threshold=0.5, remove_b1_b10=True)

### New models args

In [6]:
new_model_args_dict = vars(model_args)

In [7]:
new_model_args_dict

{'max_epochs': 1000,
 'patience': 10,
 'data_folder': '/home/gajo/code/togo-crop-mask/notebooks/../data',
 'model_base': 'lstm',
 'hidden_vector_size': 64,
 'learning_rate': 0.001,
 'batch_size': 64,
 'probability_threshold': 0.5,
 'num_classification_layers': 2,
 'alpha': 10,
 'add_togo': True,
 'add_geowiki': True,
 'remove_b1_b10': True,
 'multi_headed': True,
 'num_lstm_layers': 1,
 'lstm_dropout': 0.2}

In [36]:
# SET MODIFICATIONS TO DEFAULT MODEL ARGUMENTS:

new_model_args_dict['add_togo'] = False
new_model_args_dict['multi_headed'] = False
new_model_args_dict['num_classification_layers'] = 1
new_model_args_dict['max_epochs'] = 1 # Just for dev
new_model_args_dict['gpus'] = 0
new_model_args_dict['remove_b1_b10'] = False

In [37]:
# Initialize model with new arguments
new_model_args = Namespace(**new_model_args_dict)
model = STR2MODEL["land_cover"](new_model_args)
model.hparams

Number of geowiki instances in training set: 27947


Namespace(add_geowiki=True, add_togo=False, alpha=10, batch_size=64, data_folder='/home/gajo/code/togo-crop-mask/notebooks/../data', gpus=0, hidden_vector_size=64, learning_rate=0.001, lstm_dropout=0.2, max_epochs=1, model_base='lstm', multi_headed=False, num_classification_layers=1, num_lstm_layers=1, patience=10, probability_threshold=0.5, remove_b1_b10=False)

## 2. Check data class distribution

In [71]:
from collections import Counter

### Training set

In [90]:
loader = model.train_dataloader()
train_counter = Counter()
for sample in loader:
    x, y, weight = sample
    train_counter.update(y.numpy())

Number of geowiki instances in training set: 27947


In [92]:
print('Non-croplad / Cropland in training set (%):')
a = train_counter[0.0] / (train_counter[0.0] + train_counter[1.0])
a, 1-a

Non-croplad / Cropland in training set (%):


(0.7786166672630336, 0.22138333273696642)

### Validation set

In [93]:
loader = model.val_dataloader()
val_counter = Counter()
for sample in loader:
    x, y, weight = sample
    val_counter.update(y.numpy())

Number of geowiki instances in validation set: 7301


In [94]:
val_counter
print('Non-croplad / Cropland in validation set (%):')
a = val_counter[0.0] / (val_counter[0.0] + val_counter[1.0])
a, 1-a

Non-croplad / Cropland in validation set (%):


(0.7883851527188056, 0.2116148472811944)

### Test set

In [95]:
loader = model.test_dataloader()
test_counter = Counter()
for sample in loader:
    x, y, weight = sample
    test_counter.update(y.numpy())

Number of geowiki instances in testing set: 351


In [96]:
test_counter
print('Non-croplad / Cropland in test set (%):')
a = test_counter[0.0] / (test_counter[0.0] + test_counter[1.0])
a, 1-a

Non-croplad / Cropland in test set (%):


(0.7777777777777778, 0.2222222222222222)

## 3. Train model 

In [13]:
train_model(model, new_model_args)

Number of geowiki instances in validation set: 7301


Validation sanity check:   0%|          | 0/5 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Number of geowiki instances in training set: 27947
Number of geowiki instances in validation set: 7301


Validating:   0%|          | 0/115 [00:00<?, ?it/s]

(LandCoverMapper(
   (base): LSTM(
     (lstm): UnrolledLSTM(
       (rnn): UnrolledLSTMCell(
         (forget_gate): Sequential(
           (0): Linear(in_features=76, out_features=64, bias=True)
           (1): Sigmoid()
         )
         (update_gate): Sequential(
           (0): Linear(in_features=76, out_features=64, bias=True)
           (1): Sigmoid()
         )
         (update_candidates): Sequential(
           (0): Linear(in_features=76, out_features=64, bias=True)
           (1): Tanh()
         )
         (output_gate): Sequential(
           (0): Linear(in_features=76, out_features=64, bias=True)
           (1): Sigmoid()
         )
         (cell_state_activation): Tanh()
       )
       (dropout): Dropout(p=0.2, inplace=False)
     )
   )
   (global_classifier): Sequential(
     (0): Linear(in_features=64, out_features=1, bias=True)
   )
 ),
 <pytorch_lightning.trainer.trainer.Trainer at 0x7f0aef0b5550>)