In [1]:
import sys
from pathlib import Path
from argparse import ArgumentParser, Namespace
from tqdm import tqdm
import math
import pickle
from typing import Tuple, List, Dict, Union, Optional, TypeVar, Type

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader, random_split

sys.path.append('..')

from src.engineer.geowiki import GeoWikiDataInstance
from src.exporters.sentinel.cloudfree import BANDS
from src.models import STR2MODEL, train_model

%reload_ext autoreload
%autoreload 2

### Create custom Dataset class

In [2]:
GeowikiDatasetType = TypeVar('GeowikiDatasetType', bound='Parent') # for typing

class GeowikiDataset(Dataset):
    
    def __init__(self, data_dir: Union[Path, str],
                csv_file: str='geowiki_labels_country_crs4326.csv',
                countries_subset: Optional[List[str]]=None,
                countries_to_weight: Optional[List[str]]=None,
                remove_b1_b10: bool=True,
                labels: Optional[pd.DataFrame]=None # if this is passed csv_file will be ignored
                ) -> None:
        
        # Constructor arguments
        self.data_dir = data_dir
        self.csv_file = csv_file
        self.countries_subset = countries_subset
        self.countries_to_weight = countries_to_weight
        self.remove_b1_b10 = remove_b1_b10
        
        # Instance parameters
        self.bands_to_remove = ["B1", "B10"]
        self.crop_probability_threshold = 0.5
        self.normalizing_dict = self.load_files_and_normalizing_dict(data_dir)[1]
        print(self.normalizing_dict)

        # Functions
        if labels is None:
            self.labels = pd.read_csv(self.data_dir / self.csv_file)
            self.labels.loc[self.labels['country'].isnull(), 'country'] = 'unknown'
            if self.countries_subset:
                self.labels = self.labels[self.labels['country'].str.lower().isin(list(map(str.lower, self.countries_subset)))].reset_index(drop=True)
        else:
            self.labels = labels
        self.picke_files = self.get_pickle_files_paths(self.data_dir / 'all')
        self.file_identifiers_countries_to_weight = self.get_file_ids_for_countries(self.countries_to_weight)
        print('length labels:', len(self.labels))
        print('length pickle files:', len(self.picke_files))
        print('length local ids:', len(self.file_identifiers_countries_to_weight))

    def __len__(self) -> int:
        return len(self.labels)

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Return the data, label, and weight tensors.
        """
        target_file = self.picke_files[index]
        identifier = int(target_file.name.split('_')[0])

        with target_file.open("rb") as f:
            target_datainstance = pickle.load(f)

        if isinstance(target_datainstance, GeoWikiDataInstance):
            if self.crop_probability_threshold is None:
                label = target_datainstance.crop_probability
            else:
                label = int(target_datainstance.crop_probability >= self.crop_probability_threshold)
        else:
            raise RuntimeError(f"Unrecognized data instance type {type(target_datainstance)}")

        is_local = 0
        if identifier in self.file_identifiers_countries_to_weight:
           is_local = 1

        return (
            torch.from_numpy(
                self.remove_bands(x=self._normalize(target_datainstance.labelled_array))
            ).float(),
            torch.tensor(label).float(),
            torch.tensor(is_local).long(),
        )

    @property
    def num_output_classes(self) -> int:
        return 1

    @property
    def num_input_features(self) -> int:
        # assumes the first value in the tuple is x
        assert len(self.pickle_files) > 0, "No files to load!"
        output_tuple = self[0]
        return output_tuple[0].shape[1]

    @property
    def num_timesteps(self) -> int:
        # assumes the first value in the tuple is x
        assert len(self.pickle_files) > 0, "No files to load!"
        output_tuple = self[0]
        return output_tuple[0].shape[0]

    def remove_bands(self, x: np.ndarray) -> np.ndarray:
        """
        Expects the input to be of shape [timesteps, bands]
        """
        if self.remove_b1_b10:
            indices_to_remove: List[int] = []
            for band in self.bands_to_remove:
                indices_to_remove.append(BANDS.index(band))

            bands_index = 1 if len(x.shape) == 2 else 2
            indices_to_keep = [i for i in range(x.shape[bands_index]) if i not in indices_to_remove]
            if len(x.shape) == 2:
                # timesteps, bands
                return x[:, indices_to_keep]
            else:
                # batches, timesteps, bands
                return x[:, :, indices_to_keep]
        else:
            return x

    @staticmethod
    def load_files_and_normalizing_dict(
        features_dir: Path, subset_name: str='training'
    ) -> Tuple[List[Path], Optional[Dict[str, np.ndarray]]]:
        pickle_files = list((features_dir / subset_name).glob("*.pkl"))

        # try loading the normalizing dict. By default, if it exists we will use it
        if (features_dir / "normalizing_dict.pkl").exists():
            with (features_dir / "normalizing_dict.pkl").open("rb") as f:
                normalizing_dict = pickle.load(f)
        else:
            normalizing_dict = None

        return pickle_files, normalizing_dict

    def _normalize(self, array: np.ndarray) -> np.ndarray:
        if self.normalizing_dict is None:
            return array
        else:
            return (array - self.normalizing_dict["mean"]) / self.normalizing_dict["std"]

    def get_pickle_files_paths(self, folder_path: Path) -> Tuple[List[Path]]:
        file_paths = self.labels.filename.tolist()
        print('Checking for data files')
        pickle_files = [path for path in tqdm(folder_path.glob('*.pkl')) if path.name in file_paths]
        self._check_label_files(pickle_files)
        return pickle_files

    def _check_label_files(self, pickle_files) -> None:
        same_files = set([file.name for file in pickle_files]) == set(self.labels.filename.tolist())
        assert same_files, "Some pickle files of the labels were not found!"
        print('All pickle files were found!')

    def get_file_ids_for_countries(self, countries_list: List[str]) -> List[int]:
        file_ids = []
        if countries_list:
            countries_list_lowercase = list(map(str.lower, countries_list))
            file_ids.extend(self.labels[self.labels['country'].str.lower().isin(countries_list_lowercase)]['identifier'].tolist())
        return file_ids

    @classmethod
    def train_val_split(cls: Type[GeowikiDatasetType], class_instance: Type[GeowikiDatasetType], train_size: float=0.8, stratify_column: Optional[str]=None) -> Tuple[GeowikiDatasetType]:
        # Split labels dataframe
        stratify = None if not stratify_column else class_instance.labels[stratify_column]
        df_train, df_val = train_test_split(class_instance.labels, train_size=train_size, stratify=stratify, random_state=42)
        df_train.reset_index(drop=True, inplace=True) 
        df_val.reset_index(drop=True, inplace=True)

        # Create two new GeowikiDataset instances (train and val)
        print('Train split')
        train_dataset = cls(class_instance.data_dir, countries_to_weight=class_instance.countries_to_weight, labels=df_train)
        print('Val split')
        val_dataset = cls(class_instance.data_dir, countries_to_weight=class_instance.countries_to_weight, labels=df_val)
        return train_dataset, val_dataset

    def get_file_by_identifier(self):
        pass

In [4]:
data_dir = Path('../data/features/geowiki_landcover_2017')
subset = None #['Ghana', 'Togo', 'Nigeria', 'Chad', 'Democratic Republic of the Congo', 'Ethiopia', 'Chad', 'Mali']
dataset = GeowikiDataset(data_dir, 'geowiki_labels_country_crs4326.csv', countries_subset=subset, countries_to_weight=['Nigeria'])

{'mean': array([0.19353804, 0.17112217, 0.16083624, 0.16354993, 0.18635676,
       0.25554994, 0.29061711, 0.28009877, 0.31469831, 0.10141977,
       0.0087153 , 0.22964706, 0.15255525, 0.3221835 ]), 'std': array([0.14932182, 0.15265479, 0.14360899, 0.16329558, 0.15796025,
       0.14746618, 0.15011357, 0.14306833, 0.14913972, 0.09338568,
       0.02771975, 0.1111936 , 0.09549155, 0.23958353])}
Checking for data files


35599it [00:06, 5234.67it/s] 

All pickle files were found!
length labels: 35599
length pickle files: 35599
length local ids: 490





In [None]:
file_paths = dataset.labels.filename.tolist()
folder_path = dataset.data_dir / 'all'
pickle_files = [path for path in tqdm(folder_path.glob('*.pkl')) if path.name in file_paths]
pickle_files

35599it [00:00, 37640.66it/s]


[PosixPath('../data/features/geowiki_landcover_2017/all/12178_2017-03-28_2018-03-28.pkl'),
 PosixPath('../data/features/geowiki_landcover_2017/all/10292_2017-03-28_2018-03-28.pkl'),
 PosixPath('../data/features/geowiki_landcover_2017/all/6515_2017-03-28_2018-03-28.pkl'),
 PosixPath('../data/features/geowiki_landcover_2017/all/13225_2017-03-28_2018-03-28.pkl'),
 PosixPath('../data/features/geowiki_landcover_2017/all/11281_2017-03-28_2018-03-28.pkl'),
 PosixPath('../data/features/geowiki_landcover_2017/all/10037_2017-03-28_2018-03-28.pkl'),
 PosixPath('../data/features/geowiki_landcover_2017/all/8208_2017-03-28_2018-03-28.pkl'),
 PosixPath('../data/features/geowiki_landcover_2017/all/13822_2017-03-28_2018-03-28.pkl'),
 PosixPath('../data/features/geowiki_landcover_2017/all/7614_2017-03-28_2018-03-28.pkl'),
 PosixPath('../data/features/geowiki_landcover_2017/all/9226_2017-03-28_2018-03-28.pkl'),
 PosixPath('../data/features/geowiki_landcover_2017/all/11237_2017-03-28_2018-03-28.pkl'),
 Po

## How to get splits

### Attempt by splitting Geowiki dataset class

In [10]:
test_ratio = 0.2
test_size = math.floor(len(dataset) * test_ratio)
lenghts = [len(dataset) - test_size, test_size]
lenghts

[1828, 456]

In [11]:
sum(lenghts)

2284

In [12]:
train_dataset, val_dataset = random_split(dataset, lenghts)  #generator is not yet available in this pytorch version, generator=torch.Generator().manual_seed(42))

In [13]:
train_dataset.pickle_files # Subset class doesn't inheret properties like pickles files plus I would need to subset those as well

AttributeError: 'Subset' object has no attribute 'pickle_files'

### With sklearn train test split on dataframe with labels and then create separate Geowiki datasets for each set

In [14]:
df = dataset.labels
df_subset = df[df['country'].isin(['Nigeria', 'Ghana'])]
df_subset.groupby('country').size()

country
Ghana      143
Nigeria    490
dtype: int64

In [15]:
df_train, df_test = train_test_split(df_subset, test_size=0.1, stratify=df_subset['country'])

In [16]:
df_train.groupby('country').size()

country
Ghana      129
Nigeria    440
dtype: int64

In [17]:
df_test.groupby('country').size()

country
Ghana      14
Nigeria    50
dtype: int64

In [18]:
df_train, df_test = train_test_split(df_subset, test_size=0.1)

In [19]:
df_train.groupby('country').size()

country
Ghana      129
Nigeria    440
dtype: int64

In [20]:
df_test.groupby('country').size()

country
Ghana      14
Nigeria    50
dtype: int64

In [22]:
df['country'].value_counts().to_csv('geowiki_points_per_country.csv')

In [23]:
df['country'].isnull().value_counts()

False    2284
Name: country, dtype: int64

## Train a model

In [5]:
train_dataset, val_dataset = dataset.train_val_split(dataset)

Train split
Checking for data files


35599it [00:00, 43137.46it/s]


All pickle files were found!
length labels: 1827
length pickle files: 1827
length local ids: 374
Val split
Checking for data files


35599it [00:00, 142892.72it/s]

All pickle files were found!
length labels: 457
length pickle files: 457
length local ids: 116





In [31]:
parser = ArgumentParser()
parser.add_argument("--max_epochs", type=int, default=1000)
parser.add_argument("--patience", type=int, default=10)
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--wandb", default=False, action="store_true")
parser.add_argument("--weighted_loss_fn", default=False, action="store_true")

_StoreTrueAction(option_strings=['--weighted_loss_fn'], dest='weighted_loss_fn', nargs=0, const=True, default=False, type=None, choices=None, help=None, metavar=None)

In [32]:
model_args = STR2MODEL["land_cover"].add_model_specific_args(parser).parse_args(args=[])

In [33]:
model = STR2MODEL["land_cover"](model_args)
model.hparams

Number of geowiki instances in training set: 27947


Namespace(add_geowiki=True, add_togo=True, alpha=10, batch_size=64, data_folder='/home/gajo/code/togo-crop-mask/notebooks/../data', gpus=0, hidden_vector_size=64, learning_rate=0.001, lstm_dropout=0.2, max_epochs=1000, model_base='lstm', multi_headed=True, num_classification_layers=2, num_lstm_layers=1, patience=10, probability_threshold=0.5, remove_b1_b10=True, wandb=False, weighted_loss_fn=False)

In [34]:
new_model_args_dict = vars(model_args)

In [35]:
# SET MODIFICATIONS TO DEFAULT MODEL ARGUMENTS:
new_model_args_dict['add_togo'] = False
new_model_args_dict['multi_headed'] = False
new_model_args_dict['num_classification_layers'] = 1
new_model_args_dict['max_epochs'] = 10 # Just for dev
new_model_args_dict['weighted_loss_fn'] = True

In [36]:
# Initialize model with new arguments
new_model_args = Namespace(**new_model_args_dict)
model = STR2MODEL["land_cover"](new_model_args)
model.hparams

Number of geowiki instances in training set: 27947


Namespace(add_geowiki=True, add_togo=False, alpha=10, batch_size=64, data_folder='/home/gajo/code/togo-crop-mask/notebooks/../data', gpus=0, hidden_vector_size=64, learning_rate=0.001, lstm_dropout=0.2, max_epochs=10, model_base='lstm', multi_headed=False, num_classification_layers=1, num_lstm_layers=1, patience=10, probability_threshold=0.5, remove_b1_b10=True, wandb=False, weighted_loss_fn=True)

In [18]:
model.train_dataloader = lambda: DataLoader(train_dataset, batch_size=model.hparams.batch_size)
model.train_dataloader

<function __main__.<lambda>()>

In [19]:
model.val_dataloader = lambda: DataLoader(val_dataset, batch_size=model.hparams.batch_size)
model.val_dataloader

<function __main__.<lambda>()>

In [37]:
trainer = train_model(model, new_model_args) 

Number of geowiki instances in validation set: 7301


Validation sanity check:   0%|          | 0/5 [00:00<?, ?it/s]

confusion matrix: [[ 50 200]
 [ 15  55]]


0it [00:00, ?it/s]

Number of geowiki instances in training set: 27947
Number of geowiki instances in validation set: 7301


Validating:   0%|          | 0/115 [00:00<?, ?it/s]

confusion matrix: [[3939 1817]
 [ 375 1170]]


Validating:   0%|          | 0/115 [00:00<?, ?it/s]

confusion matrix: [[4098 1658]
 [ 315 1230]]


Validating:   0%|          | 0/115 [00:00<?, ?it/s]

confusion matrix: [[4162 1594]
 [ 307 1238]]


Validating:   0%|          | 0/115 [00:00<?, ?it/s]

confusion matrix: [[4373 1383]
 [ 333 1212]]


Validating:   0%|          | 0/115 [00:00<?, ?it/s]

confusion matrix: [[4586 1170]
 [ 379 1166]]


Validating:   0%|          | 0/115 [00:00<?, ?it/s]

confusion matrix: [[4427 1329]
 [ 305 1240]]


Validating:   0%|          | 0/115 [00:00<?, ?it/s]

confusion matrix: [[4586 1170]
 [ 349 1196]]


Validating:   0%|          | 0/115 [00:00<?, ?it/s]

confusion matrix: [[4593 1163]
 [ 344 1201]]


Validating:   0%|          | 0/115 [00:00<?, ?it/s]

confusion matrix: [[4422 1334]
 [ 296 1249]]


Validating:   0%|          | 0/115 [00:00<?, ?it/s]

confusion matrix: [[4694 1062]
 [ 352 1193]]


In [38]:
trainer.test()

Evaluating using the Nigeria evaluation dataset!
Number of instances in nigeria_farmlands_v2 test set: 739


Testing:   0%|          | 0/12 [00:00<?, ?it/s]

confusion matrix: [[156 190]
 [265 128]]
----------------------------------------------------------------------------------------------------
TEST RESULTS
{'test_loss': 2.4458844661712646, 'test_roc_auc_score': 0.32900175028313405, 'test_precision_score': 0.4025157232704403, 'test_recall_score': 0.3256997455470738, 'test_f1_score': 0.360056258790436, 'test_accuracy': 0.38430311231393777}
----------------------------------------------------------------------------------------------------


## DONE:
- Support splitting dataset into train/val and update attributes (self.labels, self.pickle_files)
    - Could either create dataset and split dataset, or split dataframe with sklearn and then create separate Geowiki datasets for each subset --> went for a hybrid
    - Stratify:
        - By country -> OK (just need to get rid of nans if there are any)
        - By label (would need to define therhesold first)
        - By both (https://stackoverflow.com/questions/45516424/sklearn-train-test-split-on-pandas-stratify-by-multiple-columns)
- Add confusion matrix --> just printing it for now as tensorboard integration of lightning is only accepting scalars
## TODOS:
- Test multihead training with Nigeria (and Togo to see how much the local head helped in the original paper)
    - Figure out normalizing dict for train and val set--> OK. just using normalizing_dict of full geowiki
    - Weighted loss fuction in multihead case. --> Probably need some separate class weights per head
- Train of subset of countries.
    - Need to figure out how to deal with normalizing dict in this case.
- Later maybe see a way if I could just inherent from LandTypeClassificationDataset so I don't repeat too much code