In [None]:
import os
import sys
import random
from random import shuffle
import gc
import math
from pathlib import Path
from datetime import datetime
from collections import Counter, defaultdict, OrderedDict
import urllib.parse as urlparse
import boto3
import shutil
import tqdm
import itertools

import numpy as np
from numpy import inf
import pandas as pd
from sklearn import metrics
import rasterio
import pickle
import json

from IPython.core.debugger import set_trace

In [None]:
def make_reproducible(seed = 42):
    """Make all the randomization processes start from a shared seed"""
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

make_reproducible()

##  Step 1. Rename the data such that the gridIDs in the filenames have leading zeros

In [None]:
"""
import os
from pathlib import Path
"""

def rename_w_leading_0s(root_dir, tif_content, num_digits, ftype, country, source, verbose = False):
    """
    Renames grid IDs with leading 0's so that sorting of filenames is in ascending order
    
    root_dir (str) -- path to the main directory which data resides. Example: "C:/My_documents/Data".
    tif_content (str) -- Seperate between Labels (mask) and img data (data).
    num_digits (int) -- Decide on the number of digits to represent a Grid-ID. Default id 6.
    ftype (str) -- identifies format of the file to be renamed.
    country (str) -- This is based on the organization of dataset that images and 
                     labels reside inside a country folder.
    source (str) -- folder name of the resource. It can be either the remote sensing sensor used to
                    acquire the image dataset or the label dataset.
    verbose (Binary) -- If set to True, prints a report of old and new names on screen. Default is False.
    """
    assert tif_content in ["mask", "data"]
    assert country in ["SouthSudan", "Ghana"]
    assert source in ["Sentinel-1", "Sentinel-2", "Labels"]
    
    path_to_src = Path(root_dir) / country / source
    old_fname = []
    new_fname = []
    
    if tif_content == "mask":
        for dirname in os.listdir(path_to_src):
            gridID = str(dirname).split("_")[-1]
            
            for filename in os.listdir(path_to_src / dirname):
                
                if ftype == "tif":
                    
                    if filename.endswith(".tif"):
                        old_fname += [path_to_src / dirname / filename]
                        new_name = filename.replace(".tif", "_" + gridID).zfill(num_digits) + ".tif"
                        new_fname += [path_to_src / dirname / new_name]
                
                elif ftype == "npy":
                    
                    if filename.endswith(".npy"):
                        old_fname += [path_to_src / dirname / filename]
                        new_name = filename.replace(".npy", "_" + gridID).zfill(num_digits) + ".npy"
                        new_fname += [path_to_src / dirname / new_name]
                    
                    elif filename.endswith(".json"):
                        old_fname += [path_to_src / dirname / filename]
                        new_name = filename.replace(".json", "_" + gridID).zfill(num_digits) + ".json"
                        new_fname += [path_to_src / dirname / new_name]
    
    elif tif_content == 'data':
        
        for dirname in os.listdir(path_to_src):
            string_list = str(dirname).split("_")[-5:]
            string_list[1] = string_list[1].zfill(num_digits)
            replace_string = '_'.join(string_list)
            
            for filename in os.listdir(path_to_src / dirname):
                
                if filename.endswith(".tif"):
                    old_fname += [path_to_src / dirname / filename]
                    new_name = filename.replace(".tif", "_" + replace_string) + ".tif"
                    new_fname += [path_to_src / dirname / new_name]
                
                elif filename.endswith(".json"):
                    old_fname += [path_to_src / dirname / filename]
                    new_name = filename.replace(".json", "_" + replace_string) + ".json"
                    new_fname += [path_to_src / dirname / new_name]
    
    for i,j in zip(old_fname, new_fname):
        os.rename(i, j)
        if verbose:
            print('Renaming {} to {}'.format(i, j))

In [None]:
root_dir = "C:/My_documents/CropTypeData_Rustowicz/toy_Ghana"
tif_content = "mask"
num_digits = 6
ftype = "tif"
country = "Ghana"
source = "Labels"

In [None]:
rename_w_leading_0s(root_dir, tif_content, num_digits, ftype, country, source, verbose = False)

## Step 2. Remove those Grid-IDs where there is no actual label recorded

In [None]:
"""
import os
from pathlib import Path
import numpy as np
import pickle
import rasterio
from collections import Counter
import json
"""

def get_grid_nums(root_dir, country, source, ftype, verbose = False):
    
    src_path = Path(root_dir) / country / source
    
    if ftype == "tif":
        files = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(src_path) for \
                 f in filenames if f.endswith(".tif")]
        files.sort()
        
        if country == "Ghana":
            grid_numbers = [str(f).split("_")[-4] for f in files]
        
        elif country == "SouthSudan":
            grid_numbers = [str(f).split("_")[-3] for f in files]
    
    elif ftype == "npy":
        files = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(src_path) for \
                 f in filenames if f.endswith(".npy") or f.endswith(".json")]
        files.sort()
        grid_numbers = [str(f).split("_")[-4] for f in files]

    grid_numbers.sort()
    
    if verbose:
        for i,j in zip(grid_numbers, files):
            print('Grid-ID: {}\nAssociated file with same ID {}\n'.format(i, j))
    
    return grid_numbers, files

##################################################

def get_empty_grids(root_dir, country, source, lbl_fldrname, verbose = True):
    """
    Provides data from input .tif files depending on function input parameters. 
    
    Args:
      directory - (str) the base directory of data
      countries - (list of str) list of strings that point to the directory names
                  of the different countries (i.e. ['ghana', 'tanzania', 'southsudan'])
      sources - (list of str) list of directory of satellite sources (i.e. 's1_64x64', 's2') 
      verbose - (boolean) prints outputs from function
      ext - (str) file type that you are working with (i.e. 'tif', 'npy') 
   
      lbl_dir - (str) the directory name that the raster labels are stored in 
                      (i.e. 'raster', 'raster_64x64')
    """

    valid_pixels_list = []
    empty_masks = []
    
    lbl_dir = Path(root_dir) / country / lbl_fldrname

    mask_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(lbl_dir) for \
                   f in filenames if f.endswith('.tif')]
    mask_ids = [str(f).split('_')[-1].replace('.tif', '') for f in mask_fnames]

    mask_fnames.sort()
    mask_ids.sort()
    
    assert len(mask_fnames) == len(mask_ids)

    for mask_fname, mask_id in zip(mask_fnames, mask_ids):
        with rasterio.open(mask_fname) as src:
            cur_mask = src.read()
            valid_pixels = np.sum(cur_mask > 0) 
            valid_pixels_list.append((mask_id, valid_pixels))
            if valid_pixels == 0:
                empty_masks.append(mask_id)

    delete_me = []
    
    grid_numbers, source_files = get_grid_nums(root_dir, country, source, ftype, verbose = False)

    all_ids = set(empty_masks + grid_numbers)
    for el in all_ids:
        if el in empty_masks and el in grid_numbers:
            delete_me.append(el)

    if verbose:
        print("valid pixels list: ", len(valid_pixels_list))
        print('empty masks: ', len(empty_masks))
        print('delete me length: ', len(delete_me))
        print('delete me: ', delete_me)
        
    return set(delete_me)

##################################################

def remove_irrelevant_files(root_dir, country, source, delete_list, ftype, verbose = True):
    
    if len(delete_list) == 0:
        print("There is no empty grid to remove.")
    
    else:
        grid_nums, source_files = get_grid_nums(root_dir, country, source, ftype, verbose = False)
    
        for grid_to_rm in delete_list:
            files_to_rm = [str(f) for f in source_files if ''.join(['_', grid_to_rm]) in str(f)]
        
        if verbose:
            print("grid to remove: {}\n".format(grid_to_rm))
            print("files to remove: {}\n".format(files_to_rm))
        
        #Remove files        
        [os.remove(f) for f in files_to_rm]

In [None]:
root_dir = "C:/My_documents/CropTypeData_Rustowicz/toy_Ghana"
lbl_fldrname = "Labels"
ftype = "tif"
country = "Ghana"
source = "Sentinel-1"

In [None]:
grids_to_delete = get_empty_grids(root_dir, country, source, lbl_fldrname)

remove_irrelevant_files(root_dir, country, source, grids_to_delete, ftype, verbose = True)

## Step 3. Get image statistics for normalization

