In [1]:
import os
import re
import sys
import dask.array as da
import pandas as pd
import numpy as np
import rasterio
import geopandas as gpd
from osgeo import gdal, osr
import matplotlib.pyplot as plt

In [2]:
def read_tif(tif_file):
    dataset = gdal.Open(tif_file)
    cols = dataset.RasterXSize
    rows = dataset.RasterYSize
    im_proj = (dataset.GetProjection())
    im_Geotrans = (dataset.GetGeoTransform())
    im_data = dataset.ReadAsArray(0, 0, cols, rows)
    if im_data.ndim == 3:
        im_data = np.moveaxis(dataset.ReadAsArray(0, 0, cols, rows), 0, -1)
    return im_data, im_Geotrans, im_proj,rows, cols
    
def array_to_geotiff(array, output_path, geo_transform, projection, band_names=None):
    rows, cols, num_bands = array.shape
    driver = gdal.GetDriverByName('GTiff')
    dataset = driver.Create(output_path, cols, rows, num_bands, gdal.GDT_Float32)
    
    dataset.SetGeoTransform(geo_transform)
    dataset.SetProjection(projection)
    
    for band_num in range(num_bands):
        band = dataset.GetRasterBand(band_num + 1)
        band.WriteArray(array[:, :, band_num])
        band.FlushCache()
        
        if band_names:
            band.SetDescription(band_names[band_num])
    
    dataset = None
    band = None
    return
def pixel_to_coord(gt, x, y):
    x_coord = gt[0] + x * gt[1] + y * gt[2]
    y_coord = gt[3] + x * gt[4] + y * gt[5]
    return x_coord, y_coord

### Apply GLT data (split bands due to the memory limitation)

In [3]:
data_path = "/Users/fji/Desktop/SHIFT/"
out_path = f"{data_path}processed_imagery/"
image_path = f"{data_path}SHIFT_20220420_imagery/"
glt_path = f"{data_path}mosaic_glt/"
date = "20220420"

glt_image = f"{glt_path}{date}_box_mosaic_glt_phase"
glt_data, glt_Geotrans, glt_proj,glt_rows, glt_cols = read_tif(glt_image)

flight_lines = f"{glt_path}{date}_box_lines.txt"
df = pd.read_csv(flight_lines, header=None)
df.columns = ["flight_names"]
df["flight_lines"] = np.arange(1,len(df)+1)
flight = dict(zip(df['flight_lines'], df['flight_names']))

print(glt_data.shape, flight)
print(glt_Geotrans, glt_proj)

(12023, 13739, 3) {1: 'ang20220420t184735', 2: 'ang20220420t190012', 3: 'ang20220420t191418', 4: 'ang20220420t192635', 5: 'ang20220420t194128', 6: 'ang20220420t195351', 7: 'ang20220420t200950', 8: 'ang20220420t202328', 9: 'ang20220420t204018'}
(717720.0, 5.0, -0.0, 3865865.0, -0.0, -5.0) PROJCS["WGS 84 / UTM zone 10N",GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,AUTHORITY["EPSG","7030"]],AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]],AUTHORITY["EPSG","4326"]],PROJECTION["Transverse_Mercator"],PARAMETER["latitude_of_origin",0],PARAMETER["central_meridian",-123],PARAMETER["scale_factor",0.9996],PARAMETER["false_easting",500000],PARAMETER["false_northing",0],UNIT["metre",1,AUTHORITY["EPSG","9001"]],AXIS["Easting",EAST],AXIS["Northing",NORTH],AUTHORITY["EPSG","32610"]]


