In [105]:
# import the needed modules
import os, sys, pickle, multiprocessing
import numpy as np, pandas as pd
import rasterio
from pathlib import Path
from collections import OrderedDict
from tqdm.auto import tqdm
from extract_data import extract_s2_train

In [106]:
# Set the directories
OUTPUT_DIR = '../data'

OUTPUT_DIR = f'{OUTPUT_DIR}/train'
os.makedirs(OUTPUT_DIR,exist_ok=True)
OUTPUT_DIR_BANDS = f'{OUTPUT_DIR}/bands-raw' 
os.makedirs(OUTPUT_DIR_BANDS,exist_ok=True)
DOWNLOAD_S2 = OrderedDict({
    'B01': False,
    'B02': True, #Blue
    'B03': True, #Green
    'B04': True, #Red
    'B05': False,
    'B06': False,
    'B07': False,
    'B08': True, #NIR
    'B8A': False, #NIR2
    'B09': False,
    'B11': True, #SWIR1
    'B12': True, #SWIR2
    'CLM': True
})

In [107]:
# Load the data
df_train = pd.read_csv(f'{OUTPUT_DIR}/train_data.csv')
df_train['date'] = df_train.datetime.astype(np.datetime64)
bands = [k for k,v in DOWNLOAD_S2.items() if v==True]

In [4]:
# Create a sorted dataframe by the tile ids
tile_ids = sorted(df_train.tile_id.unique())
print(f'extracting data from {len(tile_ids)} tiles for bands {bands}')

# Check the number of CPU cores
num_processes = multiprocessing.cpu_count()
print(f'processesing on : {num_processes} cpus')

# Create a pool of processes equal to the number of cores
pool = multiprocessing.Pool(num_processes)
# Calculate the number of tiles each core must process
tiles_per_process = len(tile_ids) / num_processes
# Create the a number of tile id batches equal to the number of cores
batches = []
for num_process in range(1, num_processes + 1):
    start_index = (num_process - 1) * tiles_per_process + 1
    end_index = num_process * tiles_per_process
    start_index = int(start_index)
    end_index = int(end_index)
    sublist = tile_ids[start_index - 1:end_index]
    batches.append((sublist,))
    print(f"Task # {num_process} process tiles {len(sublist)}")

# Set up the processes with the extract function and the given tile id batch 
results = []
for batch in batches:
    results.append(pool.apply_async(extract_s2_train, args=batch))

# Start the processes and catch the results
all_results = []
for result in results:
    df = result.get()
    all_results.append(df)

extracting data from 2650 tiles for bands ['B02', 'B03', 'B04', 'B08', 'B11', 'B12', 'CLM']
processesing on : 8 cpus
Task # 1 process tiles 331
Task # 2 process tiles 331
Task # 3 process tiles 331
Task # 4 process tiles 332
Task # 5 process tiles 331
Task # 6 process tiles 331
Task # 7 process tiles 331
Task # 8 process tiles 332


100%|██████████| 331/331 [1:19:42<00:00, 14.45s/it]  
100%|██████████| 331/331 [1:20:39<00:00, 14.62s/it]
100%|██████████| 331/331 [1:21:41<00:00, 14.81s/it]
100%|██████████| 331/331 [1:22:20<00:00, 14.93s/it]
100%|██████████| 331/331 [1:22:46<00:00, 15.01s/it]
100%|██████████| 332/332 [1:23:00<00:00, 15.00s/it]
100%|██████████| 331/331 [1:24:46<00:00, 15.37s/it]
100%|██████████| 332/332 [1:25:07<00:00, 15.39s/it]


In [39]:
# Create a data frame from the meta data results and save it as pickle file
df_train_meta = pd.concat(all_results)
df_train_meta = df_train_meta.sort_values(by=['field_id']).reset_index(drop=True)
df_train_meta.to_pickle(f'{OUTPUT_DIR}/field_meta_train.pkl')

print(f'Training bands saved to {OUTPUT_DIR}')
print(f'Training metadata saved to {OUTPUT_DIR}/field_meta_train.pkl')

Training bands saved to ../data/train
Training metadata saved to ../data/train/field_meta_train.pkl


---
### Testing area
Don't run these cells, they were for me to figure out how the extract function in the extract_data.py works!

