In [6]:
import numpy as np
import pandas as pd
from collections import defaultdict
import shutil
import os
import math
import logging
from tqdm import tqdm
import seaborn as sns
import psutil
import joblib
from joblib import Parallel, delayed
import xarray as xr
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import matplotlib.ticker as ticker
from scipy.stats import gaussian_kde
from scipy.signal import savgol_filter
from matplotlib import gridspec
from sklearn.metrics import mean_squared_error
from scipy import stats
import warnings
from datetime import datetime
import cartopy.crs as ccrs
import joypy
from osgeo import gdal
import matplotlib as mpl
from cartopy.mpl.ticker import LatitudeFormatter, LongitudeFormatter
warnings.filterwarnings('ignore')

In [7]:
def read_tif(tif_file):
    dataset = gdal.Open(tif_file)
    cols = dataset.RasterXSize
    rows = dataset.RasterYSize
    im_proj = (dataset.GetProjection())
    im_Geotrans = (dataset.GetGeoTransform())
    im_data = dataset.ReadAsArray(0, 0, cols, rows)
    if im_data.ndim == 3:
        im_data = np.moveaxis(dataset.ReadAsArray(0, 0, cols, rows), 0, -1)
    dataset = None
    return im_data, im_Geotrans, im_proj,rows, cols
    
def array_to_geotiff(array, output_path, geo_transform, projection, band_names=None):
    rows, cols, num_bands = array.shape
    driver = gdal.GetDriverByName('GTiff')
    dataset = driver.Create(output_path, cols, rows, num_bands, gdal.GDT_Float32)
    
    dataset.SetGeoTransform(geo_transform)
    dataset.SetProjection(projection)
    
    for band_num in range(num_bands):
        band = dataset.GetRasterBand(band_num + 1)
        band.WriteArray(array[:, :, band_num])
        band.FlushCache()
        
        if band_names:
            band.SetDescription(band_names[band_num])
    
    dataset = None
    band = None
    return

def get_corner(image_file):
    dataset = gdal.Open(image_file)
    geo_transform = dataset.GetGeoTransform()
    x_res = geo_transform[1]
    y_res = geo_transform[5] 
    x_min = geo_transform[0]
    y_max = geo_transform[3]
    x_max = x_min + x_res * dataset.RasterXSize
    y_min = y_max + y_res * dataset.RasterYSize
    
    x_size = dataset.RasterXSize
    y_size = dataset.RasterYSize
    im_proj = dataset.GetProjection()
    return im_proj, x_res, y_res, x_size, y_size, (x_min, y_min, x_max, y_max)

class BandInfo:
    def __init__(self):
        self.centers = None
        self.bandwidths = None
        self.centers_stdevs = None
        self.bandwidth_stdevs = None
        self.band_quantity = None
        self.band_unit = None

def erf_local(x):
    sign = 1 if x >= 0 else -1
    x = abs(x)
    a1 =  0.254829592
    a2 = -0.284496736
    a3 =  1.421413741
    a4 = -1.453152027
    a5 =  1.061405429
    p  =  0.3275911

    t = 1.0/(1.0 + p*x)
    y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*math.exp(-x*x)
    return sign*y # erf(-x) = -erf(x)

try:
    from math import erf
except:
    try:
        from scipy.special import erf
    except:
        erf = erf_local

def erfc(z):
    '''Complement of the error function.'''
    return 1.0 - erf(z)

def normal_cdf(x):
    '''CDF of the normal distribution.'''
    sqrt2 = 1.4142135623730951
    return 0.5 * erfc(-x / sqrt2)

def normal_integral(a, b):
    '''Integral of the normal distribution from a to b.'''
    return normal_cdf(b) - normal_cdf(a)

def ranges_overlap(R1, R2):
    '''Returns True if there is overlap between ranges of pairs R1 and R2.'''
    if (R1[0] < R2[0] and R1[1] < R2[0]) or \
       (R1[0] > R2[1] and R1[1] > R2[1]):
        return False
    return True

