In [1]:
import numpy as np
from tfrecords2numpy import TFRecordsParser
from tfrecords2numpy import TFRecordsElevation
import os
import pickle
from tqdm import tqdm
import pandas as pd
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.preprocessing import MinMaxScaler
import webdataset as wds
import xarray as xr
import h5py
import random

In [2]:
PATH_TO_FORCING = "/glade/work/yiwenz/TransferLearning/forcing_daily.pkl"
PATH_TO_DATA = "/glade/scratch/yiwenz/Data_TFRecord_Daily/"
PATH_TO_STORE = "/glade/scratch/priyamm/transfer_learning_data/phoenix_data.h5"
PATH_TO_LATLONG = "/glade/u/home/yiwenz/TransferLearning/modis_lon_lat.csv"
PATH_TO_ELEVATIONS = "/glade/work/yiwenz/AWS3D30_cropped.tfrecord"

channel_choices = ['Red', 'Green', 'Blue', "NIR", "SWIR1"]

In [22]:
def inverse_distance_mean():
    latlong = pd.read_csv(PATH_TO_LATLONG, usecols=["lon", "lat"])
    array = latlong.values
    dist = euclidean_distances(array, array)
    weights_idx_dict = {}
    for idx in range(len(dist)):
        sort_idx = dist[idx, :].argsort()[1:8]
        closest = dist[idx, sort_idx]**-1
        closest = closest / closest.sum()
        weights_idx_dict[idx] = {"weights":closest, "idx":sort_idx}

    return weights_idx_dict

def extract_data(path):
    """
    Extract all data from all TFRecords files and stores as pickled tuples in the format (image, label)
    :return:
    Each Sample is of Array Size 8721. For a single sample arr, this is how we can access our data:
    arr[:7] -> 7 Forcing Variables
    arr[7] -> Month Indicator of Data
    arr[8] -> LST Target
    arr[9:].reshape(8,33,33) -> Sattelite Image Tensor 
    """
    tot_samples = 0
    first_batch = True
    with h5py.File(path, "a") as hp5:
        center_weight = 1
        outside_weight = 1 - center_weight
        weights_idx_dict = inverse_distance_mean()
        elevations_dict = TFRecordsElevation(filepath=PATH_TO_ELEVATIONS).tfrecrods2numpy()
        with open(PATH_TO_FORCING, 'rb') as f:
            forcing_data = pickle.load(f)

        avail_dates = list(forcing_data.keys())

        for root, dirs, files in os.walk(PATH_TO_DATA):
            random.shuffle(files)
            for file in tqdm(files, desc="Total Progress"):
                samples = []
                path_to_tf = os.path.join(PATH_TO_DATA, file)
                file_date = file.split(".")[0]
                records = TFRecordsParser(path_to_tf, channels=channel_choices).tfrecrods2numpy()

                if file_date in avail_dates:
                    month = int(file_date[2:4])
                    for idx, (features, lst) in enumerate(records):

                        if (lst is not False) and ((features!=-9999).all() == True):
                            if outside_weight != 0:
                                weights = np.array(weights_idx_dict[idx]["weights"]) * outside_weight
                                min_idx = weights_idx_dict[idx]["idx"]
                                surround_features = np.array([records[i][0] for i in min_idx])
                                weighted_avg = np.average(surround_features, weights=weights, axis=0)
                                features = np.average([features, weighted_avg], weights=[center_weight, outside_weight], axis=0)
                            NIR_dn = (features[3]+0.2)/2.75e-05
                            SWIR1_dn = (features[4]+0.2)/2.75e-05
                            RED_dn = (features[0]+0.2)/2.75e-05

                            features_ndbi = ((SWIR1_dn-NIR_dn)/(SWIR1_dn+NIR_dn)).reshape(-1,33,33)
                            features_ndvi = ((NIR_dn-RED_dn)/(NIR_dn+RED_dn)).reshape(-1,33,33)
                            features = np.concatenate([features,features_ndbi,features_ndvi],axis=0)
                            elevations = elevations_dict[idx].reshape(-1,33,33)
                            features = np.vstack([features, elevations])
                            forcing = forcing_data[file_date][idx]

                            sample_image = features.flatten()
                            month = np.array(month).reshape(-1,)
                            lst = np.array(lst).reshape(-1,)
                    
                            ex_array = np.concatenate((forcing, month, lst, sample_image))
                            
#                             VALIDATION CHECKS
#                             reb_forcing = ex_array[:7]
#                             reb_month = ex_array[7]
#                             reb_lst = ex_array[8]
#                             reb_image = ex_array[9:].reshape(8,33,33)
                            
#                             print(reb_forcing == forcing)
#                             print(reb_month == month)
#                             print(reb_lst == lst)
#                             print((reb_image == features).all())
                            
                            samples.append(ex_array)
                samples = np.array(samples)
                new_samples_num = samples.shape[0]
                if new_samples_num > 0:
                    ## ONLY APPEND IF DATA EXISTS ##
                    tot_samples += new_samples_num
                    np.random.shuffle(samples)
                    _, feature_length = samples.shape
                    assert(feature_length == 8721)
                    if first_batch:
                        hdf5_dataset = hp5.create_dataset('phoenix', (new_samples_num, feature_length), maxshape=(None, feature_length), dtype='float32')
                        hdf5_dataset[:] = samples
                        first_batch = False

                    hdf5_dataset.resize(tot_samples, axis=0)
                    hdf5_dataset[-new_samples_num:] = samples
    return "DONE"
                

In [None]:
extract_data(PATH_TO_STORE)

Total Progress: 100%|██████████| 6024/6024 [6:15:44<00:00,  3.74s/it]  


In [None]:
f = h5py.File(PATH_TO_STORE, 'r')