In [1]:
import os
import geopandas as gpd
import numpy as np
import pandas as pd
import warnings
import psutil
import xarray as xr
import joblib
import logging
import torch.nn.init as init
import torch.nn as nn
import seaborn as sns
import datetime
from joblib import Parallel, delayed
from tqdm import tqdm
import torch.optim as optim
from scipy import stats
import geopandas as gpd
from sklearn.model_selection import ParameterGrid
import torch
import copy
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
warnings.filterwarnings("ignore")
from osgeo import gdal, ogr,gdalconst
from sklearn.linear_model import LinearRegression
from shapely.geometry import Point, Polygon,box

In [2]:
def read_tif(tif_file):
    dataset = gdal.Open(tif_file)
    cols = dataset.RasterXSize
    rows = dataset.RasterYSize
    im_proj = (dataset.GetProjection())
    im_Geotrans = (dataset.GetGeoTransform())
    im_data = dataset.ReadAsArray(0, 0, cols, rows)
    if im_data.ndim == 3:
        im_data = np.moveaxis(dataset.ReadAsArray(0, 0, cols, rows), 0, -1)
    return im_data, im_Geotrans, im_proj,rows, cols
def array_to_geotiff(array, output_path, geo_transform, projection, band_names=None):
    rows, cols, num_bands = array.shape
    driver = gdal.GetDriverByName('GTiff')
    dataset = driver.Create(output_path, cols, rows, num_bands, gdal.GDT_Float32)
    
    dataset.SetGeoTransform(geo_transform)
    dataset.SetProjection(projection)
    
    for band_num in range(num_bands):
        band = dataset.GetRasterBand(band_num + 1)
        band.WriteArray(array[:, :, band_num])
        band.FlushCache()
        
        if band_names:
            band.SetDescription(band_names[band_num])
    
    dataset = None
    band = None

class BandInfo:
    def __init__(self):
        self.centers = None
        self.bandwidths = None
        self.centers_stdevs = None
        self.bandwidth_stdevs = None
        self.band_quantity = None
        self.band_unit = None

def erf_local(x):
    sign = 1 if x >= 0 else -1
    x = abs(x)
    a1 =  0.254829592
    a2 = -0.284496736
    a3 =  1.421413741
    a4 = -1.453152027
    a5 =  1.061405429
    p  =  0.3275911

    t = 1.0/(1.0 + p*x)
    y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*math.exp(-x*x)
    return sign*y # erf(-x) = -erf(x)

try:
    from math import erf
except:
    try:
        from scipy.special import erf
    except:
        erf = erf_local

def erfc(z):
    '''Complement of the error function.'''
    return 1.0 - erf(z)

def normal_cdf(x):
    '''CDF of the normal distribution.'''
    sqrt2 = 1.4142135623730951
    return 0.5 * erfc(-x / sqrt2)

def normal_integral(a, b):
    '''Integral of the normal distribution from a to b.'''
    return normal_cdf(b) - normal_cdf(a)

def ranges_overlap(R1, R2):
    '''Returns True if there is overlap between ranges of pairs R1 and R2.'''
    if (R1[0] < R2[0] and R1[1] < R2[0]) or \
       (R1[0] > R2[1] and R1[1] > R2[1]):
        return False
    return True

def overlap(R1, R2):
    '''Returns (min, max) of overlap between the ranges of pairs R1 and R2.'''
    return (max(R1[0], R2[0]), min(R1[1], R2[1]))

def normal(mean, stdev, x):
    sqrt_2pi = 2.5066282746310002
    return math.exp(-((x - mean) / stdev)**2 / 2.0) / (sqrt_2pi * stdev)

def build_fwhm(centers):
    '''Returns FWHM list, assuming FWHM is midway between adjacent bands.
    '''
    fwhm = [0] * len(centers)
    fwhm[0] = centers[1] - centers[0]
    fwhm[-1] = centers[-1] - centers[-2]
    for i in range(1, len(centers) - 1):
        fwhm[i] = (centers[i + 1] - centers[i - 1]) / 2.0
    return fwhm

def create_resampling_matrix(centers1, fwhm1, centers2, fwhm2):
    logger = logging.getLogger('spectral')

    sqrt_8log2 = 2.3548200450309493

    N1 = len(centers1)
    N2 = len(centers2)
    bounds1 = [[centers1[i] - fwhm1[i] / 2.0, centers1[i] + fwhm1[i] /
                2.0] for i in range(N1)]
    bounds2 = [[centers2[i] - fwhm2[i] / 2.0, centers2[i] + fwhm2[i] /
                2.0] for i in range(N2)]

    M = np.zeros([N2, N1])

    jStart = 0
    nan = float('nan')
    for i in range(N2):
        stdev = fwhm2[i] / sqrt_8log2
        j = jStart

        # Find the first original band that overlaps the new band
        while j < N1 and bounds1[j][1] < bounds2[i][0]:
            j += 1

        if j == N1:
            logger.info(('No overlap for target band %d (%f / %f)' % (
                i, centers2[i], fwhm2[i])))
            M[i, 0] = nan
            continue

        matches = []

        # Get indices for all original bands that overlap the new band
        while j < N1 and bounds1[j][0] < bounds2[i][1]:
            if ranges_overlap(bounds1[j], bounds2[i]):
                matches.append(j)
            j += 1

        # Put NaN in first element of any row that doesn't produce a band in
        # the new schema.
        if len(matches) == 0:
            logger.info('No overlap for target band %d (%f / %f)',
                         i, centers2[i], fwhm2[i])
            M[i, 0] = nan
            continue

        # Determine the weights for the original bands that overlap the new
        # band. There may be multiple bands that overlap or even just a single
        # band that only partially overlaps.  Weights are normoalized so either
        # case can be handled.

        overlaps = [overlap(bounds1[k], bounds2[i]) for k in matches]
        contribs = np.zeros(len(matches))
        A = 0.
        for k in range(len(matches)):
            #endNorms = [normal(centers2[i], stdev, x) for x in overlaps[k]]
            #dA = (overlaps[k][1] - overlaps[k][0]) * sum(endNorms) / 2.0
            (a, b) = [(x - centers2[i]) / stdev for x in overlaps[k]]
            dA = normal_integral(a, b)
            contribs[k] = dA
            A += dA
        contribs = contribs / A
        for k in range(len(matches)):
            M[i, matches[k]] = contribs[k]
    return M

