In [11]:
import os
from osgeo import gdal,gdalconst
import geopandas as gpd
import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

def read_tif(tif_file):
    dataset = gdal.Open(tif_file)
    cols = dataset.RasterXSize
    rows = dataset.RasterYSize
    im_proj = (dataset.GetProjection())
    im_Geotrans = (dataset.GetGeoTransform())
    im_data = dataset.ReadAsArray(0, 0, cols, rows)
    if im_data.ndim == 3:
        im_data = np.moveaxis(dataset.ReadAsArray(0, 0, cols, rows), 0, -1)
    return im_data, im_Geotrans, im_proj,rows, cols

def array_to_geotiff(array, output_path, geo_transform, projection, band_names=None):
    rows, cols, num_bands = array.shape
    driver = gdal.GetDriverByName('GTiff')
    dataset = driver.Create(output_path, cols, rows, num_bands, gdal.GDT_Float32)
    
    dataset.SetGeoTransform(geo_transform)
    dataset.SetProjection(projection)
    
    for band_num in range(num_bands):
        band = dataset.GetRasterBand(band_num + 1)
        band.WriteArray(array[:, :, band_num])
        band.FlushCache()
        
        if band_names:
            band.SetDescription(band_names[band_num])
    dataset = None
    band = None
    return

In [6]:
## Extract EMIT values
points_file = '/130TB_RAID0/fujiang/SmallSat_part2/2_high_resolution_trait_maps/5_extract_training_samples/0_sparse_vagetated_area/2_selected_EMIT_points.shp'
out_points = "/130TB_RAID0/fujiang/SmallSat_part2/2_high_resolution_trait_maps/5_extract_training_samples/0_sparse_vagetated_area/2_EMIT_points.csv"

upscaled_traits = "/130TB_RAID0/fujiang/SmallSat_part2/2_high_resolution_trait_maps/4_upscaling_trait_maps/"
LR_image = "/130TB_RAID0/fujiang/SmallSat_part2/1_imagery_data/EMIT_L2A_RFL_20230422_clipped_resampled.tif"

traits = ['LWC_area', 'magnesium', 'd15N', 'phosphorus', 'Manganese_ppm', 'hemicellulose', 'cellulose', 'd13C', 'chl_area', 'phenolics_mg_g', 'LWC', 
          'sugars_mg_g', 'Phenolics_DS', 'Sugars_DS', 'potassium', 'sulfur', 'NSC_DS', 'nitrogen', 'lignin', 'LMA', 'Aluminum_ppm', '%C', 'calcium']

sub_traits = ['LWC_area', 'magnesium', 'phosphorus', 'Manganese_ppm', 'hemicellulose', 'cellulose', 'chl_area', 'phenolics_mg_g', 'LWC', 'sugars_mg_g',
              'Phenolics_DS', 'Sugars_DS', 'potassium', 'sulfur', 'NSC_DS', 'nitrogen', 'lignin', 'LMA', 'Aluminum_ppm', '%C', 'calcium']


filenames2 = os.listdir(upscaled_traits)
filenames2 = [x for x in filenames2 if ".aux.xml" not in x and "vegetation_fraction" not in x]

points = gpd.read_file(points_file)
points['centroid'] = points.geometry.centroid


# **********************************************Extract upscaled_traits**********************************************
print("Start extract the upscaled trait values")
for file in filenames2:
    in_tif = f"{upscaled_traits}{file}"
    tiff_ds = gdal.Open(in_tif)
    
    band = tiff_ds.GetRasterBand(1)
    extracted_values = []
    for index, row in points.iterrows():
        point = row.centroid
        x, y = point.x, point.y
        px = int((x - tiff_ds.GetGeoTransform()[0]) / tiff_ds.GetGeoTransform()[1])
        py = int((y - tiff_ds.GetGeoTransform()[3]) / tiff_ds.GetGeoTransform()[5])
        value = band.ReadAsArray(px, py, 1, 1)[0][0]
        extracted_values.append(value)

    splits = file.split("_")[1:-2]
    aa = ("_").join(splits)
    col_name = f"upscaled_{aa}"
    points[col_name] = extracted_values

filter_cols = [f"upscaled_{x}" for x in sub_traits]    
points = points[~((points[filter_cols] < 0) | (points[filter_cols].isna())).any(axis=1)]
points.reset_index(drop = True, inplace = True)

# **********************************************Extract EMIT_reflectance**********************************************
print("Start extract the reflectance of EMIT imagery")
tiff_ds = gdal.Open(LR_image)
num_bands = tiff_ds.RasterCount

extracted_values = [[] for _ in range(num_bands)]
for index, row in points.iterrows():
    point = row.centroid
    x, y = point.x, point.y
    
    px = int((x - tiff_ds.GetGeoTransform()[0]) / tiff_ds.GetGeoTransform()[1])
    py = int((y - tiff_ds.GetGeoTransform()[3]) / tiff_ds.GetGeoTransform()[5])
    
    for band_num in range(1, num_bands + 1):
        band = tiff_ds.GetRasterBand(band_num)
        value = band.ReadAsArray(px, py, 1, 1)[0][0]
        extracted_values[band_num - 1].append(value)

