### 1. Convert to Geotiff, reprojection, clip

In [4]:
import os
from osgeo import gdal
data_path = "/130TB_RAID0/fujiang/SmallSat_part2/2_high_resolution_trait_maps/1_HR_trait_maps/"
out_path = "/130TB_RAID0/fujiang/SmallSat_part2/2_high_resolution_trait_maps/2_reprojection_trait_maps/"
filenames = os.listdir(data_path)
filenames = [x for x in filenames if ".aux.xml" not in x and ".hdr" not in x]

for file in filenames:
    in_file = f"{data_path}{file}"
    out_file = f"{out_path}{file}.tif"
    
    input_ds = gdal.Open(in_file)
    output_ds = gdal.Warp(out_file, input_ds, dstSRS='EPSG:32611', format='GTiff')
    

# in_file = "/130TB_raid0/fujiang/SmallSat_part2/1_imagery_data/SHIFT_RFL_20220420_vnir.tif"
# out_file = "/130TB_raid0/fujiang/SmallSat_part2/1_imagery_data/SHIFT_RFL_20220420_vnir_reprojection.tif"

# input_ds = gdal.Open(in_file)
# output_ds = gdal.Warp(out_file, input_ds, dstSRS='EPSG:32611', format='GTiff')

In [6]:
import os
from osgeo import gdal,gdalconst
import geopandas as gpd
import warnings
warnings.filterwarnings("ignore")

data_path = "/130TB_RAID0/fujiang/SmallSat_part2/2_high_resolution_trait_maps/2_reprojection_trait_maps/"
out_path = "/130TB_RAID0/fujiang/SmallSat_part2/2_high_resolution_trait_maps/3_clipped_trait_maps/"
ul_x_new, lr_y_new, lr_x_new, ul_y_new = 174714.77272756316, 3816019.5146135557, 193347.5743273191, 3843232.8023862587

filenames = os.listdir(data_path)
filenames = [x for x in filenames if ".aux.xml" not in x]
for file in filenames:
    in_file = f"{data_path}{file}"
    out_file = f"{out_path}{file[:-4]}_clipped.tif"   
    gdal.Warp(out_file, in_file, xRes=5, yRes=5, outputBounds=(ul_x_new, lr_y_new, lr_x_new, ul_y_new), cropToCutline=True, format='GTiff',resampleAlg=gdalconst.GRA_Bilinear)
  
# in_file = "/130TB_raid0/fujiang/SmallSat_part2/1_imagery_data/SHIFT_RFL_20220420_vnir_reprojection.tif"
# out_file = "/130TB_raid0/fujiang/SmallSat_part2/1_imagery_data/SHIFT_RFL_20220420_vnir_reprojection_clipped.tif" 
# gdal.Warp(out_file, in_file, xRes=5, yRes=5, outputBounds=(ul_x_new, lr_y_new, lr_x_new, ul_y_new), cropToCutline=True, format='GTiff',resampleAlg=gdalconst.GRA_Bilinear)

### 2. Upscale trait maps

In [47]:
import os
import numpy as np
from osgeo import gdal,gdalconst
import matplotlib.pyplot as plt

def read_tif(tif_file):
    dataset = gdal.Open(tif_file)
    cols = dataset.RasterXSize
    rows = dataset.RasterYSize
    im_proj = (dataset.GetProjection())
    im_Geotrans = (dataset.GetGeoTransform())
    im_data = dataset.ReadAsArray(0, 0, cols, rows)
    if im_data.ndim == 3:
        im_data = np.moveaxis(dataset.ReadAsArray(0, 0, cols, rows), 0, -1)
    return im_data, im_Geotrans, im_proj,rows, cols

def array_to_geotiff(array, output_path, geo_transform, projection, band_names=None):
    rows, cols, num_bands = array.shape
    driver = gdal.GetDriverByName('GTiff')
    dataset = driver.Create(output_path, cols, rows, num_bands, gdal.GDT_Float32)
    
    dataset.SetGeoTransform(geo_transform)
    dataset.SetProjection(projection)
    
    for band_num in range(num_bands):
        band = dataset.GetRasterBand(band_num + 1)
        band.WriteArray(array[:, :, band_num])
        band.FlushCache()
        
        if band_names:
            band.SetDescription(band_names[band_num])
    dataset = None
    band = None
    return

planet = "/130TB_RAID0/fujiang/SmallSat_part2/1_imagery_data/PlanetScope_RFL_20230422_clipped_resampled.tif"
planet_data, planet_Geotrans, planet_proj,_, _ = read_tif(planet)

nir_band = planet_data[:,:, 7]
red_band = planet_data[:,:, 5]
ndvi = (nir_band - red_band)/(nir_band + red_band)
condition1 = ndvi > 0.3
condition2 = nir_band>0.1
mask1 = np.where(condition1, 1, np.nan)
mask2 = np.where(condition2, 1, np.nan)
mask = mask1 * mask2
mask = np.expand_dims(mask, axis=-1)

