In [1]:
import numpy as np
from tfrecords2numpy_qc import TFRecordsParser
from tfrecords2numpy_qc import TFRecordsElevation
import os
import pickle
from tqdm import tqdm
import pandas as pd
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.preprocessing import MinMaxScaler
import webdataset as wds
import xarray as xr
import h5py
import random
from collections import Counter

In [5]:
PATH_TO_FORCING = "/glade/scratch/yiwenz/CESM_0.125_US_cities/forcing_Phoenix_Mesa_AZ.pkl"
PATH_TO_DATA = "/glade/scratch/yiwenz/Data_TFRecord_Daily/"
PATH_TO_STORE = "/glade/scratch/yiwenz/transfer_learning_hdf5/phoenix_data.h5"
PATH_TO_LATLONG = "/glade/scratch/yiwenz/modis_lat_lon_US_cities/modis_grid_Phoenix.csv"
PATH_TO_ELEVATIONS = "/glade/work/yiwenz/AWS3D30_cropped.tfrecord"

channel_choices = ['Red', 'Green', 'Blue', "NIR", "SWIR1"]
threshold_cloud=0.15

In [6]:
def inverse_distance_mean():
    latlong = pd.read_csv(PATH_TO_LATLONG, usecols=["lon", "lat"])
    array = latlong.values
    dist = euclidean_distances(array, array)
    weights_idx_dict = {}
    for idx in range(len(dist)):
        sort_idx = dist[idx, :].argsort()[1:8]
        closest = dist[idx, sort_idx]**-1
        closest = closest / closest.sum()
        weights_idx_dict[idx] = {"weights":closest, "idx":sort_idx}

    return weights_idx_dict

def extract_data(path):
    """
    Extract all data from all TFRecords files and stores as pickled tuples in the format (image, label)
    :return:
    Each Sample is of Array Size 8721. For a single sample arr, this is how we can access our data:
    arr[:7] -> 7 Forcing Variables
    arr[7] -> Month Indicator of Data
    arr[8] -> LST Target
    arr[9:].reshape(8,33,33) -> Sattelite Image Tensor 
    """
    tot_samples = 0
    first_batch = True
    with h5py.File(path, "a") as hp5:
        center_weight = 1
        outside_weight = 1 - center_weight
        weights_idx_dict = inverse_distance_mean()
        elevations_dict = TFRecordsElevation(filepath=PATH_TO_ELEVATIONS).tfrecrods2numpy()
        with open(PATH_TO_FORCING, 'rb') as f:
            forcing_data = pickle.load(f)

        avail_dates = list(forcing_data.keys())

        for root, dirs, files in os.walk(PATH_TO_DATA):
            random.shuffle(files)
            for file in tqdm(files, desc="Total Progress"):
                samples = []
                file_date = file.split(".")[0] 
                if file_date in avail_dates:
                    print(file_date)
                    path_to_tf = os.path.join(PATH_TO_DATA, file)
                    records = TFRecordsParser(path_to_tf, channels=channel_choices).tfrecrods2numpy()
                    all_pixels = len(records)
                    cloud_pixels = Counter(elem[1] for elem in records)[False]
                    cloud_pct = cloud_pixels/all_pixels
                    month = int(file_date[2:4])
                    if cloud_pct<=threshold_cloud:
                        print(f'Meet cloud threshold. Cloud cover pct: {cloud_pct}')
                        for idx, (features, lst, qc_flag) in enumerate(records):
                            if (qc_flag is not False) and (qc_flag>>1&1)==0 and (qc_flag>>7&1)==0 and (lst is not False) and ((features!=-9999).all() == True):
                                if outside_weight != 0:
                                    weights = np.array(weights_idx_dict[idx]["weights"]) * outside_weight
                                    min_idx = weights_idx_dict[idx]["idx"]
                                    surround_features = np.array([records[i][0] for i in min_idx])
                                    weighted_avg = np.average(surround_features, weights=weights, axis=0)
                                    features = np.average([features, weighted_avg], weights=[center_weight, outside_weight], axis=0)
                                NIR_dn = (features[3]+0.2)/2.75e-05
                                SWIR1_dn = (features[4]+0.2)/2.75e-05
                                RED_dn = (features[0]+0.2)/2.75e-05

                                features_ndbi = ((SWIR1_dn-NIR_dn)/(SWIR1_dn+NIR_dn)).reshape(-1,33,33)
                                features_ndvi = ((NIR_dn-RED_dn)/(NIR_dn+RED_dn)).reshape(-1,33,33)
                                features = np.concatenate([features,features_ndbi,features_ndvi],axis=0)
                                elevations = elevations_dict[idx].reshape(-1,33,33)
                                features = np.vstack([features, elevations])
                                forcing = forcing_data[file_date][idx]

                                sample_image = features.flatten()
                                month = np.array(month).reshape(-1,)
                                lst = np.array(lst).reshape(-1,)

                                ex_array = np.concatenate((forcing, month, lst, sample_image))

    #                             VALIDATION CHECKS
    #                             reb_forcing = ex_array[:7]
    #                             reb_month = ex_array[7]
    #                             reb_lst = ex_array[8]
    #                             reb_image = ex_array[9:].reshape(8,33,33)

    #                             print(reb_forcing == forcing)
    #                             print(reb_month == month)
    #                             print(reb_lst == lst)
    #                             print((reb_image == features).all())

                                samples.append(ex_array)
                samples = np.array(samples)
                new_samples_num = samples.shape[0]
                if new_samples_num > 0:
                    ## ONLY APPEND IF DATA EXISTS ##
                    tot_samples += new_samples_num
                    np.random.shuffle(samples)
                    _, feature_length = samples.shape
                    assert(feature_length == 8721)
                    if first_batch:
                        hdf5_dataset = hp5.create_dataset('phoenix', (new_samples_num, feature_length), maxshape=(None, feature_length), dtype='float32')
                        hdf5_dataset[:] = samples
                        first_batch = False

                    hdf5_dataset.resize(tot_samples, axis=0)
                    hdf5_dataset[-new_samples_num:] = samples
    return "DONE"
                

