In [None]:
# default_exp data.tfrecord

In [None]:
# export
import os 
import os.path as path
from typing import Tuple

import dateutil.parser as parser
import numpy as np
import pandas as pd
import skimage
import tensorflow as tf
from osgeo import gdal

import airathon.data as data
import airathon.paths as paths
from airathon.model.modis import load_modis

from tqdm import tqdm


Dataset = tf.data.Dataset

In [None]:
import matplotlib.pyplot as plt
from IPython.display import display

In [None]:
tf.__version__

'2.7.0'

# TF Record IO

## Loading Metadata

In [None]:
# export 
grid_metadata = pd.read_csv(path.join(
    paths.dataset_metadata(), "grid_metadata.csv"))


In [None]:
# export
def get_grid_data(metadata: pd.DataFrame, grid_id: str) -> pd.DataFrame:
    return metadata[metadata["grid_id"] == grid_id]

In [None]:
# export 
satellite_metadata = pd.read_csv(path.join(
    paths.dataset_metadata(), "satellite_metadata.csv"))

satellite_metadata['Date'] = pd.to_datetime(
    satellite_metadata['time_end'], format='%Y-%m-%d')

In [None]:
print(set(satellite_metadata.loc[:, "location"]))
display(satellite_metadata.head())
print(f"size = {len(satellite_metadata)}")

{'tpe', 'dl', 'la'}


Unnamed: 0,granule_id,time_start,time_end,product,location,split,us_url,eu_url,as_url,cksum,granule_size,Date
0,20180201T191000_maiac_la_0.hdf,2018-02-01T17:25:00.000Z,2018-02-01 19:10:00+00:00,maiac,la,train,s3://drivendata-competition-airathon-public-us...,s3://drivendata-competition-airathon-public-eu...,s3://drivendata-competition-airathon-public-as...,911405771,10446736,2018-02-01 19:10:00+00:00
1,20180202T195000_maiac_la_0.hdf,2018-02-02T18:05:00.000Z,2018-02-02 19:50:00+00:00,maiac,la,train,s3://drivendata-competition-airathon-public-us...,s3://drivendata-competition-airathon-public-eu...,s3://drivendata-competition-airathon-public-as...,2244451908,11090180,2018-02-02 19:50:00+00:00
2,20180203T203000_maiac_la_0.hdf,2018-02-03T17:10:00.000Z,2018-02-03 20:30:00+00:00,maiac,la,train,s3://drivendata-competition-airathon-public-us...,s3://drivendata-competition-airathon-public-eu...,s3://drivendata-competition-airathon-public-as...,3799527997,12468482,2018-02-03 20:30:00+00:00
3,20180204T194000_maiac_la_0.hdf,2018-02-04T17:55:00.000Z,2018-02-04 19:40:00+00:00,maiac,la,train,s3://drivendata-competition-airathon-public-us...,s3://drivendata-competition-airathon-public-eu...,s3://drivendata-competition-airathon-public-as...,4105997844,13064424,2018-02-04 19:40:00+00:00
4,20180205T202000_maiac_la_0.hdf,2018-02-05T17:00:00.000Z,2018-02-05 20:20:00+00:00,maiac,la,train,s3://drivendata-competition-airathon-public-us...,s3://drivendata-competition-airathon-public-eu...,s3://drivendata-competition-airathon-public-as...,1805072340,12549313,2018-02-05 20:20:00+00:00


size = 7721


In [None]:
# export
def get_satellite_meta(
        metadata, datetime: str, location: str, datatype: str, split: str):
    if location == "Delhi":
        location = "dl"
    elif location == "Taipei":
        location = "tpe"
    else:
        location = "la"

    # filtering
    metadata = metadata[metadata['location'] == location]
    metadata = metadata[metadata['product'] == datatype]
    metadata = metadata[metadata['split'] == split]
    dateobject = parser.parse(datetime)

    return metadata.loc[(metadata['Date'].dt.month == dateobject.month) &
                        (metadata['Date'].dt.day == dateobject.day) &
                        (metadata['Date'].dt.year <= dateobject.year)]

In [None]:
# export
maiac_subset_names = [f"sds_{i}" for i in range(0, 13)]
maiac_subset_indices = [0, 3, 4, 8]

