## Pre-requsites

Go through the preprocessing/downloading_satellite_imagery.ipynb notebook.

```
data/
    dhs_tfrecords_raw/
        angola_2011_00.tfrecord.gz
        ...
        zimbabwe_2015_XX.tfrecord.gz
```

## Instructions

This notebook processes the exported TFRecords as follows:
    1. Verifies that the fields in the TFRecords match the original CSV files.
    2. Splits each monolithic TFRecord file exported from Google Earth Engine into one file per record.

After running this notebook, you should a new folder (`dhs_tfrecords`) under `data/`:

```
data/
    dhs_tfrecords/
        angola_2011/
            00000.tfrecord.gz
            ...
            00229.tfrecord.gz
         ...
         zimbabwe_2015/
            00000.tfrecord.gz
            ...
            00399.tfrecord.gz
```

This notebook also calculates the mean and standard deviation of each band

## Imports and Constants

In [1]:
from typing import Iterable
from glob import glob
from pprint import pprint
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
#from tqdm.auto import tqdm

Instructions for updating:
non-resource variables are not supported in the long term


In [2]:
REQUIRED_BANDS = [
    'BLUE', 'GREEN', 'NIGHTLIGHTS', 'NIR', 'RED',
    'SWIR1', 'SWIR2', 'TEMP1']

BANDS_ORDER = [
    'BLUE', 'GREEN', 'RED', 'SWIR1', 'SWIR2', 'TEMP1', 'NIR',
    'DMSP', 'VIIRS']

DHS_SURVEY_CSV = '/cephyr/NOBACKUP/groups/globalpoverty1/data/ssl_clusters_madagascar.csv'
DHS_EXPORT_FOLDER = '/cephyr/NOBACKUP/groups/globalpoverty1/data/ssl_dhs_tfrecords_raw/'
DHS_PROCESSED_FOLDER = '../data/dhs_tfrecords'
DHS_NEW_PROCESSED_FOLDER = '/cephyr/NOBACKUP/groups/globalpoverty1/data/ssl_dhs_tfrecords_raw_processed/'

## Validate and Split Exported TFRecords

In [3]:
def process_dataset(csv_path: str, input_dir: str, processed_dir: str) -> None:
    '''
    Args
    - csv_path: str, path to CSV of DHS or LSMS clusters
    - input_dir: str, path to TFRecords exported from Google Earth Engine
    - processed_dir: str, folder where to save processed TFRecords
    '''
    df = pd.read_csv(csv_path, float_precision='high', index_col=0)
    surveys = list(df.groupby(['country', 'year']).groups.keys())  # (country, year) tuples

    #for country, year in surveys:
    country_year = 'madagascar_2020'
    print('Processing:', country_year)

    #tfrecord_paths = glob(os.path.join(input_dir, country_year + '*'))
    tfrecord_paths = get_tfrecord_paths(input_dir, country_year)
    out_dir = os.path.join(processed_dir, country_year)
    os.makedirs(out_dir, exist_ok=True)
    subset_df = df[(df['country'] == 'madagascar') & (df['year'] == 2020)].reset_index(drop=True)
    validate_and_split_tfrecords(
        tfrecord_paths=tfrecord_paths, out_dir=out_dir, df=subset_df)
        
def get_tfrecord_paths(input_dir: str, country_year: tuple) -> Iterable[str]:
    tfrecord_paths = glob(os.path.join(input_dir, country_year + '*'))
    #tfrecord_paths.sort(key=lambda tfr: int(tfr[tfr.rfind('_') + 1 : tfr.rfind('_') + 1 + 4])) # Sort by file index
    return tfrecord_paths


def validate_and_split_tfrecords(
        tfrecord_paths: Iterable[str],
        out_dir: str,
        df: pd.DataFrame
        ) -> None:
    '''Validates and splits a list of exported TFRecord files (for a
    given country-year survey) into individual TFrecords, one per cluster.

    "Validating" a TFRecord comprises of 2 parts
    1) verifying that it contains the required bands
    2) verifying that its other features match the values from the dataset CSV

    Args
    - tfrecord_paths: str, path to exported TFRecords files
    - out_dir: str, path to dir to save processed individual TFRecords
    - df: pd.DataFrame, index is sequential and starts at 0
    '''
    # Create an iterator over the TFRecords file. The iterator yields
    # the binary representations of Example messages as strings.
    options = tf.io.TFRecordOptions(tf.io.TFRecordCompressionType.GZIP)

    # cast float64 => float32 and str => bytes
    '''
    for col in df.columns:
        if df[col].dtype == np.float64:
            df[col] = df[col].astype(np.float32)
        elif df[col].dtype == object:  # pandas uses 'object' type for str
            try:
                df[col] = df[col].astype(bytes)
            except UnicodeEncodeError:
                print(df.columns, df.dtypes)
    '''
    i = 0
    
    #progbar = tqdm(total=len(df))
    
    for tfrecord_path in tfrecord_paths:
        iterator = tf.io.tf_record_iterator(tfrecord_path, options=options)
        for record_str in iterator:
            # parse into an actual Example message
            ex = tf.train.Example.FromString(record_str)
            feature_map = ex.features.feature

            # verify required bands exist
            for band in REQUIRED_BANDS:
                print(type(band))
                if band not in feature_map: print(f'Band "{band}" not in record {i} of {tfrecord_path}')

            # compare feature map values against CSV values
            '''
            csv_feats = df.loc[i, :].to_dict()
            for col, val in csv_feats.items():
                ft_type = feature_map[col].WhichOneof('kind')
                ex_val = feature_map[col].__getattribute__(ft_type).value[0]
                #assert val == ex_val, f'Expected {col}={val}, but found {ex_val} instead'
            '''
            # serialize to string and write to file
            out_path = os.path.join(out_dir, f'{i:05d}.tfrecord.gz')  # all surveys have < 1e6 clusters
            with tf.io.TFRecordWriter(out_path, options=options) as writer:
                writer.write(ex.SerializeToString())

            i += 1
            #progbar.update(1)
            
    #progbar.close()


In [None]:
process_dataset(
    csv_path=DHS_SURVEY_CSV,
    input_dir=DHS_EXPORT_FOLDER,
    processed_dir=DHS_NEW_PROCESSED_FOLDER)

Processing: madagascar_2020
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<class 'str'>
<clas

In [28]:
df = pd.read_csv(DHS_SURVEY_CSV, float_precision='high', index_col=False)
len(df)

57195

In [33]:
df[df['country'] == 'senegal'].groupby(['year']).size()

year
1992     50
1993    169
1997    270
2005    351
2008    204
2009    114
2010    161
2011    219
2012     78
2013    121
2015    214
2019    214
dtype: int64