In [None]:
import numpy as np
import pandas as pd
from collections import defaultdict
import shutil
import os
import math
import seaborn as sns
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import matplotlib.ticker as ticker
from scipy.stats import gaussian_kde
from scipy.signal import savgol_filter
from matplotlib import gridspec
from sklearn.metrics import mean_squared_error
from scipy import stats
import warnings
from datetime import datetime
import cartopy.crs as ccrs
import joypy
from osgeo import gdal
import matplotlib as mpl
from cartopy.mpl.ticker import LatitudeFormatter, LongitudeFormatter
warnings.filterwarnings('ignore')

In [None]:
def read_tif(tif_file):
    dataset = gdal.Open(tif_file)
    cols = dataset.RasterXSize
    rows = dataset.RasterYSize
    im_proj = (dataset.GetProjection())
    im_Geotrans = (dataset.GetGeoTransform())
    im_data = dataset.ReadAsArray(0, 0, cols, rows)
    if im_data.ndim == 3:
        im_data = np.moveaxis(dataset.ReadAsArray(0, 0, cols, rows), 0, -1)
    dataset = None
    return im_data, im_Geotrans, im_proj,rows, cols
    
def array_to_geotiff(array, output_path, geo_transform, projection, band_names=None):
    rows, cols, num_bands = array.shape
    driver = gdal.GetDriverByName('GTiff')
    dataset = driver.Create(output_path, cols, rows, num_bands, gdal.GDT_Float32)
    
    dataset.SetGeoTransform(geo_transform)
    dataset.SetProjection(projection)
    
    for band_num in range(num_bands):
        band = dataset.GetRasterBand(band_num + 1)
        band.WriteArray(array[:, :, band_num])
        band.FlushCache()
        
        if band_names:
            band.SetDescription(band_names[band_num])
    
    dataset = None
    band = None
    return

def get_corner_coordinates(geotrans, cols, rows):
    ul_x = geotrans[0]
    ul_y = geotrans[3]
    lr_x = geotrans[0] + cols * geotrans[1] + rows * geotrans[2]
    lr_y = geotrans[3] + cols * geotrans[4] + rows * geotrans[5]
    return ul_x, ul_y, lr_x, lr_y

def nanmean_images(image_list):
    stacked_images = np.stack(image_list, axis=-1)
    averaged_image = np.nanmean(stacked_images, axis=-1)
    return averaged_image

def get_corner(image_file):
    dataset = gdal.Open(image_file)
    geo_transform = dataset.GetGeoTransform()
    x_res = geo_transform[1]
    y_res = geo_transform[5] 
    x_min = geo_transform[0]
    y_max = geo_transform[3]
    x_max = x_min + x_res * dataset.RasterXSize
    y_min = y_max + y_res * dataset.RasterYSize
    
    x_size = dataset.RasterXSize
    y_size = dataset.RasterYSize
    im_proj = dataset.GetProjection()
    return im_proj, x_res, y_res, x_size, y_size, (x_min, y_min, x_max, y_max)