def fetch_subset(year: str, granule_id: str, split: str) -> dict:
    modis = load_modis(year, granule_id, split)
    subdataset = modis.GetSubDatasets()  # List[tuple]

    features = dict()
    rasters = list()

    for index in maiac_subset_indices:
        url, _ = subdataset[index]
        raster = gdal.Open(url)
        raster = raster.ReadAsArray()

        raster = np.swapaxes(raster, 0, 2)
        rasters.append(raster)

        raster = skimage.transform.resize(
            raster, 
            output_shape=(240, 240, 4), 
            anti_aliasing=False)
        
        features[maiac_subset_names[index]] = raster.astype(np.float32)

    return features

In [None]:
subset_features = fetch_subset(
    year="2018",
    granule_id="20180201T191000_maiac_la_0.hdf", 
    split="train")

subset_features.keys()

dict_keys(['sds_0', 'sds_3', 'sds_4', 'sds_8'])

In [None]:
df = pd.read_csv(
    path.join(paths.dataset_metadata(), "train_labels.csv"))

df.head()

Unnamed: 0,datetime,grid_id,value
0,2018-02-01T08:00:00Z,3S31A,11.4
1,2018-02-01T08:00:00Z,A2FBI,17.0
2,2018-02-01T08:00:00Z,DJN0F,11.1
3,2018-02-01T08:00:00Z,E5P9N,22.1
4,2018-02-01T08:00:00Z,FRITQ,29.8


## Creating TF Record

In [None]:
# export
def image_feature(image: np.ndarray):
    image = image.flatten()
    return tf.train.Feature(float_list=tf.train.FloatList(value=image))


def float_feature(v: float):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[v]))


def string_feature(s: str):
    bs = bytes(s, "UTF-8")
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[bs]))

In [None]:
# export
def label_to_example(label: dict) -> tf.train.Example:
    feature = {
        "location": string_feature(label["location"]),
        "grid_id": string_feature(label["grid_id"]),
        "datetime": string_feature(label["datetime"]),
        "value": float_feature(label["value"])
    }

    del label["location"]
    del label["grid_id"]
    del label["datetime"]
    del label["value"]

    for name, image in label.items():
        feature[name] = image_feature(image) 

    example = tf.train.Example(features=tf.train.Features(feature=feature))
    
    return example

In [None]:
# export
def series_to_locations(series):
    datetime = series["datetime"]  # type: str
    grid_id = series["grid_id"]
    grid_data = get_grid_data(grid_metadata, grid_id)
    location = grid_data.iloc[0]["location"]

    if location == "Delhi":
        location = "dl"
    elif location == "Taipei":
        location = "tpe"
    else:
        location = "la"

    return location


def series_to_subset_infos(series, split = "train"):
    datetime = series["datetime"]  # type: str
    grid_id = series["grid_id"]
    location = series["location"]

    satellite_data = get_satellite_meta(
        satellite_metadata,
        datetime,
        location,
        "maiac",  # or 'misr'
        split)

    infos = list()

    for i in range(len(satellite_data)):
        granule_id = satellite_data.iloc[i]['granule_id']
        time_end = parser.parse(satellite_data.iloc[i]["time_end"])

        infos.append((granule_id, time_end))

    return infos 

def row_to_label(row, split):
    datetime = row["datetime"]  # type: str
    grid_id = row["grid_id"]
    location = row["location"]

    subset_infos = row["subset_info"]

    images = {}

    for index in maiac_subset_indices:
        name = maiac_subset_names[index]
        images[name] = list()

    for granule_id, time_end in subset_infos:
        new_images = fetch_subset(str(time_end.year), granule_id, split)

        for name, image in new_images.items():
            images[name].append(image)

    for index in maiac_subset_indices:
        name = maiac_subset_names[index]
        images_name = np.array(images[name])
        images[name] = images_name.mean(axis=0)

    label = {
        "location": location,
        "grid_id": grid_id,
        "datetime": datetime,
        "value": row["value"],
        **images
    } 

    return label 


In [None]:
test_df = df.sample(n=10, axis=0, random_state=0)
test_df["location"] = test_df.apply(series_to_locations, axis=1)
test_df["subset_info"] = test_df.apply(series_to_subset_infos, axis=1, split="train")

label = row_to_label(test_df.iloc[0], "train")
label.keys()

dict_keys(['location', 'grid_id', 'datetime', 'value', 'sds_0', 'sds_3', 'sds_4', 'sds_8'])