In [None]:
extract_data(PATH_TO_STORE)

Total Progress:   0%|          | 0/6024 [00:00<?, ?it/s]

180110


Total Progress:   0%|          | 1/6024 [00:06<10:37:17,  6.35s/it]

140928


Total Progress:   0%|          | 2/6024 [00:12<10:33:49,  6.32s/it]

030202


Total Progress:   0%|          | 3/6024 [00:18<10:31:43,  6.30s/it]

150130


Total Progress:   0%|          | 4/6024 [00:25<10:31:46,  6.30s/it]

130206


Total Progress:   0%|          | 5/6024 [00:31<10:28:48,  6.27s/it]

060521


Total Progress:   0%|          | 6/6024 [00:37<10:28:20,  6.26s/it]

Meet cloud threshold. Cloud cover pct: 0.007746716066015493
051118


Total Progress:   0%|          | 7/6024 [00:44<10:34:18,  6.33s/it]

Meet cloud threshold. Cloud cover pct: 0.0006736274840013472
051209


Total Progress:   0%|          | 8/6024 [00:50<10:33:43,  6.32s/it]

Meet cloud threshold. Cloud cover pct: 0.03233411923206467
070929


Total Progress:   0%|          | 9/6024 [00:56<10:30:30,  6.29s/it]

Meet cloud threshold. Cloud cover pct: 0.07611990569215224
180321


Total Progress:   0%|          | 10/6024 [01:02<10:25:19,  6.24s/it]

110406


Total Progress:   0%|          | 11/6024 [01:08<10:23:46,  6.22s/it]

090125


Total Progress:   0%|          | 12/6024 [01:15<10:25:26,  6.24s/it]

130205


Total Progress:   0%|          | 13/6024 [01:21<10:22:05,  6.21s/it]

Meet cloud threshold. Cloud cover pct: 0.0003368137420006736
180404


Total Progress:   0%|          | 14/6024 [01:27<10:18:55,  6.18s/it]

090322


Total Progress:   0%|          | 15/6024 [01:33<10:18:01,  6.17s/it]

Meet cloud threshold. Cloud cover pct: 0.0
100724


Total Progress:   0%|          | 16/6024 [01:39<10:20:47,  6.20s/it]

Meet cloud threshold. Cloud cover pct: 0.03166049174806332
030801