def transfer_lulc_pft(array):
    # land_cover_type = {10: "Rainfed cropland",11: "Herbaceous cover cropland",12: "Tree or shrub cover (Orchard) cropland",
    #                    20: "Irrigated cropland",51: "Open evergreen broadleaved forest",52: "Closed evergreen broadleaved forest",
    #                    61: "Open deciduous broadleaved forest",62: "Closed deciduous broadleaved forest",71: "Open evergreen needle-leaved forest",
    #                    72: "Closed evergreen needle-leaved forest",81: "Open deciduous needle-leaved forest",82: "Closed deciduous needle-leaved forest",
    #                    91: "Open mixed leaf forest (broadleaved and needle-leaved)",92: "Closed mixed leaf forest (broadleaved and needle-leaved)", 
    #                    120: "Shrubland",121: "Evergreen shrubland",122: "Deciduous shrubland",130: "Grassland",140: "Lichens and mosses",
    #                    150: "Sparse vegetation",152: "Sparse shrubland",153: "Sparse herbaceous",181: "Swamp",182: "Marsh",183: "Flooded flat",
    #                    184: "Saline",185: "Mangrove",186: "Salt marsh",187: "Tidal flat",190: "Impervious surfaces",200: "Bare areas",
    #                    201: "Consolidated bare areas",202: "Unconsolidated bare areas",210: "Water body",220: "Permanent ice and snow",
    #                    0: "Filled value",250: "Filled value"}
    
    #CPR: lulc = [10, 11, 12, 20] --> 100
    #EBF: lulc = [51, 52] --> 200
    #DBF: lulc = [61, 62] --> 300
    #ENF: lulc = [71, 72] --> 400
    #DNF: lulc = [81, 82] --> 500
    #MF: lulc = [91, 92] --> 600
    #SHR: lulc = [120, 121, 122] --> 700
    #GRA: lulc = [130] --> 800
    array = array.astype(int)
    pft = np.full(array.shape, np.nan)
    pft[np.isin(array, [10, 11, 12, 20])] = 100
    pft[np.isin(array, [51, 52])] = 200
    pft[np.isin(array, [61, 62])] = 300
    pft[np.isin(array, [71, 72])] = 400
    pft[np.isin(array, [81, 82])] = 500
    pft[np.isin(array, [91, 92])] = 600
    pft[np.isin(array, [120, 121, 122])] = 700
    pft[np.isin(array, [130])] = 800
    return pft

## 1. Clouds mask

In [None]:
data_path = "/Volumes/ChenLab-1/Fujiang/0_Seasonal_PRISMA_traits/4_Extract_training_data/7_PRISMA_estimated_traits/NBAR_refl_with_LAI/"
mask_path = "/Volumes/ChenLab-1/Fujiang/0_Seasonal_PRISMA_traits/4_Extract_training_data/8_PRISMA_cloud_mask/"
out_path = "/Users/fji/Desktop/NBAR_refl_with_LAI/0_trait_maps/"
os.makedirs(f"{out_path}1_clouds_masked", exist_ok=True)

output_path = f"{out_path}1_clouds_masked/"

folders = ['D01_BART','D01_HARV','D02_SCBI','D03_OSBS','D07_MLBS','D07_ORNL',
           'D08_TALL','D10_CPER','D13_MOAB','D14_JORN','D16_WREF']

for folder in folders:
    print(folder)
    os.makedirs(f"{output_path}{folder}", exist_ok=True)
    
    trait_tif_folder = f"{data_path}{folder}"
    mask_tif_path = f"{mask_path}{folder}"
    out_tif_path = f"{output_path}{folder}"
    
    file_name = os.listdir(trait_tif_folder)
    filter = ["_all_data_models_traits.tif", "_PFT_specific_models_traits.tif"]
    for fil in filter:
        file_name1 = [x for x in file_name if fil in x and "._" not in x and ".aux.xml" not in x]
    
        for file in file_name1:
            mask_name = f"{('_').join(file.split('_')[0:8])}_mask.tif"
            
            trait_tif = f"{trait_tif_folder}/{file}"
            mask_tif = f"{mask_tif_path}/{mask_name}"
            
            mask = gdal.Open(mask_tif).ReadAsArray()
            mask = np.expand_dims(mask, axis=-1)
            im_data, im_Geotrans, im_proj,_,_ = read_tif(trait_tif)
    
            dataset = gdal.Open(trait_tif)
            band_names = [dataset.GetRasterBand(i).GetDescription() for i in range(1, dataset.RasterCount + 1)]
    
            masked_trait = mask*im_data
            out_tif = f"{out_tif_path}/{file[:-4]}_masked.tif"
            array_to_geotiff(masked_trait, out_tif, im_Geotrans, im_proj, band_names=band_names)

## 2. Mosaic the imagery at the same data

In [None]:
data_path = "/Users/fji/Desktop/NBAR_refl_with_LAI/0_trait_maps/"
in_path = f"{data_path}1_clouds_masked/"
out_path = f"{data_path}1_clouds_masked/"

