# Notebook to adjust radiometry of Planet images using snow-covered pixels and dark areas

Rainey Aberle

Spring 2022

In [None]:
import os
import glob
import matplotlib.pyplot as plt
import numpy as np
import rasterio as rio
from shapely.geometry import Polygon
from scipy.interpolate import interp2d

In [None]:
# -----Determine whether to save output images
save_images = True # = True to save adjusted raster image files in out_path
save_figures = True # = True to save output figures to figures_out_path

# ----Define ID for study site (used in output file names)
site_ID = 'GG'

# -----Define paths in directory
# path to Planet images
im_path = '/Users/raineyaberle/Research/PhD/study-sites/Gulkana/imagery/2021-04-01_2021-10-01/PSScene4Band/'
# output path for adjusted images
out_path = im_path+'../filtered-adjusted-radiometry/'
# output path for figures
figures_out_path = im_path+'../../../figures/filtered-adjusted-radiometry/'

In [None]:
# -----Load Planet image and metadata file names from directory
ims = os.chdir(im_path) # change directory
im_names = glob.glob('*SR_clip.tif') # load all .tif file names
im_names.sort() # sort file names by date
meta_names = glob.glob('*metadata_clip.xml') # load metadata file names
meta_names.sort() # sort file names by date

# ----Create output folders if they do not exist
if os.path.isdir(out_path)==0:
    os.mkdir(out_path)
    print(out_path+' directory made')
if os.path.isdir(figures_out_path)==0:
    os.mkdir(figures_out_path)
    print(figures_out_path+' directory made')

In [None]:
# -----Define 'bright' area to use for image radiometric adjustment

# define minx, maxx, miny, and maxy for bright area
# Wolverine
# bright_minx, bright_maxx, bright_miny, bright_maxy = 393.5e3, 396.5e3, 6699.2e3, 6700.5e3
# Gulkana
bright_minx, bright_maxx, bright_miny, bright_maxy = 576.5e3, 577.6e3, 7017.4e3, 7018.1e3

# create Shapely Polygon of bright area
bright_poly = Polygon([[bright_minx, bright_miny], 
                       [bright_maxx, bright_miny], 
                       [bright_maxx, bright_maxy],
                       [bright_minx, bright_maxy],
                       [bright_minx, bright_miny]])

# load one image to plot points
im = rio.open(im_names[3])
# define bands (blue, green, red, near infrared)
im_scalar = 10000 # scalar multiplier for image reflectance values
b = im.read(1).astype(float) / im_scalar 
g = im.read(2).astype(float) / im_scalar 
r = im.read(3).astype(float) / im_scalar 
nir = im.read(4).astype(float) / im_scalar  
# define coordinates grid
x = np.linspace(im.bounds.left, im.bounds.right, num=np.shape(b)[1])
y = np.linspace(im.bounds.top, im.bounds.bottom, num=np.shape(b)[0])
# plot image with bright and dark points
fig = plt.figure(figsize=(8,8))
plt.rcParams.update({'font.size': 12, 'font.serif': 'Arial'})
plt.imshow(np.dstack([r, g, b]), extent=(np.min(x)/1000, np.max(x)/1000, np.min(y)/1000, np.max(y)/1000))
plt.plot([x/1000 for x in bright_poly.exterior.xy[0]], [y/1000 for y in bright_poly.exterior.xy[1]],
         color='black', linewidth=2, label='SCA')
plt.xlabel('Easting [km]')
plt.ylabel('Northing [km]')
plt.legend()
plt.show()

In [None]:
# -----Apply radiometric correction by 'stretching' image to bright and darkest points

# Define desired SR values at the bright area and darkest point for each band
# bright area
bright_b_adj = 0.94
bright_g_adj = 0.95
bright_r_adj = 0.94
bright_nir_adj = 0.78
# dark point
dark_adj = 0.0

# maximum cloud cover (skips image if cloud cover exceed max cloud cover)
max_cloud_cover = 20.0

