In [10]:
import sys
from argparse import ArgumentParser, Namespace
from pathlib import Path

sys.path.append('..')

from src.models import STR2MODEL

In [11]:
DATA_DIR = Path("../data")

In [12]:
def get_model(add_geowiki: bool, add_nigeria: bool, geowiki_subset: str):
    parser = ArgumentParser()
    parser.add_argument("--max_epochs", type=int, default=100)
    parser.add_argument("--patience", type=int, default=10)
    parser.add_argument("--gpus", type=int, default=0)
    parser.add_argument("--wandb", default=False, action="store_true")

    model_args = STR2MODEL["land_cover"].add_model_specific_args(parser).parse_args(args=[])
    new_model_args_dict = vars(model_args)

    # SET MODIFICATIONS TO DEFAULT MODEL ARGUMENTS:
    new_model_args_dict['add_geowiki'] = add_geowiki
    new_model_args_dict['add_nigeria'] = add_nigeria
    new_model_args_dict['geowiki_subset'] =  geowiki_subset # 'nigeria', 'neighbours1'

    new_model_args = Namespace(**new_model_args_dict)
    model = STR2MODEL["land_cover"](new_model_args)
    
    return model

In [13]:
add_geowiki = True
add_nigeria = False
geowiki_subset = 'world'
landcovermapper = get_model(add_geowiki, add_nigeria, geowiki_subset)
df = landcovermapper.geowiki_dataset.labels
df.is_crop.value_counts()

In [29]:
add_geowiki = True
add_nigeria = False
geowiki_subset = 'neighbours1'
landcovermapper = get_model(add_geowiki, add_nigeria, geowiki_subset)
df = landcovermapper.geowiki_dataset.labels
df.is_crop.value_counts()

Found normalizing dict geowiki_normalizing_dict_Ghana_Togo_Nigeria_Cameroon_Benin.h5
Loading normalizing dict geowiki_normalizing_dict_Ghana_Togo_Nigeria_Cameroon_Benin.h5
Creating Geowiki train split
Creating Geowiki val split
Number of instances in Geowiki training set: 632
Total number of files used for training: 632
Number of model parameters: 25473


1    460
0    330
Name: is_crop, dtype: int64

In [30]:
add_geowiki = True
add_nigeria = False
geowiki_subset = 'nigeria'
landcovermapper = get_model(add_geowiki, add_nigeria, geowiki_subset)
df = landcovermapper.geowiki_dataset.labels
df.is_crop.value_counts()

Found normalizing dict geowiki_normalizing_dict_Nigeria.h5
Loading normalizing dict geowiki_normalizing_dict_Nigeria.h5
Creating Geowiki train split
Creating Geowiki val split
Number of instances in Geowiki training set: 361
Total number of files used for training: 361
Number of model parameters: 25473


1    312
0    140
Name: is_crop, dtype: int64

In [31]:
add_geowiki = False
add_nigeria = True
landcovermapper = get_model(add_geowiki, add_nigeria, geowiki_subset)

Number of instances in Nigeria training set: 913
Total number of files used for training: 913
Number of model parameters: 25473


In [32]:
train_dataset = landcovermapper.get_dataset(subset="training")
val_dataset = landcovermapper.get_dataset(subset="validation", normalizing_dict=landcovermapper.normalizing_dict)
test_dataset = landcovermapper.get_dataset(subset="testing", normalizing_dict=landcovermapper.normalizing_dict)

Number of instances in Nigeria training set: 913
Total number of files used for training: 913
Number of instances in Nigeria validation set: 454
Total number of files used for validation: 454
Number of instances in Nigeria testing set: 455


In [33]:
len(train_dataset) + len(val_dataset) + len(test_dataset)

1822

In [37]:
train_dataset.datasets['nigeria'].is_crop.describe()

count    913.000000
mean       0.417306
std        0.493384
min        0.000000
25%        0.000000
50%        0.000000
75%        1.000000
max        1.000000
Name: is_crop, dtype: float64

In [40]:
val_dataset.datasets['nigeria'].is_crop.describe()

count    454.000000
mean       0.398678
std        0.490166
min        0.000000
25%        0.000000
50%        0.000000
75%        1.000000
max        1.000000
Name: is_crop, dtype: float64

In [39]:
test_dataset.datasets['nigeria'].is_crop.describe()

count    455.000000
mean       0.402198
std        0.490881
min        0.000000
25%        0.000000
50%        0.000000
75%        1.000000
max        1.000000
Name: is_crop, dtype: float64

In [44]:
(train_dataset.datasets['nigeria'].is_crop.sum() + val_dataset.datasets['nigeria'].is_crop.sum() + test_dataset.datasets['nigeria'].is_crop.sum()) / (train_dataset.datasets['nigeria'].is_crop.shape[0] + val_dataset.datasets['nigeria'].is_crop.shape[0] + test_dataset.datasets['nigeria'].is_crop.shape[0])

0.4088913282107574