def overlap(R1, R2):
    '''Returns (min, max) of overlap between the ranges of pairs R1 and R2.'''
    return (max(R1[0], R2[0]), min(R1[1], R2[1]))

def normal(mean, stdev, x):
    sqrt_2pi = 2.5066282746310002
    return math.exp(-((x - mean) / stdev)**2 / 2.0) / (sqrt_2pi * stdev)

def build_fwhm(centers):
    '''Returns FWHM list, assuming FWHM is midway between adjacent bands.
    '''
    fwhm = [0] * len(centers)
    fwhm[0] = centers[1] - centers[0]
    fwhm[-1] = centers[-1] - centers[-2]
    for i in range(1, len(centers) - 1):
        fwhm[i] = (centers[i + 1] - centers[i - 1]) / 2.0
    return fwhm

def create_resampling_matrix(centers1, fwhm1, centers2, fwhm2):
    logger = logging.getLogger('spectral')

    sqrt_8log2 = 2.3548200450309493

    N1 = len(centers1)
    N2 = len(centers2)
    bounds1 = [[centers1[i] - fwhm1[i] / 2.0, centers1[i] + fwhm1[i] /
                2.0] for i in range(N1)]
    bounds2 = [[centers2[i] - fwhm2[i] / 2.0, centers2[i] + fwhm2[i] /
                2.0] for i in range(N2)]

    M = np.zeros([N2, N1])

    jStart = 0
    nan = float('nan')
    for i in range(N2):
        stdev = fwhm2[i] / sqrt_8log2
        j = jStart

        # Find the first original band that overlaps the new band
        while j < N1 and bounds1[j][1] < bounds2[i][0]:
            j += 1

        if j == N1:
            logger.info(('No overlap for target band %d (%f / %f)' % (
                i, centers2[i], fwhm2[i])))
            M[i, 0] = nan
            continue

        matches = []

        # Get indices for all original bands that overlap the new band
        while j < N1 and bounds1[j][0] < bounds2[i][1]:
            if ranges_overlap(bounds1[j], bounds2[i]):
                matches.append(j)
            j += 1

        # Put NaN in first element of any row that doesn't produce a band in
        # the new schema.
        if len(matches) == 0:
            logger.info('No overlap for target band %d (%f / %f)',
                         i, centers2[i], fwhm2[i])
            M[i, 0] = nan
            continue

        # Determine the weights for the original bands that overlap the new
        # band. There may be multiple bands that overlap or even just a single
        # band that only partially overlaps.  Weights are normoalized so either
        # case can be handled.

        overlaps = [overlap(bounds1[k], bounds2[i]) for k in matches]
        contribs = np.zeros(len(matches))
        A = 0.
        for k in range(len(matches)):
            #endNorms = [normal(centers2[i], stdev, x) for x in overlaps[k]]
            #dA = (overlaps[k][1] - overlaps[k][0]) * sum(endNorms) / 2.0
            (a, b) = [(x - centers2[i]) / stdev for x in overlaps[k]]
            dA = normal_integral(a, b)
            contribs[k] = dA
            A += dA
        contribs = contribs / A
        for k in range(len(matches)):
            M[i, matches[k]] = contribs[k]
    return M

class BandResampler:
    def __init__(self, centers1, centers2, fwhm1=None, fwhm2=None):
        if isinstance(centers1, BandInfo):
            fwhm1 = centers1.bandwidths
            centers1 = centers1.centers
        if isinstance(centers2, BandInfo):
            fwhm2 = centers2.bandwidths
            centers2 = centers2.centers
        if fwhm1 is None:
            fwhm1 = build_fwhm(centers1)
        if fwhm2 is None:
            fwhm2 = build_fwhm(centers2)
        self.matrix = create_resampling_matrix(
            centers1, fwhm1, centers2, fwhm2)

    def __call__(self, spectrum):
        return np.dot(self.matrix, spectrum)