In [None]:
def get_img_stats(root_dir, country, source, clip):

    #src_path = Path("C:/My_documents/CropTypeData_Rustowicz/example_tile539059/sentinel2_time_series_best")
    src_path = Path(root_dir).joinpath(country , source)
    files = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(src_path) for f in filenames if f.endswith(".tif") if "source" in f]
    
    tile_min =[]
    tile_max =[]

    for file in tqdm.tqdm(files):
        with rasterio.open(file, "r") as src:
            tile = src.read()
            tile = tile.astype(float)
        
            nan_Corr_img = np.where(tile == 0, np.nan, tile)
            
            for i in range(tile.shape[0]):
                
                if np.isnan(nan_Corr_img[i, :, :]).all():
                    band_min = 10e9
                    tile_min.append(band_min)        
                    band_max = -10e9
                    tile_max.append(band_max)
                else:
                    left_tail_clip = np.nanpercentile(nan_Corr_img[i, :, :], clip)
                    right_tail_clip = np.nanpercentile(nan_Corr_img[i, :, :], 100 - clip)
                    left_clipped_band = np.where(nan_Corr_img[i, :, :] < left_tail_clip, left_tail_clip, nan_Corr_img[i, :, :])
                    clipped_band = np.where(left_clipped_band > right_tail_clip, right_tail_clip, left_clipped_band)
                
                    band_min = np.nanmin(clipped_band)
                    tile_min.append(band_min)        
                    band_max = np.nanmax(clipped_band)
                    tile_max.append(band_max)

    if source == "Sentinel-1":
        b1_min =10e9; b2_min =10e9; b3_min =10e9
        b1_max =-10e9; b2_max =-10e9; b3_max =-10e9
        
        # re-arrange the list into tuples of 3 elements to consider each band separately.
        tile_min_rearr = [tile_min[i:i+3] for i in range(0, len(tile_min), 3)]
        tile_max_rearr = [tile_max[i:i+3] for i in range(0, len(tile_max), 3)]
        
        assert len(tile_min_rearr) == len(tile_max_rearr)
        
        for i in range(len(tile_min_rearr)):
            
            if tile_min_rearr[i][0] < b1_min:
                b1_min = tile_min_rearr[i][0]
            if tile_min_rearr[i][1] < b2_min:
                b2_min = tile_min_rearr[i][1]
            if tile_min_rearr[i][2] < b3_min:
                b3_min = tile_min_rearr[i][2]
            
            if tile_max_rearr[i][0] > b1_max:
                b1_max = tile_max_rearr[i][0]
            if tile_max_rearr[i][1] > b2_max:
                b2_max = tile_max_rearr[i][1]
            if tile_max_rearr[i][2] > b3_max:
                b3_max = tile_max_rearr[i][2]
        
        print("----- Sentinel-1 range statistics per band -----")
        print("B1 ('VV') --> min:{}, max:{}".format(b1_min, b1_max))
        print("B2 ('VH') --> min:{}, max:{}".format(b2_min, b2_max))
        print("B3 ('VH/VV) --> min:{}, max:{}".format(b3_min, b3_max))
        
        return [(b1_min, b2_min, b3_min),
                (b1_max, b2_max, b3_max)]
    
    else:
        
        b1_min = 10e9; b2_min = 10e9; b3_min = 10e9; b4_min = 10e9; b5_min = 10e9; 
        b6_min = 10e9; b7_min = 10e9; b8_min = 10e9; b9_min = 10e9; b10_min = 10e9;
        b1_max = 10e-9; b2_max = 10e-9; b3_max = 10e-9; b4_max = 10e-9; b5_max = 10e-9;
        b6_max = 10e-9; b7_max = 10e-9; b8_max = 10e-9; b9_max = 10e-9; b10_max = 10e-9;
        
        tile_min_rearr = [tile_min[i:i+10] for i in range(0, len(tile_min), 10)]
        tile_max_rearr = [tile_max[i:i+10] for i in range(0, len(tile_max), 10)]
        
        assert len(tile_min_rearr) == len(tile_max_rearr)
        
        for i in range(len(tile_min_rearr)):
            if tile_min_rearr[i][0] < b1_min:
                b1_min = tile_min_rearr[i][0]
            if tile_min_rearr[i][1] < b2_min:
                b2_min = tile_min_rearr[i][1]
            if tile_min_rearr[i][2] < b3_min:
                b3_min = tile_min_rearr[i][2]
            if tile_min_rearr[i][3] < b4_min:
                b4_min = tile_min_rearr[i][3]
            if tile_min_rearr[i][4] < b5_min:
                b5_min = tile_min_rearr[i][4]
            if tile_min_rearr[i][5] < b6_min:
                b6_min = tile_min_rearr[i][5]
            if tile_min_rearr[i][6] < b7_min:
                b7_min = tile_min_rearr[i][6]
            if tile_min_rearr[i][7] < b8_min:
                b8_min = tile_min_rearr[i][7]
            if tile_min_rearr[i][8] < b9_min:
                b9_min = tile_min_rearr[i][8]
            if tile_min_rearr[i][9] < b10_min:
                b10_min = tile_min_rearr[i][9]
            
            if tile_max_rearr[i][0] > b1_max:
                b1_max = tile_max_rearr[i][0]
            if tile_max_rearr[i][1] > b2_max:
                b2_max = tile_max_rearr[i][1]
            if tile_max_rearr[i][2] > b3_max:
                b3_max = tile_max_rearr[i][2]
            if tile_max_rearr[i][3] > b4_max:
                b4_max = tile_max_rearr[i][3]
            if tile_max_rearr[i][4] > b5_max:
                b5_max = tile_max_rearr[i][4]
            if tile_max_rearr[i][5] > b6_max:
                b6_max = tile_max_rearr[i][5]
            if tile_max_rearr[i][6] > b7_max:
                b7_max = tile_max_rearr[i][6]
            if tile_max_rearr[i][7] > b8_max:
                b8_max = tile_max_rearr[i][7]
            if tile_max_rearr[i][8] > b9_max:
                b9_max = tile_max_rearr[i][8]
            if tile_max_rearr[i][9] > b10_max:
                b10_max = tile_max_rearr[i][9]

        print("----- Sentinel-2 range statistics per band -----")
        print("B1 ('Blue') --> min:{}, max:{}".format(b1_min, b1_max))
        print("B2 ('Green') --> min:{}, max:{}".format(b2_min, b2_max))
        print("B3 ('Red') --> min:{}, max:{}".format(b3_min, b3_max))
        print("B4 ('Red Edge 1') --> min:{}, max:{}".format(b4_min, b4_max))
        print("B5 ('Red Edge 2') --> min:{}, max:{}".format(b5_min, b5_max))
        print("B6 ('Red Edge 3') --> min:{}, max:{}".format(b6_min, b6_max))
        print("B7 ('NIR') --> min:{}, max:{}".format(b7_min, b7_max))
        print("B8 ('Red Edge 4') --> min:{}, max:{}".format(b8_min, b8_max))
        print("B9 ('SWIR 1') --> min:{}, max:{}".format(b9_min, b9_max))
        print("B10 ('SWIR 2') --> min:{}, max:{}".format(b10_min, b10_max))
        
        return [(b1_min, b2_min, b3_min, b4_min, b5_min, b6_min, b7_min, b8_min, b9_min, b10_min), 
                (b1_max, b2_max, b3_max, b4_max, b5_max, b6_max, b7_max, b8_max, b9_max, b10_max)]

In [None]:
#root_dir = "C:/My_documents/CropTypeData_Rustowicz/example_tile539059/sentinel2_time_series_best"
#root_dir = "C:/My_documents/CropTypeData_Rustowicz/toy_Ghana"
root_dir = "D:/CropType/Ghana/Original_dataset"
country = "Ghana"
source = "Sentinel-1"
clip = 1.5

In [None]:
get_img_stats(root_dir, country, source, clip)

## Step 4. Normalize tiles, add doy band to both sources, removes tiles with bands of NaN, add spectral indices to sentinel-2 (optional) and create temporal stacks for each grid and save them as .npy files

In [None]:
"""
import numpy as np
"""

def normalize(grid, source, country, norm_type, clip):
    r""" Normalization based on the chosen normalization type.
    Args: 
      grid (numpy array) -- grid to be normalized.
      norm_type (str) -- decide on the type of normalization. either z-value (standardization) or min/max.
      source (str) -- Satellite sensor.
      country (str) -- Geographic region where the dataset is taken.
    
    Returns:
      grid (numpy array) -- a normalized version of the input grid.
      
    NOTE: Both normalizations are based on the statistics of the whole temporal and spatial extent for each band.
    """
    
    # hard-coded global statistics over the spatial and temporal extent required for normalization. 
    MEANS = {
        "Sentinel-1": {"Ghana": np.array([-10.50, -17.24, 1.17])},
        "Sentinel-2": {"Ghana": np.array([2620.00, 2519.89, 2630.31, 2739.81, 3225.22,
                                          3562.64, 3356.57, 3788.05, 2915.40, 2102.65])}
    }


    STDS = {
        "Sentinel-1": {"Ghana": np.array([3.57, 4.86, 5.60])},     
        "Sentinel-2": {"Ghana": np.array([2171.62, 2085.69, 2174.37, 2084.56, 2058.97,
                                          2117.31, 1988.70, 2099.78, 1209.48, 918.19])}
    }
    
    MINS = {
        "Sentinel-1": {"Ghana": np.array([-24.2179511259, -29.877275167, -22.03031768])},
        "Sentinel-2": {"Ghana": np.array([735.0, 578.0, 385.425, 426.0, 490.4245, 
                                          528.0, 434.0, 456.0, 174.0, 84.0])}
    }


    MAXS = {
        "Sentinel-1": {"Ghana": np.array([1.353708618, -0.30558315, 29.4155171438])},
        "Sentinel-2": {"Ghana": np.array([14077.45, 13880.575, 14869.3, 14414.575, 14802.3, 
                                          15427.15, 14417.0, 15758.575, 14392.0, 14211.025])}
    }
    
    num_bands = grid.shape[0]
    
    if norm_type == "z-value":
        
        means = MEANS[source][country]
        stds = STDS[source][country]
        
        for i in range(num_bands):
            grid[i,:,:] = (grid[i,:,:] - means[i]) / stds[i]
        
        return grid
    
    elif norm_type == "min/max":
        
        mins = MINS[source][country]
        maxs = MAXS[source][country]
        
        if clip:
            normalized_bands = []
            
            for i in range(num_bands):
                
                nan_corr_img = np.where(grid[i, :, :] == 0, np.nan, grid[i, :, :])
                
                left_tail_clip = np.nanpercentile(nan_corr_img, clip)
                right_tail_clip = np.nanpercentile(nan_corr_img, 100 - clip)
                
                left_clipped_band = np.where(nan_corr_img < left_tail_clip, left_tail_clip, grid[i, :, :])
                clipped_band = np.where(left_clipped_band > right_tail_clip, right_tail_clip, left_clipped_band)
                
                normalized_band = (clipped_band - mins[i]) / (maxs[i] - mins[i])
                normalized_bands.append(np.expand_dims(normalized_band, 0))
            
            normal_grid = np.concatenate(normalized_bands, 0)
            normal_grid = np.where(grid == 0, 0, normal_grid)
            
            return normal_grid
        
        else:
            for i in range(num_bands):
                grid[i,:,:] = (grid[i,:,:]  - mins[i]) / (maxs[i] - mins[i])

    return grid