ffolders = ['D01_BART','D01_HARV','D02_SCBI','D03_OSBS','D07_MLBS','D07_ORNL',
           'D08_TALL','D10_CPER','D13_MOAB','D14_JORN','D16_WREF']

for folder in folders:
    trait_tif_folder = f"{in_path}{folder}"
    out_tif_path = f"{out_path}{folder}"
    
    file_name = os.listdir(trait_tif_folder)
    filter = ["_all_data_models_traits_masked.tif", "_PFT_specific_models_traits_masked.tif"]
    for fil in filter:
        file_name1 = [x for x in file_name if fil in x and "._" not in x and ".aux.xml" not in x]
        date_dict = defaultdict(list)
        for file in file_name1:
            date = file.split('_')[3][:8]
            date_dict[date].append(file)
        duplicates = {date: files for date, files in date_dict.items() if len(files) > 1}
        for date, files in duplicates.items():
            print(f"{folder}, Date: {date}")
            
            tif1 = f"{trait_tif_folder}/{files[0]}"
            tif2 = f"{trait_tif_folder}/{files[1]}"
            data_list = [tif1, tif2]
            out_tif = tif1
            options = gdal.WarpOptions(format='GTiff',srcNodata='nan', dstNodata='nan',  warpOptions=['INIT_DEST=NO_DATA'])
            mosaic = gdal.Warp(out_tif, data_list, options=options)
            mosaic = None 
            print("Mosaic created successfully")
            os.remove(tif2)

## 3. Process to same geo-extend

In [None]:
data_path = "/Users/fji/Desktop/NBAR_refl_with_LAI/0_trait_maps/"
in_path = f"{data_path}1_clouds_masked/"
os.makedirs(f"{data_path}2_same_geo_extend", exist_ok=True)
out_path = f"{data_path}2_same_geo_extend/"

folders = ['D01_BART','D01_HARV','D02_SCBI','D03_OSBS','D07_MLBS','D07_ORNL',
           'D08_TALL','D10_CPER','D13_MOAB','D14_JORN','D16_WREF']

for folder in folders:
    os.makedirs(f"{out_path}{folder}", exist_ok=True)
    trait_tif_folder = f"{in_path}{folder}"
    out_tif_path = f"{out_path}{folder}"
    
    file_name = os.listdir(trait_tif_folder)
    filter = ["_all_data_models_traits_masked.tif", "_PFT_specific_models_traits_masked.tif"]
    for fil in filter:
        file_name1 = [x for x in file_name if fil in x and "._" not in x and ".aux.xml" not in x]
        ul_x_all,lr_y_all, lr_x_all, ul_y_all = [],[],[],[]
        
        for file in file_name1:
            im_data, im_Geotrans, im_proj,rows, cols = read_tif(f"{trait_tif_folder}/{file}")
            ul_x, ul_y, lr_x, lr_y = get_corner_coordinates(im_Geotrans, cols, rows)
            ul_x_all.append(ul_x)
            lr_y_all.append(lr_y)
            lr_x_all.append(lr_x)
            ul_y_all.append(ul_y)
        ul_x, lr_y, lr_x, ul_y = min(ul_x_all),min(lr_y_all), max(lr_x_all), max(ul_y_all)

        for file in file_name1:
            input_tif = f"{trait_tif_folder}/{file}"
            output_tif = f"{out_tif_path}/{file}"
            gdal.Warp(output_tif, input_tif, format = 'GTiff', outputBounds=(ul_x, lr_y, lr_x, ul_y))
            input_tif = None
            output_tif = None

## 4. aggregated to monthly

In [None]:
data_path = "/Users/fji/Desktop/NBAR_refl_with_LAI/0_trait_maps/"
in_path = f"{data_path}2_same_geo_extend/"
os.makedirs(f"{data_path}3_aggregate_monthly", exist_ok=True)
out_path = f"{data_path}3_aggregate_monthly/"

folders = ['D01_BART','D01_HARV','D02_SCBI','D03_OSBS','D07_MLBS','D07_ORNL',
           'D08_TALL','D10_CPER','D13_MOAB','D14_JORN','D16_WREF']