def do_resample(spectra, source_wvl, source_fwhm):
    centers1 = source_wvl
    centers2 = np.arange(580,930,1)
    fwhm1 = source_fwhm
    fwhm2 = build_fwhm(centers2)
    resampler = BandResampler(centers1, centers2, fwhm1, fwhm2)
    resampled_spectra = resampler(spectra)
    return resampled_spectra

def compute_weighted_reflectance(response_curve_wl, hyper_wl, hyper_data, response_curve, hyper_fwhm):
    interp_reflectance = do_resample(hyper_data, hyper_wl, hyper_fwhm)    
    numerator = np.trapz(interp_reflectance * response_curve, response_curve_wl)
    denominator = np.trapz(response_curve, response_curve_wl)
    weighted_reflectance = numerator/denominator
    return weighted_reflectance

def run_parallel(image_chunk, response_curve_wl, response_curve_data, wvl, fwhm):
    results = np.zeros(shape = (image_chunk.shape[0], image_chunk.shape[1],2))
    for i in range(image_chunk.shape[0]):
        for j in range(image_chunk.shape[1]):
            spectra = image_chunk[i, j, :]
            spectra = np.nan_to_num(spectra, nan=0)
    
            simulated_spectra = []
            for kk in [7,11]:
                response_curve = response_curve_data[kk,:][200:550]
                refl = compute_weighted_reflectance(response_curve_wl, wvl, spectra, response_curve, fwhm)
                simulated_spectra.append(refl)
            results[i,j,:] = simulated_spectra
    return results

### 1. convolved prisma imagery

In [None]:
RSR = "/Volumes/ChenLab/Fujiang/0_PhD_dissertation_data/12_RTM_estimation_through_given_LAI/1_LAI_estimation/2_MODIS_convolved_PRISMA_imagery/terra_modis_RSR.nc"
spectral_response = xr.open_dataset(RSR)
# exported_bands = spectral_response["bands"].data[1:13]
exported_bands = spectral_response["bands"].data[np.array([7,11])]

response_curve_wl = spectral_response["wavelength"].data[200:550]
response_curve_data = spectral_response["RSR"].data

############################################################################
data_path = "/Volumes/ChenLab/Fujiang/0_PhD_dissertation_data/4_Extract_training_data/9_PRISMA_imagery_smoothed_tif/"
wvl_path = "/Volumes/ChenLab/Fujiang/0_PhD_dissertation_data/2_PRISMA_L2D/2_PRISMA_L2D_tif_2020_2023/"
out_path = "/Volumes/ChenLab/Fujiang/0_PhD_dissertation_data/12_RTM_estimation_through_given_LAI/1_LAI_estimation/2_MODIS_convolved_PRISMA_imagery/"

folders = ['D01_BART','D01_HARV','D02_SCBI','D03_OSBS','D07_MLBS','D07_ORNL','D08_TALL',
           'D10_CPER','D13_MOAB','D14_JORN','D14_SRER','D16_WREF','D19_BONA','D19_HEAL']