#########################

def date2doy(date, in_shape, norm_type, doy_mode, origin):
    r"""
    Convert string dates to equivalent day of the year and convert it to a Z-norm.

    Parameters:
        date (string) -- list of dates read from a .json file.
        in_shape (tuple ) -- Day of the year will be broadcast to the specified shape.
        norm_type (str) -- normalization method. Better match with the normalization
                           type of the other bands.
        doy_mode (str) -- If absolute the start of the year will be Jan 1st. For relative mode,
                          the start_date (first timestamp) will be used as point of origin.
        origin (str [format yyyy_mm_dd]) -- start date used with relative doy_mode.

    Returns:
        grid (np.array): Z-norm day of the year band in specified shape.

    NOTE: The code assumes a 365 days year. If the length of 'relative' mode is different then
          the normalization must revised accordingly.
    """
    if doy_mode == "absolute":
        date = datetime.strptime(date, '%Y_%m_%d').date()
        doy = date.timetuple().tm_yday
        doy = np.array([doy])
    else:
        date_0 = datetime.strptime(origin, '%Y_%m_%d').date()
        date = datetime.strptime(date, '%Y_%m_%d').date()
        # add 1 to shift the origin
        doy = (date - date_0).days + 1
        doy = np.array([doy])

    if norm_type == "z-value":
        norm_doy = (doy - 177.5) / 177.5
    else:
        norm_doy = (doy - 1) / 364.0

    C, W, H = in_shape

    stack = norm_doy[np.newaxis, :]
    stack = np.broadcast_to(stack, (W, H))
    stack = stack[np.newaxis, :]

    return stack

#########################

def get_spectral_indices(img):

    assert img.shape[0], "Incorrect number of bands."

    # Get bands with true names
    S2_BANDS = {"BLUE": 0, "GREEN": 1, "RED": 2, "RDED1": 3, "RDED2": 4,
                "RDED3": 5, "NIR": 6, "RDED4": 7, "SWIR1": 8, "SWIR2": 9}
    blue = img[S2_BANDS["BLUE"], :, :]
    green = img[S2_BANDS["GREEN"], :, :]
    red = img[S2_BANDS["RED"], :, :]
    nir = img[S2_BANDS["NIR"], :, :]
    swir1 = img[S2_BANDS["SWIR1"], :, :]
    swir2 = img[S2_BANDS["SWIR2"], :, :]

    G = 2.5
    C1 = 6
    C2 = 7.5
    L = 0.5
    
    ndvi = (nir - red) / (nir + red)  # Normalized Difference Vegetation Index
    evi = G * (nir - red) / (nir + C1 * red - C2 * blue + L)  # Enhanced Vegetation Index
    ndwi = (nir - swir2) / (nir + swir2)  # Normalized Difference Water Index
    bi = np.sqrt(((red * red) / (green * green)) / 2)
    
    """
    Absorption properties of the middle infrared band cause a low reflectance of rice plants in this
    channel (Lilliesand & Kiefer 1994). In irrigated rice fields, especially in early transplanting 
    periods, water environment plays an important role in rice spectral (Nuarsa et al., 2011). 
    Rice Growth Vegetation Index
    """
    rgvi = 1 - (blue + red) / (nir + swir1 + swir2)

    stack = np.dstack([ndvi, evi, ndwi, rgvi, bi]).transpose(2, 0, 1)

    return stack
 
##################################################

"""
import os
from pathlib import Path
import numpy as np
import pickle
import rasterio
from collections import Counter
import json
"""

def make_temporal_cube(root_dir, country, source, lbl_fldrname, out_path=None, norm_type="min/max", clip=None, 
                       channel_first=True, add_doy=True, doy_mode="absolute", origin=None, add_si=False, verbose=False):
    
    #set_trace()
    # Path to "Label" folder configuration 
    lbl_dir = Path(root_dir) / country / lbl_fldrname
    lbl_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(lbl_dir) for f in filenames if f.endswith(".tif")]
    lbl_ids = [str(f).split("_")[-1].replace(".tif", "") for f in lbl_fnames]
    
    if verbose:
        print("Number of grids for country {}: {}".format(country, len(lbl_ids)))
    
    # Path to "source" folder which contains img tiles.
    src_path = Path(root_dir) / country / source
    
    if out_path is None:
        out_path = src_path.joinpath("npy")
    Path(out_path).mkdir(parents=True, exist_ok=True)    
    
    # List of img tiles for the current source (RS sensor) and get the Grid-ID of each img tile..
    src_files = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(src_path) for f in filenames if f.endswith(".tif") if "source" in f]
    grid_numbers = [str(f).split("_")[-4] for f in src_files]
    
    src_files.sort()
    grid_numbers.sort()
    
    if source == "Sentinel-2":
        # cloud_mask categories --> {"clear":0, "cloud":1, "haze":2, "shadow":3}
        cloud_masks = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(src_path) for f in filenames if f.endswith(".tif") if "cloudmask" in f]
        cloud_masks.sort()
    
    # read one image from list to get dimensions
    with rasterio.open(src_files[0]) as src:
        meta = src.meta
        img = src.read()
    
    bands = meta["count"]
    if add_doy:
        bands = meta["count"] + 1
    if (source == "Sentinel-2") and add_si:
        bands = bands + 5
    
    if verbose:
        print("-----------------------------")
        print("Image dimensions: {}".format(img.shape))
        print("Current data source: {}".format(source))
        print("Number of grids in set: {}".format(len(set(grid_numbers))))
        print("Maximum timestamps from this data source: {}".format(Counter(grid_numbers).most_common(1)[0][1]))
        print("Set of grid numbers: {}".format(sorted(set(grid_numbers))))
              
    for grid_idx, grid in enumerate(sorted(set(grid_numbers))):
        if verbose:
            print("Grid: {}".format(grid))
        
        cur_grid_files = [str(f) for f in src_files if "_" + grid + "_" in str(f)]
        cur_grid_files.sort()
        
        usable_cur_grid_files = []
        dates = []
        
        for cur_fn in cur_grid_files:
            with rasterio.open(cur_fn) as src:
                src_array = src.read()
            num_nans = np.count_nonzero(np.isnan(src_array))
            #if not num_nans >= (src_array.shape[1] * src_array.shape[1]):
            if not num_nans > 0:
                src_date_parts = Path(cur_fn).name.replace(".tif", "").split("_")[-3:]
                date = "_".join(src_date_parts)
                dates.append(date)
                usable_cur_grid_files.append(cur_fn)
        usable_cur_grid_files.sort()
        
        if verbose:
            diff = len(cur_grid_files) - len(usable_cur_grid_files)
            total = len(cur_grid_files)
            print(f"Droping {diff} bad timestamps out of {total}.")
                
        if (doy_mode == "relative") and not origin:
            origin_date_parts = Path(usable_cur_grid_files[0]).name.replace(".tif", "").split("_")[-3:]
            origin = "_".join(origin_date_parts)
        
        if source == "Sentinel-2":
            cur_mask_files = [str(f) for f in cloud_masks if "_" + grid + "_" in str(f)]
            cur_mask_files.sort()
            
            usable_cur_mask_files = []
            if len(cur_mask_files) != len(usable_cur_grid_files):
                for fn in cur_mask_files:
                    mask_date_parts = Path(fn).name.replace(".tif", "").split("_")[-3:]
                    date = "_".join(mask_date_parts)
                    if date in dates:
                        usable_cur_mask_files.append(fn)
            else:
                usable_cur_mask_files = cur_mask_files
            
            assert len(usable_cur_mask_files) == len(usable_cur_grid_files)
            
            # dimensions: bands x rows x columns x timestamps
            data_array = np.zeros((bands, img.shape[1], img.shape[2], len(usable_cur_grid_files)))
            mask_array = np.zeros((img.shape[1], img.shape[2], len(usable_cur_mask_files)))

            for idx, (fname, mname) in enumerate(zip(usable_cur_grid_files, usable_cur_mask_files)):
                if verbose:
                    print("idx: ", idx)
                    print("fname: ", fname)
                    print("mname: ", mname)
        
                with rasterio.open(fname) as src:
                    s2_tile = src.read()
                    s2_tile = s2_tile.astype(float)
                    if not channel_first:
                        s2_tile.transpose(2, 0, 1)
                        
                    if add_si:
                        si_bands = get_spectral_indices(s2_tile)
                    if add_doy:
                        date_parts = Path(fname).name.replace(".tif", "").split("_")[-3:]
                        date = "_".join(date_parts)
                        s2_doy_band = date2doy(date, s2_tile.shape, norm_type, doy_mode, origin)

                    s2_normal_tile = normalize(s2_tile, source, country, norm_type, clip)
                        
                    if bands == 10 and not (add_doy and add_si):
                        data_array[:, :, :, idx] = s2_normal_tile
                    elif bands == 11 and add_doy and not add_si:
                        aug_array = np.concatenate([s2_normal_tile, s2_doy_band], axis=0)
                        data_array[:, :, :, idx] = aug_array
                    elif bands == 15 and add_si and not add_doy:
                        aug_array = np.concatenate([s2_normal_tile, si_bands], axis=0)
                        data_array[:, :, :, idx] = aug_array
                    elif bands == 16 and (add_doy and add_si):
                        aug_array = np.concatenate([s2_normal_tile, si_bands, s2_doy_band], axis=0)
                        data_array[:, :, :, idx] = aug_array
                        
                with rasterio.open(mname) as msrc:
                        mask_array[:, :, idx] = msrc.read()
    
            tmp_fn = Path(usable_cur_grid_files[0]).name.replace(".tif", "").split("_")
            fn = "_".join(tmp_fn[0:3])
            out_fname = out_path / fn
                
            tmp_mn = Path(usable_cur_mask_files[0]).name.replace(".tif", "").split("_")
            mn = "_".join(tmp_mn[0:3])
            out_mname = out_path / mn
    
            # store and save metadata
            date_dict = {}
            date_dict["dates"] = dates
    
            with open(str(out_fname) + ".json", "w") as fp:
                json.dump(date_dict, fp)
    
            # save image stack as .npy
            np.save(out_fname, data_array)
            np.save(out_mname, mask_array)
        
        else:
            
            data_array = np.zeros((bands, img.shape[1], img.shape[2], len(usable_cur_grid_files)))
            
            for idx, fname in enumerate(usable_cur_grid_files):
                
                if verbose:
                    print("idx: ", idx)
                    print("fname: ", fname)
        
                with rasterio.open(fname) as src:
                    s1_tile = src.read()
                    s1_tile = s1_tile.astype(float)
                    if not channel_first:
                        s1_tile.transpose(2, 0, 1)
                    
                    s1_normal_tile = normalize(s1_tile, source, country, norm_type, clip)
                    if add_doy:
                        date_parts = Path(fname).name.replace(".tif", "").split("_")[-3:]
                        date = "_".join(date_parts)
                        s1_doy_band = date2doy(date, s1_tile.shape, norm_type, doy_mode, origin)
                        aug_array = np.concatenate([s1_normal_tile, s1_doy_band], axis=0)
                        data_array[:, :, :, idx] = aug_array
                    else:
                        data_array[:, :, :, idx] = s1_normal_tile
            
            tmp_fn = Path(usable_cur_grid_files[0]).name.replace(".tif", "").split("_")
            fn = "_".join(tmp_fn[0:3])
            out_fname = out_path / fn
    
            # store and save metadata
            meta = {}
            meta["dates"] = dates
    
            with open(str(out_fname) + ".json", "w") as fp:
                json.dump(meta, fp)
    
            # save image stack as .npy
            np.save(out_fname, data_array)