for folder in folders:
    os.makedirs(f"{out_path}{folder}", exist_ok=True)
    trait_tif_folder = f"{in_path}{folder}"
    out_tif_path = f"{out_path}{folder}"
    
    file_name = os.listdir(trait_tif_folder)
    filter = ["_all_data_models_traits_masked.tif", "_PFT_specific_models_traits_masked.tif"]
    for fil in filter:
        file_name1 = [x for x in file_name if fil in x and "._" not in x and ".aux.xml" not in x]
        
        lists = []
        for file in file_name1:
            month = file.split('_')[3][4:6]
            dicts = (month, file)
            lists.append(dicts)
        dictionary = {}
        for key, value in lists:
            if key in dictionary:
                dictionary[key].append(value)
            else:
                dictionary[key] = [value]
                
        keys = list(dictionary.keys())
        for key in keys:
            values = dictionary[key]
            image_list = []
            for image in values:
                im_data, im_Geotrans, im_proj,rows, cols = read_tif(f"{trait_tif_folder}/{image}")
                im_data = np.where(im_data <= 0, np.nan, im_data)
                image_list.append(im_data)
            mean_image = nanmean_images(image_list)
            out_tif = f"{out_tif_path}/{key}{fil}"
            band_names = ['Chla+b_mean','Chla+b_std', 'Ccar_mean','Ccar_std','EWT_mean','EWT_std','LMA_mean', 'LMA_std', 'Nitrogen_mean','Nitrogen_std']
            array_to_geotiff(mean_image, out_tif, im_Geotrans, im_proj, band_names=band_names)

## 5. Merged trait maps with LULC data and convert to PFTs.

In [None]:
data_path = "/Users/fji/Desktop/NBAR_refl_with_LAI/0_trait_maps/"
lulc_path = f"{data_path}0_original_lulc/"
in_path = f"{data_path}3_aggregate_monthly/"
os.makedirs(f"{data_path}4_merge_LULC", exist_ok=True)
out_path = f"{data_path}4_merge_LULC/"

folders = ['D01_BART','D01_HARV','D02_SCBI','D03_OSBS','D07_MLBS','D07_ORNL',
           'D08_TALL','D10_CPER','D13_MOAB','D14_JORN','D16_WREF']

for folder in folders:
    os.makedirs(f"{out_path}{folder}", exist_ok=True)
    trait_tif_folder = f"{in_path}{folder}"
    out_tif_path = f"{out_path}{folder}"
    lulc_folder = f"{lulc_path}{folder}"
    file_name = os.listdir(trait_tif_folder)
    filter = ["_all_data_models_traits_masked.tif", "_PFT_specific_models_traits_masked.tif"]
    
    landuse = f"{lulc_folder}/{folder}_2021_land_cover.tif"
    
    for fil in filter:
        file_name1 = [x for x in file_name if fil in x and "._" not in x and ".aux.xml" not in x]
        for file in file_name1:
            print(folder, fil, file)
            input_ds = gdal.Open(landuse)
            out_lulc = f"{out_tif_path}/{file[:-4]}_lulc.tif"
            proj, x_res, y_res, x_size, y_size, bounds = get_corner(f"{trait_tif_folder}/{file}")
            gdal.Warp(out_lulc, input_ds, xRes=x_res, yRes=abs(y_res),dstSRS=proj, outputBounds=bounds, 
                      width=x_size, height=y_size, resampleAlg=gdal.GRA_NearestNeighbour)
            input_ds = None
            out_lulc = None
            
            lulc_file = f"{out_tif_path}/{file[:-4]}_lulc.tif"
            im_data, im_Geotrans, im_proj,im_rows, im_cols = read_tif(f"{trait_tif_folder}/{file}")
            lulc_data,lulc_Geotrans, lulc_proj,lulc_rows, lulc_cols = read_tif(lulc_file)
            pft_data = transfer_lulc_pft(lulc_data)
            pft_data = pft_data[:, :, np.newaxis]
            im_data = np.concatenate((im_data, pft_data), axis=2)
            
            out_name = f"{out_tif_path}/{file}"
            band_names = ['Chla+b_mean','Chla+b_std', 'Ccar_mean','Ccar_std','EWT_mean','EWT_std','LMA_mean', 'LMA_std', 'Nitrogen_mean','Nitrogen_std',"PFTs"]
            array_to_geotiff(im_data, out_name, im_Geotrans, im_proj, band_names=band_names)