In [4]:
band_splits = 110
total_bands = 425
for i in range(0, total_bands, band_splits):
    start_band = i
    end_band = min(i + band_splits, total_bands)
    band_nums = end_band - start_band
    print(f"processing band {start_band+1} to band {end_band}.")
    out_array = np.full((glt_data.shape[0],glt_data.shape[1], band_nums), np.nan)
    
    lines = list(flight.keys())
    for line in lines:
        image = flight[line]
        mask = glt_data[:, :, 2] != line
        expanded_mask = mask[:,:,np.newaxis]
        expanded_mask = np.broadcast_to(expanded_mask, glt_data.shape)
        glt = np.where(expanded_mask, np.nan, glt_data)
        
        idx = np.where(glt[:,:,2] == line)
        idx = list(zip(idx[0], idx[1]))
        print(f"  line {line}: total {len(idx)} points. opening....")
        
        with rasterio.open(f"{image_path}{image}_rfl_v0p1") as dataset:
            im_data = dataset.read([i for i in range(start_band+1, end_band+1)])
            
        im_data = np.moveaxis(im_data,0,-1)
        print(f"    opened line {line} {im_data.shape}")
        
        for coor in idx:
            image_col = int(glt[coor[0],coor[1],0])-1
            image_row = int(glt[coor[0],coor[1],1])-1
            out_array[coor[0],coor[1],:] = im_data[image_row, image_col,:]
        im_data = None
            
    print(f"start saving band{start_band+1} to band{end_band} imagery")       
    out_tif = f"{out_path}SHIFT_RFL_{date}_band{start_band+1}_band{end_band}.tif"
    array_to_geotiff(out_array, out_tif, glt_Geotrans, glt_proj, band_names=[f"band {x}" for x in range(start_band+1, end_band+1)])
    out_array = None
    im_data = None   

processing band 1 to band 110.
  line 1: total 6336104 points. opening....


  dataset = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)


    opened line 1 (17446, 598, 110)
  line 2: total 5581688 points. opening....
    opened line 2 (16048, 598, 110)
  line 3: total 6619687 points. opening....
    opened line 3 (13045, 598, 110)
  line 4: total 5455417 points. opening....
    opened line 4 (15642, 598, 110)
  line 5: total 6238058 points. opening....
    opened line 5 (12990, 598, 110)
  line 6: total 5031907 points. opening....
    opened line 6 (15078, 598, 110)
  line 7: total 6019322 points. opening....
    opened line 7 (17157, 598, 110)
  line 8: total 6835884 points. opening....
    opened line 8 (16887, 598, 110)
  line 9: total 5823821 points. opening....
    opened line 9 (17063, 598, 110)
start saving band1 to band110 imagery
processing band 111 to band 220.
  line 1: total 6336104 points. opening....
    opened line 1 (17446, 598, 110)
  line 2: total 5581688 points. opening....
    opened line 2 (16048, 598, 110)
  line 3: total 6619687 points. opening....
    opened line 3 (13045, 598, 110)
  line 4: tot

### clip data to EMIT and Planet extent

In [5]:
data_path = "/Users/fji/Desktop/SHIFT/processed_imagery/"
shp = f"/Users/fji/Desktop/SHIFT/shp_clip/clipped_areas.shp"
file1 = "SHIFT_RFL_20220420_band1_band110.tif"
file2 = "SHIFT_RFL_20220420_band111_band220.tif"
file3 = "SHIFT_RFL_20220420_band221_band330.tif"
file4 = "SHIFT_RFL_20220420_band331_band425.tif"

files = [file1, file2, file3, file4]

gdf = gpd.read_file(shp)
bounds = gdf.bounds
min_x = bounds["minx"].values[0]
min_y = bounds["miny"].values[0]
max_x = bounds["maxx"].values[0]
max_y = bounds["maxy"].values[0]
ul_x, ul_y = (min_x, max_y)
lr_x, lr_y = (max_x, min_y)
print(ul_x, ul_y, lr_x, lr_y)

723109.8808592575 3841945.93245933 745869.6095620227 3812729.2183499457


In [6]:
ul_x, ul_y, lr_x, lr_y = 723110, 3841946, 745870, 3812729  ## adjust according to the values from previous step

for file in files:
    input_tif = f"{data_path}{file}"
    output_tif = f"{data_path}/{file[:-4]}_clipped.tif"
    gdal.Warp(output_tif, input_tif, format = 'GTiff', outputBounds=(ul_x, lr_y, lr_x, ul_y))
    output_tif = None

### merge the bands data to full spectrum data

In [7]:
head_file = "/Users/fji/Desktop/SHIFT/SHIFT_20220420_imagery/ang20220420t191418_rfl_v0p1.hdr"
with open(head_file, 'r') as file:
    hdr_content = file.read()

wavelength_match = re.search(r'wavelength\s*=\s*\{(.*?)\}', hdr_content, re.DOTALL)
if wavelength_match:
    wavelength_str = wavelength_match.group(1)
    wavelength = [float(value.strip()) for value in wavelength_str.split(',')]