In [None]:
#root_dir = "C:/My_documents/CropTypeData_Rustowicz/toy_Ghana"
root_dir = "D:/CropType/Ghana/Original_dataset"
country = "Ghana"
source = "Sentinel-1"
lbl_fldrname = "Labels"
out_path = None 
norm_type = "min/max" 
clip = 1.5
channel_first = True 
add_doy = True 
doy_mode = "absolute" 
origin = None 
add_si = False
verbose = True

In [None]:
make_temporal_cube(root_dir, country, source, lbl_fldrname, out_path, norm_type, 
                   clip, channel_first, add_doy, doy_mode, origin, add_si, verbose)

In [None]:
img = np.load("C:/My_documents/CropTypeData_Rustowicz/toy_Ghana/Ghana/Sentinel-1/npy/source_s1_000334.npy")

In [None]:
img.shape

In [None]:
img[3,10:21, 10:21, 10]

## step 5. Reclassify Labels

In [None]:
def reclassify_lbl(root_dir, country, lbl_fldrname, categories, out_path=None, verbose=False):
    #set_trace()
    lbl_dir = Path(root_dir) / country / lbl_fldrname
    
    lbl_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(lbl_dir) for \
                  f in filenames if f.endswith(".tif")]
    
    complete_categories = {"unknown": 0, "ground nut": 1, "maize": 2, "rice": 3, "soya bean": 4, "yam": 5, 
                           "intercrop": 6, "sorghum": 7, "okra": 8, "cassava": 9, "millet": 10, "tomato": 11, 
                           "cowpea": 12, "sweet potato": 13, "babala beans": 14, "salad vegetables": 15, 
                           "bra and ayoyo": 16, "watermelon": 17, "zabla": 18, "nili": 19, "kpalika": 20, 
                           "cotton": 21, "akata": 22, "nyenabe": 23, "pepper": 24}
    
    for fn in lbl_fnames:
        
        with rasterio.open(fn) as src:
            profile = src.profile
            lbl_array = src.read()
            
            categories_ls = categories.keys()
            categories_to_other = list(np.setdiff1d(list(complete_categories.keys()),list(categories.keys())))
            aggregator_cat = list(np.setdiff1d(list(categories.keys()), list(complete_categories.keys())))
            assert len(aggregator_cat) == 1, "Your classification scheme contains invalid classes."
            
            if verbose:
                lbl_id = str(fn).split("_")[-1].replace(".tif", "")
                print("---Grid: {} ---".format(lbl_id))
                print("List of crop categories: {} merged into the {} category".format(categories_to_other, aggregator_cat))
            
            # initialize the canvas for reclassed layer
            remapped_lbl = np.zeros_like((lbl_array), dtype="uint8")
            
            for cat in categories_ls:
                if cat == aggregator_cat[0]:
                    for cat_to_other in categories_to_other:
                        remapped_lbl[lbl_array == complete_categories[cat_to_other]] = categories[cat]
                else:
                    remapped_lbl[lbl_array == complete_categories[cat]] = categories[cat]
            
            if out_path is None:
                out_path = lbl_dir.joinpath("reclass")
            Path(out_path).mkdir(parents=True, exist_ok=True)    
            
            
            reclass_lbl_out_path = lbl_dir / "reclass"
            Path(reclass_lbl_out_path).mkdir(parents=True, exist_ok=True)
            
            profile.update(
                dtype=rasterio.uint8
            )
                
            with rasterio.open(Path(out_path) / fn.name, "w", **profile) as dst:
                dst.write(remapped_lbl)

In [None]:
root_dir = "D:/CropType/Ghana/Original_dataset"
country = "Ghana"
lbl_fldrname = "Labels"
categories = {"unknown": 0, "maize": 1, "rice": 2, "other_crop": 3}
#categories = {"unknown": 0, "ground nut": 1, "maize": 2, "rice": 3, "soya bean": 4, "other_crop": 5}
out_path = None
verbose = False

In [None]:
reclassify_lbl(root_dir, country, lbl_fldrname, categories, out_path, verbose)

In [None]:
fn1 = "C:/My_documents/CropTypeData_Rustowicz/toy_Ghana/Ghana/Labels/su_african_crops_ghana_labels_000001/labels_000001.tif"
fn2 = "C:/My_documents/CropTypeData_Rustowicz/toy_Ghana/Ghana/Labels/reclass/labels_000001.tif"

In [None]:
with rasterio.open(fn2) as src:
    profile = src.profile
    lbl_array = src.read()

In [None]:
lbl_array.shape

In [None]:
profile

## step 6. Explore the temporal dimension and remove tiles with few dates that don't cover at least half (end of June) of the growing season (April-August) and those tiles that more than 75% of the crop field is covered with cloud in more than 70% of the temporal extent

### Step 6.1 Check if tile IDs match between S1, S2 and Label folders

In [None]:
def check_tile_ids_match(root_dir, country, lbl_fldrname):
    
    lbl_dir = Path(root_dir) / country / lbl_fldrname
    lbl_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(lbl_dir) for f in filenames if f.endswith(".tif")]
    lbl_fnames.sort()
    
    s1_src_path = Path(root_dir) / country / "S1_npy"
    s1_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(s1_src_path) for f in filenames if f.endswith(".npy")]
    s1_fnames.sort()
            
    s2_src_path = Path(root_dir) / country / "S2_npy"
    s2_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(s2_src_path) for f in filenames if f.endswith(".npy") if "source" in f]
    cmasks = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(s2_src_path) for f in filenames if f.endswith(".npy") if "cloudmask" in f]
    s2_fnames.sort()
    cmasks.sort()
    
    for lbl_fn, s1_fn, s2_fn, cmask_fn in zip(lbl_fnames, s1_fnames, s2_fnames, cmasks):
        
        lbl_grid_id = str(lbl_fn).split("_")[-1].replace(".tif", "")
        s1_grid_id = str(s1_fn).split("_")[-1].replace(".npy", "")
        s2_grid_id = str(s2_fn).split("_")[-1].replace(".npy", "")
        cmask_grid_id = str(s2_fn).split("_")[-1].replace(".npy", "")
        
        print("label: ", lbl_grid_id)
        print("S1: ", s1_grid_id)
        print("S2: ", s2_grid_id)
        print("cmask: ", cmask_grid_id)
        assert lbl_grid_id == s1_grid_id == s2_grid_id == cmask_grid_id

In [None]:
root_dir = "D:/CropType/Ghana/Original_dataset"
country = "Ghana"
lbl_fldrname = "Labels"

In [None]:
check_tile_ids_match(root_dir, country, lbl_fldrname)

### Step 6.2 Explore sequence length and available months