## 6. Convert trait maps to WGS-84

In [None]:
data_path = "/Users/fji/Desktop/NBAR_refl_with_LAI/0_trait_maps/"
in_path = f"{data_path}4_merge_LULC/"
os.makedirs(f"{data_path}5_convert_wgs84", exist_ok=True)
out_path = f"{data_path}5_convert_wgs84/"

folders = ['D01_BART','D01_HARV','D02_SCBI','D03_OSBS','D07_MLBS','D07_ORNL',
           'D08_TALL','D10_CPER','D13_MOAB','D14_JORN','D16_WREF']

for folder in folders:
    os.makedirs(f"{out_path}{folder}", exist_ok=True)
    trait_tif_folder = f"{in_path}{folder}"
    out_tif_path = f"{out_path}{folder}"
    
    file_name = os.listdir(trait_tif_folder)
    filter = ["_all_data_models_traits_masked.tif", "_PFT_specific_models_traits_masked.tif"]
    for fil in filter:
        file_name1 = [x for x in file_name if fil in x and "._" not in x and ".aux.xml" not in x]
        for file in file_name1:
            in_tif = f"{trait_tif_folder}/{file}"
            out_tif = f"{out_tif_path}/{file}"
    
            input_ds = gdal.Open(in_tif)
            output_ds = gdal.Warp(out_tif, input_ds, dstSRS='EPSG:4326')
            
            input_ds = None
            output_ds = None

## 7. Clip the aggregated maps.

In [None]:
bounds = {'D01_BART':(-71.45, 43.95, -71.15, 44.166),'D01_HARV':(-72.335, 42.415,-72.025, 42.62),'D02_SCBI':(-78.285, 38.785, -77.99, 39.005),
          'D03_OSBS':(-82.17, 29.56, -81.87, 29.79),'D07_MLBS': (-80.67, 37.265, -80.38, 37.49),'D07_ORNL': (-84.458, 35.815, -84.16, 36.055),
          'D08_TALL': (-87.538, 32.84, -87.245, 33.07),'D10_CPER': (-104.905, 40.7, -104.587, 40.93),
          'D13_MOAB': (-109.544, 38.135, -109.247, 38.365),'D14_JORN': (-107.0, 32.47, -106.725, 32.706),
          'D14_SRER': (-110.99, 31.72, -110.725, 31.95),'D16_WREF': (-122.11, 45.73, -121.82, 45.93),
          'D19_BONA':(-147.76, 65.09, -147.24, 65.28),'D19_HEAL':(-149.46, 63.77, -148.95, 63.97)}

data_path = "/Users/fji/Desktop/NBAR_refl_with_LAI/0_trait_maps/"
in_path = f"{data_path}5_convert_wgs84/"
os.makedirs(f"{data_path}6_clipped_maps", exist_ok=True)
out_path = f"{data_path}6_clipped_maps/"

folders = ['D01_BART','D01_HARV','D02_SCBI','D03_OSBS','D07_MLBS','D07_ORNL',
           'D08_TALL','D10_CPER','D13_MOAB','D14_JORN','D16_WREF']

for folder in folders:
    os.makedirs(f"{out_path}{folder}", exist_ok=True)
    trait_tif_folder = f"{in_path}{folder}"
    out_tif_path = f"{out_path}{folder}"
    
    file_name = os.listdir(trait_tif_folder)
    filter = ["_all_data_models_traits_masked.tif", "_PFT_specific_models_traits_masked.tif"]
    for fil in filter:
        file_name1 = [x for x in file_name if fil in x and "._" not in x and ".aux.xml" not in x]
        for file in file_name1:
            input_tif = f"{trait_tif_folder}/{file}"
            output_tif = f"{out_tif_path}/{file[:-4]}_clipped.tif"
            gdal.Warp(output_tif,  input_tif, format = 'GTiff', outputBounds=bounds[folder])
            input_tif = None
            output_tif = None