extracted_values = pd.DataFrame(np.array(extracted_values)).T
extracted_values.columns = [f"EMIT_{x}_nm" for x in np.arange(400, 2401, 10)]
points = pd.concat([points, extracted_values], axis = 1)

points.insert(0, "X", points['centroid'].x, allow_duplicates=False)
points.insert(1, "Y", points['centroid'].y, allow_duplicates=False)
points.drop(columns = ["geometry", "centroid"],inplace = True)
points.to_csv(out_points,index = False)

Start extract the upscaled trait values
Start extract the reflectance of EMIT imagery


In [10]:
## Extract HR values
points_file = '/130TB_RAID0/fujiang/SmallSat_part2/2_high_resolution_trait_maps/5_extract_training_samples/0_sparse_vagetated_area/2_selected_HR_points.shp'
out_points = "/130TB_RAID0/fujiang/SmallSat_part2/2_high_resolution_trait_maps/5_extract_training_samples/0_sparse_vagetated_area/2_HR_points.csv"

HR_traits = "/130TB_RAID0/fujiang/SmallSat_part2/2_high_resolution_trait_maps/3_clipped_trait_maps/"
HR_image_path = "/130TB_RAID0/fujiang/SmallSat_part2/1_imagery_data/1_fused_imagery/"
HR_image = "/130TB_RAID0/fujiang/SmallSat_part2/1_imagery_data/PlanetScope_RFL_20230422_clipped_resampled.tif"

filenames1 = os.listdir(HR_traits)
filenames1 = [x for x in filenames1 if ".aux.xml" not in x]

points = gpd.read_file(points_file)
# **********************************************Extract HR_traits**********************************************
print("Start extract the high resolution trait values")
for file in filenames1:
    in_tif = f"{HR_traits}{file}"
    tiff_ds = gdal.Open(in_tif)
    
    band = tiff_ds.GetRasterBand(1)
    extracted_values = []
    for index, row in points.iterrows():
        point = row.geometry
        x, y = point.x, point.y
        px = int((x - tiff_ds.GetGeoTransform()[0]) / tiff_ds.GetGeoTransform()[1])
        py = int((y - tiff_ds.GetGeoTransform()[3]) / tiff_ds.GetGeoTransform()[5])
        value = band.ReadAsArray(px, py, 1, 1)[0][0]
        extracted_values.append(value)
    
    splits = file.split("_")[1:-1]
    aa = ("_").join(splits)
    col_name = f"HR_{aa}"
    points[col_name] = extracted_values
# **********************************************Extract Planet_reflectance**********************************************
print("Start extract the reflectance of PlanetScope imagery")
tiff_ds = gdal.Open(HR_image)
num_bands = tiff_ds.RasterCount

extracted_values = [[] for _ in range(num_bands)]
for index, row in points.iterrows():
    point = row.geometry
    x, y = point.x, point.y
    px = int((x - tiff_ds.GetGeoTransform()[0]) / tiff_ds.GetGeoTransform()[1])
    py = int((y - tiff_ds.GetGeoTransform()[3]) / tiff_ds.GetGeoTransform()[5])
    
    for band_num in range(1, num_bands + 1):
        band = tiff_ds.GetRasterBand(band_num)
        value = band.ReadAsArray(px, py, 1, 1)[0][0]
        extracted_values[band_num - 1].append(value)

extracted_values = pd.DataFrame(np.array(extracted_values)).T
extracted_values.columns = [f"Planet_{x}_nm" for x in [443, 490, 531, 565, 610, 665, 705, 865]]
points = pd.concat([points, extracted_values], axis = 1)

# **********************************************Extract fused_image_reflectance**********************************************
print("Start extract the reflectance of fused images")
filenames = os.listdir(HR_image_path)
filenames = [x for x in filenames if ".aux.xml" not in x]

start_var = True
for file in filenames:
    model = file.split("_")[0]
    if model == "Fusion3":
        model = "MSHFNET"
    elif model == "Fusion5":
        model = "MSAHFNET"
    print(f"   -{model}")
        
    tiff_ds = gdal.Open(f"{HR_image_path}{file}")
    num_bands = tiff_ds.RasterCount
    
    extracted_values = [[] for _ in range(num_bands)]
    for index, row in points.iterrows():
        point = row.geometry
        x, y = point.x, point.y
        px = int((x - tiff_ds.GetGeoTransform()[0]) / tiff_ds.GetGeoTransform()[1])
        py = int((y - tiff_ds.GetGeoTransform()[3]) / tiff_ds.GetGeoTransform()[5])
        
        for band_num in range(1, num_bands + 1):
            band = tiff_ds.GetRasterBand(band_num)
            arr = band.ReadAsArray(px, py, 1, 1)
            if arr is not None:
                value = arr[0][0]
            else:
                value = np.nan
            extracted_values[band_num - 1].append(value)

    extracted_values = pd.DataFrame(np.array(extracted_values)).T
    extracted_values.columns = [f"{model}_model_{x}_nm" for x in np.arange(400, 2401, 10)]
    if start_var:
        fused_refl = extracted_values
        start_var = False
    else:
        fused_refl = pd.concat([fused_refl, extracted_values], axis = 1)