In [None]:
# export
def create_tfrecord(dataframe: pd.DataFrame, split: str, path: str):
    dataframe = dataframe.copy()

    print("fetching locations...")
    dataframe["location"] = dataframe.apply(series_to_locations, axis=1)

    print("fetching subset_infos...")
    dataframe["subset_info"] = dataframe.apply(series_to_subset_infos, axis=1, split=split)

    with tf.io.TFRecordWriter(path) as writer, tqdm(total=len(dataframe)) as progress_bar:
        for _, row in dataframe.iterrows():
            label = row_to_label(row, split)
            example = label_to_example(label)

            writer.write(example.SerializeToString())

            progress_bar.update(1)


In [None]:
create_tfrecord(df.iloc[0:3], "train", "test.tfrecord")

fetching locations...
fetching subset_infos...


100%|██████████| 3/3 [00:02<00:00,  1.19it/s]


## Reading TF Record

In [None]:
# export
def _decode(raw_person):
    spec = {
        "value": tf.io.FixedLenFeature([], dtype=tf.float32),
        "location": tf.io.FixedLenFeature([], dtype=tf.string),
        "datetime": tf.io.FixedLenFeature([], dtype=tf.string),
        "grid_id": tf.io.FixedLenFeature([], dtype=tf.string),
        # "sds_0": tf.io.FixedLenFeature((240, 240, 4), dtype=tf.float32),
        # "sds_1": tf.io.FixedLenFeature((240, 240, 4), dtype=tf.float32),
        # "sds_2": tf.io.FixedLenFeature((240, 240, 4), dtype=tf.float32),
        # "sds_3": tf.io.FixedLenFeature((240, 240, 4), dtype=tf.float32),
        # "sds_4": tf.io.FixedLenFeature((240, 240, 4), dtype=tf.float32),
        # "sds_5": tf.io.FixedLenFeature((240, 240, 4), dtype=tf.float32),
        # "sds_6": tf.io.FixedLenFeature((240, 240, 4), dtype=tf.float32),
        # "sds_7": tf.io.FixedLenFeature((240, 240, 4), dtype=tf.float32),
        # "sds_8": tf.io.FixedLenFeature((240, 240, 4), dtype=tf.float32),
        # "sds_9": tf.io.FixedLenFeature((240, 240, 4), dtype=tf.float32),
        # "sds_10": tf.io.FixedLenFeature((240, 240, 4), dtype=tf.float32),
        # "sds_11": tf.io.FixedLenFeature((240, 240, 4), dtype=tf.float32),
        # "sds_12": tf.io.FixedLenFeature((240, 240, 4), dtype=tf.float32),
    }

    for index in maiac_subset_indices:
        name = maiac_subset_names[index]
        spec[name] = tf.io.FixedLenFeature((240, 240, 4), dtype=tf.float32)

    return tf.io.parse_single_example(raw_person, spec)

In [None]:
# export
def load_tfrecord(path: str):
    return tf.data.TFRecordDataset(path).map(_decode)

In [None]:
read_ds = load_tfrecord("test.tfrecord")
read_ds.element_spec

{'datetime': TensorSpec(shape=(), dtype=tf.string, name=None),
 'grid_id': TensorSpec(shape=(), dtype=tf.string, name=None),
 'location': TensorSpec(shape=(), dtype=tf.string, name=None),
 'sds_0': TensorSpec(shape=(240, 240, 4), dtype=tf.float32, name=None),
 'sds_3': TensorSpec(shape=(240, 240, 4), dtype=tf.float32, name=None),
 'sds_4': TensorSpec(shape=(240, 240, 4), dtype=tf.float32, name=None),
 'sds_8': TensorSpec(shape=(240, 240, 4), dtype=tf.float32, name=None),
 'value': TensorSpec(shape=(), dtype=tf.float32, name=None)}

In [None]:
for element in read_ds.take(2):
    print(element.keys())
    print(element["sds_0"].shape)
    print()

dict_keys(['datetime', 'grid_id', 'location', 'sds_0', 'sds_3', 'sds_4', 'sds_8', 'value'])
(240, 240, 4)

dict_keys(['datetime', 'grid_id', 'location', 'sds_0', 'sds_3', 'sds_4', 'sds_8', 'value'])
(240, 240, 4)