In [51]:
data_path = "/130TB_RAID0/fujiang/SmallSat_part2/2_high_resolution_trait_maps/3_clipped_trait_maps/"
out_path = "/130TB_RAID0/fujiang/SmallSat_part2/2_high_resolution_trait_maps/4_upscaling_trait_maps/"

filenames = os.listdir(data_path)
filenames = [x for x in filenames if ".aux.xml" not in x]

for idx, file in enumerate(filenames):
    trait_file = f"{data_path}{file}"
    trait_array, trait_Geotrans, trait_proj, rows, cols =  read_tif(trait_file)
    if (file == '20220420_d15N_clipped.tif')|(file == '20220420_d13C_clipped.tif'):
        trait_array[trait_array<-100] = np.nan
    else:
        trait_array[trait_array<0] = np.nan
        
    masked_trait_map = mask * trait_array
    
    aggregated_traits = np.zeros(shape = ((rows//12)+1, (cols//12)+1, masked_trait_map.shape[2]))
    veg_fraction = np.zeros(shape = ((rows//12)+1, (cols//12)+1, mask.shape[2]))
    
    for ii in range(0, rows, 12):
        for jj in range(0, cols, 12):
            data_array = masked_trait_map[ii: min(rows, (ii + 12)), jj: min(cols, (jj + 12)),:]
            mean_values = np.nanmean(data_array, axis = (0,1))
            aggregated_traits[int(ii/12), int(jj/12), :] = mean_values

            data_array2 = mask[ii: min(rows, (ii + 12)), jj: min(cols, (jj + 12)),:]
            one_counts = np.count_nonzero(data_array2 == 1)
            fraction = one_counts/144
            veg_fraction[int(ii/12), int(jj/12),:] = fraction
    
    output_path = f"{out_path}/{file[:-4]}_60m.tif"
    mask_path = f"{out_path}/0_20220420_vegetation_fraction.tif"
    
    geo_transform = (trait_Geotrans[0], 60.0, trait_Geotrans[2],
                     trait_Geotrans[3],trait_Geotrans[4],-60.0)

    array_to_geotiff(aggregated_traits, output_path, geo_transform, trait_proj, band_names = ["mean", "uncertainty"])
    if idx ==0:
        array_to_geotiff(veg_fraction, mask_path, geo_transform, trait_proj, band_names = ["mask"])

## 3. Extract training samples

In [1]:
import os
from osgeo import gdal,gdalconst
import geopandas as gpd
import numpy as np 
import pandas as pd
import warnings
warnings.filterwarnings("ignore")

def raster_to_points(geotiff, shp_name):
    inDs = gdal.Open(geotiff)
    DsoutDs = gdal.Translate(f"{shp_name}.xyz", inDs, format='XYZ', creationOptions=["ADD_HEADER_LINE=YES"])
    outDs = None
    try:
        os.remove(f'{shp_name}.csv')
    except OSError:
        pass

    os.rename(f'{shp_name}.xyz', f'{shp_name}.csv')
    os.system('ogr2ogr -f "ESRI Shapefile" -oo X_POSSIBLE_NAMES=X* -oo Y_POSSIBLE_NAMES=Y* -oo KEEP_GEOM_COLUMNS=NO {0}.shp {0}.csv'.format(shp_name))
    try:
        os.remove(f'{shp_name}.csv')
    except OSError:
        pass
    crs_wkt = inDs.GetProjection()
    shp_layer = gpd.read_file(f"{shp_name}.shp")
    shp_layer.crs = crs_wkt
    shp_layer.to_file(f"{shp_name}.shp")
    return

def filter_points(gdf, min_distance):
    gdf_copy = gdf.copy()
    to_remove = []

    for index, row in gdf_copy.iterrows():
        if index not in to_remove:
            distances = gdf_copy.geometry.distance(row.geometry)
            close_points = distances[distances < min_distance].index.tolist()
            close_points.remove(index)
            to_remove.extend(close_points)

    gdf_copy.drop(index=to_remove, inplace=True)
    gdf_copy.reset_index(drop = True, inplace = True)
    return gdf_copy

In [None]:
# convert to points shapefile
data = "/130TB_RAID0/fujiang/SmallSat_part2/2_high_resolution_trait_maps/4_upscaling_trait_maps/0_20220420_vegetation_fraction.tif"
out_path = "/130TB_RAID0/fujiang/SmallSat_part2/2_high_resolution_trait_maps/5_extract_training_samples/"
raster_to_points(data, f"{out_path}1_extracted_points")
# 60 distance to avoid spatial autocorelation
points = gpd.read_file("/130TB_RAID0/fujiang/SmallSat_part2/2_high_resolution_trait_maps/5_extract_training_samples/1_extracted_points.shp")
filtered_points = filter_points(points, 100)
filtered_points.to_file("/130TB_RAID0/fujiang/SmallSat_part2/2_high_resolution_trait_maps/5_extract_training_samples/2_extracted_points_filter.shp")

In [None]:
# extract trait values
point_path = '/130TB_RAID0/fujiang/SmallSat_part2/2_high_resolution_trait_maps/5_extract_training_samples/'
points_files = ["2_extracted_points_filter.shp", "2_extracted_points_filter_train.shp", "2_extracted_points_filter_test.shp"]
out_name = ["3_extracted_points.csv", "3_extracted_points_train.csv", "3_extracted_points_test.csv"]

HR_image_path = "/130TB_RAID0/fujiang/SmallSat_part2/1_imagery_data/1_fused_imagery/"
HR_image = "/130TB_RAID0/fujiang/SmallSat_part2/1_imagery_data/PlanetScope_RFL_20230422_clipped_resampled.tif"
LR_image = "/130TB_RAID0/fujiang/SmallSat_part2/1_imagery_data/EMIT_L2A_RFL_20230422_clipped_resampled.tif"

HR_traits = "/130TB_RAID0/fujiang/SmallSat_part2/2_high_resolution_trait_maps/3_clipped_trait_maps/"
upscaled_traits = "/130TB_RAID0/fujiang/SmallSat_part2/2_high_resolution_trait_maps/4_upscaling_trait_maps/"
out_path = "/130TB_RAID0/fujiang/SmallSat_part2/2_high_resolution_trait_maps/5_extract_training_samples/"

filenames1 = os.listdir(HR_traits)
filenames1 = [x for x in filenames1 if ".aux.xml" not in x]

filenames2 = os.listdir(upscaled_traits)
filenames2 = [x for x in filenames2 if ".aux.xml" not in x and "vegetation_fraction" not in x]

traits = ['LWC_area', 'magnesium', 'd15N', 'phosphorus', 'Manganese_ppm', 'hemicellulose', 'cellulose', 'd13C', 'chl_area', 'phenolics_mg_g', 'LWC', 
          'sugars_mg_g', 'Phenolics_DS', 'Sugars_DS', 'potassium', 'sulfur', 'NSC_DS', 'nitrogen', 'lignin', 'LMA', 'Aluminum_ppm', '%C', 'calcium']

sub_traits = ['LWC_area', 'magnesium', 'phosphorus', 'Manganese_ppm', 'hemicellulose', 'cellulose', 'chl_area', 'phenolics_mg_g', 'LWC', 'sugars_mg_g',
              'Phenolics_DS', 'Sugars_DS', 'potassium', 'sulfur', 'NSC_DS', 'nitrogen', 'lignin', 'LMA', 'Aluminum_ppm', '%C', 'calcium']

for idx, points_file in enumerate(points_files):
    print(f"***********************{points_file}***********************")
    points = gpd.read_file(f"{point_path}{points_file}")
    points["vege_frac"] = points["Z"].astype(float)
    points.drop(columns=['Z'],inplace = True)
    points.reset_index(drop = True, inplace = True)
    # **********************************************Extract HR_traits**********************************************
    print("Start extract the high resolution trait values")
    for file in filenames1:
        in_tif = f"{HR_traits}{file}"
        tiff_ds = gdal.Open(in_tif)
        
        band = tiff_ds.GetRasterBand(1)
        extracted_values = []
        for index, row in points.iterrows():
            point = row.geometry
            x, y = point.x, point.y
            px = int((x - tiff_ds.GetGeoTransform()[0]) / tiff_ds.GetGeoTransform()[1])
            py = int((y - tiff_ds.GetGeoTransform()[3]) / tiff_ds.GetGeoTransform()[5])
            value = band.ReadAsArray(px, py, 1, 1)[0][0]
            extracted_values.append(value)
        
        splits = file.split("_")[1:-1]
        aa = ("_").join(splits)
        col_name = f"HR_{aa}"
        
        points[col_name] = extracted_values

    
    filter_cols = [f"HR_{x}" for x in sub_traits]
    points = points[~((points[filter_cols] < 0) | (points[filter_cols].isna())).any(axis=1)]
    points.reset_index(drop = True, inplace = True)

    # **********************************************Extract upscaled_traits**********************************************
    print("Start extract the upscaled trait values")
    for file in filenames2:
        in_tif = f"{upscaled_traits}{file}"
        tiff_ds = gdal.Open(in_tif)
        
        band = tiff_ds.GetRasterBand(1)
        extracted_values = []
        for index, row in points.iterrows():
            point = row.geometry
            x, y = point.x, point.y
            px = int((x - tiff_ds.GetGeoTransform()[0]) / tiff_ds.GetGeoTransform()[1])
            py = int((y - tiff_ds.GetGeoTransform()[3]) / tiff_ds.GetGeoTransform()[5])
            value = band.ReadAsArray(px, py, 1, 1)[0][0]
            extracted_values.append(value)

        splits = file.split("_")[1:-2]
        aa = ("_").join(splits)
        col_name = f"upscaled_{aa}"
        points[col_name] = extracted_values
    
    filter_cols = [f"upscaled_{x}" for x in sub_traits]    
    points = points[~((points[filter_cols] < 0) | (points[filter_cols].isna())).any(axis=1)]
    points.reset_index(drop = True, inplace = True)

    # **********************************************Extract EMIT_reflectance**********************************************
    print("Start extract the reflectance of EMIT imagery")
    tiff_ds = gdal.Open(LR_image)
    num_bands = tiff_ds.RasterCount

    extracted_values = [[] for _ in range(num_bands)]
    for index, row in points.iterrows():
        point = row.geometry
        x, y = point.x, point.y
        
        px = int((x - tiff_ds.GetGeoTransform()[0]) / tiff_ds.GetGeoTransform()[1])
        py = int((y - tiff_ds.GetGeoTransform()[3]) / tiff_ds.GetGeoTransform()[5])
        
        for band_num in range(1, num_bands + 1):
            band = tiff_ds.GetRasterBand(band_num)
            value = band.ReadAsArray(px, py, 1, 1)[0][0]
            extracted_values[band_num - 1].append(value)

    extracted_values = pd.DataFrame(np.array(extracted_values)).T
    extracted_values.columns = [f"EMIT_{x}_nm" for x in np.arange(400, 2401, 10)]
    points = pd.concat([points, extracted_values], axis = 1)
    
    # **********************************************Extract Planet_reflectance**********************************************
    print("Start extract the reflectance of PlanetScope imagery")
    tiff_ds = gdal.Open(HR_image)
    num_bands = tiff_ds.RasterCount

    extracted_values = [[] for _ in range(num_bands)]
    for index, row in points.iterrows():
        point = row.geometry
        x, y = point.x, point.y
        
        px = int((x - tiff_ds.GetGeoTransform()[0]) / tiff_ds.GetGeoTransform()[1])
        py = int((y - tiff_ds.GetGeoTransform()[3]) / tiff_ds.GetGeoTransform()[5])
        
        for band_num in range(1, num_bands + 1):
            band = tiff_ds.GetRasterBand(band_num)
            value = band.ReadAsArray(px, py, 1, 1)[0][0]
            extracted_values[band_num - 1].append(value)

    extracted_values = pd.DataFrame(np.array(extracted_values)).T
    extracted_values.columns = [f"Planet_{x}_nm" for x in [443, 490, 531, 565, 610, 665, 705, 865]]
    points = pd.concat([points, extracted_values], axis = 1)

    # **********************************************Extract fused_image_reflectance**********************************************
    print("Start extract the reflectance of fused images")
    filenames = os.listdir(HR_image_path)
    filenames = [x for x in filenames if ".aux.xml" not in x]

    start_var = True
    for file in filenames:
        model = file.split("_")[0]
        if model == "Fusion3":
            model = "MSHFNET"
        elif model == "Fusion5":
            model = "MSAHFNET"
        print(f"   -{model}")
            
        tiff_ds = gdal.Open(f"{HR_image_path}{file}")
        num_bands = tiff_ds.RasterCount
        
        extracted_values = [[] for _ in range(num_bands)]
        for index, row in points.iterrows():
            point = row.geometry
            x, y = point.x, point.y
            
            px = int((x - tiff_ds.GetGeoTransform()[0]) / tiff_ds.GetGeoTransform()[1])
            py = int((y - tiff_ds.GetGeoTransform()[3]) / tiff_ds.GetGeoTransform()[5])
            
            for band_num in range(1, num_bands + 1):
                band = tiff_ds.GetRasterBand(band_num)
                arr = band.ReadAsArray(px, py, 1, 1)
                if arr is not None:
                    value = arr[0][0]
                else:
                    value = np.nan
                extracted_values[band_num - 1].append(value)

        extracted_values = pd.DataFrame(np.array(extracted_values)).T
        extracted_values.columns = [f"{model}_model_{x}_nm" for x in np.arange(400, 2401, 10)]
        if start_var:
            fused_refl = extracted_values
            start_var = False
        else:
            fused_refl = pd.concat([fused_refl, extracted_values], axis = 1)

    fused_refl.reset_index(drop = True, inplace = True)
    points = pd.concat([points, fused_refl], axis =1)
    points = points[~points.filter(like="_model_").isna().all(axis=1)]

    saved_points = points.copy()
    saved_points.drop(columns = ["geometry"],inplace = True)
    saved_points.to_csv(f'{out_path}{out_name[idx]}',index = False)