for folder in folders[0:1]:
    os.makedirs(f"{out_path}/{folder}", exist_ok=True)
    imagery_path = f"{data_path}{folder}"
    wvl_folder = f"{wvl_path}{folder}"
    output_path = f"{out_path}{folder}"
    
    file_name = os.listdir(imagery_path)
    file_name = [x for x in file_name if "_FULL.tif" in x and "._" not in x and ".aux.xml" not in x]

    idx = 1
    for file in file_name[0:1]:
        print(f"{folder} -- {idx}/{len(file_name)} -- {file}")
        
        wvl_file = f"{wvl_folder}/{file.split('.')[0]}.wvl"
        df = pd.read_csv(wvl_file,delimiter=" ")
        im_data, im_Geotrans, im_proj,rows, cols = read_tif(f"{imagery_path}/{file}")
        
        wvl = df["wl"].values[20:60]
        fwhm = df["fwhm"].values[20:60]
        im_data = im_data[:,:,20:60]
        
        chunk_size = 50
        image_chunks = []
        for i in range(0, im_data.shape[0], chunk_size):
                for j in range(0, im_data.shape[1], chunk_size):
                    chunk = im_data[i:i + chunk_size, j:j + chunk_size]
                    image_chunks.append(chunk)
                    
        num_processes = psutil.cpu_count(logical=False)
        chunk_results = Parallel(n_jobs=num_processes)(delayed(run_parallel)(image_chunk, response_curve_wl, response_curve_data, wvl, fwhm) for image_chunk in tqdm(image_chunks,desc="Processing Chunks"))

        convolved_imagery = np.zeros(shape = (im_data.shape[0], im_data.shape[1], 2))
        chunk_index = 0
        for i in range(0, convolved_imagery.shape[0], chunk_size):
            for j in range(0, convolved_imagery.shape[1], chunk_size):
                convolved_imagery[i:i + chunk_size, j:j + chunk_size,:] = chunk_results[chunk_index]
                chunk_index = chunk_index+1
                
        out_tif = f"{output_path}/{file[:-4]}_convolved.tif"
        band_names = [f"{x} nm" for x in exported_bands]
        array_to_geotiff(convolved_imagery, out_tif, im_Geotrans, im_proj, band_names=band_names)
        idx = idx +1

### 2. calculate PRISMA VI

In [8]:
data_path = "/Volumes/ChenLab/Fujiang/0_Seasonal_PRISMA_traits/12_RTM_estimation_through_given_LAI/1_LAI_estimation/2_MODIS_convolved_PRISMA_imagery/NBAR_refl/"
out_path = "/Volumes/ChenLab/Fujiang/0_Seasonal_PRISMA_traits/12_RTM_estimation_through_given_LAI/1_LAI_estimation/3_NDVI_NIRv/NBAR_refl/"

folders = ['D01_BART','D01_HARV','D02_SCBI','D03_OSBS','D07_MLBS','D07_ORNL','D08_TALL',
           'D10_CPER','D13_MOAB','D14_JORN','D14_SRER','D16_WREF','D19_BONA','D19_HEAL']

for folder in folders:
    os.makedirs(f"{out_path}/{folder}", exist_ok=True)
    imagery_path = f"{data_path}{folder}"
    output_path = f"{out_path}{folder}"
    
    file_name = os.listdir(imagery_path)
    file_name = [x for x in file_name if "_FULL_NBAR_convolved.tif" in x and "._" not in x and ".aux.xml" not in x]
    
    for file in file_name:
        im_data, im_Geotrans, im_proj,rows, cols = read_tif(f"{imagery_path}/{file}")
        red = im_data[:,:,0]
        nir = im_data[:,:,1]
        ndvi  = (nir - red)/ (nir + red)
        nirv = ndvi*nir
        data_array = np.stack([ndvi, nirv], axis = -1)
        out_tif = f"{output_path}/{file[:-4]}_VI.tif"
        array_to_geotiff(data_array, out_tif, im_Geotrans, im_proj, band_names=["NDVI", "NIRV"])

### 3. clipped lulc data

In [9]:
data_path = "/Volumes/ChenLab/Fujiang/0_Seasonal_PRISMA_traits/12_RTM_estimation_through_given_LAI/1_LAI_estimation/3_NDVI_NIRv/NBAR_refl/"
out_path = "/Volumes/ChenLab/Fujiang/0_Seasonal_PRISMA_traits/12_RTM_estimation_through_given_LAI/1_LAI_estimation/4_clipped_lulc/NBAR_refl/"
lulc_path = "/Volumes/ChenLab/Fujiang/0_Seasonal_PRISMA_traits/3_GLC_FCS30D_Land_cover_data_US/4_land_cover_NEON_sites/"

folders = ['D01_BART','D01_HARV','D02_SCBI','D03_OSBS','D07_MLBS','D07_ORNL','D08_TALL',
           'D10_CPER','D13_MOAB','D14_JORN','D14_SRER','D16_WREF','D19_BONA','D19_HEAL']