In [None]:
def get_seq_length_info(root_dir, country, source, seq_threshold, ext_file_path, verbose=False):
    
    src_path = Path(root_dir) / country / source
    
    files = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(src_path) for f in filenames if f.endswith(".json") if "source" in f]  
    files.sort()
    
    out_dict = {}
    flagged_grids = []
    
    out_dir = Path(ext_file_path)
    name = "{}_seq_length_report.txt".format(str(source).split("_")[0])
    out_path = out_dir.joinpath(name)
    
    if os.path.exists(out_path):
        os.remove(out_path)
    
    for file in files:
        with open(file, 'r') as f:
            data = json.load(f)
            seq_length = len(data["dates"])
            grid_id = str(file.name).split("_")[-1].replace(".json", "")
            out_dict[grid_id] = seq_length
            
            if verbose:
                print("--- Listing all ---")
                print("Grid_id: {}".format(grid_id))
                print("Temporal length: {}".format(grid_id))
                print("---")
            
            if seq_length < seq_threshold:
                
                with open(out_path, "a") as external_file_1:
                    print("Grid ID: {}, has {} timestamps.".format(grid_id, seq_length), file=external_file_1)
                
                flagged_grids.append(grid_id)
    
    print(f"Report is saved at: {out_path}")            
    return flagged_grids, out_dict

#########################

def get_available_months(root_dir, country, source, month_threshold, ext_file_path):
     
    src_path = Path(root_dir) / country / source
    files = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(src_path) for
             f in filenames if f.endswith(".json") if "source" in f]
    
    
    all_ids_dict = {}
    flagged_grids = []
    
    out_dir = Path(ext_file_path)
    name = "{}_avail_months_report.txt".format(str(source).split("_")[0])
    out_path = out_dir.joinpath(name)
    
    if os.path.exists(out_path):
        os.remove(out_path)
    
    for file in files:
        with open(file) as f:
            data = json.load(f)
            grid_id = str(file.name).split("_")[2].replace(".json", "")
            months = [int(date.split("_")[1]) for date in data["dates"]]
            dic_unique_count_months = dict(zip(months,[months.count(i) for i in months]))
            available_months = len(dic_unique_count_months)
            
            all_ids_dict[grid_id] = available_months
            
            if available_months < month_threshold:
    
                with open(out_path, "a") as external_file_2:
                    
                    print("Grid ID: {}, has {} available months. Detail: {}".format(grid_id, available_months, dic_unique_count_months), file=external_file_2)
                
                flagged_grids.append(grid_id)
    
    print(f"Report is saved at: {out_path}")
    return flagged_grids, all_ids_dict          

In [None]:
root_dir = "D:/CropType/Ghana/Original_dataset"
country = "Ghana"
ext_file_path = "D:/CropType/Ghana/Original_dataset/Ghana"
s1_seq_threshold = 15
s2_seq_threshold = 20
month_threshold = 8

In [None]:
s1_seq_flagged_grids, s1_seq_all_grids = get_seq_length_info(root_dir, country, source="S1_npy", 
                                                             seq_threshold=s1_seq_threshold, ext_file_path=ext_file_path, verbose=False)
s1_unique_tmp_lengths = Counter(list(s1_seq_all_grids.values()))
s1_m_flagged_grids, s1_m_all_grids = get_available_months(root_dir, country, source="S1_npy", 
                                                          month_threshold=month_threshold, ext_file_path=ext_file_path)

print("--- Sentinel-1 ---")
print("Unique temporal lengths and their occurances (unique, count):")
print(sorted(s1_unique_tmp_lengths.items()))
print("")
print("Flagged grids with shorter number of timestamps than {}:".format(s1_seq_threshold))
print(s1_seq_flagged_grids)
print("total flagged tiles: ", len(s1_seq_flagged_grids))
print("")
print("Flagged grids with data from a temporal extent of less than {} months:".format(month_threshold))
print(s1_m_flagged_grids)
print("total flagged tiles: ", len(s1_m_flagged_grids))


In [None]:
s2_seq_flagged_grids, s2_seq_all_grids = get_seq_length_info(root_dir, country, source="S2_npy", 
                                                             seq_threshold=s2_seq_threshold, ext_file_path=ext_file_path, verbose=False)
s2_unique_tmp_lengths = Counter(list(s2_seq_all_grids.values()))
s2_m_flagged_grids, s2_m_all_grids = get_available_months(root_dir, country, source="S2_npy", 
                                                          month_threshold=month_threshold, ext_file_path=ext_file_path)

print("--- Sentinel-2 ---")
print("Unique temporal lengths and their occurances (unique, count):")
print(sorted(s2_unique_tmp_lengths.items()))
print("")
print("Flagged grids with shorter number of timestamps than {}:".format(s2_seq_threshold))
print(s2_seq_flagged_grids)
print("total flagged tiles: ", len(s2_seq_flagged_grids))
print("")
print("Flagged grids with data from a temporal extent of less than {} months:".format(month_threshold))
print(s2_m_flagged_grids)
print("total flagged tiles: ", len(s2_m_flagged_grids))


### Step 6.3 Explore the cloud coverage in the temporal extent

In [None]:
def reclass_cloudmask_stack(cloud_stack):
    """ 
     Reclassify cloud mask values to a binary class of cloud and clear.
     clear = 0 --> 0, clouds = 1  --> 1, shadows = 2 --> 1, haze = 3 --> 1
     
     output: numpy nd array with the same size as input.
    """
    remapped_cloud_stack = np.zeros_like((cloud_stack))
    remapped_cloud_stack[cloud_stack == 0] = 0
    remapped_cloud_stack[cloud_stack == 1] = 1
    remapped_cloud_stack[cloud_stack == 2] = 1
    remapped_cloud_stack[cloud_stack == 3] = 1
    
    return remapped_cloud_stack

############################################################

def explore_cloudiness(root_dir, country, lbl_fldrname, ext_file_path, cloudiness_areal_threshold, cloudiness_numdates_threshold):
    
    lbl_dir = Path(root_dir) / country / lbl_fldrname
    lbl_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(lbl_dir) for f in filenames if f.endswith(".tif")]
    lbl_fnames.sort()
    
    src_path = Path(root_dir) / country / "S2_npy"
    date_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(src_path) for f in filenames if f.endswith(".json") if "source" in f]
    cmask_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(src_path) for f in filenames if f.endswith(".npy") if "cloudmask" in f]
    date_fnames.sort()
    cmask_fnames.sort()
    
    out_path = Path(ext_file_path).joinpath("S2_cloudiness_report.txt")
    if os.path.exists(out_path):
        os.remove(out_path)
    
    flagged_grid_ids=[]
    
    for lbl_fn, date_fn, cmask_fn in zip(lbl_fnames, date_fnames, cmask_fnames):
        lbl_grid_id = str(lbl_fn).split("_")[-1].replace(".tif", "")
        date_grid_id = str(date_fn).split("_")[-1].replace(".json", "")
        cmask_grid_id = str(cmask_fn).split("_")[-1].replace(".npy", "")
        assert lbl_grid_id == date_grid_id == cmask_grid_id, "problematic grid ids: lbl: {}, src: {}, cmask: {}".format(
            lbl_grid_id, date_grid_id, cmask_grid_id)
        
        with rasterio.open(lbl_fn, "r") as src:
            if src.count != 1:
                raise ValueError("Label must have only 1 band but {} bands were detected.".format(src.count))
            lbl_array = src.read(1)
            crop_mask = np.where(lbl_array>0, 1, 0)
            crop_area = np.sum(crop_mask)
        
        with open(date_fn, 'r') as f:
            data = json.load(f)
            date_ls = data["dates"]
        
        cmask_array = np.load(cmask_fn)
        binary_cloud_mask = reclass_cloudmask_stack(cmask_array)
        
        temporal_length = len(date_ls)
        
        assert lbl_array.shape[0] == binary_cloud_mask.shape[0], "problematic grid id: {}; lbl:{}, cmask:{}".format(
            lbl_grid_id, lbl_array.shape[0], binary_cloud_mask.shape[0])
        assert lbl_array.shape[1] == binary_cloud_mask.shape[1], "problematic grid id: {}; lbl:{}, cmask:{}".format(
            lbl_grid_id, lbl_array.shape[1], binary_cloud_mask.shape[1])
        assert temporal_length == binary_cloud_mask.shape[2]
        
        detailed_dict ={}
        counter = 0
        
        for i in range(temporal_length):
            cloudy_crop_area = np.sum(crop_mask * binary_cloud_mask[:,:,i])
            cloudiness_ratio = cloudy_crop_area / crop_area
            
            if cloudiness_ratio > cloudiness_areal_threshold:
                detailed_dict[date_ls[i]] = cloudiness_ratio
                counter += 1
        
        counter_ratio = counter / temporal_length 
        
        if counter_ratio > cloudiness_numdates_threshold:       
            with open(out_path, "a") as external_file:
                print(f"Grid ID: {lbl_grid_id}", file=external_file)
                
                print(f"With {counter} over areal_threshold cloudy days out of {temporal_length}", file=external_file)
                print(f"details: {detailed_dict}", file=external_file)
                print("", file=external_file)
            
            flagged_grid_ids.append(lbl_grid_id)
    
    print(f"Report is saved at: {out_path}")
    return flagged_grid_ids 

In [None]:
root_dir = "D:/CropType/Ghana/Original_dataset"
country = "Ghana"
lbl_fldrname = "Labels"
ext_file_path = "D:/CropType/Ghana/Original_dataset/Ghana"
cloudiness_areal_threshold = 0.75
cloudiness_numdates_threshold = 0.68