# Loop through images
y_count = 0
n_count = 0
instrument = []
cloud_cover = []
for im_name in im_names:
    
    # load image
    im = rio.open(im_name)
    
    # define bands (blue, green, red, near infrared)
    im_scalar = 10000 # scalar multiplier for image reflectance values
    b = im.read(1).astype(float) / im_scalar 
    g = im.read(2).astype(float) / im_scalar 
    r = im.read(3).astype(float) / im_scalar 
    nir = im.read(4).astype(float) / im_scalar  
    
    # define coordinates grid
    x = np.linspace(im.bounds.left, im.bounds.right, num=np.shape(b)[1])
    y = np.linspace(im.bounds.top, im.bounds.bottom, num=np.shape(b)[0])

    # check if image contains bright area
    if ((np.min(bright_poly.exterior.xy[0])>np.min(x)) 
        & (np.max(bright_poly.exterior.xy[0])<np.max(x)) 
        & (np.min(bright_poly.exterior.xy[1])>np.min(y)) 
        & (np.max(bright_poly.exterior.xy[1])<np.max(y))):
        
        y_count+=1 # increase counter for adjusted images
    
        # create vectors of image points that fall within the bright area
        x_pts = x[(x>bright_minx) & (x<bright_maxx)]
        y_pts = y[(y>bright_miny) & (y<bright_maxy)]
        
        # check that values are greater than zero at the bright area and dark point
        f_b = interp2d(x, y, b)
        if (np.median(f_b(x_pts, y_pts))>0) & (~np.isnan(np.median(f_b(x_pts, y_pts)))):
        
            # filter images with clipped green and blue bands
            
            
            # load instrument name and cloud cover percentage from metadata
            for meta_name in meta_names:
                if im_name[0:24] in meta_name:
                    # open the sample file used
                    meta = open(meta_name)
                    # read the content of the file opened
                    meta_content = meta.readlines()
                    # read instrument name from the file
                    inst = meta_content[53].split('>')[1]
                    if "PS2" in inst:
                        inst = inst[0:3]
                    elif "PSB" in inst:
                        inst = inst[0:6]
                    instrument = instrument + [inst]
                    # read cloud cover percentage from the file
                    cc = meta_content[148].split('>')[1]
                    cc = cc.split('<')[0]
                    cc = float(cc)
                    cloud_cover = cloud_cover + [cc]
    
            # continue to next iteration if cloud cover is above 20%
            if cc > max_cloud_cover:
                continue
            
            # adjust SR using bright and dark points
            # band_adjusted = band*A - B
            # A = (bright_adjusted - dark_adjusted) / (bright - dark)
            # B = (dark*bright_adjusted - bright*dark_adjusted) / (bright - dark)
            # blue band
            bright_b = np.median(f_b(x_pts, y_pts)) # SR at bright point
            dark_b = np.min(b) # SR at darkest point
            A = (bright_b_adj - dark_adj) / (bright_b - dark_b)
            B = (dark_b*bright_b_adj - bright_b*dark_adj) / (bright_b - dark_b)
            b_adj = (b * A) - B
            b_adj = np.where(b==0, np.nan, b_adj) # replace no data values with nan
            # green band
            f_g = interp2d(x, y, g)
            bright_g = np.median(f_g(x_pts, y_pts)) # SR at bright point
            dark_g = np.min(g) # SR at darkest point
            A = (bright_g_adj - dark_adj) / (bright_g - dark_g)
            B = (dark_g*bright_g_adj - bright_g*dark_adj) / (bright_g - dark_g)
            g_adj = (g * A) - B
            g_adj = np.where(g==0, np.nan, g_adj) # replace no data values with nan
            # red band
            f_r = interp2d(x, y, r)
            bright_r = np.median(f_r(x_pts, y_pts)) # SR at bright point
            dark_r = np.min(r) # SR at darkest point
            A = (bright_r_adj - dark_adj) / (bright_r - dark_r)
            B = (dark_r*bright_r_adj - bright_r*dark_adj) / (bright_r - dark_r)
            r_adj = (r * A) - B
            r_adj = np.where(r==0, np.nan, r_adj) # replace no data values with nan
            # nir band
            f_nir = interp2d(x, y, nir)
            bright_nir = np.median(f_nir(x_pts, y_pts)) # SR at bright point
            dark_nir = np.min(nir) # SR at darkest point
            A = (bright_nir_adj - dark_adj) / (bright_nir - dark_nir)
            B = (dark_nir*bright_nir_adj - bright_nir*dark_adj) / (bright_nir - dark_nir)
            nir_adj = (nir * A) - B
            nir_adj = np.where(nir==0, np.nan, nir_adj) # replace no data values with nan

            # print new values at the bright and dark points to check for success
    #         f_b_adj = interp2d(x, y, b_adj)
    #         f_g_adj = interp2d(x, y, g_adj)
    #         f_r_adj = interp2d(x, y, r_adj)
    #         f_nir_adj = interp2d(x, y, nir_adj)
    #         print('    blue:',f_b_adj(bright_pt[0], bright_pt[1]))
    #         print('    green:',f_g_adj(bright_pt[0], bright_pt[1]))
    #         print('    red:',f_r_adj(bright_pt[0], bright_pt[1]))
    #         print('    nir:',f_nir_adj(bright_pt[0], bright_pt[1]))
                    
            # plot results
            fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2,2,figsize=(16,10),gridspec_kw={'height_ratios': [3,1]})
            plt.rcParams.update({'font.size': 12, 'font.serif': 'Arial'})
            # original image
            im_original = ax1.imshow(np.dstack([r, g, b]), 
                        extent=(np.min(x)/1000, np.max(x)/1000, np.min(y)/1000, np.max(y)/1000))
            ax1.plot([x/1000 for x in bright_poly.exterior.xy[0]], [y/1000 for y in bright_poly.exterior.xy[1]],
                     color='black', linewidth=2, label='SCA')
            ax1.legend()
            ax1.set_xlabel('Easting [km]')
            ax1.set_ylabel('Northing [km]')
            ax1.set_title('Original image')
            # adjusted image
            im_adjusted = ax2.imshow(np.dstack([r_adj, g_adj, b_adj]), 
                        extent=(np.min(x)/1000, np.max(x)/1000, np.min(y)/1000, np.max(y)/1000))
            ax2.plot([x/1000 for x in bright_poly.exterior.xy[0]], [y/1000 for y in bright_poly.exterior.xy[1]],
                     color='black', linewidth=2, label='SCA')
            ax2.set_xlabel('Easting [km]')
            ax2.set_title('Adjusted image')
            # histograms
            h1_nir = ax3.hist(nir.flatten(), color='purple', bins=100, alpha=0.5, label='NIR')
            h1_b = ax3.hist(b.flatten(), color='blue', bins=100, alpha=0.5, label='blue')
            h1_g = ax3.hist(g.flatten(), color='green', bins=100, alpha=0.5, label='green')
            h1_r = ax3.hist(r.flatten(), color='red', bins=100, alpha=0.5, label='red')
            ax3.set_xlabel('Surface reflectance')
            ax3.set_ylabel('Pixel counts')
            ax3.grid()
            ax3.legend(loc='right')
            ax3.set_ylim(0,np.max([h1_nir[0][1:], h1_g[0][1:], h1_r[0][1:], h1_b[0][1:]])+5000)
            h2_nir = ax4.hist(nir_adj.flatten(), color='purple', bins=100, alpha=0.5, label='NIR')
            h2_b = ax4.hist(b_adj.flatten(), color='blue', bins=100, alpha=0.5, label='blue')
            h2_g = ax4.hist(g_adj.flatten(), color='green', bins=100, alpha=0.5, label='green')
            h2_r = ax4.hist(r_adj.flatten(), color='red', bins=100, alpha=0.5, label='red')
            ax4.set_xlabel('Surface reflectance')
            ax4.set_ylim(0,np.max([h1_nir[0][1:], h1_g[0][1:], h1_r[0][1:], h1_b[0][1:]])+5000)
            ax4.grid()
            fig.suptitle(im_name[0:8]+' '+im_name[9:11]+':'+im_name[11:13]+':'+im_name[13:15]+', Inst: '+inst)
            fig.tight_layout()
            plt.show() 

            # save adjusted raster to file
            if save_images==True:
                # file name
                fn = im_name[0:-4]+'_'+inst+'_adj.tif'
                # metadata
                out_meta = im.meta.copy()
                out_meta.update({'driver':'GTiff',
                                 'width':b_adj.shape[1],
                                 'height':b_adj.shape[0],
                                 'count':4,
                                 'dtype':'float64',
                                 'crs':im.crs, 
                                 'transform':im.transform})
                # write to file
                with rio.open(out_path+fn, mode='w',**out_meta) as dst:
                    dst.write_band(1,b_adj)
                    dst.write_band(2,g_adj)
                    dst.write_band(3,r_adj)
                    dst.write_band(4,nir_adj)
                print('adjusted image saved to file')

            # save output figure to file
            if save_figures==True:
                # file name
                fn = site_ID+'_'+im_name[0:8]+'_adj.png'
                # save
                fig.savefig(figures_out_path+fn, dpi=200, facecolor='white', edgecolor='none')
                print('figure saved to file')
                
        else:
            n_count+=1 # increase counter for UN-adjusted images
        
    else:
        n_count+=1 # increase counter for UN-adjusted images
    
print('Number of images:',len(im_names))
print('Number of adjusted images:', y_count)
print('Number of UN-adjusted/filtered out images:', n_count)

In [None]:
# Define image class and list of images
# class image:
#     def __init__(self, im_name, x, y, b, g, r, nir):
#         self.im_name = im_name
#         self.x = x
#         self.y = y
#         self.b = b
#         self.g = g
#         self.r = r
#         self.nir = nir
#     def show(self):
#         print('Image file name:',self.im_name)
# im_list = [] # list for saving image info


    # save image info
#     newImage = image(im_name, x, y, b, g, r, nir)
#     im_list.append(newImage)