for folder in folders:
    os.makedirs(f"{out_path}/{folder}", exist_ok=True)
    imagery_path = f"{data_path}{folder}"
    output_path = f"{out_path}{folder}"
    lulc_folder = f"{lulc_path}{folder}"
    
    file_name = os.listdir(imagery_path)
    file_name = [x for x in file_name if "_FULL_NBAR_convolved_VI.tif" in x and "._" not in x and ".aux.xml" not in x]
    
    for file in file_name:
        vi_tif = f"{imagery_path}/{file}"
        
        year = file.split("_")[3][0:4]
        if year == "2023":
            year ="2022"
        lulc_file = f"{lulc_folder}/{folder}_{year}_land_cover.tif"
        out_lulc = f"{output_path}/{file[:-4]}_lulc.tif"

        input_ds = gdal.Open(lulc_file)
        proj, x_res, y_res, x_size, y_size, bounds = get_corner(vi_tif)
        gdal.Warp(out_lulc, input_ds, xRes=x_res, yRes=abs(y_res),dstSRS=proj, outputBounds=bounds, 
                  width=x_size, height=y_size, resampleAlg=gdal.GRA_NearestNeighbour)
        input_ds = None
        out_lulc = None

### 4. merged with lulc data

In [10]:
data_path = "/Volumes/ChenLab/Fujiang/0_Seasonal_PRISMA_traits/12_RTM_estimation_through_given_LAI/1_LAI_estimation/3_NDVI_NIRv/NBAR_refl/"
lulc_path = "/Volumes/ChenLab/Fujiang/0_Seasonal_PRISMA_traits/12_RTM_estimation_through_given_LAI/1_LAI_estimation/4_clipped_lulc/NBAR_refl/"
out_path = "/Volumes/ChenLab/Fujiang/0_Seasonal_PRISMA_traits/12_RTM_estimation_through_given_LAI/1_LAI_estimation/5_merged_with_lulc/NBAR_refl/"

folders = ['D01_BART','D01_HARV','D02_SCBI','D03_OSBS','D07_MLBS','D07_ORNL','D08_TALL',
           'D10_CPER','D13_MOAB','D14_JORN','D14_SRER','D16_WREF','D19_BONA','D19_HEAL']

for folder in folders:
    os.makedirs(f"{out_path}/{folder}", exist_ok=True)
    imagery_path = f"{data_path}{folder}"
    output_path = f"{out_path}{folder}"
    lulc_folder = f"{lulc_path}{folder}"
    
    file_name = os.listdir(imagery_path)
    file_name = [x for x in file_name if "_FULL_NBAR_convolved_VI.tif" in x and "._" not in x and ".aux.xml" not in x]
    
    for file in file_name:
        vi_tif = f"{imagery_path}/{file}"
        lulc_file = f"{lulc_folder}/{file[:-4]}_lulc.tif"
        
        im_data, im_Geotrans, im_proj,rows, cols = read_tif(vi_tif)
        lc_data, lc_Geotrans, im_proj,rows, cols = read_tif(lulc_file)
        lc_data = lc_data[:,:,np.newaxis]

        im_data = np.concatenate((im_data, lc_data), axis=2)
        out_tif = f"{output_path}/{file}"
        array_to_geotiff(im_data, out_tif, im_Geotrans, im_proj, band_names=["NDVI", "NIRV","LULC"])

### 5. convert lulc to pft