In [None]:
cloud_flags = explore_cloudiness(root_dir, country, lbl_fldrname, ext_file_path, cloudiness_areal_threshold, cloudiness_numdates_threshold)

In [None]:
len(cloud_flags)

### Step 6.4 Move the flagged tiles

In [None]:
def MoveTiles(root_dir, country, lbl_fldrname, remove_tile_ls):
    #set_trace()
    lbl_dir = Path(root_dir) / country / lbl_fldrname 
    lbl_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(lbl_dir) for f in filenames if f.endswith(".tif")]

    s1_path = Path(root_dir) / country / "S1_npy"
    s1_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(s1_path) for f in filenames if f.endswith(".npy") if "source" in f]
    s1_meta_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(s1_path) for f in filenames if f.endswith(".json")]
    s2_path = Path(root_dir) / country / "S2_npy"
    s2_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(s2_path) for f in filenames if f.endswith(".npy") if "source" in f]
    s2_meta_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(s2_path) for f in filenames if f.endswith(".json")]
    cloud_masks = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(s2_path) for f in filenames if f.endswith(".npy") if "cloudmask" in f]
    
    assert len(s1_fnames) == len(s2_fnames) == len(lbl_fnames) == len(s1_meta_fnames) == len(s1_meta_fnames) == len(cloud_masks)
    lbl_fnames.sort()
    s1_fnames.sort()
    s1_meta_fnames.sort()
    s2_fnames.sort()
    s2_meta_fnames.sort()
    cloud_masks.sort()

    s1_out_path = s1_path / "S1_moved"
    s2_out_path = s2_path / "S2_moved"
    lbl_out_path = lbl_dir / "lbl_moved"
    
    dirs = [s1_out_path, s2_out_path, lbl_out_path]
    
    for p in dirs:
        if not os.path.exists(p):
            os.makedirs(p)

    for s1_fn, s1_meta_fn, s2_fn, s2_meta_fn, cmask_fn, lbl_fn in zip(s1_fnames, s1_meta_fnames, s2_fnames, s2_meta_fnames, cloud_masks, lbl_fnames):
        
        s1_grid_id = str(s1_fn).split("_")[-1].replace(".npy", "")
        s2_grid_id = str(s2_fn).split("_")[-1].replace(".npy", "")
        lbl_grid_id = str(lbl_fn).split("_")[-1].replace(".tif", "")
        assert s1_grid_id == s2_grid_id == lbl_grid_id, "Grid Id mis-match between Sentinel-1 & 2 chips."
        
        if s1_grid_id in remove_tile_ls:
            shutil.move(str(s1_fn), str(s1_out_path))
            shutil.move(str(s1_meta_fn), str(s1_out_path))
            shutil.move(str(s2_fn), str(s2_out_path))
            shutil.move(str(s2_meta_fn), str(s2_out_path))
            shutil.move(str(cmask_fn), str(s2_out_path))
            shutil.move(str(lbl_fn), str(lbl_out_path))
        
    print("End of process")


In [None]:
root_dir = "D:/CropType/Ghana/Original_dataset"
country = "Ghana"
lbl_fldrname = "Labels"
remove_tile_ls = cloud_flags

In [None]:
MoveTiles(root_dir, country, lbl_fldrname, remove_tile_ls)

## Step 7. Add Non-crop class to the labels

In [None]:
def get_fid_coord(num_mesh_cells, img_extent, start_indexing="upper left", index_pos="center"):
    
    r"""Generated Mesh grid is matching based on the 'num_mesh_cells' and 'img_extent'.
    Params:
    num_mesh_cells (int) -- Number of mesh cells.
    img_extent (int) -- Either number of row or colums of the square image patch.
    start_indexing (str): point of origin to index mesh grids. 
    index_pos (str) -- position of the identifier for each mesh cell.  
    
    Note 1: 'num_mesh_cells' and 'img_extent' must be divisible.
    Note 2: ArcGIS mesh-grids start indexing for FID of point labels in the center of cells from lower left corner
          of the img as opposed to numpy array indexing that starts from upper left.
    Note 3: If you want fids to sample a numpy array, choose "upper left" for 'start_indexing',
          Otherwise if you wish to simulate FID naming of ArcGIS mesh grid, then choose: "lower left"."""
    
    
    assert index_pos in ["upper left", "center"], "Invalid index type."
    assert start_indexing in ["upper left", "lower left"], "Invalid index type."
    assert (num_mesh_cells % img_extent) == 0, "'num_mesh_cells' must be divisible by 'img_extent'."
    
    fids = np.arange(num_mesh_cells)
    h = w = img_extent
    grid_size = len(fids) // img_extent
    
    if index_pos == "center":
        x_ls = range(grid_size//2, h - (grid_size//2)+1, grid_size)
        if start_indexing == "upper left":
            x_ls = x_ls[::-1]
        y_ls = range(grid_size//2, w - (grid_size//2)+1, grid_size)
    else:
        x_ls = range(0, h, grid_size)
        if start_indexing == "upper left":
            x_ls = x_ls[::-1]
        y_ls = range(0, w, grid_size)
    
    index = list(itertools.product(x_ls, y_ls))
    fid_coord_dict = dict((str(fid), idx) for (fid, idx) in zip(fids, index))
    
    return fid_coord_dict

In [None]:
num_mesh_cells = 256
img_extent=64
start_indexing = "upper left" 
index_pos = "center"
mesh_grid_dict = get_fid_coord(num_mesh_cells, img_extent, start_indexing, index_pos)

In [None]:
def add_non_crop_to_lbl(lbl_dir, csv_path, mesh_grid_dict, out_dir=None, verbose=False):
    
    tile_fid_ls = pd.read_csv(csv_path, index_col="ID")
    tile_ids_to_reclass = list(tile_fid_ls.index)
    
    lbl_dir = Path(lbl_dir)
    lbl_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(lbl_dir) for \
                  f in filenames if f.endswith(".tif")]
    
    if out_dir is None:
        out_dir = lbl_dir.joinpath("reclass")
    Path(out_dir).mkdir(parents=True, exist_ok=True)    
    
    for fn in lbl_fnames:
        lbl_id = str(fn).split("_")[-1].replace(".tif", "")
        
        if int(lbl_id) in tile_ids_to_reclass:
            
            sample_fid_ls = list(tile_fid_ls.loc[int(lbl_id)])[0].split(",")
            
            num_samples = len(sample_fid_ls) * 16
            if verbose:
                print(f"{num_samples} non-crop pixels are added to the tile: {lbl_id}")

            with rasterio.open(fn) as src:
                profile = src.profile
                lbl_array = src.read()
                
                profile.update(
                    dtype=rasterio.uint8
                )
                
                remapped_lbl = np.zeros_like((lbl_array), dtype="uint8")
                remapped_lbl[lbl_array == 1] = 1
                remapped_lbl[lbl_array == 2] = 2
                remapped_lbl[lbl_array == 3] = 3
                
                for fid in sample_fid_ls:
                    fid_coord = mesh_grid_dict[fid.strip()]
                    row = fid_coord[0]
                    col = fid_coord[1]
                    
                    if verbose:
                        if lbl_array[:, row-2:row+2, col-2:col+2].any() > 0:
                            print(f"Bad FID sample: {fid}")
                    
                    if lbl_array[:, row-2:row+2, col-2:col+2].all() == 0:
                        remapped_lbl[:, row-2:row+2, col-2:col+2] = 4
                    
            with rasterio.open(Path(out_dir) / fn.name, "w", **profile) as dst:
                dst.write(remapped_lbl)
        else:
            if verbose:
                print(f"No change to the tile: {lbl_id}")
            shutil.copy(fn, out_dir)

In [None]:
lbl_dir = "C:/My_documents/CropTypeData_Rustowicz/Ghana/Labels"
csv_path = "C:/My_documents/CropTypeData_Rustowicz/Ghana/usable_training_data.csv"
out_dir = None
verbose = True

In [None]:
add_non_crop_to_lbl(lbl_dir, csv_path, mesh_grid_dict, out_dir, verbose)

## step 8. Summarize statistics of categories

In [None]:
"""
import os
from pathlib import Path
import random
import math
import numpy as np
import pandas as pd
import rasterio
"""

def summarize_lbl(lbl_dir, out_filename, category = None):
    
    lbl_dir = Path(lbl_dir)
    lbl_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(lbl_dir) for \
                  f in filenames if f.endswith(".tif")]
    
    lbl_ids = [str(f).split("_")[-1].replace(".tif", "") for f in lbl_fnames]
    
    if category:
        category_dict = category
    else:
        category_dict = {"unknown": 0, "ground nut": 1, "maize": 2, "rice": 3, "soya bean": 4, "yam": 5, 
                         "intercrop": 6, "sorghum": 7, "okra": 8, "cassava": 9, "millet": 10, "tomato": 11, 
                         "cowpea": 12, "sweet potato": 13, "babala beans": 14, "salad vegetables": 15, 
                         "bra and ayoyo": 16, "watermelon": 17, "zabla": 18, "nili": 19, "kpalika": 20, 
                         "cotton": 21, "akata": 22, "nyenabe": 23, "pepper": 24}
    
    key_list = list(category_dict.keys())
    val_list = list(category_dict.values())
    
    df = pd.DataFrame(columns = key_list, index = lbl_ids)
    
    for fn in lbl_fnames:
        lbl_id = fn.name.split("_")[-1].replace(".tif", "")
        
        with rasterio.open(fn) as src:
            lbl_array = src.read()
            categories, counts = np.unique(lbl_array, return_counts=True)
            for a,b in zip(list(categories), list(counts)):
                if lbl_id in list(df.index):
                    if key_list[val_list.index(a)] in list(df.columns):
                        df.loc[lbl_id, key_list[val_list.index(a)]] = b
    df = df.fillna(0)
    df.to_csv(lbl_dir / out_filename, index_label='Grid-ID')
    return df

