In [1]:
import sys
from pathlib import Path
import math
import pickle
from typing import Union, Tuple, List

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, random_split

sys.path.append('..')

from src.engineer.geowiki import GeoWikiDataInstance
from src.exporters.sentinel.cloudfree import BANDS

%reload_ext autoreload
%autoreload 2

### Create custom Dataset class

In [2]:
class GeowikiDataset(Dataset):
    
    def __init__(self, root_dir: Union[Path, str],
                csv_file: str='geowiki_labels_country_crs4326.csv',
                remove_b1_b10: bool=True) -> None:
        # Constructor arguments
        self.root_dir = root_dir
        self.csv_file = csv_file
        self.remove_b1_b10 = remove_b1_b10
        # Instance parameters
        self.bands_to_remove = ["B1", "B10"]
        self.crop_probability_threshold = 0.5
        self.local_countries = ['Nigeria', 'Togo', 'Ghana'] # could try with all points in africa
        self.normalizing_dict = None
        # Functions
        self.labels = pd.read_csv(self.root_dir / self.csv_file)
        self.picke_files = self.get_pickle_files_paths(self.root_dir / 'all')
        self.local_file_identifiers = self.get_local_file_ids()
        self.check_labels()

    def __len__(self) -> int:
        return len(self.labels)

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Return the data, label, and weight tensors.
        """
        target_file = self.picke_files[index]
        identifier = int(target_file.name.split('_')[0])

        with target_file.open("rb") as f:
            target_datainstance = pickle.load(f)

        if isinstance(target_datainstance, GeoWikiDataInstance):
            if self.crop_probability_threshold is None:
                label = target_datainstance.crop_probability
            else:
                label = int(target_datainstance.crop_probability >= self.crop_probability_threshold)
        else:
            raise RuntimeError(f"Unrecognized data instance type {type(target_datainstance)}")

        is_local = 0
        if identifier in self.local_file_identifiers:
           is_local = 1

        return (
            torch.from_numpy(
                self.remove_bands(x=self._normalize(target_datainstance.labelled_array))
            ).float(),
            torch.tensor(label).float(),
            torch.tensor(is_local).long(),
        )

    @property
    def num_output_classes(self) -> int:
        return 1

    @property
    def num_input_features(self) -> int:
        # assumes the first value in the tuple is x
        assert len(self.pickle_files) > 0, "No files to load!"
        output_tuple = self[0]
        return output_tuple[0].shape[1]

    @property
    def num_timesteps(self) -> int:
        # assumes the first value in the tuple is x
        assert len(self.pickle_files) > 0, "No files to load!"
        output_tuple = self[0]
        return output_tuple[0].shape[0]

    def remove_bands(self, x: np.ndarray) -> np.ndarray:
        """
        Expects the input to be of shape [timesteps, bands]
        """
        if self.remove_b1_b10:
            indices_to_remove: List[int] = []
            for band in self.bands_to_remove:
                indices_to_remove.append(BANDS.index(band))

            bands_index = 1 if len(x.shape) == 2 else 2
            indices_to_keep = [i for i in range(x.shape[bands_index]) if i not in indices_to_remove]
            if len(x.shape) == 2:
                # timesteps, bands
                return x[:, indices_to_keep]
            else:
                # batches, timesteps, bands
                return x[:, :, indices_to_keep]
        else:
            return x

    def _normalize(self, array: np.ndarray) -> np.ndarray:
        if self.normalizing_dict is None:
            return array
        else:
            return (array - self.normalizing_dict["mean"]) / self.normalizing_dict["std"]

    @staticmethod
    def get_pickle_files_paths(folder_path: Path) -> Tuple[List[Path]]: 
        return list(folder_path.glob('*.pkl'))

    def check_labels(self) -> None:
        is_subset = set([file.name for file in self.picke_files]).issubset(set(self.labels.filename.tolist()))
        assert is_subset, "Some of the pickle files are not present in the provided csv file!"

    def get_local_file_ids(self) -> List[int]:
        countries_list_lowercase = list(map(str.lower, self.local_countries))
        file_ids = self.labels[self.labels['country'].str.lower().isin(countries_list_lowercase)]['identifier'].tolist()
        return file_ids

    def subset_by_country(self):
        pass

    def random_by_country(self):
        pass

    def get_file_by_identifier(self):
        pass

In [3]:
root_dir = Path('../data/features/geowiki_landcover_2017')
dataset = GeowikiDataset(root_dir, 'geowiki_labels_country_crs4326.csv')

In [4]:
df = dataset.labels
df.head()

Unnamed: 0,identifier,date,lat,lon,label,filename,country
0,3233,2017-03-28_2018-03-28,-21.547574,138.249959,0.0,3233_2017-03-28_2018-03-28.pkl,Australia
1,2246,2017-03-28_2018-03-28,-26.547597,148.250004,0.8,2246_2017-03-28_2018-03-28.pkl,Australia
2,12913,2017-03-28_2018-03-28,11.151821,-2.848243,0.224,12913_2017-03-28_2018-03-28.pkl,Burkina Faso
3,18861,2017-03-28_2018-03-28,27.45238,56.250033,0.0,18861_2017-03-28_2018-03-28.pkl,Iran
4,32400,2017-03-28_2018-03-28,51.25001,4.952367,0.233333,32400_2017-03-28_2018-03-28.pkl,Belgium


In [5]:
array = dataset[0]
array

(tensor([[0.1606, 0.1863, 0.2670, 0.2753, 0.2995, 0.3166, 0.3207, 0.3426, 0.1343,
          0.4140, 0.3225, 0.0914],
         [0.1489, 0.1714, 0.2357, 0.2594, 0.2811, 0.3039, 0.2938, 0.3280, 0.0971,
          0.3987, 0.2954, 0.1097],
         [0.1549, 0.1761, 0.2477, 0.2610, 0.2802, 0.3071, 0.2913, 0.3314, 0.0970,
          0.3896, 0.2865, 0.0809],
         [0.1511, 0.1722, 0.2384, 0.2545, 0.2757, 0.2952, 0.2954, 0.3192, 0.1522,
          0.3897, 0.3005, 0.1068],
         [0.1503, 0.1720, 0.2370, 0.2534, 0.2739, 0.2906, 0.2946, 0.3187, 0.1542,
          0.3909, 0.3098, 0.1084],
         [0.1582, 0.1801, 0.2507, 0.2776, 0.3008, 0.3195, 0.3083, 0.3452, 0.1512,
          0.4320, 0.3369, 0.1030],
         [0.1600, 0.1866, 0.2597, 0.2689, 0.2896, 0.3178, 0.3016, 0.3434, 0.0741,
          0.4374, 0.3362, 0.0746],
         [0.1638, 0.1924, 0.2705, 0.2694, 0.2929, 0.3183, 0.3062, 0.3468, 0.0722,
          0.4389, 0.3288, 0.0619],
         [0.1588, 0.1863, 0.2674, 0.2755, 0.2910, 0.3169, 0.3073

## How to get splits

### Attempt by splitting Geowiki dataset class

In [6]:
test_ratio = 0.2
test_size = math.floor(len(dataset) * test_ratio)
lenghts = [len(dataset) - test_size, test_size]
lenghts

[28480, 7119]

In [7]:
sum(lenghts)

35599

In [8]:
train_dataset, val_dataset = random_split(dataset, lenghts)  #generator is not yet available in this pytorch version, generator=torch.Generator().manual_seed(42))

In [9]:
train_dataset.pickle_files # Subset class doesn't inheret properties like pickles files plus I would need to subset those as well

AttributeError: 'Subset' object has no attribute 'pickle_files'

### With sklearn train test split on dataframe with labels and then create separate Geowiki datasets for each set

In [10]:
df_subset = df[df['country'].isin(['Nigeria', 'Ghana'])]
df_subset.groupby('country').size()

country
Ghana      143
Nigeria    490
dtype: int64

In [11]:
df_train, df_test = train_test_split(df_subset, test_size=0.1, stratify=df_subset['country'])

In [12]:
df_train.groupby('country').size()

country
Ghana      129
Nigeria    440
dtype: int64

In [13]:
df_test.groupby('country').size()

country
Ghana      14
Nigeria    50
dtype: int64

In [14]:
df_train, df_test = train_test_split(df_subset, test_size=0.1)

In [15]:
df_train.groupby('country').size()

country
Ghana      128
Nigeria    441
dtype: int64

In [16]:
df_test.groupby('country').size()

country
Ghana      15
Nigeria    49
dtype: int64

In [17]:
df_train[df_train['country'] == 'Ghana']

Unnamed: 0,identifier,date,lat,lon,label,filename,country
24442,11061,2017-03-28_2018-03-28,6.952376,-0.249956,0.040,11061_2017-03-28_2018-03-28.pkl,Ghana
1719,11545,2017-03-28_2018-03-28,8.249993,-1.047660,0.032,11545_2017-03-28_2018-03-28.pkl,Ghana
17445,11102,2017-03-28_2018-03-28,6.952376,-2.249965,0.000,11102_2017-03-28_2018-03-28.pkl,Ghana
32374,10810,2017-03-28_2018-03-28,5.952372,-0.749959,0.000,10810_2017-03-28_2018-03-28.pkl,Ghana
17932,11036,2017-03-28_2018-03-28,6.851226,-0.949385,0.000,11036_2017-03-28_2018-03-28.pkl,Ghana
...,...,...,...,...,...,...,...
27521,12681,2017-03-28_2018-03-28,10.651818,-2.047575,0.290,12681_2017-03-28_2018-03-28.pkl,Ghana
16805,10597,2017-03-28_2018-03-28,5.249979,-1.348236,0.050,10597_2017-03-28_2018-03-28.pkl,Ghana
27186,11362,2017-03-28_2018-03-28,7.749991,0.351196,0.000,11362_2017-03-28_2018-03-28.pkl,Ghana
10387,10949,2017-03-28_2018-03-28,6.452374,-0.249956,0.000,10949_2017-03-28_2018-03-28.pkl,Ghana


In [18]:
df['country'].value_counts().sum()#.to_csv('geowiki_points_per_country.csv')

35486

In [19]:
df['country'].isnull().value_counts()

False    35486
True       113
Name: country, dtype: int64

In [20]:
df[df['country'].isnull()]

Unnamed: 0,identifier,date,lat,lon,label,filename,country
347,2631,2017-03-28_2018-03-28,-24.047586,151.750020,0.100000,2631_2017-03-28_2018-03-28.pkl,
462,35284,2017-03-28_2018-03-28,57.952340,-153.250027,0.000000,35284_2017-03-28_2018-03-28.pkl,
501,14337,2017-03-28_2018-03-28,13.952408,-91.250013,1.000000,14337_2017-03-28_2018-03-28.pkl,
653,21850,2017-03-28_2018-03-28,33.952410,132.250021,0.050000,21850_2017-03-28_2018-03-28.pkl,
677,10018,2017-03-28_2018-03-28,2.452356,96.250036,0.000000,10018_2017-03-28_2018-03-28.pkl,
...,...,...,...,...,...,...,...
32111,5606,2017-03-28_2018-03-28,-13.750018,126.952386,0.046667,5606_2017-03-28_2018-03-28.pkl,
32730,8796,2017-03-28_2018-03-28,-2.749968,-42.047578,0.000000,8796_2017-03-28_2018-03-28.pkl,
33334,35440,2017-03-28_2018-03-28,58.952345,-2.749968,0.680000,35440_2017-03-28_2018-03-28.pkl,
33455,9149,2017-03-28_2018-03-28,-1.547662,-52.250014,0.000000,9149_2017-03-28_2018-03-28.pkl,


TODOS:
- Figure out normalizing dict
- Support splitting dataset into train/val and update attributes (self.labels, self.pickle_files)
    - Could either create dataset and split dataset, or split dataframe with sklearn and then create separate Geowiki datasets for each subset
    - Stratify:
        - By country -> OK (just need to get rid of nans if there are any)
        - By label (need to define therhesold first)
        - By both (https://stackoverflow.com/questions/45516424/sklearn-train-test-split-on-pandas-stratify-by-multiple-columns)
- Test multihead training with Togo and then Nigeria or two countries
- Train of subset of countries
- Later maybe see a way if I could just inherent from LandTypeClassificationDataset so I don't repeat too much code