In [11]:
def transfer_lulc_pft(array):
    # land_cover_type = {10: "Rainfed cropland",11: "Herbaceous cover cropland",12: "Tree or shrub cover (Orchard) cropland",
    #                    20: "Irrigated cropland",51: "Open evergreen broadleaved forest",52: "Closed evergreen broadleaved forest",
    #                    61: "Open deciduous broadleaved forest",62: "Closed deciduous broadleaved forest",71: "Open evergreen needle-leaved forest",
    #                    72: "Closed evergreen needle-leaved forest",81: "Open deciduous needle-leaved forest",82: "Closed deciduous needle-leaved forest",
    #                    91: "Open mixed leaf forest (broadleaved and needle-leaved)",92: "Closed mixed leaf forest (broadleaved and needle-leaved)", 
    #                    120: "Shrubland",121: "Evergreen shrubland",122: "Deciduous shrubland",130: "Grassland",140: "Lichens and mosses",
    #                    150: "Sparse vegetation",152: "Sparse shrubland",153: "Sparse herbaceous",181: "Swamp",182: "Marsh",183: "Flooded flat",
    #                    184: "Saline",185: "Mangrove",186: "Salt marsh",187: "Tidal flat",190: "Impervious surfaces",200: "Bare areas",
    #                    201: "Consolidated bare areas",202: "Unconsolidated bare areas",210: "Water body",220: "Permanent ice and snow",
    #                    0: "Filled value",250: "Filled value"}
    
    #CPR: lulc = [10, 11, 12, 20] --> 100
    #EBF: lulc = [51, 52] --> 200
    #DBF: lulc = [61, 62] --> 300
    #ENF: lulc = [71, 72] --> 400
    #DNF: lulc = [81, 82] --> 500
    #MF: lulc = [91, 92] --> 600
    #SHR: lulc = [120, 121, 122] --> 700
    #GRA: lulc = [130] --> 800
    array = array.astype(int)
    pft = np.zeros(array.shape)
    pft[np.isin(array, [10, 11, 12, 20])] = 100
    pft[np.isin(array, [51, 52])] = 200
    pft[np.isin(array, [61, 62])] = 300
    pft[np.isin(array, [71, 72])] = 400
    pft[np.isin(array, [81, 82])] = 500
    pft[np.isin(array, [91, 92])] = 600
    pft[np.isin(array, [120, 121, 122])] = 700
    pft[np.isin(array, [130])] = 800
    return pft

In [12]:
data_path = "/Volumes/ChenLab/Fujiang/0_Seasonal_PRISMA_traits/12_RTM_estimation_through_given_LAI/1_LAI_estimation/5_merged_with_lulc/NBAR_refl/"
out_path = "/Volumes/ChenLab/Fujiang/0_Seasonal_PRISMA_traits/12_RTM_estimation_through_given_LAI/1_LAI_estimation/6_convert_lulc_pft/NBAR_refl/"

folders = ['D01_BART','D01_HARV','D02_SCBI','D03_OSBS','D07_MLBS','D07_ORNL','D08_TALL',
           'D10_CPER','D13_MOAB','D14_JORN','D14_SRER','D16_WREF','D19_BONA','D19_HEAL']

for folder in folders:
    os.makedirs(f"{out_path}/{folder}", exist_ok=True)
    imagery_path = f"{data_path}{folder}"
    output_path = f"{out_path}{folder}"
    
    file_name = os.listdir(imagery_path)
    file_name = [x for x in file_name if "_FULL_NBAR_convolved_VI.tif" in x and "._" not in x and ".aux.xml" not in x]
    for file in file_name:
        im_data, im_Geotrans, im_proj,rows, cols = read_tif(f"{imagery_path}/{file}")
        pft = transfer_lulc_pft(im_data[:,:,-1])
        pft = pft[:, :, np.newaxis]
        data_array = im_data[:,:,:-1]        
        im_data = np.concatenate((data_array, pft), axis=2)
        
        out_tif = f"{output_path}/{file}"
        array_to_geotiff(im_data, out_tif, im_Geotrans, im_proj, band_names=["NDVI", "NIRV","PFT"])

### 6. calculate PRISMA LAI

In [353]:
def exponential_model(x, a, b):
    return a * np.exp(b * x)
def linear_model(x, a, b):
    return a * x +b
