In [133]:
import os
import pandas as pd
import numpy as np
from osgeo import gdal, osr
from math import pi
import datetime
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

In [131]:
def read_tif(tif_file):
    dataset = gdal.Open(tif_file)
    cols = dataset.RasterXSize
    rows = dataset.RasterYSize
    im_proj = (dataset.GetProjection())
    im_Geotrans = (dataset.GetGeoTransform())
    im_data = dataset.ReadAsArray(0, 0, cols, rows)
    if im_data.ndim == 3:
        im_data = np.moveaxis(dataset.ReadAsArray(0, 0, cols, rows), 0, -1)
    dataset = None
    return im_data, im_Geotrans, im_proj,rows, cols
    
def array_to_geotiff(array, output_path, geo_transform, projection, band_names=None):
    rows, cols, num_bands = array.shape
    driver = gdal.GetDriverByName('GTiff')
    dataset = driver.Create(output_path, cols, rows, num_bands, gdal.GDT_Float32)
    
    dataset.SetGeoTransform(geo_transform)
    dataset.SetProjection(projection)
    
    for band_num in range(num_bands):
        band = dataset.GetRasterBand(band_num + 1)
        band.WriteArray(array[:, :, band_num])
        band.FlushCache()
        
        if band_names:
            band.SetDescription(band_names[band_num])
    
    dataset = None
    band = None
    return

def get_corner(image_file):
    dataset = gdal.Open(image_file)
    geo_transform = dataset.GetGeoTransform()
    x_res = geo_transform[1]
    y_res = geo_transform[5] 
    x_min = geo_transform[0]
    y_max = geo_transform[3]
    x_max = x_min + x_res * dataset.RasterXSize
    y_min = y_max + y_res * dataset.RasterYSize
    
    x_size = dataset.RasterXSize
    y_size = dataset.RasterYSize
    im_proj = dataset.GetProjection()
    return im_proj, x_res, y_res, x_size, y_size, (x_min, y_min, x_max, y_max)

def find_closest_band(c_match, band_value):
    closest_key = None
    min_diff = float('inf') 

    for key, (low, high) in c_match.items():
        if low <= band_value <= high:
            return key
        else:
            diff = min(abs(band_value - low), abs(band_value - high))
            if diff < min_diff:
                min_diff = diff
                closest_key = key
    return closest_key
    
def sec(arr):
    return 1.0 / np.cos(arr)

def cal_Kvol(sZenith, vZenith, rAzimuth):
    # adjust relative azimuth angle (0-pi)
    rAzimuth = abs(rAzimuth)
    rAzimuth = np.where(rAzimuth > pi, 2 * pi - rAzimuth, rAzimuth)

    # calculate kernel value
    sz = sZenith
    vz = vZenith
    relaz = rAzimuth

    cosxi = np.cos(sz) * np.cos(vz) + np.sin(sz) * np.sin(vz) * np.cos(relaz)
    cosxi = np.where(cosxi > 1, 1, cosxi)
    cosxi = np.where(cosxi <= -1, -1, cosxi)

    xi = np.arccos(cosxi)
    Kvol = ((pi / 2 - xi) * cosxi + np.sin(xi)) / (np.cos(sz) + np.cos(vz)) - pi / 4

    return Kvol

def cal_Kgeo(sZenith, vZenith, rAzimuth):
    # adjust relative azimuth angle (0-pi)
    rAzimuth = abs(rAzimuth)
    rAzimuth = np.where(rAzimuth > pi, 2 * pi - rAzimuth, rAzimuth)

    # calculate kernel value
    hbratio = 2
    hrratio = 1

    sZenith = np.arctan(hrratio * np.tan(sZenith))
    vZenith = np.arctan(hrratio * np.tan(vZenith))

    coszeta = np.cos(sZenith) * np.cos(vZenith) + np.sin(sZenith) * np.sin(vZenith) * np.cos(rAzimuth)
    D2 = np.tan(sZenith) ** 2 + np.tan(vZenith) ** 2 - 2 * np.tan(sZenith) * np.tan(vZenith) * np.cos(rAzimuth)
    cost = hbratio * np.sqrt(D2 + (np.tan(sZenith) * np.tan(vZenith) * np.sin(rAzimuth)) ** 2) / (
                sec(sZenith) + sec(vZenith))

    cost = np.where(cost > 1, 1, cost)
    cost = np.where(cost <= -1, -1, cost)

    t = np.arccos(cost)
    overlap = 1 / pi * (t - np.sin(t) * cost) * (sec(sZenith) + sec(vZenith))
    Kgeo = overlap - sec(sZenith) - sec(vZenith) + 1 / 2 * (1 + coszeta) * sec(vZenith) * sec(sZenith)

    # output
    return Kgeo

def brdf_model(angles, f_iso, f_vol, f_geo):
    sZenith, vZenith, rAzimuth = np.radians(angles)
    
    k_vol = cal_Kvol(sZenith, vZenith, rAzimuth)
    k_geo = cal_Kgeo(sZenith, vZenith, rAzimuth)

    reflectance = f_iso + f_vol * k_vol + f_geo * k_geo
    return reflectance

