In [3]:
%load_ext autoreload
%autoreload 2

from __future__ import annotations

from collections.abc import Iterable
from glob import glob
from pprint import pprint
import os
from typing import Optional

import numpy as np
import pandas as pd
import tensorflow as tf
from tqdm.auto import tqdm

#import cartopy.crs as ccrs
#import cartopy.feature as cfeature
import matplotlib.pyplot as plt

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
DHS_FOLDER = 'data/ke_dhs_tfrecords_raw'

TFRECORD_PATH_TEMPLATE = 'data/ke_dhs_tfrecords_raw/kenya_{year}.tfrecord.gz'

# Inspecting a single record in the TFRecord file

In [5]:
def get_feature_types(tfrecord_path):
    '''Gets the types and shapes of each feature in a feature_map.
    Args
    - tfrecord_path: str, path to a TFRecord file with GZIP compression
    Returns:
    - feature_types: dict, maps feature names (str) to tuple of (ft_type, ft_shape)
    '''
    options = tf.io.TFRecordOptions(compression_type = 'GZIP')
    iterator = tf.compat.v1.io.tf_record_iterator(tfrecord_path, options=options)

    # get the first Example stored in the TFRecords file
    record_str = next(iterator)
    example = tf.train.Example.FromString(record_str) 
    feature_map = example.features.feature  # get Features message within the Example

    # use the WhichOneof() method on messages with `oneof` fields to determine the type of the field
    feature_types = {}
    for name in feature_map.keys():
        ft_type = feature_map[name].WhichOneof('kind')
        ft_shape = np.array(feature_map[name].__getattribute__(ft_type).value).shape
        feature_types[name] = (ft_type, ft_shape)
        
    return feature_types

In [7]:
tfrecord_path = TFRECORD_PATH_TEMPLATE.format(year=2014)
feature_types = get_feature_types(tfrecord_path)
print(f'TFRecord path: {tfrecord_path}')
print('Features and types:')
pprint(feature_types)

TFRecord path: data/ke_dhs_tfrecords_raw/kenya_2014.tfrecord.gz
Features and types:
{'BLUE': ('float_list', (65025,)),
 'GREEN': ('float_list', (65025,)),
 'LAT': ('float_list', (65025,)),
 'LON': ('float_list', (65025,)),
 'NIGHTLIGHTS': ('float_list', (65025,)),
 'NIR': ('float_list', (65025,)),
 'RED': ('float_list', (65025,)),
 'SWIR1': ('float_list', (65025,)),
 'SWIR2': ('float_list', (65025,)),
 'TEMP1': ('float_list', (65025,)),
 'country': ('bytes_list', (1,)),
 'households': ('float_list', (1,)),
 'lat': ('float_list', (1,)),
 'lon': ('float_list', (1,)),
 'system:index': ('bytes_list', (1,)),
 'urban_rural': ('float_list', (1,)),
 'wealthpooled': ('float_list', (1,)),
 'year': ('float_list', (1,))}


# Cross checking TFrecord Files against Survey Data Files

In [35]:
def split_survey(survey_path, out_dir, year):   
    '''
    Args
    - survey_path: str, path to complete survey CSV file
    - out_dir: str, path to output directory
    '''
    os.makedirs(out_dir, exist_ok=True)
    data = pd.read_csv(survey_path, float_precision='high')
    data = data.loc[data.year == year].drop(columns = ['GID_1', 'GID_2'])
    
    survey_out_name = f'kenya_{year}.csv'
    survey_out_path = os.path.join(out_dir, survey_out_name)
        
    # save CSV: pandas uses float64 which maintains enough precision for all of our numbers
    data.to_csv(survey_out_path, index=False)

In [36]:
survey_path = 'data/ke_dhs_clusters.csv'
out_dir = 'data/ke_dhs_surveys'
split_survey(survey_path, out_dir, 2014)
split_survey(survey_path, out_dir, 2015)

In [39]:
# checks that the fields in the TFRecords match the original CSV files.