def estimate_LAI_parallel(image_chunk, model, site):
    pft_map = {100:"CPR",200:"EBF", 300:"DBF", 400:"ENF", 500:"DNF", 600:"MF", 700:"SHR", 800:"GRA", 0: "all"}
    results = np.zeros(shape = (image_chunk.shape[0], image_chunk.shape[1]))
    for i in range(image_chunk.shape[0]):
            for j in range(image_chunk.shape[1]):
                points = image_chunk[i, j, :]
                ndvi, nirv, pft = points[0],points[1],points[2]
                
                ndvi = np.clip(ndvi, -2, 2)
                nirv = np.clip(nirv, -2, 2)
                
                PFT = pft_map[pft]
                
                temp = model[model["site"] == site]
                model_pft = temp["PFT"].unique()
                if PFT not in model_pft:
                    ml = temp[temp["PFT"] == "all"]
                else:
                    ml = temp[temp["PFT"] == PFT]

                max_ml = ml.loc[ml['R2'].idxmax()]
                a, b = max_ml["a"],max_ml["b"]
                if max_ml["x"] == "NDVI":
                    LAI = exponential_model(ndvi, a, b)
                else:
                    LAI = linear_model(nirv, a, b)
                results[i, j] = LAI
    return results

In [336]:
data_path = "/Volumes/ChenLab/Fujiang/0_PhD_dissertation_data/12_RTM_estimation_through_given_LAI/1_LAI_estimation/6_convert_lulc_pft/"
out_path = "/Volumes/ChenLab/Fujiang/0_PhD_dissertation_data/12_RTM_estimation_through_given_LAI/1_LAI_estimation/7_estimated_LAI/"
model_path = "/Volumes/ChenLab/Fujiang/0_PhD_dissertation_data/12_RTM_estimation_through_given_LAI/1_LAI_estimation/1_original_data/8_extract_VI_LAI_to_points/"
model = pd.read_csv(f"{model_path}2_saved_models.csv")

folders = ['D01_BART','D01_HARV','D02_SCBI','D03_OSBS','D07_MLBS','D07_ORNL','D08_TALL',
           'D10_CPER','D13_MOAB','D14_JORN','D14_SRER','D16_WREF','D19_BONA','D19_HEAL']

for folder in folders[3:4]:
    os.makedirs(f"{out_path}/{folder}", exist_ok=True)
    imagery_path = f"{data_path}{folder}"
    output_path = f"{out_path}{folder}"
    
    file_name = os.listdir(imagery_path)
    file_name = [x for x in file_name if "_FULL_convolved_VI.tif" in x and "._" not in x and ".aux.xml" not in x]

    site = folder[-4:]
    for file in file_name:
        im_data, im_Geotrans, im_proj,rows, cols = read_tif(f"{imagery_path}/{file}")
        im_data = np.where(np.isnan(im_data), 0, im_data)
        
        chunk_size = 50
        image_chunks = []
        for i in range(0, im_data.shape[0], chunk_size):
                for j in range(0, im_data.shape[1], chunk_size):
                    chunk = im_data[i:i + chunk_size, j:j + chunk_size]
                    image_chunks.append(chunk)

        num_processes = psutil.cpu_count(logical=False)
        chunk_results = Parallel(n_jobs=num_processes)(delayed(estimate_LAI_parallel)(image_chunk, model, site) for image_chunk in tqdm(image_chunks,desc="Processing Chunks"))
        
        result_imagery = np.zeros(shape = (im_data.shape[0], im_data.shape[1]))
        chunk_index = 0
        for i in range(0, result_imagery.shape[0], chunk_size):
            for j in range(0, result_imagery.shape[1], chunk_size):
                result_imagery[i:i + chunk_size, j:j + chunk_size] = chunk_results[chunk_index]
                chunk_index = chunk_index+1
                
        result_imagery = result_imagery[:, :, np.newaxis]
        out_tif = f"{output_path}/{file[:-17]}_LAI.tif"
        array_to_geotiff(result_imagery, out_tif, im_Geotrans, im_proj, band_names=["LAI"])