class BandResampler:
    def __init__(self, centers1, centers2, fwhm1=None, fwhm2=None):
        if isinstance(centers1, BandInfo):
            fwhm1 = centers1.bandwidths
            centers1 = centers1.centers
        if isinstance(centers2, BandInfo):
            fwhm2 = centers2.bandwidths
            centers2 = centers2.centers
        if fwhm1 is None:
            fwhm1 = build_fwhm(centers1)
        if fwhm2 is None:
            fwhm2 = build_fwhm(centers2)
        self.matrix = create_resampling_matrix(
            centers1, fwhm1, centers2, fwhm2)

    def __call__(self, spectrum):
        return np.dot(self.matrix, spectrum)
def do_resample(spectra, source_wvl, source_fwhm):
    centers1 = source_wvl
    centers2 = np.arange(400,1001,1)
    fwhm1 = source_fwhm
    fwhm2 = build_fwhm(centers2)
    resampler = BandResampler(centers1, centers2, fwhm1, fwhm2)
    resampled_spectra = resampler(spectra)
    return resampled_spectra

def compute_weighted_reflectance(response_curve_wl, hyper_wl, hyper_data, response_curve, hyper_fwhm):
    interp_reflectance = do_resample(hyper_data, hyper_wl, hyper_fwhm)    
    numerator = np.trapz(interp_reflectance * response_curve, response_curve_wl)
    denominator = np.trapz(response_curve, response_curve_wl)
    weighted_reflectance = numerator/denominator
    return weighted_reflectance

def run_parallel(image_chunk, df, wvl, fwhm):
    results = np.zeros(shape = (image_chunk.shape[0], image_chunk.shape[1],8))
    for i in range(image_chunk.shape[0]):
        for j in range(image_chunk.shape[1]):
            ratio_data = image_chunk[i, j, :]
            response_curve_wl = df[df.columns[0]]
            simulated_nosie = []
            for kk in range(1,9):
                response_curve = df[df.columns[kk]]
                noise = compute_weighted_reflectance(response_curve_wl, wvl, ratio_data, response_curve, fwhm)
                simulated_nosie.append(noise)
            results[i,j,:] = simulated_nosie
    return results

def raster_to_points(geotiff, shp_name):
    inDs = gdal.Open(geotiff)
    DsoutDs = gdal.Translate(f"{shp_name}.xyz", inDs, format='XYZ', creationOptions=["ADD_HEADER_LINE=YES"])
    outDs = None
    try:
        os.remove(f'{shp_name}.csv')
    except OSError:
        pass

    os.rename(f'{shp_name}.xyz', f'{shp_name}.csv')
    os.system('ogr2ogr -f "ESRI Shapefile" -oo X_POSSIBLE_NAMES=X* -oo Y_POSSIBLE_NAMES=Y* -oo KEEP_GEOM_COLUMNS=NO {0}.shp {0}.csv'.format(shp_name))
    try:
        os.remove(f'{shp_name}.csv')
    except OSError:
        pass
    crs_wkt = inDs.GetProjection()
    shp_layer = gpd.read_file(f"{shp_name}.shp")
    shp_layer.crs = crs_wkt
    shp_layer.to_file(f"{shp_name}.shp")
    return

def filter_points(gdf, min_distance):
    gdf_copy = gdf.copy()
    to_remove = []

    for index, row in gdf_copy.iterrows():
        if index not in to_remove:
            distances = gdf_copy.geometry.distance(row.geometry)
            close_points = distances[distances < min_distance].index.tolist()
            close_points.remove(index)
            to_remove.extend(close_points)

    gdf_copy.drop(index=to_remove, inplace=True)
    gdf_copy.reset_index(drop = True, inplace = True)
    return gdf_copy

## 1. process data of 20230422

In [None]:
emit_path = "/Volumes/ChenLab/Fujiang/2_SmallSat_project/2_processed_data/1_SHIFT_areas/"
out_path = "/Volumes/ChenLab/Fujiang/2_SmallSat_project/3_paired_SHIFT_Planet/1_20230422/"

file_name = "EMIT_L2A_RFL_001_20230422T195924_2311213_002_reflectance.img"
src_ds = gdal.Open(f"{emit_path}{file_name}")
out_ds = f"{out_path}EMIT_L2A_RFL_20230422.tif"
gdal.Translate(out_ds, src_ds, format='GTiff')
src_ds = None

file_name = "EMIT_L2A_RFLUNCERT_001_20230422T195924_2311213_002_reflectance_uncertainty.img"
src_ds = gdal.Open(f"{emit_path}{file_name}")
out_ds = f"{out_path}EMIT_L2A_RFLUNCERT_20230422.tif"
gdal.Translate(out_ds, src_ds, format='GTiff')
src_ds = None

### project EMIT data

In [32]:
data_path = "/Volumes/ChenLab/Fujiang/2_SmallSat_project/3_paired_SHIFT_Planet/1_20230422/"