REQUIRED_BANDS = ['BLUE', 'GREEN', 'LAT', 'LON', 'NIGHTLIGHTS', 'NIR', 'RED', 'SWIR1', 'SWIR2', 'TEMP1']

def crosscheck_records(csv_path, tfrecord_path):   
    '''crosschecks individual TFRecord files (for a given country-year survey). 
    1) verifies that it contains the required bands
    2) checks whether the other features match values from the survey CSV

    Args
    - csv_path: path to csv file containing survey data
    - tfrecord_paths: path to exported TFRecords file
    - df: pd.DataFrame, index is sequential and starts at 0
    '''
    df = pd.read_csv(csv_path, float_precision='high', index_col=False)
    df = df.drop(columns=['Unnamed: 0'])

    for col in df.columns:
        if df[col].dtype == np.float64:
            df[col] = df[col].astype(np.float32)
        elif df[col].dtype == object: 
            df[col] = df[col].astype(bytes)

    i = 0
    progbar = tqdm(total=len(df))
    
    options = tf.io.TFRecordOptions(compression_type = 'GZIP')
    iterator = tf.compat.v1.io.tf_record_iterator(tfrecord_path, options=options)

    for record_str in iterator:
        example = tf.train.Example.FromString(record_str)
        feature_map = example.features.feature

        for band in REQUIRED_BANDS:
            assert band in feature_map, f'Band "{band}" not in record {i} of {tfrecord_path}'

        csv_feats = df.loc[i, :].to_dict()       
        for col, val in csv_feats.items():
            ft_type = feature_map[col].WhichOneof('kind')
            ex_val = feature_map[col].__getattribute__(ft_type).value[0] 
            if val != ex_val:
                print(f'Expected {col}={val}, but found {col}={ex_val} instead') 
                  
            i += 1
            progbar.update(1)
    progbar.close()

In [40]:
crosscheck_records(csv_path='data/ke_dhs_surveys/kenya_2014.csv', tfrecord_path=TFRECORD_PATH_TEMPLATE.format(year=2014))

  0%|          | 0/1585 [00:00<?, ?it/s]

Expected lat=-1.312427043914795, but found lat=-1.2787810564041138 instead
Expected lon=36.78540802001953, but found lon=36.75844192504883 instead
Expected wealthpooled=0.6667720079421997, but found wealthpooled=0.2383285015821457 instead
Expected households=23, but found households=22.0 instead
Expected lat=-1.254804015159607, but found lat=-1.2796460390090942 instead
Expected lon=36.91013717651367, but found lon=36.74592971801758 instead
Expected wealthpooled=1.2472656965255737, but found wealthpooled=1.0458334684371948 instead
Expected households=19, but found households=25.0 instead
Expected lat=-1.2795250415802002, but found lat=-1.2803800106048584 instead
Expected lon=36.893592834472656, but found lon=36.69709014892578 instead
Expected wealthpooled=1.137634515762329, but found wealthpooled=0.8064674735069275 instead
Expected households=21, but found households=20.0 instead
Expected lat=-1.2842650413513184, but found lat=-1.272063970565796 instead
Expected lon=36.886566162109375, 

Expected lat=-1.0343929529190063, but found lat=-1.2850940227508545 instead
Expected lon=37.36174011230469, but found lon=36.90538787841797 instead
Expected wealthpooled=-0.3726208209991455, but found wealthpooled=1.9097405672073364 instead
Expected households=21, but found households=23.0 instead
Expected urban_rural=0, but found urban_rural=1.0 instead
Expected lat=-1.1389659643173218, but found lat=-1.2842650413513184 instead
Expected lon=36.75346755981445, but found lon=36.886566162109375 instead
Expected wealthpooled=0.5871864557266235, but found wealthpooled=1.4598958492279053 instead
Expected households=25, but found households=19.0 instead
Expected urban_rural=0, but found urban_rural=1.0 instead
Expected lat=-1.2214349508285522, but found lat=-1.2840720415115356 instead
Expected lon=36.7347297668457, but found lon=36.908843994140625 instead
Expected wealthpooled=0.719716489315033, but found wealthpooled=0.7588577270507812 instead
Expected households=21, but found households=22

