In [60]:
import sys
from pathlib import Path
from argparse import ArgumentParser, Namespace
from tqdm import tqdm
import math
import pickle
from typing import Tuple, List, Dict, Union, Optional, TypeVar, Type
from itertools import permutations

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader, random_split

sys.path.append('..')

from src.engineer.geowiki import GeoWikiDataInstance, GeoWikiEngineer
from src.exporters.sentinel.cloudfree import BANDS
from src.models import STR2MODEL, train_model

%reload_ext autoreload
%autoreload 2

### Create custom Dataset class

In [104]:
GeowikiDatasetType = TypeVar('GeowikiDatasetType', bound='Parent') # for typing

class GeowikiDataset(Dataset):
    
    def __init__(self, data_dir: Union[Path, str],
                csv_file: str='geowiki_labels_country_crs4326.csv',
                countries_subset: Optional[List[str]]=None,
                countries_to_weight: Optional[List[str]]=None,
                remove_b1_b10: bool=True,
                labels: Optional[pd.DataFrame]=None # if this is passed csv_file will be ignored
                ) -> None:
        
        # Constructor arguments
        self.data_dir = data_dir
        self.csv_file = csv_file
        self.countries_subset = countries_subset
        self.countries_to_weight = countries_to_weight
        self.remove_b1_b10 = remove_b1_b10
        
        # Instance parameters
        self.bands_to_remove = ["B1", "B10"]
        self.crop_probability_threshold = 0.5

        # Functions
        if labels is None:
            self.labels = pd.read_csv(self.data_dir / self.csv_file)
            self.labels.loc[self.labels['country'].isnull(), 'country'] = 'unknown'
            if self.countries_subset:
                self.labels = self.labels[self.labels['country'].str.lower().isin(list(map(str.lower, self.countries_subset)))].reset_index(drop=True)
        else:
            self.labels = labels
        self.pickle_files = self.get_pickle_files_paths(self.data_dir / 'all')
        self.file_identifiers_countries_to_weight = self.get_file_ids_for_countries(self.countries_to_weight)
        print('length labels:', len(self.labels))
        print('length pickle files:', len(self.pickle_files))
        print('length local ids:', len(self.file_identifiers_countries_to_weight))

        # Normalizing dictionary
        self.normalizing_dict = self.get_normalizing_dict()
        #self.normalizing_dict = self.load_files_and_normalizing_dict(data_dir)[1]
        print(self.normalizing_dict)

    def __len__(self) -> int:
        return len(self.labels)

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Return the data, label, and weight tensors.
        """
        target_file = self.pickle_files[index]
        identifier = int(target_file.name.split('_')[0])

        with target_file.open("rb") as f:
            target_datainstance = pickle.load(f)

        if isinstance(target_datainstance, GeoWikiDataInstance):
            if self.crop_probability_threshold is None:
                label = target_datainstance.crop_probability
            else:
                label = int(target_datainstance.crop_probability >= self.crop_probability_threshold)
        else:
            raise RuntimeError(f"Unrecognized data instance type {type(target_datainstance)}")

        is_local = 0
        if identifier in self.file_identifiers_countries_to_weight:
           is_local = 1

        return (
            torch.from_numpy(
                self.remove_bands(x=self._normalize(target_datainstance.labelled_array))
            ).float(),
            torch.tensor(label).float(),
            torch.tensor(is_local).long(),
        )

    @property
    def num_output_classes(self) -> int:
        return 1

    @property
    def num_input_features(self) -> int:
        # assumes the first value in the tuple is x
        assert len(self.pickle_files) > 0, "No files to load!"
        output_tuple = self[0]
        return output_tuple[0].shape[1]

    @property
    def num_timesteps(self) -> int:
        # assumes the first value in the tuple is x
        assert len(self.pickle_files) > 0, "No files to load!"
        output_tuple = self[0]
        return output_tuple[0].shape[0]

    def remove_bands(self, x: np.ndarray) -> np.ndarray:
        """
        Expects the input to be of shape [timesteps, bands]
        """
        if self.remove_b1_b10:
            indices_to_remove: List[int] = []
            for band in self.bands_to_remove:
                indices_to_remove.append(BANDS.index(band))

            bands_index = 1 if len(x.shape) == 2 else 2
            indices_to_keep = [i for i in range(x.shape[bands_index]) if i not in indices_to_remove]
            if len(x.shape) == 2:
                # timesteps, bands
                return x[:, indices_to_keep]
            else:
                # batches, timesteps, bands
                return x[:, :, indices_to_keep]
        else:
            return x

    def _normalize(self, array: np.ndarray) -> np.ndarray:
        if self.normalizing_dict is None:
            return array
        else:
            return (array - self.normalizing_dict["mean"]) / self.normalizing_dict["std"]

    @staticmethod
    def load_files_and_normalizing_dict(
        features_dir: Path, subset_name: str='training', file_name: str="normalizing_dict.pkl"
    ) -> Tuple[List[Path], Optional[Dict[str, np.ndarray]]]:
        pickle_files = list((features_dir / subset_name).glob("*.pkl"))

        # try loading the normalizing dict. By default, if it exists we will use it
        if (features_dir / file_name).exists():
            with (features_dir / file_name).open("rb") as f:
                normalizing_dict = pickle.load(f)
        else:
            normalizing_dict = None

        return pickle_files, normalizing_dict
    
    def search_normalizing_dict(self, default_file_name: str="normalizing_dict.pkl") -> Optional[Path]:
        '''
        Searches for the normalizing dict file in the self.data_dir directory and returns its path. Returns None if it was not found.
        '''
        prefix = default_file_name.split('.')[0]
        
        if not self.countries_subset:
            file_path = self.data_dir / default_file_name
            if file_path.exists():
                print(f'Found normalizing dict {file_path.name}')
                return file_path
        elif len(self.countries_subset) == 1 and self.countries[0].lower() == 'africa':
            raise NotImplementedError # TODO

        else:
            assert len(self.countries_subset) < 10, 'Execution time will be too big!' # TODO: add warning when passing subset to constructor
            countries_permutations = list(permutations(self.countries_subset))
            countries_permutations = ['_'.join(permutation) for permutation in countries_permutations]
            for permutation in countries_permutations:
                file_name = f"{prefix}_{permutation}.pkl"
                file_path = self.data_dir / file_name
                if file_path.exists():
                    print(f'Found normalizing dict {file_name}')
                    return file_path
        print('Normalizing dict not found.')
        return None
    
    def get_normalizing_dict(self, save: bool=False) -> Dict:
        # Return dict if it was found or create it and save
        default_file_name = "normalizing_dict.pkl"
        file_path = self.search_normalizing_dict(default_file_name)
        if file_path:
            print('Loading normalizing dict.')
            return self.load_files_and_normalizing_dict(self.data_dir, file_name=file_path.name)[1]
        else:
            print('Calculating normalizing dict...')
            assert len(self) == len(self.pickle_files), 'Length of self.labels must be the same as of the list of pickle files.'
            geowiki_engineer = GeoWikiEngineer(Path('../data'))
            
            for file_path in tqdm(self.pickle_files):
                identifier = int(file_path.name.split('_')[0])
                with file_path.open("rb") as f:   
                    target_datainstance = pickle.load(f)
                geowiki_engineer.update_normalizing_values(target_datainstance.labelled_array)

            normalizing_dict = geowiki_engineer.calculate_normalizing_dict()

            # Write file
            if save:    
                if self.countries_subset:
                    prefix = default_file_name.split('.')[0]
                    countries_str = '_'.join(self.countries_subset)
                    file_name = f"{prefix}_{countries_str}.pkl"
                else:
                    file_name = default_file_name
                file_path = self.data_dir / file_name
                print('Saving normalizing dict', file_path.name)
                with file_path.open("wb") as f:
                    pickle.dump(normalizing_dict, f)

            return normalizing_dict
            
    def get_pickle_files_paths(self, folder_path: Path) -> Tuple[List[Path]]:
        file_paths = self.labels.filename.tolist()
        print('Checking for data files')
        pickle_files = [path for path in tqdm(folder_path.glob('*.pkl')) if path.name in file_paths]
        self._check_label_files(pickle_files)
        return pickle_files

    def _check_label_files(self, pickle_files) -> None:
        same_files = set([file.name for file in pickle_files]) == set(self.labels.filename.tolist())
        assert same_files, "Some pickle files of the labels were not found!"
        print('All pickle files were found!')

    def get_file_ids_for_countries(self, countries_list: List[str]) -> List[int]:
        file_ids = []
        if countries_list:
            countries_list_lowercase = list(map(str.lower, countries_list))
            file_ids.extend(self.labels[self.labels['country'].str.lower().isin(countries_list_lowercase)]['identifier'].tolist())
        return file_ids

    @classmethod
    def train_val_split(cls: Type[GeowikiDatasetType], class_instance: Type[GeowikiDatasetType], train_size: float=0.8, stratify_column: Optional[str]=None) -> Tuple[GeowikiDatasetType]:
        # Split labels dataframe
        stratify = None if not stratify_column else class_instance.labels[stratify_column]
        df_train, df_val = train_test_split(class_instance.labels, train_size=train_size, stratify=stratify, random_state=42)
        df_train.reset_index(drop=True, inplace=True) 
        df_val.reset_index(drop=True, inplace=True)

        # Create two new GeowikiDataset instances (train and val)
        print('Train split')
        train_dataset = cls(class_instance.data_dir, countries_to_weight=class_instance.countries_to_weight, labels=df_train)
        print('Val split')
        val_dataset = cls(class_instance.data_dir, countries_to_weight=class_instance.countries_to_weight, labels=df_val)
        return train_dataset, val_dataset

    def get_file_by_identifier(self):
        pass

In [107]:
data_dir = Path('../data/features/geowiki_landcover_2017')
subset = ['Ghana', 'Togo', 'Nigeria', 'Cameroon', 'Benin'] #['Nigeria'] #None
subset = ['Nigeria', 'India']
subset = None
#subset = ['Ghana', 'Togo', 'Nigeria', 'Chad', 'Democratic Republic of the Congo', 'Ethiopia', 'Chad', 'Mali']
dataset = GeowikiDataset(data_dir, 'geowiki_labels_country_crs4326.csv', countries_subset=subset, countries_to_weight=['Nigeria'])

Checking for data files


35599it [00:08, 4383.52it/s] 


All pickle files were found!
length labels: 35599
length pickle files: 35599
length local ids: 490
Found normalizing dict normalizing_dict.pkl
Loading normalizing dict.
{'mean': array([0.19353804, 0.17112217, 0.16083624, 0.16354993, 0.18635676,
       0.25554994, 0.29061711, 0.28009877, 0.31469831, 0.10141977,
       0.0087153 , 0.22964706, 0.15255525, 0.3221835 ]), 'std': array([0.14932182, 0.15265479, 0.14360899, 0.16329558, 0.15796025,
       0.14746618, 0.15011357, 0.14306833, 0.14913972, 0.09338568,
       0.02771975, 0.1111936 , 0.09549155, 0.23958353])}


In [110]:
dataset.labels['country'].value_counts()#.to_csv('geowiki_points_per_country.csv')

China                       3923
United States of America    3489
Brazil                      3165
Russia                      2511
India                       1589
                            ... 
Northern Cyprus                1
Samoa                          1
Saint Lucia                    1
Trinidad and Tobago            1
Mauritius                      1
Name: country, Length: 177, dtype: int64

In [113]:
import geopandas as gpd
world_map = gpd.read_file('../assets/ne_50m_admin_0_countries/ne_50m_admin_0_countries.shp')
world_map

Unnamed: 0,featurecla,scalerank,LABELRANK,SOVEREIGNT,SOV_A3,ADM0_DIF,LEVEL,TYPE,ADMIN,ADM0_A3,...,FCLASS_TR,FCLASS_ID,FCLASS_PL,FCLASS_GR,FCLASS_IT,FCLASS_NL,FCLASS_SE,FCLASS_BD,FCLASS_UA,geometry
0,Admin-0 country,1,3,Zimbabwe,ZWE,0,2,Sovereign country,Zimbabwe,ZWE,...,,,,,,,,,,"POLYGON ((31.28789 -22.40205, 31.19727 -22.344..."
1,Admin-0 country,1,3,Zambia,ZMB,0,2,Sovereign country,Zambia,ZMB,...,,,,,,,,,,"POLYGON ((30.39609 -15.64307, 30.25068 -15.643..."
2,Admin-0 country,1,3,Yemen,YEM,0,2,Sovereign country,Yemen,YEM,...,,,,,,,,,,"MULTIPOLYGON (((53.08564 16.64839, 52.58145 16..."
3,Admin-0 country,3,2,Vietnam,VNM,0,2,Sovereign country,Vietnam,VNM,...,,,,,,,,,,"MULTIPOLYGON (((104.06396 10.39082, 104.08301 ..."
4,Admin-0 country,5,3,Venezuela,VEN,0,2,Sovereign country,Venezuela,VEN,...,,,,,,,,,,"MULTIPOLYGON (((-60.82119 9.13838, -60.94141 9..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
237,Admin-0 country,1,3,Afghanistan,AFG,0,2,Sovereign country,Afghanistan,AFG,...,,,,,,,,,,"POLYGON ((66.52227 37.34849, 66.82773 37.37129..."
238,Admin-0 country,1,5,Kashmir,KAS,0,2,Indeterminate,Siachen Glacier,KAS,...,Unrecognized,Unrecognized,Unrecognized,Unrecognized,Unrecognized,Unrecognized,Unrecognized,Unrecognized,Unrecognized,"POLYGON ((77.04863 35.10991, 77.00449 35.19634..."
239,Admin-0 country,3,4,Antarctica,ATA,0,2,Indeterminate,Antarctica,ATA,...,,,,,,,,,,"MULTIPOLYGON (((-45.71777 -60.52090, -45.49971..."
240,Admin-0 country,3,6,Netherlands,NL1,1,2,Country,Sint Maarten,SXM,...,,,,,,,,,,"POLYGON ((-63.12305 18.06895, -63.01118 18.068..."


## How to get splits

### Attempt by splitting Geowiki dataset class

In [None]:
test_ratio = 0.2
test_size = math.floor(len(dataset) * test_ratio)
lenghts = [len(dataset) - test_size, test_size]
lenghts

In [None]:
sum(lenghts)

In [None]:
train_dataset, val_dataset = random_split(dataset, lenghts)  #generator is not yet available in this pytorch version, generator=torch.Generator().manual_seed(42))

In [None]:
train_dataset.pickle_files # Subset class doesn't inheret properties like pickles files plus I would need to subset those as well

### With sklearn train test split on dataframe with labels and then create separate Geowiki datasets for each set

In [None]:
df = dataset.labels
df_subset = df[df['country'].isin(['Nigeria', 'Ghana'])]
df_subset.groupby('country').size()

In [None]:
df_train, df_test = train_test_split(df_subset, test_size=0.1, stratify=df_subset['country'])

In [None]:
df_train.groupby('country').size()

In [None]:
df_test.groupby('country').size()

In [None]:
df_train, df_test = train_test_split(df_subset, test_size=0.1)

In [None]:
df_train.groupby('country').size()

In [None]:
df_test.groupby('country').size()

In [None]:
df['country'].value_counts().to_csv('geowiki_points_per_country.csv')

In [None]:
df['country'].isnull().value_counts()

## Train a model

In [None]:
train_dataset, val_dataset = dataset.train_val_split(dataset)

In [None]:
parser = ArgumentParser()
parser.add_argument("--max_epochs", type=int, default=1000)
parser.add_argument("--patience", type=int, default=10)
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--wandb", default=False, action="store_true")
parser.add_argument("--weighted_loss_fn", default=False, action="store_true")

In [None]:
model_args = STR2MODEL["land_cover"].add_model_specific_args(parser).parse_args(args=[])

In [None]:
model = STR2MODEL["land_cover"](model_args)
model.hparams

In [None]:
new_model_args_dict = vars(model_args)

In [None]:
# SET MODIFICATIONS TO DEFAULT MODEL ARGUMENTS:
new_model_args_dict['add_togo'] = False
new_model_args_dict['multi_headed'] = False
new_model_args_dict['num_classification_layers'] = 1
new_model_args_dict['max_epochs'] = 100 # Just for dev
new_model_args_dict['weighted_loss_fn'] = True
new_model_args_dict['hidden_vector_size'] = 64

In [None]:
# Initialize model with new arguments
new_model_args = Namespace(**new_model_args_dict)
model = STR2MODEL["land_cover"](new_model_args)
model.hparams

In [None]:
print('Number of model parameters:')
sum(param.numel() for param in model.parameters() if param.requires_grad_)

In [None]:
model.train_dataloader = lambda: DataLoader(train_dataset, batch_size=model.hparams.batch_size)
model.train_dataloader

In [None]:
model.val_dataloader = lambda: DataLoader(val_dataset, batch_size=model.hparams.batch_size)
model.val_dataloader

In [None]:
model.global_class_weights, model.local_class_weights = model.get_class_weights([train_dataset, val_dataset])
model.global_class_weights, model.local_class_weights

In [None]:
trainer = train_model(model, new_model_args) 

In [None]:
trainer.test()

## DONE:
- Support splitting dataset into train/val and update attributes (self.labels, self.pickle_files)
    - Could either create dataset and split dataset, or split dataframe with sklearn and then create separate Geowiki datasets for each subset --> went for a hybrid
    - Stratify:
        - By country -> OK (just need to get rid of nans if there are any)
        - By label (would need to define therhesold first)
        - By both (https://stackoverflow.com/questions/45516424/sklearn-train-test-split-on-pandas-stratify-by-multiple-columns)
- Add confusion matrix --> just printing it for now as tensorboard integration of lightning is only accepting scalars
- Test multihead training with Nigeria (and Togo to see how much the local head helped in the original paper)
    - Figure out normalizing dict for train and val set--> OK. just using normalizing_dict of full geowiki
- Train of subset of countries.
    - Need to figure out how to deal with normalizing dict automatically in this case (currently just loading the ones previously computed)
    - Weighted loss fuction in multihead case. --> Probably need some separate class weights per head
## TODOS:
- Implement into src
- Add only Africa
- Later maybe see a way if I could just inherent from LandTypeClassificationDataset so I don't repeat too much code