emit_file = "EMIT_L2A_RFL_20230422.tif"
in_tif = f"{data_path}/{emit_file}"
out_tif = f"{data_path}/{emit_file[:-4]}_projection.tif"
input_ds = gdal.Open(in_tif)
output_ds = gdal.Warp(out_tif, input_ds, dstSRS='EPSG:32611')

emit_file = "EMIT_L2A_RFLUNCERT_20230422.tif"
in_tif = f"{data_path}/{emit_file}"
out_tif = f"{data_path}/{emit_file[:-4]}_projection.tif"
input_ds = gdal.Open(in_tif)
output_ds = gdal.Warp(out_tif, input_ds, dstSRS='EPSG:32611')

### clip the data

In [33]:
data_path = "/Volumes/ChenLab/Fujiang/2_SmallSat_project/3_paired_SHIFT_Planet/1_20230422/"
shp = f"{data_path}0_clip_shp/0_first_beginning_clip_shp.shp"
emit_file = "EMIT_L2A_RFL_20230422_projection.tif"
planet_file = "PlanetScope_RFL_20230422.tif"
emit_file2 = "EMIT_L2A_RFLUNCERT_20230422_projection.tif"

gdf = gpd.read_file(shp)
bounds = gdf.bounds
min_x = bounds["minx"].values[0]
min_y = bounds["miny"].values[0]
max_x = bounds["maxx"].values[0]
max_y = bounds["maxy"].values[0]
ul_x, ul_y = (min_x, max_y)
lr_x, lr_y = (max_x, min_y)

### clip EMIT data
input_tif = f"{data_path}/{emit_file}"
output_tif = f"{data_path}/{emit_file[:-4]}_clipped.tif"
gdal.Warp(output_tif, input_tif, format = 'GTiff', outputBounds=(ul_x, lr_y, lr_x, ul_y))
output_tif = None

### clip Planet data
input_tif = f"{data_path}/{planet_file}"
output_tif = f"{data_path}/{planet_file[:-4]}_clipped.tif"
gdal.Warp(output_tif, input_tif, format = 'GTiff', outputBounds=(ul_x, lr_y, lr_x, ul_y))
output_tif = None


### clip EMIT uncertainty data
input_tif = f"{data_path}/{emit_file2}"
output_tif = f"{data_path}/{emit_file2[:-4]}_clipped.tif"
gdal.Warp(output_tif, input_tif, format = 'GTiff', outputBounds=(ul_x, lr_y, lr_x, ul_y))
output_tif = None


### geo-corrections

In [None]:
# # Use the upscaled Planet data as reference to correct EMIT data.
# data_path = "/Volumes/ChenLab-1/Fujiang/2_SmallSat_project/3_paired_SHIFT_Planet/1_20230422/"
# emit_file = "EMIT_L2A_RFL_20230422_projection_clipped.tif"

# emit_data, emit_Geotrans, emit_proj,_, _ = read_tif(f"{data_path}{emit_file}")
# nir = emit_data[:,:,51]
# red = emit_data[:,:,39]
# green = emit_data[:,:,23]

# data = [nir, red, green]
# data = np.stack(data, axis = 2)

# output_path = f"{data_path}0_geocorrections/EMIT_L2A_RFL_20230422_3bands.tif"
# array_to_geotiff(data, output_path, emit_Geotrans, emit_proj, band_names=["nir", "red", "green"])

### clip again

In [34]:
data_path = "/Volumes/ChenLab/Fujiang/2_SmallSat_project/3_paired_SHIFT_Planet/1_20230422/"
shp = f"{data_path}0_clip_shp/20230422_clip_shp.shp"
emit_file = "EMIT_L2A_RFL_20230422_projection_clipped_modified.tif"
planet_file = "PlanetScope_RFL_20230422_clipped.tif"
emit_file2 = "EMIT_L2A_RFLUNCERT_20230422_projection_clipped_modified.tif"

gdf = gpd.read_file(shp)
bounds = gdf.bounds
min_x = bounds["minx"].values[0]
min_y = bounds["miny"].values[0]
max_x = bounds["maxx"].values[0]
max_y = bounds["maxy"].values[0]
ul_x, ul_y = (min_x, max_y)
lr_x, lr_y = (max_x, min_y)

# ### clip EMIT data
# input_tif = f"{data_path}/{emit_file}"
# output_tif = f"{data_path}/{emit_file[:-4]}_clipped.tif"
# gdal.Warp(output_tif, input_tif, format = 'GTiff', outputBounds=(ul_x, lr_y, lr_x, ul_y))
# output_tif = None

# ### clip Planet data
# input_tif = f"{data_path}/{planet_file}"
# output_tif = f"{data_path}/{planet_file[:-4]}_clipped.tif"
# gdal.Warp(output_tif, input_tif, format = 'GTiff', outputBounds=(ul_x, lr_y, lr_x, ul_y))
# output_tif = None


### clip EMIT uncertainty data
input_tif = f"{data_path}/{emit_file2}"
output_tif = f"{data_path}/{emit_file2[:-4]}_clipped.tif"
gdal.Warp(output_tif, input_tif, format = 'GTiff', outputBounds=(ul_x, lr_y, lr_x, ul_y))
output_tif = None

### process the bad bands of EMIT data

In [3]:
data_path = "/Volumes/ChenLab/Fujiang/2_SmallSat_project/1_original_data/1_SHIFT_areas/1_Spaceborne_hyperspectral_imagery/"
file_name = "EMIT_L2A_RFL_001_20230422T195924_2311213_002.nc"
ds_nc = xr.open_dataset(f"{data_path}{file_name}", engine="h5netcdf", group="sensor_band_parameters")
fwhm = ds_nc["fwhm"].values
wvl = ds_nc["wavelengths"].values
good_wvl = ds_nc["good_wavelengths"].values

