In [3]:
import sys
from pathlib import Path
from tqdm import tqdm
import math
import pickle
from typing import Union, Tuple, List, Optional, TypeVar, Type

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, random_split

sys.path.append('..')

from src.engineer.geowiki import GeoWikiDataInstance
from src.exporters.sentinel.cloudfree import BANDS

%reload_ext autoreload
%autoreload 2

### Create custom Dataset class

In [10]:
GeowikiDatasetType = TypeVar('GeowikiDatasetType', bound='Parent') # for annotations

class GeowikiDataset(Dataset):
    
    def __init__(self, data_dir: Union[Path, str],
                csv_file: str='geowiki_labels_country_crs4326.csv',
                countries_subset: Optional[List[str]]=None,
                countries_to_weight: Optional[List[str]]=None,
                remove_b1_b10: bool=True,
                labels: Optional[pd.DataFrame]=None # if this is passed csv_file will be ignored
                ) -> None:
        
        # Constructor arguments
        self.data_dir = data_dir
        self.csv_file = csv_file
        self.countries_subset = countries_subset
        self.countries_to_weight = countries_to_weight
        self.remove_b1_b10 = remove_b1_b10
        
        # Instance parameters
        self.bands_to_remove = ["B1", "B10"]
        self.crop_probability_threshold = 0.5
        self.normalizing_dict = None

        # Functions
        if labels is None:
            self.labels = pd.read_csv(self.data_dir / self.csv_file)
            self.labels.loc[self.labels['country'].isnull(), 'country'] = 'unknown'
            if self.countries_subset:
                self.labels = self.labels[self.labels['country'].str.lower().isin(list(map(str.lower, self.countries_subset)))].reset_index(drop=True)
        else:
            self.labels = labels
        self.picke_files = self.get_pickle_files_paths(self.data_dir / 'all')
        self.file_identifiers_countries_to_weight = self.get_file_ids_for_countries(self.countries_to_weight)
        print('length labels:', len(self.labels))
        print('length pickle files:', len(self.picke_files))
        print('length local ids:', len(self.file_identifiers_countries_to_weight))

    def __len__(self) -> int:
        return len(self.labels)

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Return the data, label, and weight tensors.
        """
        target_file = self.picke_files[index]
        identifier = int(target_file.name.split('_')[0])

        with target_file.open("rb") as f:
            target_datainstance = pickle.load(f)

        if isinstance(target_datainstance, GeoWikiDataInstance):
            if self.crop_probability_threshold is None:
                label = target_datainstance.crop_probability
            else:
                label = int(target_datainstance.crop_probability >= self.crop_probability_threshold)
        else:
            raise RuntimeError(f"Unrecognized data instance type {type(target_datainstance)}")

        is_local = 0
        if identifier in self.file_identifiers_countries_to_weight:
           is_local = 1

        return (
            torch.from_numpy(
                self.remove_bands(x=self._normalize(target_datainstance.labelled_array))
            ).float(),
            torch.tensor(label).float(),
            torch.tensor(is_local).long(),
        )

    @property
    def num_output_classes(self) -> int:
        return 1

    @property
    def num_input_features(self) -> int:
        # assumes the first value in the tuple is x
        assert len(self.pickle_files) > 0, "No files to load!"
        output_tuple = self[0]
        return output_tuple[0].shape[1]

    @property
    def num_timesteps(self) -> int:
        # assumes the first value in the tuple is x
        assert len(self.pickle_files) > 0, "No files to load!"
        output_tuple = self[0]
        return output_tuple[0].shape[0]

    def remove_bands(self, x: np.ndarray) -> np.ndarray:
        """
        Expects the input to be of shape [timesteps, bands]
        """
        if self.remove_b1_b10:
            indices_to_remove: List[int] = []
            for band in self.bands_to_remove:
                indices_to_remove.append(BANDS.index(band))

            bands_index = 1 if len(x.shape) == 2 else 2
            indices_to_keep = [i for i in range(x.shape[bands_index]) if i not in indices_to_remove]
            if len(x.shape) == 2:
                # timesteps, bands
                return x[:, indices_to_keep]
            else:
                # batches, timesteps, bands
                return x[:, :, indices_to_keep]
        else:
            return x

    def _normalize(self, array: np.ndarray) -> np.ndarray:
        if self.normalizing_dict is None:
            return array
        else:
            return (array - self.normalizing_dict["mean"]) / self.normalizing_dict["std"]

    def get_pickle_files_paths(self, folder_path: Path) -> Tuple[List[Path]]:
        file_paths = self.labels.filename.tolist()
        print('Checking for data files')
        pickle_files = [path for path in tqdm(folder_path.glob('*.pkl')) if path.name in file_paths]
        self._check_label_files(pickle_files)
        return pickle_files

    def _check_label_files(self, pickle_files) -> None:
        same_files = set([file.name for file in pickle_files]) == set(self.labels.filename.tolist())
        assert same_files, "Some pickle files of the labels were not found!"
        print('All pickle files were found!')

    def get_file_ids_for_countries(self, countries_list: List[str]) -> List[int]:
        file_ids = []
        if countries_list:
            countries_list_lowercase = list(map(str.lower, countries_list))
            file_ids.extend(self.labels[self.labels['country'].str.lower().isin(countries_list_lowercase)]['identifier'].tolist())
        return file_ids

    @classmethod
    def train_val_split(cls: Type[GeowikiDatasetType], class_instance: Type[GeowikiDatasetType], train_size: float=0.8, stratify_column: Optional[str]=None) -> Tuple[GeowikiDatasetType]:
        # Split labels dataframe
        stratify = None if not stratify_column else class_instance.labels[stratify_column]
        df_train, df_val = train_test_split(class_instance.labels, train_size=train_size, stratify=stratify, random_state=42)
        df_train.reset_index(drop=True, inplace=True) 
        df_val.reset_index(drop=True, inplace=True)

        # Create two new GeowikiDataset instances (train and val)
        print('Train split')
        train_dataset = cls(class_instance.data_dir, countries_to_weight=class_instance.countries_to_weight, labels=df_train)
        print('Val split')
        val_dataset = cls(class_instance.data_dir, countries_to_weight=class_instance.countries_to_weight, labels=df_val)
        return train_dataset, val_dataset

    def get_file_by_identifier(self):
        pass

In [24]:
data_dir = Path('../data/features/geowiki_landcover_2017')
subset = ['Ghana', 'Togo', 'Nigeria', 'Chad', 'Democratic Republic of the Congo', 'Ethiopia', 'Chad', 'Mali']
dataset = GeowikiDataset(data_dir, 'geowiki_labels_country_crs4326.csv', countries_subset=subset, countries_to_weight=['Nigeria'])

Checking for data files


35599it [00:00, 36993.80it/s]

All pickle files were found!
length labels: 2284
length pickle files: 2284
length local ids: 490





In [25]:
df = dataset.labels
df

Unnamed: 0,identifier,date,lat,lon,label,filename,country
0,12178,2017-03-28_2018-03-28,9.651814,0.550622,0.154286,12178_2017-03-28_2018-03-28.pkl,Togo
1,10292,2017-03-28_2018-03-28,3.952363,24.249976,0.000000,10292_2017-03-28_2018-03-28.pkl,Democratic Republic of the Congo
2,6515,2017-03-28_2018-03-28,-10.047612,24.749978,0.000000,6515_2017-03-28_2018-03-28.pkl,Democratic Republic of the Congo
3,13225,2017-03-28_2018-03-28,11.651823,37.550612,0.160000,13225_2017-03-28_2018-03-28.pkl,Ethiopia
4,11281,2017-03-28_2018-03-28,7.452379,5.249979,0.000000,11281_2017-03-28_2018-03-28.pkl,Nigeria
...,...,...,...,...,...,...,...
2279,10224,2017-03-28_2018-03-28,3.452360,21.749965,0.000000,10224_2017-03-28_2018-03-28.pkl,Democratic Republic of the Congo
2280,13134,2017-03-28_2018-03-28,11.452397,-5.249979,0.553846,13134_2017-03-28_2018-03-28.pkl,Mali
2281,14245,2017-03-28_2018-03-28,13.651742,20.952351,0.000000,14245_2017-03-28_2018-03-28.pkl,Chad
2282,14813,2017-03-28_2018-03-28,14.952413,-4.249975,0.000000,14813_2017-03-28_2018-03-28.pkl,Mali


In [26]:
train, test = dataset.train_val_split(dataset)

Train split
Checking for data files


35599it [00:00, 45286.90it/s]


All pickle files were found!
length labels: 1827
length pickle files: 1827
length local ids: 374
Val split
Checking for data files


35599it [00:00, 144565.50it/s]

All pickle files were found!
length labels: 457
length pickle files: 457
length local ids: 116





In [27]:
train.labels

Unnamed: 0,identifier,date,lat,lon,label,filename,country
0,13582,2017-03-28_2018-03-28,12.452402,4.749977,0.500000,13582_2017-03-28_2018-03-28.pkl,Nigeria
1,13309,2017-03-28_2018-03-28,11.952399,6.249984,0.150000,13309_2017-03-28_2018-03-28.pkl,Nigeria
2,11739,2017-03-28_2018-03-28,8.651809,9.952390,0.180000,11739_2017-03-28_2018-03-28.pkl,Nigeria
3,7507,2017-03-28_2018-03-28,-6.749986,17.952337,0.000000,7507_2017-03-28_2018-03-28.pkl,Democratic Republic of the Congo
4,10968,2017-03-28_2018-03-28,6.452374,35.250026,0.000000,10968_2017-03-28_2018-03-28.pkl,Ethiopia
...,...,...,...,...,...,...,...
1822,15040,2017-03-28_2018-03-28,15.651751,-4.047584,0.000000,15040_2017-03-28_2018-03-28.pkl,Mali
1823,14366,2017-03-28_2018-03-28,13.952408,-6.249984,0.000000,14366_2017-03-28_2018-03-28.pkl,Mali
1824,6800,2017-03-28_2018-03-28,-9.047607,26.749988,0.006667,6800_2017-03-28_2018-03-28.pkl,Democratic Republic of the Congo
1825,11427,2017-03-28_2018-03-28,7.952381,47.749994,0.000000,11427_2017-03-28_2018-03-28.pkl,Ethiopia


In [20]:
test.labels

Unnamed: 0,identifier,date,lat,lon,label,filename,country
0,13587,2017-03-28_2018-03-28,12.452402,7.249988,0.000000,13587_2017-03-28_2018-03-28.pkl,Nigeria
1,13268,2017-03-28_2018-03-28,11.851159,21.050626,0.000000,13268_2017-03-28_2018-03-28.pkl,Chad
2,12575,2017-03-28_2018-03-28,10.452392,8.249993,0.406667,12575_2017-03-28_2018-03-28.pkl,Nigeria
3,12953,2017-03-28_2018-03-28,11.250007,-8.047602,0.000000,12953_2017-03-28_2018-03-28.pkl,Mali
4,12013,2017-03-28_2018-03-28,9.249997,34.651749,0.000000,12013_2017-03-28_2018-03-28.pkl,Ethiopia
...,...,...,...,...,...,...,...
452,12253,2017-03-28_2018-03-28,9.851150,-0.949385,0.850000,12253_2017-03-28_2018-03-28.pkl,Ghana
453,12767,2017-03-28_2018-03-28,10.851155,13.050590,0.180000,12767_2017-03-28_2018-03-28.pkl,Nigeria
454,14275,2017-03-28_2018-03-28,13.651742,-6.047593,0.000000,14275_2017-03-28_2018-03-28.pkl,Mali
455,11413,2017-03-28_2018-03-28,7.952381,5.249979,0.025000,11413_2017-03-28_2018-03-28.pkl,Nigeria


In [None]:
file_paths = dataset.labels.filename.tolist()
folder_path = dataset.data_dir / 'all'
pickle_files = [path for path in tqdm(folder_path.glob('*.pkl')) if path.name in file_paths]
pickle_files

In [None]:
array = dataset[0]
array

## How to get splits

### Attempt by splitting Geowiki dataset class

In [None]:
test_ratio = 0.2
test_size = math.floor(len(dataset) * test_ratio)
lenghts = [len(dataset) - test_size, test_size]
lenghts

In [None]:
sum(lenghts)

In [None]:
train_dataset, val_dataset = random_split(dataset, lenghts)  #generator is not yet available in this pytorch version, generator=torch.Generator().manual_seed(42))

In [None]:
train_dataset.pickle_files # Subset class doesn't inheret properties like pickles files plus I would need to subset those as well

### With sklearn train test split on dataframe with labels and then create separate Geowiki datasets for each set

In [None]:
df_subset = df[df['country'].isin(['Nigeria', 'Ghana'])]
df_subset.groupby('country').size()

In [None]:
df_train, df_test = train_test_split(df_subset, test_size=0.1, stratify=df_subset['country'])

In [None]:
df_train.groupby('country').size()

In [None]:
df_test.groupby('country').size()

In [None]:
df_train, df_test = train_test_split(df_subset, test_size=0.1)

In [None]:
df_train.groupby('country').size()

In [None]:
df_test.groupby('country').size()

In [None]:
df_train[df_train['country'] == 'Ghana']

In [None]:
df['country'].value_counts().to_csv('geowiki_points_per_country.csv')

In [None]:
df['country'].isnull().value_counts()

TODOS:
- Figure out normalizing dict
- Support splitting dataset into train/val and update attributes (self.labels, self.pickle_files)
    - Could either create dataset and split dataset, or split dataframe with sklearn and then create separate Geowiki datasets for each subset --> went for a hybrid
    - Stratify:
        - By country -> OK (just need to get rid of nans if there are any)
        - By label (would need to define therhesold first)
        - By both (https://stackoverflow.com/questions/45516424/sklearn-train-test-split-on-pandas-stratify-by-multiple-columns)
- Test multihead training with Togo and then Nigeria or two countries
- Train of subset of countries
- Later maybe see a way if I could just inherent from LandTypeClassificationDataset so I don't repeat too much code