Expected lat=-2.2814130783081055, but found lat=-1.2613279819488525 instead
Expected lon=40.896915435791016, but found lon=36.75200653076172 instead
Expected wealthpooled=0.3369019031524658, but found wealthpooled=1.115096092224121 instead
Expected households=22, but found households=24.0 instead
Expected lat=-2.228303909301758, but found lat=-1.2627209424972534 instead
Expected lon=40.85572052001953, but found lon=36.80139923095703 instead
Expected wealthpooled=0.15950517356395721, but found wealthpooled=2.2479119300842285 instead
Expected households=19, but found households=21.0 instead
Expected urban_rural=0, but found urban_rural=1.0 instead
Expected lat=-2.4263288974761963, but found lat=-1.2829409837722778 instead
Expected lon=40.72547149658203, but found lon=36.81635284423828 instead
Expected wealthpooled=-0.25757142901420593, but found wealthpooled=1.6946938037872314 instead
Expected urban_rural=0, but found urban_rural=1.0 instead
Expected lat=-2.1076691150665283, but found la

Expected lat=0.021916000172495842, but found lat=-0.6127099990844727 instead
Expected lon=38.04643630981445, but found lon=36.46508026123047 instead
Expected wealthpooled=-0.5290536284446716, but found wealthpooled=-0.0431256964802742 instead
Expected households=22, but found households=25.0 instead
Expected lat=-0.5442590117454529, but found lat=-0.6091920137405396 instead
Expected lon=37.449073791503906, but found lon=36.59257888793945 instead
Expected wealthpooled=1.3390763998031616, but found wealthpooled=1.1875593662261963 instead
Expected households=23, but found households=21.0 instead
Expected lat=-0.4774619936943054, but found lat=-0.570946991443634 instead
Expected lon=37.44856643676758, but found lon=36.57275390625 instead
Expected wealthpooled=0.18332457542419434, but found wealthpooled=0.14348289370536804 instead
Expected households=24, but found households=25.0 instead
Expected lat=-0.43243399262428284, but found lat=-0.5157830119132996 instead
Expected lon=37.57339477539

Expected lat=1.7372039556503296, but found lat=-0.48871299624443054 instead
Expected lon=40.0569953918457, but found lon=37.00611114501953 instead
Expected wealthpooled=0.9609355926513672, but found wealthpooled=0.6038946509361267 instead
Expected households=24, but found households=25.0 instead
Expected urban_rural=1, but found urban_rural=0.0 instead
Expected lat=1.6382880210876465, but found lat=-0.5879639983177185 instead
Expected lon=39.422847747802734, but found lon=36.937259674072266 instead
Expected wealthpooled=-0.9255240559577942, but found wealthpooled=0.6378392577171326 instead
Expected households=20, but found households=25.0 instead
Expected lat=2.832750082015991, but found lat=-0.52920001745224 instead
Expected lon=40.918399810791016, but found lon=36.946807861328125 instead
Expected wealthpooled=-0.6655680537223816, but found wealthpooled=0.09650744497776031 instead
Expected households=22, but found households=24.0 instead
Expected urban_rural=1, but found urban_rural=0

Expected lat=-0.686152994632721, but found lat=-0.6287069916725159 instead
Expected lon=34.77263641357422, but found lon=37.41975402832031 instead
Expected wealthpooled=0.8522014617919922, but found wealthpooled=-0.48903489112854004 instead
Expected households=24, but found households=25.0 instead
Expected urban_rural=1, but found urban_rural=0.0 instead
Expected lat=-0.6536999940872192, but found lat=-0.6705340147018433 instead
Expected lon=34.78951644897461, but found lon=37.42005920410156 instead
Expected wealthpooled=-0.2736639976501465, but found wealthpooled=-0.4975834786891937 instead
Expected households=24, but found households=21.0 instead
Expected lat=-0.8070930242538452, but found lat=-0.6041370034217834 instead
Expected lon=34.944244384765625, but found lon=37.37736892700195 instead
Expected wealthpooled=0.3310481309890747, but found wealthpooled=-0.14451463520526886 instead
Expected urban_rural=1, but found urban_rural=0.0 instead
Expected lat=-0.9236339926719666, but foun

