In [21]:
import sys
from pathlib import Path
from argparse import ArgumentParser, Namespace
from tqdm import tqdm
import math
import pickle
from typing import Tuple, List, Dict, Union, Optional, TypeVar, Type
from itertools import permutations

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader, random_split

sys.path.append('..')

from src.engineer.geowiki import GeoWikiDataInstance, GeoWikiEngineer
from src.exporters.sentinel.cloudfree import BANDS
from src.models import STR2MODEL, train_model

%reload_ext autoreload
%autoreload 2

### Create custom Dataset class

In [22]:
GeowikiDatasetType = TypeVar('GeowikiDatasetType', bound='Parent') # for typing

class GeowikiDataset(Dataset):
    
    dataset_name: str = GeoWikiEngineer.dataset
    csv_file: str = 'geowiki_labels_country_crs4326.csv'

    def __init__(self, data_folder: Path = Path('../data'),
                countries_subset: Optional[List[str]] = None,
                countries_to_weight: Optional[List[str]] = None,
                crop_probability_threshold: float = 0.5,
                remove_b1_b10: bool = True,
                normalizing_dict: Optional[Dict] = None,
                labels: Optional[pd.DataFrame] = None # if this is passed csv_file will be ignored
                ) -> None:
        
        # Attributes
        self.data_folder = data_folder
        self.dataset_dir = data_folder / "features" / self.dataset_name
        self.countries_subset = countries_subset
        self.countries_to_weight = countries_to_weight
        self.bands_to_remove = ["B1", "B10"]
        self.remove_b1_b10 = remove_b1_b10
        self.crop_probability_threshold = crop_probability_threshold
        
        # Functions
        if labels is None:
            self.labels = pd.read_csv(self.dataset_dir / self.csv_file)
            self.labels.loc[self.labels['country'].isnull(), 'country'] = 'unknown'
            if self.countries_subset:
                self.labels = self.labels[self.labels['country'].str.lower().isin(list(map(str.lower, self.countries_subset)))].reset_index(drop=True)
        else:
            self.labels = labels
        self.pickle_files = self.get_pickle_files_paths(self.dataset_dir / 'all')
        self.file_identifiers_countries_to_weight = self.get_file_ids_for_countries(self.countries_to_weight)
        print('length labels:', len(self.labels))
        print('length pickle files:', len(self.pickle_files))
        print('length local ids:', len(self.file_identifiers_countries_to_weight))

        # Normalizing dictionary
        if normalizing_dict is None:
            self.normalizing_dict = self.get_normalizing_dict()
        else:
            self.normalizing_dict = normalizing_dict    
        print(self.normalizing_dict)

    def __len__(self) -> int:
        return len(self.labels)

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Return the data, label, and weight tensors.
        """
        target_file = self.pickle_files[index]
        identifier = int(target_file.name.split('_')[0])

        with target_file.open("rb") as f:
            target_datainstance = pickle.load(f)

        if isinstance(target_datainstance, GeoWikiDataInstance):
            if self.crop_probability_threshold is None:
                label = target_datainstance.crop_probability
            else:
                label = int(target_datainstance.crop_probability >= self.crop_probability_threshold)
        else:
            raise RuntimeError(f"Unrecognized data instance type {type(target_datainstance)}")

        is_local = 0
        if identifier in self.file_identifiers_countries_to_weight:
            is_local = 1

        return (
            torch.from_numpy(
                self.remove_bands(x=self._normalize(target_datainstance.labelled_array))
            ).float(),
            torch.tensor(label).float(),
            torch.tensor(is_local).long(),
        )

    @property
    def num_output_classes(self) -> int:
        return 1

    @property
    def num_input_features(self) -> int:
        # assumes the first value in the tuple is x
        assert len(self.pickle_files) > 0, "No files to load!"
        output_tuple = self[0]
        return output_tuple[0].shape[1]

    @property
    def num_timesteps(self) -> int:
        # assumes the first value in the tuple is x
        assert len(self.pickle_files) > 0, "No files to load!"
        output_tuple = self[0]
        return output_tuple[0].shape[0]

    def remove_bands(self, x: np.ndarray) -> np.ndarray:
        """
        Expects the input to be of shape [timesteps, bands]
        """
        if self.remove_b1_b10:
            indices_to_remove: List[int] = []
            for band in self.bands_to_remove:
                indices_to_remove.append(BANDS.index(band))

            bands_index = 1 if len(x.shape) == 2 else 2
            indices_to_keep = [i for i in range(x.shape[bands_index]) if i not in indices_to_remove]
            if len(x.shape) == 2:
                # timesteps, bands
                return x[:, indices_to_keep]
            else:
                # batches, timesteps, bands
                return x[:, :, indices_to_keep]
        else:
            return x

    def _normalize(self, array: np.ndarray) -> np.ndarray:
        if self.normalizing_dict is None:
            return array
        else:
            return (array - self.normalizing_dict["mean"]) / self.normalizing_dict["std"]

    @staticmethod
    def load_files_and_normalizing_dict(
        features_dir: Path, subset_name: str='training', file_name: str="normalizing_dict.pkl"
    ) -> Tuple[List[Path], Optional[Dict[str, np.ndarray]]]:
        pickle_files = list((features_dir / subset_name).glob("*.pkl"))

        # try loading the normalizing dict. By default, if it exists we will use it
        if (features_dir / file_name).exists():
            with (features_dir / file_name).open("rb") as f:
                normalizing_dict = pickle.load(f)
        else:
            normalizing_dict = None

        return pickle_files, normalizing_dict

    def search_normalizing_dict(self, default_file_name: str="normalizing_dict.pkl") -> Optional[Path]:
        '''
        Searches for the normalizing dict file in the self.data_dir directory and returns its path. Returns None if it was not found.
        '''
        prefix = default_file_name.split('.')[0]
        
        if not self.countries_subset:
            file_path = self.dataset_dir / default_file_name
            if file_path.exists():
                print(f'Found normalizing dict {file_path.name}')
                return file_path
        elif len(self.countries_subset) == 1 and self.countries[0].lower() == 'africa':
            raise NotImplementedError # TODO

        else:
            assert len(self.countries_subset) < 10, 'Execution time will be too big!' # TODO: add warning when passing subset to constructor
            countries_permutations = list(permutations(self.countries_subset))
            countries_permutations = ['_'.join(permutation) for permutation in countries_permutations]
            for permutation in countries_permutations:
                file_name = f"{prefix}_{permutation}.pkl"
                file_path = self.dataset_dir / file_name
                if file_path.exists():
                    print(f'Found normalizing dict {file_name}')
                    return file_path
        print('Normalizing dict not found.')
        return None

    def get_normalizing_dict(self, save: bool=False) -> Dict:
        # Return dict if it was found or create it and save
        default_file_name = "normalizing_dict.pkl"
        file_path = self.search_normalizing_dict(default_file_name)
        if file_path:
            print('Loading normalizing dict.')
            return self.load_files_and_normalizing_dict(self.dataset_dir, file_name=file_path.name)[1]
        else:
            print('Calculating normalizing dict...')
            assert len(self) == len(self.pickle_files), 'Length of self.labels must be the same as of the list of pickle files.'
            geowiki_engineer = GeoWikiEngineer(Path('../data'))
            
            for file_path in tqdm(self.pickle_files):
                identifier = int(file_path.name.split('_')[0])
                with file_path.open("rb") as f:   
                    target_datainstance = pickle.load(f)
                geowiki_engineer.update_normalizing_values(target_datainstance.labelled_array)

            normalizing_dict = geowiki_engineer.calculate_normalizing_dict()

            # Write file
            if save:    
                if self.countries_subset:
                    prefix = default_file_name.split('.')[0]
                    countries_str = '_'.join(self.countries_subset)
                    file_name = f"{prefix}_{countries_str}.pkl"
                else:
                    file_name = default_file_name
                file_path = self.data_dir / file_name
                print('Saving normalizing dict', file_path.name)
                with file_path.open("wb") as f:
                    pickle.dump(normalizing_dict, f)

            return normalizing_dict
            
    def get_pickle_files_paths(self, folder_path: Path) -> Tuple[List[Path]]:
        file_paths = self.labels.filename.tolist()
        print('Checking for data files')
        pickle_files = [path for path in tqdm(folder_path.glob('*.pkl')) if path.name in file_paths]
        self._check_label_files(pickle_files)
        return pickle_files

    def _check_label_files(self, pickle_files) -> None:
        same_files = set([file.name for file in pickle_files]) == set(self.labels.filename.tolist())
        assert same_files, "Some pickle files of the labels were not found!"
        print('All pickle files were found!')

    def get_file_ids_for_countries(self, countries_list: List[str]) -> List[int]:
        file_ids = []
        if countries_list:
            countries_list_lowercase = list(map(str.lower, countries_list))
            file_ids.extend(self.labels[self.labels['country'].str.lower().isin(countries_list_lowercase)]['identifier'].tolist())
        return file_ids

    @classmethod
    def train_val_split(cls: Type[GeowikiDatasetType], class_instance: Type[GeowikiDatasetType],
                        train_size: float=0.8, stratify_column: Optional[str]=None
                        ) -> Tuple[GeowikiDatasetType]:
        # Made it a class method to be able to generate two child instances of the class.
        # Haven't figured out a better way for now than passing the parent instance as an argument.
        
        # Split labels dataframe
        stratify = None if not stratify_column else class_instance.labels[stratify_column]
        df_train, df_val = train_test_split(class_instance.labels, train_size=train_size, stratify=stratify, random_state=42)
        df_train.reset_index(drop=True, inplace=True) 
        df_val.reset_index(drop=True, inplace=True)

        # Create two new GeowikiDataset instances (train and val)
        print('Train split')
        train_dataset = cls(countries_to_weight=class_instance.countries_to_weight,
                        normalizing_dict=class_instance.normalizing_dict, labels=df_train)
        print('Val split')
        val_dataset = cls(countries_to_weight=class_instance.countries_to_weight,
                        normalizing_dict=class_instance.normalizing_dict, labels=df_val)
        return train_dataset, val_dataset

    def get_file_by_identifier(self):
        raise NotImplementedError


In [23]:
subset = ['Ghana', 'Togo', 'Nigeria', 'Cameroon', 'Benin'] #['Nigeria'] #None
#subset = ['Ghana', 'Togo', 'Nigeria', 'Chad', 'Democratic Republic of the Congo', 'Ethiopia', 'Chad', 'Mali']
dataset = GeowikiDataset(countries_subset=subset, countries_to_weight=['Nigeria'])

Checking for data files


35599it [00:00, 93443.17it/s] 


All pickle files were found!
length labels: 837
length pickle files: 837
length local ids: 490
Found normalizing dict normalizing_dict_Nigeria_Cameroon_Benin_Togo_Ghana.pkl
Loading normalizing dict.
{'mean': array([0.17615442, 0.15639622, 0.15478129, 0.16052005, 0.18210957,
       0.25861192, 0.30205652, 0.28616927, 0.3337731 , 0.07249666,
       0.00730804, 0.29148372, 0.19433004, 0.3125438 ]), 'std': array([0.07856543, 0.08016264, 0.07900642, 0.09663687, 0.08940903,
       0.08163167, 0.08889727, 0.08184323, 0.09073389, 0.05365633,
       0.01807999, 0.10828142, 0.10158471, 0.19213894])}


## How to get splits

### Attempt by splitting Geowiki dataset class

In [188]:
test_ratio = 0.2
test_size = math.floor(len(dataset) * test_ratio)
lenghts = [len(dataset) - test_size, test_size]
lenghts

[670, 167]

In [189]:
sum(lenghts)

837

In [190]:
train_dataset, val_dataset = random_split(dataset, lenghts)  #generator is not yet available in this pytorch version, generator=torch.Generator().manual_seed(42))

In [None]:
train_dataset.pickle_files # Subset class doesn't inheret properties like pickles files plus I would need to subset those as well

### With sklearn train test split on dataframe with labels and then create separate Geowiki datasets for each set

In [None]:
df = dataset.labels
df_subset = df[df['country'].isin(['Nigeria', 'Ghana'])]
df_subset.groupby('country').size()

In [None]:
df_train, df_test = train_test_split(df_subset, test_size=0.1, stratify=df_subset['country'])

In [None]:
df_train.groupby('country').size()

In [None]:
df_test.groupby('country').size()

In [None]:
df_train, df_test = train_test_split(df_subset, test_size=0.1)

In [None]:
df_train.groupby('country').size()

In [None]:
df_test.groupby('country').size()

In [None]:
df['country'].value_counts().to_csv('geowiki_points_per_country.csv')

In [None]:
df['country'].isnull().value_counts()

## Train a model

In [24]:
train_dataset, val_dataset = dataset.train_val_split(dataset)

Train split
Checking for data files


35599it [00:00, 106750.68it/s]


All pickle files were found!
length labels: 669
length pickle files: 669
length local ids: 388
{'mean': array([0.17615442, 0.15639622, 0.15478129, 0.16052005, 0.18210957,
       0.25861192, 0.30205652, 0.28616927, 0.3337731 , 0.07249666,
       0.00730804, 0.29148372, 0.19433004, 0.3125438 ]), 'std': array([0.07856543, 0.08016264, 0.07900642, 0.09663687, 0.08940903,
       0.08163167, 0.08889727, 0.08184323, 0.09073389, 0.05365633,
       0.01807999, 0.10828142, 0.10158471, 0.19213894])}
Val split
Checking for data files


35599it [00:00, 266608.51it/s]

All pickle files were found!
length labels: 168
length pickle files: 168
length local ids: 102
{'mean': array([0.17615442, 0.15639622, 0.15478129, 0.16052005, 0.18210957,
       0.25861192, 0.30205652, 0.28616927, 0.3337731 , 0.07249666,
       0.00730804, 0.29148372, 0.19433004, 0.3125438 ]), 'std': array([0.07856543, 0.08016264, 0.07900642, 0.09663687, 0.08940903,
       0.08163167, 0.08889727, 0.08184323, 0.09073389, 0.05365633,
       0.01807999, 0.10828142, 0.10158471, 0.19213894])}





In [50]:
parser = ArgumentParser()
parser.add_argument("--max_epochs", type=int, default=1000)
parser.add_argument("--patience", type=int, default=10)
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--wandb", default=False, action="store_true")
parser.add_argument("--weighted_loss_fn", default=False, action="store_true")

_StoreTrueAction(option_strings=['--weighted_loss_fn'], dest='weighted_loss_fn', nargs=0, const=True, default=False, type=None, choices=None, help=None, metavar=None)

In [51]:
model_args = STR2MODEL["land_cover"].add_model_specific_args(parser).parse_args(args=[])

In [52]:
model = STR2MODEL["land_cover"](model_args)
model.hparams

Checking for data files


35599it [00:00, 92487.73it/s]


All pickle files were found!
length labels: 837
length pickle files: 837
length local ids: 490
Found normalizing dict normalizing_dict_Nigeria_Cameroon_Benin_Togo_Ghana.pkl
Loading normalizing dict.
{'mean': array([0.17615442, 0.15639622, 0.15478129, 0.16052005, 0.18210957,
       0.25861192, 0.30205652, 0.28616927, 0.3337731 , 0.07249666,
       0.00730804, 0.29148372, 0.19433004, 0.3125438 ]), 'std': array([0.07856543, 0.08016264, 0.07900642, 0.09663687, 0.08940903,
       0.08163167, 0.08889727, 0.08184323, 0.09073389, 0.05365633,
       0.01807999, 0.10828142, 0.10158471, 0.19213894])}
Train split
Checking for data files


35599it [00:00, 124717.39it/s]


All pickle files were found!
length labels: 669
length pickle files: 669
length local ids: 388
{'mean': array([0.17615442, 0.15639622, 0.15478129, 0.16052005, 0.18210957,
       0.25861192, 0.30205652, 0.28616927, 0.3337731 , 0.07249666,
       0.00730804, 0.29148372, 0.19433004, 0.3125438 ]), 'std': array([0.07856543, 0.08016264, 0.07900642, 0.09663687, 0.08940903,
       0.08163167, 0.08889727, 0.08184323, 0.09073389, 0.05365633,
       0.01807999, 0.10828142, 0.10158471, 0.19213894])}
Val split
Checking for data files


35599it [00:00, 183432.82it/s]


All pickle files were found!
length labels: 168
length pickle files: 168
length local ids: 102
{'mean': array([0.17615442, 0.15639622, 0.15478129, 0.16052005, 0.18210957,
       0.25861192, 0.30205652, 0.28616927, 0.3337731 , 0.07249666,
       0.00730804, 0.29148372, 0.19433004, 0.3125438 ]), 'std': array([0.07856543, 0.08016264, 0.07900642, 0.09663687, 0.08940903,
       0.08163167, 0.08889727, 0.08184323, 0.09073389, 0.05365633,
       0.01807999, 0.10828142, 0.10158471, 0.19213894])}
Number of model parameters: 28162


Namespace(add_geowiki=True, add_togo=True, alpha=10, batch_size=64, data_folder='/home/gajo/code/togo-crop-mask/notebooks/../data', gpus=0, hidden_vector_size=64, learning_rate=0.001, lstm_dropout=0.2, max_epochs=1000, model_base='lstm', multi_headed=True, num_classification_layers=2, num_lstm_layers=1, patience=10, probability_threshold=0.5, remove_b1_b10=True, wandb=False, weighted_loss_fn=False)

In [53]:
new_model_args_dict = vars(model_args)

In [54]:
# SET MODIFICATIONS TO DEFAULT MODEL ARGUMENTS:
new_model_args_dict['add_togo'] = False
new_model_args_dict['multi_headed'] = False
new_model_args_dict['num_classification_layers'] = 1
new_model_args_dict['max_epochs'] = 100 # Just for dev
new_model_args_dict['weighted_loss_fn'] = True
new_model_args_dict['hidden_vector_size'] = 64

In [55]:
# Initialize model with new arguments
new_model_args = Namespace(**new_model_args_dict)
model = STR2MODEL["land_cover"](new_model_args)
model.hparams

Checking for data files


35599it [00:00, 94942.14it/s]


All pickle files were found!
length labels: 837
length pickle files: 837
length local ids: 490
Found normalizing dict normalizing_dict_Nigeria_Cameroon_Benin_Togo_Ghana.pkl
Loading normalizing dict.
{'mean': array([0.17615442, 0.15639622, 0.15478129, 0.16052005, 0.18210957,
       0.25861192, 0.30205652, 0.28616927, 0.3337731 , 0.07249666,
       0.00730804, 0.29148372, 0.19433004, 0.3125438 ]), 'std': array([0.07856543, 0.08016264, 0.07900642, 0.09663687, 0.08940903,
       0.08163167, 0.08889727, 0.08184323, 0.09073389, 0.05365633,
       0.01807999, 0.10828142, 0.10158471, 0.19213894])}
Train split
Checking for data files


35599it [00:00, 125406.64it/s]


All pickle files were found!
length labels: 669
length pickle files: 669
length local ids: 388
{'mean': array([0.17615442, 0.15639622, 0.15478129, 0.16052005, 0.18210957,
       0.25861192, 0.30205652, 0.28616927, 0.3337731 , 0.07249666,
       0.00730804, 0.29148372, 0.19433004, 0.3125438 ]), 'std': array([0.07856543, 0.08016264, 0.07900642, 0.09663687, 0.08940903,
       0.08163167, 0.08889727, 0.08184323, 0.09073389, 0.05365633,
       0.01807999, 0.10828142, 0.10158471, 0.19213894])}
Val split
Checking for data files


35599it [00:00, 214960.44it/s]


All pickle files were found!
length labels: 168
length pickle files: 168
length local ids: 102
{'mean': array([0.17615442, 0.15639622, 0.15478129, 0.16052005, 0.18210957,
       0.25861192, 0.30205652, 0.28616927, 0.3337731 , 0.07249666,
       0.00730804, 0.29148372, 0.19433004, 0.3125438 ]), 'std': array([0.07856543, 0.08016264, 0.07900642, 0.09663687, 0.08940903,
       0.08163167, 0.08889727, 0.08184323, 0.09073389, 0.05365633,
       0.01807999, 0.10828142, 0.10158471, 0.19213894])}
Number of global labels: 837
Number of local labels: 0
Global class weights: tensor([1, 6])
Local class weights: None
Number of model parameters: 19777


Namespace(add_geowiki=True, add_togo=False, alpha=10, batch_size=64, data_folder='/home/gajo/code/togo-crop-mask/notebooks/../data', gpus=0, hidden_vector_size=64, learning_rate=0.001, lstm_dropout=0.2, max_epochs=100, model_base='lstm', multi_headed=False, num_classification_layers=1, num_lstm_layers=1, patience=10, probability_threshold=0.5, remove_b1_b10=True, wandb=False, weighted_loss_fn=True)

In [56]:
print('Number of model parameters:')
sum(param.numel() for param in model.parameters() if param.requires_grad_)

Number of model parameters:


19777

In [57]:
model.train_dataloader = lambda: DataLoader(train_dataset, batch_size=model.hparams.batch_size, shuffle=True)
model.train_dataloader

<function __main__.<lambda>()>

In [58]:
model.val_dataloader = lambda: DataLoader(val_dataset, batch_size=model.hparams.batch_size)
model.val_dataloader

<function __main__.<lambda>()>

In [59]:
model.global_class_weights, model.local_class_weights = model.get_class_weights()
model.global_class_weights, model.local_class_weights

Number of global labels: 837
Number of local labels: 0


(tensor([1, 6]), None)

In [60]:
model.normalizing_dict = train_dataset.normalizing_dict

In [61]:
trainer = train_model(model, new_model_args) 

Validation sanity check:   0%|          | 0/5 [00:00<?, ?it/s]

confusion matrix: [[ 16 120]
 [  0  32]]


0it [00:00, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[69 67]
 [ 4 28]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[83 53]
 [ 8 24]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[88 48]
 [ 9 23]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[88 48]
 [ 9 23]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[83 53]
 [ 8 24]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[76 60]
 [ 5 27]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[77 59]
 [ 6 26]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[83 53]
 [ 8 24]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[77 59]
 [ 6 26]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[75 61]
 [ 4 28]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[76 60]
 [ 5 27]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[81 55]
 [ 7 25]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[75 61]
 [ 5 27]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[75 61]
 [ 5 27]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[77 59]
 [ 5 27]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[78 58]
 [ 6 26]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[76 60]
 [ 4 28]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[74 62]
 [ 4 28]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[78 58]
 [ 6 26]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[79 57]
 [ 6 26]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[81 55]
 [ 6 26]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[73 63]
 [ 4 28]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[83 53]
 [ 7 25]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[77 59]
 [ 6 26]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[77 59]
 [ 4 28]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[84 52]
 [ 7 25]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[71 65]
 [ 2 30]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[82 54]
 [ 7 25]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[82 54]
 [ 4 28]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[72 64]
 [ 5 27]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[81 55]
 [ 7 25]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[73 63]
 [ 5 27]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[83 53]
 [ 6 26]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[82 54]
 [ 8 24]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[77 59]
 [ 3 29]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[84 52]
 [ 8 24]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[79 57]
 [ 2 30]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[90 46]
 [10 22]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[79 57]
 [ 3 29]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[84 52]
 [ 7 25]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[87 49]
 [ 8 24]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[86 50]
 [ 7 25]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[88 48]
 [ 9 23]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[83 53]
 [ 4 28]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[82 54]
 [ 5 27]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[90 46]
 [ 8 24]]


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

confusion matrix: [[86 50]
 [ 5 27]]


In [62]:
trainer.test()

Evaluating using the Nigeria evaluation dataset!
Number of instances in nigeria_farmlands_v2 test set: 739
{'mean': array([0.17615442, 0.15639622, 0.15478129, 0.16052005, 0.18210957,
       0.25861192, 0.30205652, 0.28616927, 0.3337731 , 0.07249666,
       0.00730804, 0.29148372, 0.19433004, 0.3125438 ]), 'std': array([0.07856543, 0.08016264, 0.07900642, 0.09663687, 0.08940903,
       0.08163167, 0.08889727, 0.08184323, 0.09073389, 0.05365633,
       0.01807999, 0.10828142, 0.10158471, 0.19213894])}


Testing:   0%|          | 0/12 [00:00<?, ?it/s]

confusion matrix: [[185 161]
 [268 125]]
----------------------------------------------------------------------------------------------------
TEST RESULTS
{'test_loss': 8.46878719329834, 'test_roc_auc_score': 0.36391180926326316, 'test_precision_score': 0.4370629370629371, 'test_recall_score': 0.31806615776081426, 'test_f1_score': 0.36818851251840945, 'test_accuracy': 0.41948579161028415}
----------------------------------------------------------------------------------------------------


## DONE:
- Support splitting dataset into train/val and update attributes (self.labels, self.pickle_files)
    - Could either create dataset and split dataset, or split dataframe with sklearn and then create separate Geowiki datasets for each subset --> went for a hybrid
    - Stratify:
        - By country -> OK (just need to get rid of nans if there are any)
        - By label (would need to define therhesold first)
        - By both (https://stackoverflow.com/questions/45516424/sklearn-train-test-split-on-pandas-stratify-by-multiple-columns)
- Add confusion matrix --> just printing it for now as tensorboard integration of lightning is only accepting scalars
- Test multihead training with Nigeria (and Togo to see how much the local head helped in the original paper)
    - Figure out normalizing dict for train and val set--> OK. just using normalizing_dict of full geowiki
- Train of subset of countries.
    - Need to figure out how to deal with normalizing dict automatically in this case (currently just loading the ones previously computed)
    - Weighted loss fuction in multihead case. --> Probably need some separate class weights per head
## TODOS:
- Implement into src
    - Pass countries_subset as list for argparser and prevent error of logging with tensorboard by not logging lists like confusion matrix or countries subset.
    - Generate csv file with labels per country (in geowikiEngineer, and can be call in engineer.py script)
    - Move all geowiki pickle files to all folder
- Add only Africa
- Later maybe see a way if I could just inherent from LandTypeClassificationDataset so I don't repeat too much code