In [None]:
lbl_dir = "D:/CropType/Ghana/Labels/validation"
out_filename = "report.csv"
category = {"unknown": 0, "maize": 1, "rice": 2, "other_crop": 3}
#category = {"unknown": 0, "maize": 1, "rice": 2, "other_crop": 3, "non_crop": 4}

In [None]:
report = summarize_lbl(lbl_dir, out_filename, category)
report.sum(axis = 0, skipna = True)

## Step 9. Split the dataset into train and validation

In [None]:
"""
import os
from pathlib import Path
import random
import math
import shutil
import numpy as np
import pandas as pd
"""

def create_grid_splits(root_dir, country, sources, lbl_fldrname, csv_fn, split_threshold):
    """
    Splitting the dataset into train and test datasets.
    
    root_dir (str) -- path to the main directory which data resides. Example: "C:/My_documents/Data".
    country (str) -- This is based on the organization of dataset that images and 
                     labels reside inside a country folder.
    sources (list) -- folder name of the image resource. example: ["Sentinel-1", "Sentinel-2"]
    lbl_fldrname (str) -- Name of the folder containing annotated grids.
    csv_fn (str) -- Name of the csv file summerizing the content of each grid.
    split_threshold (float) -- scalar value as a threshold to decide how many of the grids will be
                               in the training folder. Default is 0.8.
    """
    
    lbl_dir = Path(root_dir) / country / lbl_fldrname
    lbl_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(lbl_dir) for \
                  f in filenames if f.endswith(".tif")]
    lbl_ids = [str(f).split("_")[-1].replace(".tif", "") for f in lbl_fnames]
    
    report = pd.read_csv(lbl_dir / csv_fn, index_col="Grid-ID")
    assert(report.shape[0] == len(lbl_ids))
    
    num_train_girds = math.ceil(split_threshold * len(lbl_ids))
    categories = list(report.columns)
    
    # make a dictionary with the keys as category names and values as list of grid-ids containing
    # those categories.
    cat_grid = {}
    for category in categories:
        if category != "unknown":
            cat_grid[category] = list(report[report[category] > 0].index)
    
    # Choose the category with the least amount of tiles as the initial category and add the sampled ids
    # in the list of training grids.
    initial_category = min(cat_grid, key=lambda cat: len(cat_grid[cat]))
    num_initial_tiles = math.ceil(len(cat_grid[initial_category]) * split_threshold)
    training_grids = random.sample(cat_grid[initial_category], num_initial_tiles)
    
    # for each category find the similar grid-ids that are already in the training list. Recalculate the number of
    # samples that need to be taken. Sample unique grid-ids for the category and add it to the list of training grids.
    for category in categories:
        if category not in ["unknown", initial_category]:
            similar_grids = set(cat_grid[category]).intersection(training_grids)
            num_samples_to_take = math.ceil(len(cat_grid[category]) * split_threshold) - len(similar_grids)
            allowable_grids = list(np.setdiff1d(cat_grid[category], training_grids))
            grid_ids = random.sample(allowable_grids, num_samples_to_take)
            training_grids.extend(grid_ids)
    
    # make sure that training folder contains correct number of grids as decided by the split threshold.
    #To do that we add or drop grids from the category with the max number of grids.
    if len(training_grids) < num_train_girds:
        num_extra_samples = num_train_girds - len(training_grids)
        biggest_category = max(cat_grid, key=lambda cat: len(cat_grid[cat]))
        allowable_other_grids = list(np.setdiff1d(cat_grid[biggest_category], training_grids))
        extra_other_grid_ids = random.sample(allowable_other_grids, num_extra_samples)
        training_grids.extend(extra_other_grid_ids)
    else:
        num_samples_to_drop = len(training_grids) - num_train_girds
        biggest_category = max(cat_grid, key=lambda cat: len(cat_grid[cat]))
        allowable_other_grids = set(cat_grid[biggest_category]).intersection(training_grids)
        droppable_other_grid_ids = random.sample(allowable_other_grids, num_samples_to_drop)
        for item in training_grids:
            if item in droppable_other_grid_ids:
                training_grids.remove(item)
    
    # add preceding zeros to the list of training grids
    training_grids = [str(item).zfill(6) for item in training_grids]
    val_grids = list(np.setdiff1d(lbl_ids, training_grids))
    training_grids.sort()
    val_grids.sort()
    
    # Create proper folders for the splitted dataset
    lbl_train_out_path = lbl_dir / "train"
    lbl_val_out_path = lbl_dir / "validation"
    Path(lbl_train_out_path).mkdir(parents=True, exist_ok=True)
    Path(lbl_val_out_path).mkdir(parents=True, exist_ok=True)
    
    # Copy labels into train, validate subfolders.
    for id, fn in zip(lbl_ids, lbl_fnames):
        
        if (id in training_grids) and (id in fn.name):
            #shutil.copy(fn, lbl_train_out_path)
            shutil.move(str(fn), str(lbl_train_out_path))
        
        if (id in val_grids) and (id in fn.name):
            #shutil.copy(fn, lbl_val_out_path)
            shutil.move(str(fn), str(lbl_val_out_path))
    
    # Copy the img dataset based on the grid-ID to equivalent subfolders.
    for source in sources:        
        
        src_dir = Path(root_dir) / country / source
        src_train_out_path = src_dir / "train"
        src_val_out_path = src_dir / "validation"
        
        Path(src_train_out_path).mkdir(parents=True, exist_ok=True)
        Path(src_val_out_path).mkdir(parents=True, exist_ok=True)
        
        src_files = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(src_dir) for \
                     f in filenames]
        
        for id in lbl_ids:
            for fn in src_files:
                if (id in training_grids) and (id in fn.name):
                    #shutil.copy(fn, src_train_out_path)
                    shutil.move(str(fn), str(src_train_out_path))
            
                if (id in val_grids) and (id in fn.name):
                    #shutil.copy(fn, src_val_out_path)
                    shutil.move(str(fn), str(src_val_out_path))

In [None]:
root_dir = "D:/CropType"
country = "Ghana"
sources = ["S1_npy", "S2_npy"]
lbl_fldrname = "Labels"
csv_fn = "report.csv"
split_threshold = 0.61

In [None]:
create_grid_splits(root_dir, country, sources, lbl_fldrname, csv_fn, split_threshold)

## Step 10. Make pixel dataset

In [None]:
def reclass_cloudmask_stack(cloud_stack):
    """ 
     Reclassify cloud mask values to a binary class of cloud and clear.
     clear = 0 --> 0, clouds = 1  --> 1, shadows = 2 --> 1, haze = 3 --> 1
     
     output: numpy nd array with the same size as input.
    """
    remapped_cloud_stack = np.zeros_like((cloud_stack))
    remapped_cloud_stack[cloud_stack == 0] = 0
    remapped_cloud_stack[cloud_stack == 1] = 1
    remapped_cloud_stack[cloud_stack == 2] = 1
    remapped_cloud_stack[cloud_stack == 3] = 1
    
    return remapped_cloud_stack

############################################################

def load_data(dataPath, isLabel = False):
    """Load the dataset.
    Args:
        dataPath (str) -- Path to either the image or label raster.
        isLabel (binary) -- decide wether the input dataset is label. Default is False.
    
    Returns:
        loaded data as numpy ndarray. 
    """
    
    if isLabel:
        
        with rasterio.open(dataPath, "r") as src:
            
            if src.count != 1:
                raise ValueError("Label must have only 1 band but {} bands were detected.".format(src.count))
            img = src.read(1)
    
    else:
        img = np.load(dataPath)
    
    return img

############################################################