Total Progress:   0%|          | 17/6024 [01:46<10:22:08,  6.21s/it]

161130


Total Progress:   0%|          | 18/6024 [01:52<10:19:05,  6.18s/it]

Meet cloud threshold. Cloud cover pct: 0.0057258336140114515
100629


Total Progress:   0%|          | 19/6024 [01:58<10:19:30,  6.19s/it]

Meet cloud threshold. Cloud cover pct: 0.010441226002020883
071120


Total Progress:   0%|          | 20/6024 [02:04<10:22:50,  6.22s/it]

071028


Total Progress:   0%|          | 21/6024 [02:10<10:20:45,  6.20s/it]

170907


Total Progress:   0%|          | 22/6024 [02:17<10:18:51,  6.19s/it]

Meet cloud threshold. Cloud cover pct: 0.0
051219


Total Progress:   0%|          | 23/6024 [02:23<10:26:48,  6.27s/it]

021215


Total Progress:   0%|          | 24/6024 [02:29<10:24:10,  6.24s/it]

061228


Total Progress:   0%|          | 25/6024 [02:35<10:22:15,  6.22s/it]

080302


Total Progress:   0%|          | 26/6024 [02:41<10:20:13,  6.20s/it]

Meet cloud threshold. Cloud cover pct: 0.0023576961940047153
021011


Total Progress:   0%|          | 27/6024 [02:48<10:17:38,  6.18s/it]

030623


Total Progress:   0%|          | 28/6024 [02:54<10:16:12,  6.17s/it]

Meet cloud threshold. Cloud cover pct: 0.0
080226


Total Progress:   0%|          | 29/6024 [03:00<10:15:49,  6.16s/it]

Meet cloud threshold. Cloud cover pct: 0.0
030129


Total Progress:   0%|          | 30/6024 [03:06<10:17:22,  6.18s/it]

Meet cloud threshold. Cloud cover pct: 0.0003368137420006736
030804


Total Progress:   1%|          | 31/6024 [03:12<10:15:21,  6.16s/it]

Meet cloud threshold. Cloud cover pct: 0.0006736274840013472
141118


Total Progress:   1%|          | 32/6024 [03:18<10:15:46,  6.17s/it]

Meet cloud threshold. Cloud cover pct: 0.0023576961940047153
160104


Total Progress:   1%|          | 33/6024 [03:25<10:14:58,  6.16s/it]

140116


Total Progress:   1%|          | 34/6024 [03:31<10:14:37,  6.16s/it]

Meet cloud threshold. Cloud cover pct: 0.015493432132030987
121130


Total Progress:   1%|          | 35/6024 [03:37<10:13:31,  6.15s/it]

140428


Total Progress:   1%|          | 36/6024 [03:43<10:12:11,  6.13s/it]

Meet cloud threshold. Cloud cover pct: 0.0
150708


Total Progress:   1%|          | 37/6024 [03:49<10:12:23,  6.14s/it]

Meet cloud threshold. Cloud cover pct: 0.011451667228022903
060617


Total Progress:   1%|          | 38/6024 [03:55<10:12:07,  6.14s/it]

Meet cloud threshold. Cloud cover pct: 0.0
090526


Total Progress:   1%|          | 39/6024 [04:01<10:15:10,  6.17s/it]

Meet cloud threshold. Cloud cover pct: 0.0
040404


Total Progress:   1%|          | 40/6024 [04:08<10:20:08,  6.22s/it]

040623


Total Progress:   1%|          | 41/6024 [04:14<10:18:34,  6.20s/it]

Meet cloud threshold. Cloud cover pct: 0.08251936679016504
070620


Total Progress:   1%|          | 42/6024 [04:20<10:17:24,  6.19s/it]

Meet cloud threshold. Cloud cover pct: 0.0
160630


Total Progress:   1%|          | 43/6024 [04:27<10:48:48,  6.51s/it]

110119


Total Progress:   1%|          | 44/6024 [04:34<10:38:39,  6.41s/it]

091225


Total Progress:   1%|          | 45/6024 [04:40<10:32:45,  6.35s/it]

Meet cloud threshold. Cloud cover pct: 0.06736274840013473
050123


Total Progress:   1%|          | 46/6024 [04:46<10:32:19,  6.35s/it]