In [5]:
data_path = "/Volumes/ChenLab/Fujiang/2_SmallSat_project/3_paired_SHIFT_Planet/1_20230422/"
emit_file = "EMIT_L2A_RFL_20230422_projection_clipped_modified_clipped.tif"

emit_data, emit_Geotrans, emit_proj, _ , _ = read_tif(f"{data_path}{emit_file}")
print(emit_data.shape)

def EMIT_bad_bands_parallel(img_chunk, good_wvl_mask):
    small_image = img_chunk.copy()
    results = np.zeros(shape = (small_image.shape[0], small_image.shape[1],small_image.shape[2]))
    for i in range(small_image.shape[0]):
        for j in range(small_image.shape[1]):
            spectra_data = small_image[i, j, :]
            
            valid_indices = np.where(spectra_data >= 0)[0]
            invalid_indices = np.where(spectra_data < 0)[0]
            if len(valid_indices) > 0 and len(invalid_indices) > 0:
                spectra_data[invalid_indices] = np.interp(invalid_indices, valid_indices, spectra_data[valid_indices])
            results[i,j] = spectra_data
    mask = good_wvl_mask[np.newaxis, np.newaxis, :] 
    results = np.where(mask == 0, 0, results)
    return results

chunk_size = 50
image_chunks = []
for i in range(0, emit_data.shape[0], chunk_size):
        for j in range(0, emit_data.shape[1], chunk_size):
            chunk = emit_data[i:i + chunk_size, j:j + chunk_size]
            image_chunks.append(chunk)

num_processes = psutil.cpu_count(logical=False)
chunk_results = Parallel(n_jobs=num_processes)(delayed(EMIT_bad_bands_parallel)(image_chunk,good_wvl) for image_chunk in tqdm(image_chunks,desc="Processing Chunks"))

final_imagery = np.zeros(shape = (emit_data.shape[0], emit_data.shape[1], emit_data.shape[2]))
chunk_index = 0
for i in range(0, final_imagery.shape[0], chunk_size):
    for j in range(0, final_imagery.shape[1], chunk_size):
        final_imagery[i:i + chunk_size, j:j + chunk_size,:] = chunk_results[chunk_index]
        chunk_index = chunk_index+1
print(final_imagery.shape)

out_tif = f"{data_path}{emit_file[:-4]}_interp.tif"
band_names = [f"{x} nm" for x in wvl]
array_to_geotiff(final_imagery, out_tif, emit_Geotrans, emit_proj, band_names=band_names)

(994, 340, 285)