## 8. Different traits saperated analysis

In [None]:
data_path = "/Users/fji/Desktop/NBAR_refl_with_LAI/0_trait_maps/"
in_path = f"{data_path}6_clipped_maps/"
os.makedirs(f"{data_path}7_trait_separate_analysis", exist_ok=True)
out_path = f"{data_path}7_trait_separate_analysis/"

M1 = ["04","05","06","07","08","09","10","11"]
M2 = ["04","05","06","07","08","09","10"]
M3 = ["05","06","07","08","09"]
M4 = ["05","06","07","08","09","10"]
M5 = ["06","07","08","09"]
folder_month_map = {'D14_SRER': M1, 'D07_ORNL': ["04","06","07","08","09","10"], 'D08_TALL': M2, 
                    'D16_WREF': M2, 'D13_MOAB': M3}
default_month = M4

folders = ['D01_BART','D01_HARV','D02_SCBI','D03_OSBS','D07_MLBS','D07_ORNL',
           'D08_TALL','D10_CPER','D13_MOAB','D14_JORN','D16_WREF']

for folder in folders:
    os.makedirs(f"{out_path}{folder}", exist_ok=True)
    trait_tif_folder = f"{in_path}{folder}"
    out_tif_path = f"{out_path}{folder}"
    month = folder_month_map.get(folder, default_month)
    
    filter = ["_all_data_models_traits_masked_clipped.tif", "_PFT_specific_models_traits_masked_clipped.tif"]
    for fil in filter:
        cab_all, car_all, ewt_all, lma_all, nitrogen_all = [],[],[],[],[]
        for mon in month:
            im_data, im_Geotrans, im_proj,rows, cols = read_tif(f"{trait_tif_folder}/{mon}{fil}")
            cab,car,ewt,lma,nitrogen = im_data[:,:,0],im_data[:,:,2],im_data[:,:,4],im_data[:,:,6],im_data[:,:,8]
            cab_all.append(cab), car_all.append(car), ewt_all.append(ewt), lma_all.append(lma), nitrogen_all.append(nitrogen)
        
        pft = im_data[:,:,10]
        pft = pft[:, :, np.newaxis]
        cab_all, car_all, ewt_all = np.stack(cab_all, axis=-1),np.stack(car_all, axis=-1),np.stack(ewt_all, axis=-1),
        lma_all, nitrogen_all = np.stack(lma_all, axis=-1),np.stack(nitrogen_all, axis=-1)

        data_all = [cab_all,car_all,ewt_all,lma_all,nitrogen_all]
        out_names = ["Chla+b", "Ccar","EWT", "LMA", "Nitrogen"]
        kkk = 0
        for data in data_all:
            for i in range(data.shape[0]):
                for j in range(data.shape[1]):
                    if np.any(np.isnan(data[i, j, :])):
                       data[i, j, :] = np.nan
            
            data = np.concatenate((data, pft), axis=2)
            band_names = month + ["PFTs"]
            out_name = f"{out_tif_path}/{out_names[kkk]}{('_').join(fil.split('_')[0:4])}_ex_nan.tif"
            array_to_geotiff(data, out_name, im_Geotrans, im_proj, band_names=band_names)
            kkk = kkk+1
            
        data_all_2 = [cab_all,car_all,ewt_all,lma_all,nitrogen_all]
        kkk = 0
        for trait_data in data_all_2:
            trait_data = np.concatenate((trait_data, pft), axis=2)
            band_names = month + ["PFTs"]
            out_name2 = f"{out_tif_path}/{out_names[kkk]}{('_').join(fil.split('_')[0:4])}_original.tif"
            array_to_geotiff(trait_data, out_name2, im_Geotrans, im_proj, band_names=band_names)
            kkk = kkk+1