170623


Total Progress:   1%|          | 47/6024 [04:52<10:28:17,  6.31s/it]

Meet cloud threshold. Cloud cover pct: 0.0
070118


Total Progress:   1%|          | 48/6024 [04:59<10:26:53,  6.29s/it]

Meet cloud threshold. Cloud cover pct: 0.0932974065341866
150424


Total Progress:   1%|          | 49/6024 [05:05<10:21:54,  6.25s/it]

140806


Total Progress:   1%|          | 50/6024 [05:11<10:21:53,  6.25s/it]

Meet cloud threshold. Cloud cover pct: 0.0
060912


Total Progress:   1%|          | 51/6024 [05:17<10:18:56,  6.22s/it]

Meet cloud threshold. Cloud cover pct: 0.04782755136409565
050415


Total Progress:   1%|          | 52/6024 [05:23<10:16:30,  6.19s/it]

061027


Total Progress:   1%|          | 53/6024 [05:30<10:19:01,  6.22s/it]

Meet cloud threshold. Cloud cover pct: 0.0
131223


Total Progress:   1%|          | 54/6024 [05:36<10:18:27,  6.22s/it]

030610


Total Progress:   1%|          | 55/6024 [05:42<10:15:40,  6.19s/it]

140812


Total Progress:   1%|          | 56/6024 [05:48<10:18:17,  6.22s/it]

180211


Total Progress:   1%|          | 57/6024 [05:54<10:19:39,  6.23s/it]

Meet cloud threshold. Cloud cover pct: 0.0
140603


Total Progress:   1%|          | 58/6024 [06:01<10:19:37,  6.23s/it]

080718


Total Progress:   1%|          | 59/6024 [06:07<10:20:00,  6.24s/it]

050327


Total Progress:   1%|          | 60/6024 [06:13<10:23:33,  6.27s/it]

110311


Total Progress:   1%|          | 61/6024 [06:19<10:21:02,  6.25s/it]

Meet cloud threshold. Cloud cover pct: 0.0
080910


Total Progress:   1%|          | 62/6024 [06:26<10:20:33,  6.25s/it]

151212


Total Progress:   1%|          | 63/6024 [06:32<10:21:03,  6.25s/it]

040504


Total Progress:   1%|          | 64/6024 [06:38<10:18:45,  6.23s/it]

Meet cloud threshold. Cloud cover pct: 0.0
081009


Total Progress:   1%|          | 65/6024 [06:44<10:17:37,  6.22s/it]

Meet cloud threshold. Cloud cover pct: 0.0
110326


Total Progress:   1%|          | 66/6024 [06:51<10:17:10,  6.22s/it]

Meet cloud threshold. Cloud cover pct: 0.02761872684405524
080808


Total Progress:   1%|          | 67/6024 [06:57<10:21:16,  6.26s/it]

080330


Total Progress:   1%|          | 68/6024 [07:03<10:20:17,  6.25s/it]

Meet cloud threshold. Cloud cover pct: 0.0
040701


Total Progress:   1%|          | 69/6024 [07:09<10:19:22,  6.24s/it]

Meet cloud threshold. Cloud cover pct: 0.0
180704


Total Progress:   1%|          | 70/6024 [07:15<10:17:06,  6.22s/it]

Meet cloud threshold. Cloud cover pct: 0.07342539575614684
071201


Total Progress:   1%|          | 71/6024 [07:22<10:16:07,  6.21s/it]

040512


Total Progress:   1%|          | 72/6024 [07:28<10:17:38,  6.23s/it]

Meet cloud threshold. Cloud cover pct: 0.0
070907


Total Progress:   1%|          | 73/6024 [07:34<10:15:44,  6.21s/it]

Meet cloud threshold. Cloud cover pct: 0.10946446615021893
100829


Total Progress:   1%|          | 74/6024 [07:40<10:14:50,  6.20s/it]

120413


Total Progress:   1%|          | 75/6024 [07:46<10:14:02,  6.19s/it]

040315


Total Progress:   1%|▏         | 76/6024 [07:53<10:13:58,  6.19s/it]

Meet cloud threshold. Cloud cover pct: 0.0
060110


In [None]:
f = h5py.File(PATH_TO_STORE, 'r')