Expected lat=1.1787110567092896, but found lat=-0.7302809953689575 instead
Expected lon=35.159263610839844, but found lon=37.16413497924805 instead
Expected wealthpooled=-0.21344874799251556, but found wealthpooled=0.6530092358589172 instead
Expected households=24, but found households=22.0 instead
Expected urban_rural=0, but found urban_rural=1.0 instead
Expected lat=1.1089049577713013, but found lat=-0.7106530070304871 instead
Expected lon=34.970977783203125, but found lon=37.172088623046875 instead
Expected wealthpooled=-0.2660691440105438, but found wealthpooled=1.2242066860198975 instead
Expected households=24, but found households=23.0 instead
Expected urban_rural=0, but found urban_rural=1.0 instead
Expected lat=0.49400898814201355, but found lat=-0.6489030122756958 instead
Expected lon=35.73976135253906, but found lon=37.080020904541016 instead
Expected wealthpooled=0.6449058055877686, but found wealthpooled=-0.048194002360105515 instead
Expected households=21, but found househ

Expected lat=-0.6952729821205139, but found lat=-0.8377479910850525 instead
Expected lon=36.44089126586914, but found lon=37.09904479980469 instead
Expected wealthpooled=0.8302162289619446, but found wealthpooled=-0.4837827980518341 instead
Expected households=25, but found households=24.0 instead
Expected urban_rural=1, but found urban_rural=0.0 instead
Expected lat=-0.5136070251464844, but found lat=-0.8014010190963745 instead
Expected lon=36.338218688964844, but found lon=37.10893249511719 instead
Expected wealthpooled=-0.13657121360301971, but found wealthpooled=0.14615757763385773 instead
Expected households=23, but found households=22.0 instead
Expected urban_rural=0, but found urban_rural=1.0 instead
Expected lat=-0.4944239854812622, but found lat=-0.7962909936904907 instead
Expected lon=35.57129669189453, but found lon=37.09922409057617 instead
Expected wealthpooled=-0.5616186261177063, but found wealthpooled=0.03785029053688049 instead
Expected lat=-1.0151699781417847, but fou

Expected lat=0.3196989893913269, but found lat=-1.1368989944458008 instead
Expected lon=34.44491195678711, but found lon=36.768638610839844 instead
Expected wealthpooled=-0.42868340015411377, but found wealthpooled=0.6877092719078064 instead
Expected households=24, but found households=21.0 instead
Expected urban_rural=1, but found urban_rural=0.0 instead
Expected lat=0.11484400182962418, but found lat=-1.0928889513015747 instead
Expected lon=34.722816467285156, but found lon=36.79035186767578 instead
Expected wealthpooled=-0.2578733563423157, but found wealthpooled=0.5269313454627991 instead
Expected households=22, but found households=25.0 instead
Expected urban_rural=1, but found urban_rural=0.0 instead
Expected lat=0.11049800366163254, but found lat=-1.0312550067901611 instead
Expected lon=34.76728057861328, but found lon=36.84781265258789 instead
Expected wealthpooled=-0.05261940136551857, but found wealthpooled=0.5093560814857483 instead
Expected households=23, but found househol

KeyError: 1589

In [42]:
crosscheck_records(csv_path='data/ke_dhs_surveys/kenya_2015.csv', tfrecord_path=TFRECORD_PATH_TEMPLATE.format(year=2015))

  0%|          | 0/245 [00:00<?, ?it/s]