In [None]:
fields = []
labels = []
dates = []
tiles = []

for tile_id in tqdm(tile_ids):
    df_tile = df_train[df_train['tile_id']==tile_id]
    tile_dates = sorted(df_tile[df_tile['satellite_platform']=='s2']['date'].unique())
    
    ARR = {}
    for band in bands:
        band_arr = []
        for date in tile_dates:
            src = rasterio.open(df_tile[(df_tile['date']==date) & (df_tile['asset']==band)]['file_path'].values[0])
            band_arr.append(src.read(1))
        ARR[band] = np.array(band_arr,dtype='float32')

    multi_band_arr = np.stack(list(ARR.values())).astype(np.float32)
    multi_band_arr = multi_band_arr.transpose(2,3,0,1) #w,h,bands,dates
    label_src = rasterio.open(df_tile[df_tile['asset']=='labels']['file_path'].values[0])
    label_array = label_src.read(1)
    field_src = rasterio.open(df_tile[df_tile['asset']=='field_ids']['file_path'].values[0])
    fields_arr = field_src.read(1) #fields in tile

In [69]:
ARR['B02'].shape

(38, 256, 256)

In [72]:
multi_band_arr = np.stack(list(ARR.values())).astype(np.float32)
multi_band_arr.shape

(1, 38, 256, 256)

In [73]:
multi_band_arr = multi_band_arr.transpose(2,3,0,1)
multi_band_arr.shape

(256, 256, 1, 38)

In [76]:
label_src = rasterio.open(df_tile[df_tile['asset']=='labels']['file_path'].values[0])
label_array = label_src.read(1)
label_array.shape

(256, 256)

In [82]:
field_src = rasterio.open(df_tile[df_tile['asset']=='field_ids']['file_path'].values[0])
fields_arr = field_src.read(1) #fields in tile
fields_arr

array([[   0,    0,    0, ...,    0,    0,    0],
       [   0,    0,    0, ...,    0,    0,    0],
       [   0,    0,    0, ...,    0,    0,    0],
       ...,
       [5694, 5694, 5694, ...,    0,    0,    0],
       [   0, 5694, 5694, ...,    0,    0,    0],
       [   0, 5694, 5694, ...,    0,    0,    0]], dtype=uint32)

In [104]:
for field_id in np.unique(fields_arr):
    if field_id==0:
        continue
    mask = fields_arr==field_id
    field_label = np.unique(label_array[mask])
    field_label = [l for l in field_label if l!=0]

    if len(field_label)==1: 
        field_label = field_label[0]
        patch = multi_band_arr[mask]
        print('mask____')
        print(np.count_nonzero(mask))
        print('multi____')
        print(multi_band_arr.shape)
        print('patch____')
        print(patch.shape)

mask____
50
multi____
(256, 256, 1, 38)
patch____
(50, 1, 38)
mask____
269
multi____
(256, 256, 1, 38)
patch____
(269, 1, 38)
mask____
31
multi____
(256, 256, 1, 38)
patch____
(31, 1, 38)
mask____
99
multi____
(256, 256, 1, 38)
patch____
(99, 1, 38)
mask____
94
multi____
(256, 256, 1, 38)
patch____
(94, 1, 38)
mask____
6
multi____
(256, 256, 1, 38)
patch____
(6, 1, 38)
mask____
600
multi____
(256, 256, 1, 38)
patch____
(600, 1, 38)
mask____
202
multi____
(256, 256, 1, 38)
patch____
(202, 1, 38)
mask____
353
multi____
(256, 256, 1, 38)
patch____
(353, 1, 38)
mask____
26
multi____
(256, 256, 1, 38)
patch____
(26, 1, 38)
mask____
2
multi____
(256, 256, 1, 38)
patch____
(2, 1, 38)
mask____
22
multi____
(256, 256, 1, 38)
patch____
(22, 1, 38)
mask____
222
multi____
(256, 256, 1, 38)
patch____
(222, 1, 38)
mask____
67
multi____
(256, 256, 1, 38)
patch____
(67, 1, 38)
mask____
668
multi____
(256, 256, 1, 38)
patch____
(668, 1, 38)
mask____
958
multi____
(256, 256, 1, 38)
patch____
(958, 1, 38

In [86]:
label_array[mask]

array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=uint8)