def Make_pixel_dataset(root_dir, country, sources, lbl_fldrname, categories, usage, verbose):
    
    lbl_dir = Path(root_dir) / country / lbl_fldrname / usage
    lbl_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(lbl_dir) for f in filenames if f.endswith(".tif")]
    lbl_fnames.sort()

    s1_src_path = Path(root_dir) / country / "Sentinel-1" / usage
    s1_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(s1_src_path) for f in filenames if f.endswith(".npy")]
    s1_fnames.sort()
            
    s2_src_path = Path(root_dir) / country / "Sentinel-2" / usage
    s2_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(s2_src_path) for f in filenames if f.endswith(".npy") if "source" in f]
    cmasks = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(s2_src_path) for f in filenames if f.endswith(".npy") if "cloudmask" in f]
    s2_fnames.sort()
    cmasks.sort()
    
    category_list = list(categories.keys())
    inv_category_map = {v: k for k, v in categories.items()}
    
    for cat in category_list:
        s1_out_dir = s1_src_path / cat
        s2_out_dir = s2_src_path / cat
        Path(s1_out_dir).mkdir(parents=True, exist_ok=True)
        Path(s2_out_dir).mkdir(parents=True, exist_ok=True)
        
    for lbl_fn, s1_fn, s2_fn, cmask_fn in tqdm.tqdm(zip(lbl_fnames, s1_fnames, s2_fnames, cmasks), total = len(lbl_fnames)):
        
        lbl_grid_id = str(lbl_fn).split("_")[-1].replace(".tif", "")
        s1_grid_id = str(s1_fn).split("_")[-1].replace(".npy", "")
        s2_grid_id = str(s2_fn).split("_")[-1].replace(".npy", "")
        cmask_grid_id = str(cmask_fn).split("_")[-1].replace(".npy", "")
        assert lbl_grid_id == s1_grid_id == s2_grid_id == cmask_grid_id, "problematic grid id: {}".format(lbl_grid_id)
        
        lbl_array = load_data(lbl_fn, isLabel = True)
        s1_array = load_data(s1_fn, isLabel = False)
        s2_array = load_data(s2_fn, isLabel = False)
        s2_array[s2_array == +inf] = 0
        s2_array[s2_array == -inf] = 0
        cmask_array = load_data(cmask_fn, isLabel = False)
        binary_cloud_array = reclass_cloudmask_stack(cmask_array)
        
        assert lbl_array.shape[0] == s1_array.shape[1] == s2_array.shape[1] == cmask_array.shape[0], "problematic grid id: {}".format(lbl_grid_id)
        assert lbl_array.shape[1] == s1_array.shape[2] == s2_array.shape[2] == cmask_array.shape[1], "problematic grid id: {}".format(lbl_grid_id)
        
        unique_vals, unique_counts = np.unique(lbl_array, return_counts=True)
        cloudy_days = np.sum(binary_cloud_array, axis=2)
        cloudy_days = cloudy_days / cmask_array.shape[0]
        
        for val, count in zip(unique_vals, unique_counts):
            mask = lbl_array == [val]
            crop_indices = np.where(mask)
            crop_coordinates = list(zip(crop_indices[0], crop_indices[1]))
            
            if val == 0:
                num_samples = min(60, np.sum(mask))
                crop_coordinates = random.sample(crop_coordinates, num_samples)

            cr_ls = [cloudy_days[coord[0], coord[1]] for coord in crop_coordinates]
            df = pd.DataFrame(zip(crop_coordinates, cr_ls), columns=['Coordinates','Cloudiness'])
            df = df.sort_values("Cloudiness")
            ranked_crop_coordinates = list(df.Coordinates)
                
            if verbose:
                print("Grid ID: {}, number of {} samples: {}".format(lbl_grid_id, inv_category_map[val], len(crop_coordinates)))
            
            for index, coord in enumerate(ranked_crop_coordinates, start=1):
                lbl_val = lbl_array[coord[0], coord[1]]
                #lbl_out_path = lbl_dir / inv_category_map[val]
                #lbl_out_fname = "lbl_"+lbl_grid_id+"_sample_"+str(index)
                #np.save(lbl_out_path / lbl_out_fname, lbl_val)
                
                s1_val = s1_array[:,coord[0], coord[1],:]
                s1_out_path = s1_src_path / inv_category_map[val]
                s1_out_fname = "s1_"+s1_grid_id+"_sample_"+str(index)+"_lbl_"+str(lbl_val)
                np.save(s1_out_path / s1_out_fname, s1_val)
                
                s2_val = s2_array[:,coord[0], coord[1],:]
                s2_out_path = s2_src_path / inv_category_map[val]
                s2_out_fname = "s2_"+s2_grid_id+"_sample_"+str(index)+"_lbl_"+str(lbl_val)
                np.save(s2_out_path / s2_out_fname, s2_val)
                    
                #cmask_val = cmask_array[coord[0], coord[1],:]
                #cmask_out_fname = "cmask_"+cmask_grid_id+"_sample_"+str(index)
                #np.save(s2_out_path / cmask_out_fname, cmask_val)

In [None]:
root_dir = "D:/CropType"
#root_dir = "C:/My_documents/CropTypeData_Rustowicz/CropType"
country = "Ghana"
sources = ["Sentinel-1", "Sentinel-2"]
lbl_fldrname = "Labels"
usage = "validation"
#categories = {"unknown": 0, "maize": 1, "rice": 2, "other_crop": 3, "non_crop": 4}
categories = {"unknown": 0, "maize": 1, "rice": 2, "other_crop": 3}
verbose = False

In [None]:
Make_pixel_dataset(root_dir, country, sources, lbl_fldrname, categories, usage, verbose)

## Optional

### Make temporal cloud mask for non-crop sampling

In [None]:
def reclass_cloudmask_stack(cloud_stack):
    """ 
     Reclassify cloud mask values to a binary class of cloud and clear.
     clear = 0 --> 0, clouds = 1  --> 1, shadows = 2 --> 1, haze = 3 --> 1
     
     output: numpy nd array with the same size as input.
    """
    remapped_cloud_stack = np.zeros_like((cloud_stack))
    remapped_cloud_stack[cloud_stack == 0] = 0
    remapped_cloud_stack[cloud_stack == 1] = 1
    remapped_cloud_stack[cloud_stack == 2] = 1
    remapped_cloud_stack[cloud_stack == 3] = 1
    
    return remapped_cloud_stack


def Make_temporal_cloud_mask(root_dir, source, cday_threshold):
    
    src_path = Path(root_dir) / source
    src_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(src_path) for f in filenames if f.endswith(".npy") if "source" in f]
    cmasks = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(src_path) for f in filenames if f.endswith(".npy") if "cloudmask" in f]
    src_fnames.sort()
    cmasks.sort()
        
    outdir = Path(root_dir) / "temporal_cloud_mask"
    Path(outdir).mkdir(parents=True, exist_ok=True)
    
    for src_fn, cmask_fn in tqdm.tqdm(zip(src_fnames, cmasks), total = len(src_fnames)):
        
        src_grid_id = str(src_fn).split("_")[-1].replace(".npy", "")
        cmask_grid_id = str(cmask_fn).split("_")[-1].replace(".npy", "")
        assert src_grid_id == cmask_grid_id, "problematic grid id: {}".format(lbl_grid_id)
        
        src_array = np.load(src_fn)
        cmask_array = np.load(cmask_fn)
        binary_cloud_array = reclass_cloudmask_stack(cmask_array)
        
        assert src_array.shape[1] == cmask_array.shape[1], "problematic grid id: {}".format(lbl_grid_id)
        assert src_array.shape[2] == cmask_array.shape[1], "problematic grid id: {}".format(lbl_grid_id)
        
        cloudy_days = np.sum(binary_cloud_array, axis=(0,3))
        cloudy_days = cloudy_days / cmask_array.shape[0]
        temporal_cloud_mask = np.where(cloudy_days < cday_threshold, 1, 0)
        
        out_profile = {
            'driver': 'GTiff', 'dtype': 'int32', 'nodata': None, 'width': 64, 'height': 64, 'count': 1, 'crs': None,
            'transform': rasterio.Affine(1.0, 0.0, 0.0,0.0, 1.0, 0.0), 'tiled': False, 'interleave': 'band'
                       }
        
        out_name = "temporal_cloud_mask_22days_{}.tif".format(src_grid_id)
        with rasterio.open(Path(outdir) / out_name, "w", **out_profile) as dst:
            dst.write(np.expand_dims(temporal_cloud_mask, 0))
        
    print("SEOO")

In [None]:
def reclass_cloudmask_stack(cloud_stack):
    """ 
     Reclassify cloud mask values to a binary class of cloud and clear.
     clear = 0 --> 0, clouds = 1  --> 1, shadows = 2 --> 1, haze = 3 --> 1
     
     output: numpy nd array with the same size as input.
    """
    remapped_cloud_stack = np.zeros_like((cloud_stack))
    remapped_cloud_stack[cloud_stack == 0] = 0
    remapped_cloud_stack[cloud_stack == 1] = 1
    remapped_cloud_stack[cloud_stack == 2] = 1
    remapped_cloud_stack[cloud_stack == 3] = 1
    
    return remapped_cloud_stack


def Make_temporal_cloud_mask(root_dir, source, cday_threshold):
    
    src_path = Path(root_dir) / source
    cmasks = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(src_path) for f in filenames if f.endswith(".npy") if "cloudmask" in f]
    cmasks.sort()
        
    outdir = Path(root_dir) / "temporal_cloud_mask"
    Path(outdir).mkdir(parents=True, exist_ok=True)
    
    for cmask_fn in tqdm.tqdm(cmasks):
        
        cmask_grid_id = str(cmask_fn).split("_")[-1].replace(".npy", "")
        cmask_array = np.load(cmask_fn)
        binary_cloud_array = reclass_cloudmask_stack(cmask_array)
        
        cloudy_days = np.sum(binary_cloud_array, axis=2)
        cloudy_days = cloudy_days / cmask_array.shape[0]
        temporal_cloud_mask = np.where(cloudy_days < cday_threshold, 1, 0)
        
        out_profile = {
            'driver': 'GTiff', 'dtype': 'int32', 'nodata': None, 'width': 64, 'height': 64, 'count': 1, 'crs': None,
            'transform': rasterio.Affine(1.0, 0.0, 0.0,0.0, 1.0, 0.0), 'tiled': False, 'interleave': 'band'
                       }
        
        out_name = "temporal_cloud_mask_0.5_{}.tif".format(cmask_grid_id)
        with rasterio.open(Path(outdir) / out_name, "w", **out_profile) as dst:
            dst.write(np.expand_dims(temporal_cloud_mask, 0))
        
    print("SEOO")

In [None]:
root_dir = "D:/CropType/Ghana/Original_dataset/Ghana"
source = "S2_npy"
cday_threshold = 0.5

In [None]:
Make_temporal_cloud_mask(root_dir, source, cday_threshold)