fused_refl.reset_index(drop = True, inplace = True)
points = pd.concat([points, fused_refl], axis =1)

filter_cols = [f"HR_{x}" for x in sub_traits]
points = points[~((points[filter_cols] < 0) | (points[filter_cols].isna())).any(axis=1)]
points = points[~points.filter(like="_model_").isna().all(axis=1)]
points.to_csv(out_points, index = False)

Start extract the high resolution trait values
Start extract the reflectance of PlanetScope imagery
Start extract the reflectance of fused images
   -DSSFNET
   -MSHFNET
   -MSAHFNET
   -SSFCNN
   -TFNET
   -CONSSFCNN
   -RESTFNET
   -MSDCNN
   -SSRNET


In [12]:
## Extract HR values
points_file = '/130TB_RAID0/fujiang/SmallSat_part2/2_high_resolution_trait_maps/5_extract_training_samples/0_sparse_vagetated_area/2_selected_HR_points.shp'
out_points = "/130TB_RAID0/fujiang/SmallSat_part2/7_retrain_fusion_model/4_extract_training_samples/0_sparse_vagetated_area/2_HR_points.csv"

HR_traits = "/130TB_RAID0/fujiang/SmallSat_part2/2_high_resolution_trait_maps/3_clipped_trait_maps/"
filenames1 = os.listdir(HR_traits)
filenames1 = [x for x in filenames1 if ".aux.xml" not in x]

HR_image_path = "/130TB_RAID0/fujiang/SmallSat_part2/7_retrain_fusion_model/3_fused_imagery/"

traits = ['chl_area', 'LWC_area', 'LWC', 'LMA', 'lignin', 'cellulose', 'hemicellulose', 'sugars_mg_g', '%C', 'NSC_DS', 
          'nitrogen', 'phenolics_mg_g', 'phosphorus', 'potassium', 'calcium', 'Manganese_ppm','magnesium', 'sulfur']

points = gpd.read_file(points_file)
# **********************************************Extract HR_traits**********************************************
print("Start extract the high resolution trait values")
for tr in traits:
    in_tif = f"{HR_traits}20220420_{tr}_clipped.tif"
    tiff_ds = gdal.Open(in_tif)
    
    band = tiff_ds.GetRasterBand(1)
    extracted_values = []
    for index, row in points.iterrows():
        point = row.geometry
        x, y = point.x, point.y
        px = int((x - tiff_ds.GetGeoTransform()[0]) / tiff_ds.GetGeoTransform()[1])
        py = int((y - tiff_ds.GetGeoTransform()[3]) / tiff_ds.GetGeoTransform()[5])
        value = band.ReadAsArray(px, py, 1, 1)[0][0]
        extracted_values.append(value)
    
    points[f"HR_{tr}"] = extracted_values

filter_cols = [f"HR_{tr}" for x in traits]
points = points[~((points[filter_cols] < 0) | (points[filter_cols].isna())).any(axis=1)]
points.reset_index(drop = True, inplace = True)

# **********************************************Extract fused_image_reflectance**********************************************
print("Start extract the reflectance of fused images")
filenames = os.listdir(HR_image_path)
filenames = [x for x in filenames if ".aux.xml" not in x]
start_var = True
for tr in traits:
    tiff_ds = gdal.Open(f"{HR_image_path}{tr}_MSAHFNet_model_fused_imagery.tif")
    num_bands = tiff_ds.RasterCount
    
    extracted_values = [[] for _ in range(num_bands)]
    for index, row in points.iterrows():
        point = row.geometry
        x, y = point.x, point.y
        
        px = int((x - tiff_ds.GetGeoTransform()[0]) / tiff_ds.GetGeoTransform()[1])
        py = int((y - tiff_ds.GetGeoTransform()[3]) / tiff_ds.GetGeoTransform()[5])
        
        for band_num in range(1, num_bands + 1):
            band = tiff_ds.GetRasterBand(band_num)
            arr = band.ReadAsArray(px, py, 1, 1)
            if arr is not None:
                value = arr[0][0]
            else:
                value = np.nan
            extracted_values[band_num - 1].append(value)

    extracted_values = pd.DataFrame(np.array(extracted_values)).T
    extracted_values.columns = [f"{tr}_MSAHFNet_model_{x}_nm" for x in np.arange(400, 2401, 10)]
    if start_var:
        fused_refl = extracted_values
        start_var = False
    else:
        fused_refl = pd.concat([fused_refl, extracted_values], axis = 1)

fused_refl.reset_index(drop = True, inplace = True)
points = pd.concat([points, fused_refl], axis =1)
points = points[~points.filter(like="_model_").isna().all(axis=1)]

saved_points = points.copy()
saved_points.drop(columns = ["geometry"],inplace = True)
saved_points.to_csv(out_points,index = False)

Start extract the high resolution trait values
Start extract the reflectance of fused images