### Step 1: get the extent of imagery on each sites and then download the MODIS BRDF parameters products on GEE
* MCD43A1.061 MODIS BRDF-Albedo Model Parameters Daily 500m (https://developers.google.com/earth-engine/datasets/catalog/MODIS_061_MCD43A1)

In [25]:
data_path = "/Volumes/ChenLab/Fujiang/0_Seasonal_PRISMA_traits/2_PRISMA_L2D/7_PRISMA_latlon/"
folders = ['D01_BART','D01_HARV','D02_SCBI','D03_OSBS','D07_MLBS','D07_ORNL','D08_TALL',
           'D10_CPER','D13_MOAB','D14_JORN','D14_SRER','D16_WREF','D19_BONA','D19_HEAL']

coordinates = {}
all_dates = {}
for folder in folders:
    path = f"{data_path}{folder}"
    file_list = os.listdir(path)
    file_list = [x for x in file_list if "LATLON.tif" in x]
    dates = [x.split("_")[3] for x in file_list] 
    dates = [f"{x[0:4]}-{x[4:6]}-{x[6:8]}" for x in dates]
    dates = {folder: dates}
    all_dates.update(dates)
    
    ul_x_all,lr_y_all, lr_x_all, ul_y_all = [],[],[],[]
    for file in file_list:
        im_data, im_Geotrans, im_proj,rows, cols = read_tif(f"{path}/{file}")
        ul_x, lr_y, lr_x, ul_y = im_data[0,0,1],im_data[-1,-1,0],im_data[-1,-1,1],im_data[0,0,0]
        ul_x_all.append(ul_x)
        lr_y_all.append(lr_y)
        lr_x_all.append(lr_x)
        ul_y_all.append(ul_y)
    ul_x, lr_y, lr_x, ul_y = min(ul_x_all),min(lr_y_all), max(lr_x_all), max(ul_y_all)
    ul_x, lr_y, lr_x, ul_y = ul_x-0.05, lr_y-0.05, lr_x+0.05, ul_y+0.05
    coord = {folder: [ul_x, lr_y, lr_x, ul_y]}
    coordinates.update(coord)
print(coordinates)
print(all_dates)

{'D01_BART': [-71.60350494384765, 43.71203536987305, -70.81015319824219, 44.39507369995117], 'D01_HARV': [-72.59523468017578, 42.18336029052735, -71.87100524902344, 42.89200286865234], 'D02_SCBI': [-78.42472534179687, 38.670466613769534, -77.86292572021485, 39.12803726196289], 'D03_OSBS': [-82.29445343017578, 29.448540878295898, -81.72926635742188, 29.9076717376709], 'D07_MLBS': [-81.0091064453125, 37.04331893920899, -80.2543212890625, 37.75365142822265], 'D07_ORNL': [-84.72948150634765, 35.6282341003418, -84.04168243408203, 36.3220718383789], 'D08_TALL': [-87.81290893554687, 32.585356903076175, -86.96782989501953, 33.31374816894531], 'D10_CPER': [-105.06370239257812, 40.454165649414065, -104.33207244873047, 41.15525894165039], 'D13_MOAB': [-109.69263153076172, 38.016291809082034, -109.00702209472657, 38.51306228637695], 'D14_JORN': [-107.15126495361328, 32.23346633911133, -106.46416778564453, 33.04036407470703], 'D14_SRER': [-111.2750747680664, 31.35882110595703, -110.36632080078125, 

### Step 2: Apply BRDF correction methods (https://doi.org/10.1016/j.rse.2016.01.023)

In [137]:
data_path = "/Volumes/ChenLab/Fujiang/0_Seasonal_PRISMA_traits/2_PRISMA_L2D/3_PRISMA_full_band_data/4_smoothed_data/"
modis_paras_path = "/Volumes/ChenLab/Fujiang/0_Seasonal_PRISMA_traits/2_PRISMA_L2D/3_PRISMA_full_band_data/5_MODIS_BRDF_parameters/"
wvl_path = "/Volumes/ChenLab/Fujiang/0_Seasonal_PRISMA_traits/2_PRISMA_L2D/4_PRISMA_wvl_data/"
angle_path = "/Volumes/ChenLab/Fujiang/0_Seasonal_PRISMA_traits/2_PRISMA_L2D/5_PRISMA_angles_data/"
out_path = "/Volumes/ChenLab/Fujiang/0_Seasonal_PRISMA_traits/2_PRISMA_L2D/3_PRISMA_full_band_data/6_BRDF_correction/"

folders = ['D01_BART','D01_HARV','D02_SCBI','D03_OSBS','D07_MLBS','D07_ORNL','D08_TALL',
           'D10_CPER','D13_MOAB','D14_JORN','D14_SRER','D16_WREF','D19_BONA','D19_HEAL']

for folder in folders[0:1]:
    os.makedirs(f"{out_path}/{folder}", exist_ok=True)
    imagery_path = f"{data_path}{folder}"
    angle_folder = f"{angle_path}{folder}"
    wvl_folder = f"{wvl_path}{folder}"
    modis_folder = f"{modis_paras_path}{folder}"
    output_path = f"{out_path}{folder}"
    
    file_name = os.listdir(imagery_path)
    file_name = [x for x in file_name if "_FULL.tif" in x and "._" not in x and ".aux.xml" not in x]
    
    for kk, file in enumerate(file_name):
        print(f"{datetime.datetime.now().replace(microsecond=0)}, {folder}, {kk+1}/{len(file_name)}: {file}")
        
        date = file.split("_")[3]
        imagery_file = f"{imagery_path}/{file}"
        angle_file = f"{angle_folder}/{file[:-9]}_ANG.tif"
        wvl_file = f"{wvl_folder}/{file.split('.')[0]}.wvl"
        modis_paras_file = f"{modis_folder}/{date[0:4]}_{date[4:6]}_{date[6:8]}.tif"
        ####
        df = pd.read_csv(wvl_file,delimiter=" ")
        
        ####
        input_ds = gdal.Open(modis_paras_file)
        out = f"{modis_folder}/{date[0:4]}_{date[4:6]}_{date[6:8]}_resample.tif"
        proj, x_res, y_res, x_size, y_size, bounds = get_corner(imagery_file)
        gdal.Warp(out, input_ds, xRes=x_res, yRes=abs(y_res),dstSRS=proj, outputBounds=bounds, 
                  width=x_size, height=y_size, resampleAlg=gdal.GRA_NearestNeighbour)
        input_ds = None
        out = None
        
        ####
        modis_paras_file = f"{modis_folder}/{date[0:4]}_{date[4:6]}_{date[6:8]}_resample.tif"
        
        im_data, im_Geotrans, im_proj,im_rows, im_cols = read_tif(imagery_file)
        im_data = np.nan_to_num(im_data, nan=0)
        angle_data, angle_Geotrans, angle_proj,angle_rows, angle_cols = read_tif(angle_file)
        paras_data,paras_Geotrans, paras_proj,paras_rows, paras_cols = read_tif(modis_paras_file)
        paras_data = paras_data*0.001
        print(im_data.shape, angle_data.shape, paras_data.shape)
        
        view_zenith_angle = angle_data[:,:,0].flatten()
        relative_azimuth_angle = angle_data[:,:,1].flatten()
        solar_zenith_angle = angle_data[:,:,2].flatten()
        angles = (solar_zenith_angle, view_zenith_angle, relative_azimuth_angle)
        
        paras_slice = [paras_data[:, :, i*3:(i+1)*3] for i in range(7)] # 21 layers, every 3 layers are f_iso, f_vol, f_geo for each MODIS band (band 1 -- band7). 
        start_var = True
        for idx, geo_paras in enumerate(paras_slice):
            f_iso = geo_paras[:,:,0].flatten()
            f_vol = geo_paras[:,:,1].flatten()
            f_geo = geo_paras[:,:,2].flatten()
            
            sza_ref = np.full(f_iso.shape, 45)
            vza_ref = np.full(f_iso.shape, 0)
            raz_ref = np.full(f_iso.shape, 0)
            reference_angles = (sza_ref, vza_ref, raz_ref)
            
            c_factor = brdf_model(reference_angles, f_iso, f_vol, f_geo)/brdf_model(angles, f_iso, f_vol, f_geo)
            c_factor = c_factor.reshape(geo_paras.shape[0],geo_paras.shape[1], 1)
            if start_var:
                c_factors = c_factor
                start_var = False
            else:
                c_factors = np.concatenate((c_factors, c_factor), axis=-1) # band 1: 620-670nm; band 2: 841-876nm; band 3: 459-479nm; band 4: 545-565nm; band 5: 1230-1250nm; band 6: 1628-1652nm; band 7: 2105-2155nm
        
        c_match = {0:(620, 670), 1: (841, 876), 2:(459, 479), 3:(545,565), 4:(1230,1250), 5:(1628, 1652), 6:(2105,2155)} ## Key represents the index in c_factors, value represents the covered wavelenth.
        start_var = True
        for i, wl in enumerate(df["wl"].tolist()):
            idx = find_closest_band(c_match, wl)
            c_array = c_factors[:,:,idx]
            image_array = im_data[:,:,i]
            nbar_refl = c_array*image_array
            nbar_refl = nbar_refl[:,:, np.newaxis]
            if start_var:
                corrected_nbar = nbar_refl
                start_var = False
            else:
                corrected_nbar = np.concatenate((corrected_nbar, nbar_refl), axis=-1)
        corrected_nbar[corrected_nbar <= 0] = np.nan
        band_names = [f"{round(x,2)} nm" for x in df["wl"]]
        out_tif = f"{output_path}/{file[:-4]}_NBAR.tif"
        array_to_geotiff(corrected_nbar, out_tif, im_Geotrans, im_proj, band_names=band_names)
        im_data = None
        paras_data = None
        angle_data = None

2024-10-21 23:21:45, D01_BART, 1/28: PRS_L2D_STD_20210525153514_20210525153518_0001_HCO_FULL.tif