Processing Chunks:   0%|                                                                           | 0/140 [00:00<?, ?it/s][A
Processing Chunks:  17%|███████████▏                                                     | 24/140 [00:00<00:00, 121.66it/s][A
Processing Chunks:  34%|██████████████████████▋                                           | 48/140 [00:02<00:06, 14.74it/s][A
Processing Chunks:  51%|█████████████████████████████████▉                                | 72/140 [00:03<00:02, 22.87it/s][A
Processing Chunks:  69%|█████████████████████████████████████████████▎                    | 96/140 [00:03<00:01, 33.33it/s][A
Processing Chunks: 100%|█████████████████████████████████████████████████████████████████| 140/140 [00:03<00:00, 38.96it/s][A


(994, 340, 285)


### scale factors to convert PlanetScope reflectance to 0-1

In [26]:
data_path = "/Volumes/ChenLab-1/Fujiang/2_SmallSat_project/3_paired_SHIFT_Planet/1_20230422/"
planet_file = "PlanetScope_RFL_20230422_clipped_clipped.tif"
im_data, im_Geotrans, im_proj,rows, cols = read_tif(f"{data_path}{planet_file}")
im_data = im_data/10000

dataset = gdal.Open(f"{data_path}{planet_file}") 
num_bands = dataset.RasterCount
band_names = []
for band_index in range(1, num_bands + 1):
    band = dataset.GetRasterBand(band_index)
    band_name = band.GetDescription() 
    band_names.append(band_name)
    
output_path = f"{data_path}{planet_file[:-4]}_scaled.tif"
array_to_geotiff(im_data, output_path, im_Geotrans, im_proj, band_names=band_names)

### upscale PlanetScope to the Spatial resolution of EMIT data

In [27]:
data_path = "/Volumes/ChenLab-1/Fujiang/2_SmallSat_project/3_paired_SHIFT_Planet/1_20230422/"
planet_file = "PlanetScope_RFL_20230422_clipped_clipped_scaled.tif"
emit_file = "EMIT_L2A_RFL_20230422_projection_clipped_modified_clipped.tif"

emit_data, emit_Geotrans, emit_proj,_, _ = read_tif(f"{data_path}{emit_file}")
x_res = emit_Geotrans[1]
y_res = abs(emit_Geotrans[5])

input_ds = gdal.Open(f"{data_path}{planet_file}")
output_path = f"{data_path}{planet_file[:-4]}_aggregated_60m.tif"

gdal.Warp(output_path, input_ds, xRes=x_res, yRes=y_res,resampleAlg=gdalconst.GRA_Bilinear)
    
input_ds = None

### calculate the EMIT noise ratio

In [54]:
data_path = "/Volumes/ChenLab/Fujiang/2_SmallSat_project/1_original_data/1_SHIFT_areas/1_Spaceborne_hyperspectral_imagery/"
file_name = "EMIT_L2A_RFL_001_20230422T195924_2311213_002.nc"
ds_nc = xr.open_dataset(f"{data_path}{file_name}", engine="h5netcdf", group="sensor_band_parameters")
fwhm = ds_nc["fwhm"].values
wvl = ds_nc["wavelengths"].values
good_wvl = ds_nc["good_wavelengths"].values

In [11]:
data_path = "/Volumes/ChenLab/Fujiang/2_SmallSat_project/3_paired_SHIFT_Planet/1_20230422/"
emit_file = "EMIT_L2A_RFL_20230422_projection_clipped_modified_clipped.tif"
emit_uncertainty = "EMIT_L2A_RFLUNCERT_20230422_projection_clipped_modified_clipped.tif"

emit_data, emit_Geotrans, emit_proj, _ , _ = read_tif(f"{data_path}{emit_file}")
uncert_data, uncert_Geotrans, uncert_proj, _ , _ = read_tif(f"{data_path}{emit_uncertainty}")
ratio = uncert_data/emit_data

out_tif = f"{data_path}EMIT_reflectance_uncertainty_ratio.tif"
band_names = [f"{x}nm" for x in wvl]
array_to_geotiff(ratio, out_tif, emit_Geotrans, emit_proj, band_names=band_names)

### add EMIT noise to upscaled PlanetScope data

In [40]:
spectral_response = "/Volumes/ChenLab/Fujiang/2_SmallSat_project/3_paired_SHIFT_Planet/1_20230422/1_support_materials/Spectral_response_curve_PlanetScope.csv"
data_path = "/Volumes/ChenLab/Fujiang/2_SmallSat_project/3_paired_SHIFT_Planet/1_20230422/"
emit_noise_file = "EMIT_reflectance_uncertainty_ratio.tif"
planet_file = "PlanetScope_RFL_20230422_clipped_clipped_scaled_aggregated_60m.tif"
df = pd.read_csv(spectral_response)

emit_noise, emit_Geotrans, emit_proj, _ , _ = read_tif(f"{data_path}{emit_noise_file}")
planet_data, planet_Geotrans, planet_proj,_, _ = read_tif(f"{data_path}{planet_file}")

chunk_size = 50
image_chunks = []
for i in range(0, emit_noise.shape[0], chunk_size):
        for j in range(0, emit_noise.shape[1], chunk_size):
            chunk = emit_noise[i:i + chunk_size, j:j + chunk_size]
            image_chunks.append(chunk)

num_processes = psutil.cpu_count(logical=False)
chunk_results = Parallel(n_jobs=num_processes)(delayed(run_parallel)(image_chunk, df, wvl, fwhm) for image_chunk in tqdm(image_chunks,desc="Processing Chunks"))

convolved_noise_ratio = np.zeros(shape = (emit_noise.shape[0], emit_noise.shape[1], 8))
chunk_index = 0
for i in range(0, convolved_noise_ratio.shape[0], chunk_size):
    for j in range(0, convolved_noise_ratio.shape[1], chunk_size):
        convolved_noise_ratio[i:i + chunk_size, j:j + chunk_size,:] = chunk_results[chunk_index]
        chunk_index = chunk_index+1

planet_uncertainty = planet_data*convolved_noise_ratio
gaussian_noise = np.random.normal(0, 1, planet_uncertainty.shape) * planet_uncertainty
planet_noisy_array = planet_data + gaussian_noise

dataset = gdal.Open(f"{data_path}{planet_file}") 
num_bands = dataset.RasterCount
band_names = []
for band_index in range(1, num_bands + 1):
    band = dataset.GetRasterBand(band_index)
    band_name = band.GetDescription() 
    band_names.append(band_name)

out_tif = f"{data_path}{planet_file[:-4]}_noisy.tif"
array_to_geotiff(planet_noisy_array, out_tif, planet_Geotrans, planet_proj, band_names=band_names)
out_tif = f"{data_path}Planet_uncertainty_ratio_from_convolved_EMIT_ratio.tif"
array_to_geotiff(convolved_noise_ratio, out_tif, planet_Geotrans, planet_proj, band_names=band_names)


Processing Chunks:   0%|                                                | 0/140 [00:00<?, ?it/s][A
Processing Chunks:  17%|██████▋                                | 24/140 [00:02<00:11, 10.54it/s][A
Processing Chunks:  19%|███████▏                               | 26/140 [00:02<00:11,  9.93it/s][A
Processing Chunks:  19%|███████▌                               | 27/140 [00:02<00:11,  9.83it/s][A
Processing Chunks:  34%|█████████████▎                         | 48/140 [04:36<13:37,  8.88s/it][A
Processing Chunks:  51%|████████████████████                   | 72/140 [09:39<12:17, 10.85s/it][A
Processing Chunks:  52%|████████████████████▎                  | 73/140 [09:39<11:45, 10.53s/it][A
Processing Chunks:  69%|██████████████████████████▋            | 96/140 [13:44<07:45, 10.59s/it][A
Processing Chunks: 100%|██████████████████████████████████████| 140/140 [19:10<00:00,  8.22s/it][A


### inter-calibration for EMIT and PlanetScope imagery

In [None]:
data_path = "/Volumes/ChenLab/Fujiang/2_SmallSat_project/3_paired_SHIFT_Planet/1_20230422/"
out_path = "/Volumes/ChenLab/Fujiang/2_SmallSat_project/3_paired_SHIFT_Planet/1_20230422/2_extracted_points_for_intercalibration/"
file_name = "EMIT_L2A_RFL_20230422_projection_clipped_modified_clipped.tif"

raster_to_points(f"{data_path}{file_name}", f"{out_path}extracted_points")

points = gpd.read_file(f'{out_path}extracted_points.shp')
filtered_points = filter_points(points, 150)
filtered_points.to_file(f'{out_path}extracted_points_filtered.shp')

In [70]:
data_path = "/Volumes/ChenLab/Fujiang/2_SmallSat_project/3_paired_SHIFT_Planet/1_20230422/"
point_path = f"{data_path}2_extracted_points_for_intercalibration/"
out_path = f"{data_path}3_intercalibration_training_data/"

emit_file = "EMIT_L2A_RFL_20230422_projection_clipped_modified_clipped_interp.tif"
planet_file = "PlanetScope_RFL_20230422_clipped_clipped_scaled_aggregated_60m_noisy.tif"
planet_file2 = "PlanetScope_RFL_20230422_clipped_clipped_scaled_aggregated_60m.tif"
shp_name = "extracted_points_filtered.shp"

points = gpd.read_file(f"{point_path}{shp_name}")
points.drop(columns=['Z'],inplace = True)

tiff_ds = gdal.Open(f"{data_path}{emit_file}")
num_bands = tiff_ds.RasterCount
band_names = wvl.tolist()

extracted_bands = band_names[:100]
extracted_bands = [f"{round(x, 5)} nm" for x in extracted_bands]
extracted_values = [[] for _ in range(100)]
for index, row in points.iterrows():
    point = row.geometry
    x, y = point.x, point.y
    
    # Convert point coordinates to pixel coordinates
    px = int((x - tiff_ds.GetGeoTransform()[0]) / tiff_ds.GetGeoTransform()[1])
    py = int((y - tiff_ds.GetGeoTransform()[3]) / tiff_ds.GetGeoTransform()[5])
    
    # Read values from GeoTIFF for each band
    for band_num in range(1, 100 + 1):
        band = tiff_ds.GetRasterBand(band_num)
        value = band.ReadAsArray(px, py, 1, 1)[0][0]
        extracted_values[band_num - 1].append(value)

extracted_values = pd.DataFrame(np.array(extracted_values)).T
extracted_values.columns = extracted_bands

tiff_ds = gdal.Open(f"{data_path}{planet_file}")
num_bands = tiff_ds.RasterCount
band_names = []
for band_number in range(1, num_bands + 1):
    band = tiff_ds.GetRasterBand(band_number)
    band_description = band.GetDescription()
    band_names.append(band_description)
    
extracted_values2 = [[] for _ in range(num_bands)]
for index, row in points.iterrows():
    point = row.geometry
    x, y = point.x, point.y
    
    # Convert point coordinates to pixel coordinates
    px = int((x - tiff_ds.GetGeoTransform()[0]) / tiff_ds.GetGeoTransform()[1])
    py = int((y - tiff_ds.GetGeoTransform()[3]) / tiff_ds.GetGeoTransform()[5])
    
    # Read values from GeoTIFF for each band
    for band_num in range(1, num_bands + 1):
        band = tiff_ds.GetRasterBand(band_num)
        value = band.ReadAsArray(px, py, 1, 1)[0][0]
        extracted_values2[band_num - 1].append(value)

extracted_values2 = pd.DataFrame(np.array(extracted_values2)).T
extracted_values2.columns = band_names


tiff_ds = gdal.Open(f"{data_path}{planet_file2}")
num_bands = tiff_ds.RasterCount
band_names = []
for band_number in range(1, num_bands + 1):
    band = tiff_ds.GetRasterBand(band_number)
    band_description = band.GetDescription()
    band_names.append(band_description)
    
extracted_values3 = [[] for _ in range(num_bands)]
for index, row in points.iterrows():
    point = row.geometry
    x, y = point.x, point.y
    
    # Convert point coordinates to pixel coordinates
    px = int((x - tiff_ds.GetGeoTransform()[0]) / tiff_ds.GetGeoTransform()[1])
    py = int((y - tiff_ds.GetGeoTransform()[3]) / tiff_ds.GetGeoTransform()[5])
    
    # Read values from GeoTIFF for each band
    for band_num in range(1, num_bands + 1):
        band = tiff_ds.GetRasterBand(band_num)
        value = band.ReadAsArray(px, py, 1, 1)[0][0]
        extracted_values3[band_num - 1].append(value)

extracted_values3 = pd.DataFrame(np.array(extracted_values3)).T
extracted_values3.columns = [f"{x}_before_noisy" for x in band_names]

final = pd.concat([points,extracted_values,extracted_values2,extracted_values3], axis = 1)
final.drop(columns=['geometry'],inplace = True)
final.to_csv(f"{out_path}extracted_samples_reflectance.csv", index = False)

#### Start inter-calibration

In [9]:
def rsquared(x, y): 
    slope, intercept, r_value, p_value, std_err = stats.linregress(x, y) 
    a = r_value**2
    return a

class RegressionModel(nn.Module):
    def __init__(self, input_size, nodes):
        super(RegressionModel, self).__init__()
        self.fc1 = nn.Linear(input_size, nodes)
        self.act1 = nn.ReLU()
        self.fc2 = nn.Linear(nodes, nodes//2)
        self.act2 =nn.ReLU()
        self.out = nn.Linear(nodes//2, 8)
        self.initialize_weights()

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                torch.manual_seed(0)
                init.xavier_normal_(m.weight)
                init.constant_(m.bias, 0)
    def forward(self, x):
        x = self.act1(self.fc1(x))
        x = self.act2(self.fc2(x))
        x =self.out(x)
        return x

In [136]:
spectral_response = "/Volumes/ChenLab/Fujiang/2_SmallSat_project/3_paired_SHIFT_Planet/1_20230422/1_support_materials/Spectral_response_curve_PlanetScope.csv"
df = pd.read_csv(spectral_response)
path = "/Volumes/ChenLab/Fujiang/2_SmallSat_project/3_paired_SHIFT_Planet/1_20230422/"
data_path = f"{path}3_intercalibration_training_data/"

data = pd.read_csv(f"{data_path}extracted_samples_reflectance.csv")
data = data[(data >= 0).all(axis=1)]
data.reset_index(drop = True, inplace = True)

emit = data.loc[:,"381.00558 nm":"1118.73682 nm"]
planet = data.loc[:,"coastal_blue":"nir"]
planet2 = data.loc[:,"coastal_blue_before_noisy":"nir_before_noisy"]
emit.to_csv(f'{data_path}0_EMIT_reflectance.csv', index = False)
planet.to_csv(f'{data_path}0_Planet_reflectance.csv', index = False)
planet2.to_csv(f'{data_path}0_Planet_reflectance_before_noisy.csv', index = False)


wl_emit = [float(x.split(' ')[0]) for x in emit.columns]
wl_planet = [443, 490, 531, 565, 610, 665, 705, 865]


def run_para(data_chunk, df, wvl, fwhm):
    var_start = True
    for i in range(len(data_chunk)):
        hyper_spectra = data_chunk.iloc[i]
        response_curve_wl = df[df.columns[0]]
        
        simulated_refl = []
        for kk in range(1,9):
            response_curve = df[df.columns[kk]]
            refl = compute_weighted_reflectance(response_curve_wl, wvl[:100], hyper_spectra, response_curve, fwhm[:100])
            simulated_refl.append(refl)
        simulated_refl = pd.DataFrame(np.array(simulated_refl)).T
        simulated_refl.columns = [f"simulated {x}" for x in planet.columns]
        if var_start:
            final_simulated_refl = simulated_refl
            var_start = False
        else:
            final_simulated_refl = pd.concat([final_simulated_refl, simulated_refl], axis = 0)
    final_simulated_refl.reset_index(drop = True, inplace = True)
    return final_simulated_refl

chunk_size = 300
data_chunks = []
for i in range(0, len(emit), chunk_size):
    chunk = emit.iloc[i:i+chunk_size]
    data_chunks.append(chunk)
num_processes = psutil.cpu_count(logical=False)
chunk_results = Parallel(n_jobs=num_processes)(delayed(run_para)(data_chunk, df, wvl, fwhm) for data_chunk in tqdm(data_chunks,desc="Processing Chunks"))

start_var = True
for chunk in chunk_results:
    chunk_result = chunk
    if start_var:
        final = chunk_result
        start_var = False
    else:
        final = pd.concat([final, chunk_result], axis = 0)
final.reset_index(drop = True, inplace = True)
final.to_csv(f'{data_path}0_Convolved_EMIT_reflectance.csv', index = False)

Processing Chunks: 100%|██████████████████████████████████████| 119/119 [01:54<00:00,  1.04it/s]


In [None]:
path = "/Volumes/ChenLab/Fujiang/2_SmallSat_project/3_paired_SHIFT_Planet/1_20230422/"
data_path = f"{path}3_intercalibration_training_data/"
final_simulated_refl = pd.read_csv(f'{data_path}0_Convolved_EMIT_reflectance.csv')
planet = pd.read_csv(f'{data_path}0_Planet_reflectance.csv')

X = planet.values
y = final_simulated_refl.values

## linear model
start_t = datetime.datetime.now()
print('start linear regression:', start_t)
model = LinearRegression()
model.fit(X, y)
joblib.dump(model, f"{data_path}1_saved_linear_model_for_inter_calibration.pkl")
predict_refl = model.predict(X)

predict_refl = pd.DataFrame(predict_refl)
predict_refl.columns = [f"predicted {x}" for x in planet.columns]
predict_refl.to_csv(f'{data_path}1_Corrected_Planet reflectance_linear_regression.csv', index=False)

end_t = datetime.datetime.now()
elapsed_sec = (end_t - start_t).total_seconds()
print('end linear regression:', end_t)
print('   total:', elapsed_sec / 60, 'min')
print('****************************************************************************************************************')

############################################################################################################
start_t = datetime.datetime.now()
print('start train DNN models:', start_t)

device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Use {device} to train DNN models")
scaler_x = StandardScaler()
X_scale = scaler_x.fit_transform(X)

scaler_y = StandardScaler()
y_scale = scaler_y.fit_transform(y)

joblib.dump(scaler_x, f'{data_path}2_scaler_x.pkl')
joblib.dump(scaler_y, f'{data_path}2_scaler_y.pkl')

# X_tensor, y_tensor = torch.Tensor(X_scale), torch.Tensor(y_scale)
X_tensor, y_tensor = torch.Tensor(X_scale).to(device), torch.Tensor(y_scale).to(device)  ####### Use GPU to train

param_grid = {'learning_rate': [0.01], 'batch_size': [16, 32, 64], 'nodes': [32, 64, 72]}
# param_grid = {'learning_rate':[0.01],'batch_size':[32], 'nodes':[32,40]}             ####### for code testing
grid = ParameterGrid(param_grid)
loss_fn = nn.L1Loss()
best_accuracy = []
best_params = []
for paras in grid:
    learning_rate = paras['learning_rate']
    batch_size = paras['batch_size']
    nodes = paras['nodes']

    train_loader = DataLoader(TensorDataset(X_tensor, y_tensor), batch_size=batch_size, shuffle=True)
    model = RegressionModel(X_tensor.size()[1], nodes).to(device)  ####### Use GPU to train
    # model = RegressionModel(X_tensor.size()[1], nodes)

    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0.005)
    epoch_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 5, gamma=0.9)
    best_mse = pow(10, 10)
    history = []

    for epoch in range(300):
        model.train()
        for X_batch, y_batch in train_loader:
            y_pred = model(X_batch)
            loss = loss_fn(y_pred, y_batch)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        epoch_scheduler.step()
        model.eval()
        y_pred = model(X_tensor)
        mse = loss_fn(y_pred, y_tensor)
        mse = float(mse)
        history.append(mse)
        if mse < best_mse:
            best_weights = copy.deepcopy(model.state_dict())
            best_mse = mse

    model.load_state_dict(best_weights)
    model.eval()
    with torch.no_grad():
        pred_dnn = model(X_tensor)
        pred_dnn = pred_dnn.detach().cpu().numpy()
        pred_dnn = scaler_y.inverse_transform(pred_dnn)

        obs_y = y_tensor.detach().cpu().numpy()
        obs_y_raw = scaler_y.inverse_transform(obs_y)
        accu = r2_score(pred_dnn, obs_y_raw)

        print("  ", datetime.datetime.now(), paras, 'R^2:', accu)
        best_accuracy.append(accu)
        best_params.append(paras)

#######obtained  best  parameters
new_paras = best_params[best_accuracy.index(max(best_accuracy))]
print("   bset parameters:", new_paras)
learning_rate = new_paras['learning_rate']
batch_size = new_paras['batch_size']
nodes = new_paras['nodes']

train_loader = DataLoader(TensorDataset(X_tensor, y_tensor), batch_size=batch_size, shuffle=True)
model = RegressionModel(X_tensor.size()[1], nodes).to(device)  ####### Use GPU to train
# model = RegressionModel(X_tensor.size()[1], nodes)
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0.005)
epoch_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 5, gamma=0.9)
best_mse = pow(10, 10)
history = []

para_path = f'{data_path}2_saved_DNN_model_for_inter_calibration.pt'
for epoch in range(300):
    model.train()
    for X_batch, y_batch in train_loader:
        y_pred = model(X_batch)
        loss = loss_fn(y_pred, y_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    epoch_scheduler.step()
    model.eval()
    y_pred = model(X_tensor)
    mse = loss_fn(y_pred, y_tensor)
    mse = float(mse)
    history.append(mse)
    if mse < best_mse:
        best_mse = mse
        torch.save(model.state_dict(), para_path)

model.load_state_dict(torch.load(para_path))
model.eval()

fig, ax = plt.subplots(figsize=(9, 4))
ax.set_facecolor((0, 0, 0, 0.03))
ax.grid(color='gray', linestyle=':', linewidth=0.3)
config = {"font.family": 'Helvetica'}
plt.rcParams.update(config)

ax.plot(history, color="red", lw=2)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('MSE', fontsize=12)
plt.savefig(f'figures/1_Epoch for training DNN model.png', dpi=1000, bbox_inches='tight')
with torch.no_grad():
    pred_dnn = model(X_tensor)
    pred_dnn = pred_dnn.detach().cpu().numpy()
    pred_dnn = scaler_y.inverse_transform(pred_dnn)

pred_dnn = pd.DataFrame(pred_dnn)
pred_dnn.columns = [f"predicted {x}" for x in planet.columns]
pred_dnn.to_csv(f'{data_path}2_Corrected_Planet reflectance_DNN_models.csv', index=False)
print('end training DNN model:', end_t)
print('   total:', elapsed_sec / 60, 'min')
print('****************************************************************************************************************')

#### apply trained model to Planet imagery

In [11]:
def apply_DNN_models(image_chunk, DNN_model, scaler_x, scaler_y):
    image_chunk_reshape = image_chunk.reshape(-1, image_chunk.shape[2])
    X = scaler_x.transform(image_chunk_reshape)
    X_tensor = torch.Tensor(X)
    # model = RegressionModel(X_tensor.size()[1], 32)
    model = RegressionModel(X_tensor.size()[1], 72)
    model.load_state_dict(torch.load(DNN_model, weights_only=True))
    model.eval()
    with torch.no_grad():
        pred_dnn = model(X_tensor)
        pred_dnn = pred_dnn.detach().cpu().numpy()
        pred_dnn = scaler_y.inverse_transform(pred_dnn)
    pred_dnn = pred_dnn.reshape(image_chunk.shape)
    return pred_dnn

data_path = "/Volumes/ChenLab/Fujiang/2_SmallSat_project/3_paired_SHIFT_Planet/1_20230422/"
planet_file = "PlanetScope_RFL_20230422_clipped_clipped_scaled.tif"
im_data, im_Geotrans, im_proj,rows, cols = read_tif(f"{data_path}{planet_file}")

dataset = gdal.Open(f"{data_path}{planet_file}") 
num_bands = dataset.RasterCount
band_names = []
for band_index in range(1, num_bands + 1):
    band = dataset.GetRasterBand(band_index)
    band_name = band.GetDescription() 
    band_names.append(band_name)
    
DNN_model = f"{data_path}3_intercalibration_training_data/2_saved_DNN_model_for_inter_calibration.pt"
scaler_x =  joblib.load(f"{data_path}3_intercalibration_training_data/2_scaler_x.pkl")
scaler_y =  joblib.load(f"{data_path}3_intercalibration_training_data/2_scaler_y.pkl")

corrected_planet = apply_DNN_models(im_data, DNN_model, scaler_x, scaler_y)

output_path = f"{data_path}{planet_file[:-4]}_intersensors_corrected.tif"
array_to_geotiff(corrected_planet, output_path, im_Geotrans, im_proj, band_names=band_names)