fwhm_match = re.search(r'fwhm\s*=\s*\{(.*?)\}', hdr_content, re.DOTALL)
if fwhm_match:
    fwhm_str = fwhm_match.group(1)
    fwhm = [float(value.strip()) for value in fwhm_str.split(',')]

band_names = [f"{round(x,3)} nm" for x in wavelength]

In [10]:
data_path = "/Users/fji/Desktop/SHIFT/processed_imagery/"

file1 = "SHIFT_RFL_20220420_band1_band110_clipped.tif"
file2 = "SHIFT_RFL_20220420_band111_band220_clipped.tif"
file3 = "SHIFT_RFL_20220420_band221_band330_clipped.tif"
file4 = "SHIFT_RFL_20220420_band331_band425_clipped.tif"

files = [file1, file2, file3, file4]
start_var = True
for file in files:
    print(f"opening {file}")
    im_data, im_Geotrans, im_proj,rows, cols = read_tif(f"{data_path}{file}")
    print(f"  opened {file}, shape: {im_data.shape}")
    if start_var:
        final_image = im_data
        start_var = False
    else:
        final_image = np.concatenate((final_image, im_data), axis=-1)
    print(f"   final image shape: {final_image.shape}")
    im_data = None

print(f"start saving final imagery")       
out_tif = f"{data_path}SHIFT_RFL_20220420_clipped.tif"
array_to_geotiff(final_image, out_tif, im_Geotrans, im_proj, band_names=band_names)

opening SHIFT_RFL_20220420_band1_band110_clipped.tif
  opened SHIFT_RFL_20220420_band1_band110_clipped.tif, shape: (5843, 4552, 110)
   final image shape: (5843, 4552, 110)
opening SHIFT_RFL_20220420_band111_band220_clipped.tif
  opened SHIFT_RFL_20220420_band111_band220_clipped.tif, shape: (5843, 4552, 110)
   final image shape: (5843, 4552, 220)
opening SHIFT_RFL_20220420_band221_band330_clipped.tif
  opened SHIFT_RFL_20220420_band221_band330_clipped.tif, shape: (5843, 4552, 110)
   final image shape: (5843, 4552, 330)
opening SHIFT_RFL_20220420_band331_band425_clipped.tif
  opened SHIFT_RFL_20220420_band331_band425_clipped.tif, shape: (5843, 4552, 95)
   final image shape: (5843, 4552, 425)
start saving final imagery


### Clip the SHIFT to training_testing_validating_area, convert the projection from WGS 84 / UTM zone 10N to 11N to match the projection of Planet and EMIT

In [15]:
data_path = "/Users/fji/Desktop/SHIFT/processed_imagery/"
file_name = f"{data_path}SHIFT_RFL_20220420_clipped.tif"

shp_path = "/Volumes/ChenLab-1/Fujiang/2_SmallSat_project/3_paired_SHIFT_Planet/1_20230422/4_data_fusion_algorithm/1_original_data/0_experimental_shp/"
shp1 = "1_experimental_shp.shp"
shp2 = "2_whole_training_shp.shp"
shp3 = "3_testing_shp.shp"
shp4 = "4_validation_shp.shp"


shps = [shp1, shp2, shp3, shp4]
out_tifs = ["1_SHIFT_small_area.tif", "2_SHIFT_whole_area.tif", "3_SHIFT_testing_area.tif", "4_SHIFT_validating_area.tif"]

for idx, shp in enumerate(shps):
    gdf = gpd.read_file(f"{shp_path}{shp}")
    bounds = gdf.bounds
    min_x = bounds["minx"].values[0]
    min_y = bounds["miny"].values[0]
    max_x = bounds["maxx"].values[0]
    max_y = bounds["maxy"].values[0]
    ul_x, ul_y = (min_x, max_y)
    lr_x, lr_y = (max_x, min_y)
    
    input_tif = file_name
    output_tif = f"{data_path}/{out_tifs[idx]}"
    gdal.Warp(output_tif, input_tif, format = 'GTiff', outputBounds=(ul_x, lr_y, lr_x, ul_y), dstSRS='EPSG:32611')
    output_tif = None