Expected lat=-0.0021780000533908606, but found lat=-1.3389370441436768 instead
Expected lon=36.364097595214844, but found lon=36.759521484375 instead
Expected wealthpooled=0.9807378649711609, but found wealthpooled=1.2465413808822632 instead
Expected households=25, but found households=17.0 instead
Expected urban_rural=0, but found urban_rural=1.0 instead
Expected lat=-0.5667009949684143, but found lat=-1.2705830335617065 instead
Expected lon=37.04641342163086, but found lon=36.96373748779297 instead
Expected wealthpooled=1.0735020637512207, but found wealthpooled=1.5495524406433105 instead
Expected households=29, but found households=24.0 instead
Expected urban_rural=0, but found urban_rural=1.0 instead
Expected lat=-0.5231860280036926, but found lat=-1.3225150108337402 instead
Expected lon=37.261741638183594, but found lon=36.90370178222656 instead
Expected wealthpooled=1.2819308042526245, but found wealthpooled=1.7608996629714966 instead
Expected households=30, but found households=

Expected lat=0.19295500218868256, but found lat=-0.899258017539978 instead
Expected lon=35.07558059692383, but found lon=37.20903396606445 instead
Expected wealthpooled=0.3301715552806854, but found wealthpooled=1.0255366563796997 instead
Expected households=25, but found households=23.0 instead
Expected lat=-0.0638590008020401, but found lat=-1.061885952949524 instead
Expected lon=36.078407287597656, but found lon=36.6513557434082 instead
Expected wealthpooled=-0.3810509145259857, but found wealthpooled=0.48833975195884705 instead
Expected households=25, but found households=27.0 instead
Expected lat=-0.741104006767273, but found lat=-1.0773639678955078 instead
Expected lon=36.511444091796875, but found lon=36.98975372314453 instead
Expected wealthpooled=0.3665660619735718, but found wealthpooled=0.534129798412323 instead
Expected households=30, but found households=28.0 instead
Expected urban_rural=1, but found urban_rural=0.0 instead
Expected lat=-1.6821240186691284, but found lat=-

KeyError: 245

# Splitting the TFRecord Files into Training, Validation and Testing Datasets

In [10]:
# Load in our dataset, shuffle the input features.
files = os.listdir(DHS_FOLDER) 
files = [f'{DHS_FOLDER}/{fn}' for fn in files] 
files

['data/ke_dhs_tfrecords_raw/kenya_2014.tfrecord.gz',
 'data/ke_dhs_tfrecords_raw/kenya_2015.tfrecord.gz']

In [11]:
# Getting the size/len of the records in each tfrecord files
options = tf.io.TFRecordOptions(compression_type = 'GZIP')

DATASET_SIZE = 0
for file in files:
    size = len([x for x in tf.compat.v1.io.tf_record_iterator(file, options=options)])
    
    DATASET_SIZE += size

DATASET_SIZE   

1830

In [12]:
# splitting the datasets into training, evaluating and testing sets
train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)

dataset = tf.data.TFRecordDataset(files, compression_type='GZIP')
dataset = dataset.shuffle(buffer_size=100)
trainset = dataset.take(train_size)
testset = dataset.skip(train_size)
valset = testset.skip(test_size)
testset = testset.take(test_size)

In [14]:
def write(dataset, outdir, outfile, compress=True):
    '''
    Writes the datasets into a TFrecord File
  
    Args: 
    dataset: Variable, dataset to write into tfrecord
    outdir: String, path in which to save the serialized dataset
    outfile: String, name in which to save the serialized tfrecord file
    compress: Boolean, whether or not to apply gzip compression
    Returns: None
    '''
    os.makedirs(outdir, exist_ok=True)
    outpath = os.path.join(outdir, outfile)

    if compress == True:
        writer = tf.data.experimental.TFRecordWriter(outpath, compression_type='GZIP')
        writer.write(dataset)
        
    else:
        writer = tf.data.experimental.TFRecordWriter(outpath)
        writer.write(dataset)


In [15]:
outdir = 'data/datasets'

write(dataset=trainset, outdir=outdir, outfile='trainset.tfrecord.gz', compress=True)
write(dataset=valset, outdir=outdir, outfile='valset.tfrecord.gz', compress=True)
write(dataset=testset, outdir=outdir, outfile='testset.tfrecord.gz', compress=True)

In [None]:
feature_types = get_feature_types('data/datasets/valset.tfrecord.gz')
print('Features and types:')
pprint(feature_types)