In [1]:
#!/usr/bin/env python
# coding: utf-8

In [2]:
import argparse
import csv
import json
import numpy as np
import os
import pandas as pd
import rasterio
import tensorflow as tf
from glob import glob
from tqdm import tqdm
from random import choices

In [3]:
# Path to the BigEarthNet extracted files
eurosat_path = '/workspace/app/data/raw/EuroSat/fulldata/'

In [4]:
# Models folder is already checkin. No need to download the models
eurosat_folder = '/workspace/app/data/raw/eurosat-models/'

In [5]:
# Stores the TFRecords
out_folder = '/workspace/app/data/processed/EuroSat'

In [6]:
if not os.path.exists(eurosat_path):
    print('folder', eurosat_path, 'does not exist')
    print('Downloading Data...')
    # Downloads the data from EuroSat website
    # os.system("curl http://madm.dfki.de/files/sentinel/EuroSAT.zip -o /data/raw/eurosat_rgb.zip")
#     os.system("unzip /data/raw/eurosat_rgb.zip -d /data/raw")
#     os.mkdir(eurosat_path)
#     os.rename("/data/raw/ds/..../*)   

In [7]:
if not os.path.exists(eurosat_folder):
    print('ERROR: folder', eurosat_folder, 'does not exist')
    os.mkdir(eurosat_folder)

In [8]:
if not os.path.exists(out_folder):
    print('ERROR: folder', out_folder, 'does not exist')
    os.mkdir(out_folder)

In [9]:
print(f'Using Python Version: {pd.__version__}')
print(f'Using TensorFlow Version: {tf.__version__}')

Using Python Version: 1.1.5
Using TensorFlow Version: 2.3.0


In [10]:
cont = 0
label_list = os.listdir(eurosat_path)
label_indices = {'original_labels':{}}
for lbl in label_list:
    label_indices['original_labels'][lbl] = cont
    cont += 1

In [11]:
def prep_example_eurosat(bands, original_labels, original_labels_multi_hot, patch_name):
    return tf.train.Example(
            features=tf.train.Features(
                feature={
                    'B01': tf.train.Feature(
                        int64_list=tf.train.Int64List(value=np.ravel(bands['B01']))),
                    'B02': tf.train.Feature(
                        int64_list=tf.train.Int64List(value=np.ravel(bands['B02']))),
                    'B03': tf.train.Feature(
                        int64_list=tf.train.Int64List(value=np.ravel(bands['B03']))),
                    'B04': tf.train.Feature(
                        int64_list=tf.train.Int64List(value=np.ravel(bands['B04']))),
                    'B05': tf.train.Feature(
                        int64_list=tf.train.Int64List(value=np.ravel(bands['B05']))),
                    'B06': tf.train.Feature(
                        int64_list=tf.train.Int64List(value=np.ravel(bands['B06']))),
                    'B07': tf.train.Feature(
                        int64_list=tf.train.Int64List(value=np.ravel(bands['B07']))),
                    'B08': tf.train.Feature(
                        int64_list=tf.train.Int64List(value=np.ravel(bands['B08']))),
                    'B8A': tf.train.Feature(
                        int64_list=tf.train.Int64List(value=np.ravel(bands['B8A']))),
                    'B09': tf.train.Feature(
                        int64_list=tf.train.Int64List(value=np.ravel(bands['B09']))),
                    'B11': tf.train.Feature(
                        int64_list=tf.train.Int64List(value=np.ravel(bands['B11']))),
                    'B12': tf.train.Feature(
                        int64_list=tf.train.Int64List(value=np.ravel(bands['B12']))),
                    'original_labels': tf.train.Feature(
                        bytes_list=tf.train.BytesList(
                            value=[i.encode('utf-8') for i in original_labels])),
                    'original_labels_multi_hot': tf.train.Feature(
                        int64_list=tf.train.Int64List(value=original_labels_multi_hot)),
                    'patch_name': tf.train.Feature(
                        bytes_list=tf.train.BytesList(value=[patch_name.encode('utf-8')]))
                }))

In [12]:
bands_l =['B01','B02','B03','B04','B05','B06','B07','B08',
          'B09','B10','B11','B12','B8A']

# Split options [Train = 0, Test= 1, Validation= 2]
split_option = [0,1,2]
# Split Probabilities [Train = 50%, Test= 25%, Validation= 25%]
weights = [0.5,0.25, 0.25]

# TFRecords Writers
TFRec_writer_train = tf.io.TFRecordWriter(os.path.join(out_folder, 'train.tfrecord'))
TFRec_writer_test = tf.io.TFRecordWriter(os.path.join(out_folder, 'test.tfrecord'))
TFRec_writer_val = tf.io.TFRecordWriter(os.path.join(out_folder, 'val.tfrecord'))
patch_name = out_folder
for tifile in label_list:
    # create Folder
    print('\n Label: {} \n'.format(tifile))

    try:
        os.mkdir(patch_name)
    except:
        print('Folder {} exist already'.format(tifile))
        
    # create labels
    original_labels = [tifile]
    
    # hot encode label
    original_labels_multi_hot = np.zeros(len(label_list),dtype=int)
    lidx = label_indices['original_labels'][tifile]
    original_labels_multi_hot[lidx] = 1

    # loop in the folder
    files_list = os.listdir(os.path.join(eurosat_path,tifile))
   
    # write holder tfrecord
    progress_bar = tf.keras.utils.Progbar(target = len(files_list))
    for findex,fex in enumerate(files_list):
        sname = fex.split('_')[1][:-4]
        band_ds = rasterio.open(os.path.join(eurosat_path,tifile,fex))
        bands = {}
        for idx in range(13): # 13 bands
            if bands_l[idx] == 'B10':
                continue
            bands[bands_l[idx]] = np.array(band_ds.read(idx+1))

        # prep example dev example
        example = prep_example_eurosat(bands,
                                       original_labels, 
                                       original_labels_multi_hot, 
                                       patch_name)
        
        pick = choices(split_option,weights)[0]
        if pick == 0:
            TFRec_writer_train.write(example.SerializeToString()) # This is the full path to tfrecord train
        elif pick == 1:
            TFRec_writer_test.write(example.SerializeToString()) # This is the full path to tfrecord test
        else:
            TFRec_writer_val.write(example.SerializeToString()) # This is the full path to tfrecord val           
            
        progress_bar.update(findex)

print('Completed!!!')


 Label: Pasture 

Folder Pasture exist already
 Label: Industrial 

Folder Industrial exist already
 Label: PermanentCrop 

Folder PermanentCrop exist already
 Label: AnnualCrop 

Folder AnnualCrop exist already
 Label: Highway 

Folder Highway exist already
 Label: HerbaceousVegetation 

Folder HerbaceousVegetation exist already
 Label: Residential 

Folder Residential exist already
 Label: Forest 

Folder Forest exist already
 Label: River 

Folder River exist already
 Label: SeaLake 

Folder SeaLake exist already
