In [16]:
def get_raster(compare_type,run_id, granule_id, band, step, sr_key_format=None, file_format=None):
    tile   = granule_id.split(".")[2]
    sat_id = granule_id.split(".")[1]
    
    # Map Spectral Bands
    # Ignore Coastal aerosol and Cirrus bands
    sr_bands_l30 = ["B02","B03","B04","B05","B06","B07", "Fmask"]
    sr_bands_s30 = ["B02","B03","B04","B8A","B11","B12", "Fmask"]
    common_bands = ["B","G","R","NIR","SWIR1", "SWIR2", 'Fmask']

    sr_bands_common_l30 = dict(zip(common_bands,sr_bands_l30))
    sr_bands_common_s30 = dict(zip(common_bands,sr_bands_s30))

    
    # L30 intermediate outputs
    post_AC_L30   = "gridded"      ## After atmosphetic correction
    post_BRDF_L30 = granule_id     ## After BRDF normalization
    post_BP_L30   = post_BRDF_L30  ## After Band pass adjustment
    inter_L30     = [post_AC_L30, post_BRDF_L30, post_BP_L30]

    ## S30 intermediate outputs
    post_AC_S30   = "resample30m"      ## After atmospheric correction
    post_BRDF_S30 = "nbarIntermediate" ## After BRDF normalization
    post_BP_S30   = granule_id         ## After band pass adjustment
    inter_S30     = [post_AC_S30, post_BRDF_S30, post_BP_S30]

    ## common intermediate outputs
    common_fname  = ["post_AC", "post_BRDF", "post_BP"]

    intermediate_common_l30 = dict(zip(common_fname,inter_L30))
    intermediate_common_s30 = dict(zip(common_fname,inter_S30))

    if sat_id == "L30":
        try:
            b = sr_bands_common_l30[band] 
        except:
            pass
        s = intermediate_common_l30[step]
    elif sat_id == "S30":
        try:
            b = sr_bands_common_s30[band]
        except:
            pass
        s = intermediate_common_s30[step]
        
    if file_format is None:
        file_format = '.tif'
    else:
        file_format = file_format

    if band == '':
        product_key = f"{s}"
    else:
        if 'sat' in compare_type:
            product_key = f"{s}.{b}"
        else:
            product_key = f"{s}+{b}"
    
    if sr_key_format is None:
        sr_key = f"{bucket}/{run_id}/{granule_id}/"+product_key+file_format
    else:
        sr_key = sr_key_format+product_key+file_format
    
    return sr_key



In [57]:
def add_regression_line(x, y, ax):
    """
        Calculate and plot a linear regression line 
    """
    # remove nans
    mask = ~np.isnan(x) & ~np.isnan(y)
    slope, intercept, r, p, se = scipy.stats.linregress(x[mask], y[mask])
    line = f'y={intercept:.3f}+{slope:.3f}x,\n$R^2$={r*r:.5f}'
    ax.plot(x, intercept + slope * x, label=line, color='k')
    ax.legend(loc=2, frameon=False)
    


def plot_scatter(sr_ds1, sr_ds2, p1_label, p2_label, ax):
    """
    """
    #fig, ax = plt.subplots(figsize=(6, 4))

    ## Set common axes limits
    dat_min = np.nanmin([np.nanmin(sr_ds1.squeeze()), np.nanmin(sr_ds2.squeeze())])
    dat_max = np.nanmax([np.nanmax(sr_ds1.squeeze()), np.nanmax(sr_ds2.squeeze())])
    

    density_hist = ax.hist2d(sr_ds1.squeeze().values.ravel(), sr_ds2.squeeze().values.ravel(),
                            range=[[dat_min, dat_max], [dat_min, dat_max]],
                            cmin = 1,
                            norm=mpl_colors.LogNorm(vmin=1,vmax=5E6),
                            cmap = 'viridis',
                            bins=200)
    # plot ideal correlation line
    line = mlines.Line2D([dat_min, dat_max], [dat_min, dat_max], color='red')
    ax.add_line(line)
    # plot regression line
    add_regression_line(sr_ds1.squeeze().values.ravel(), sr_ds2.squeeze().values.ravel(), ax)
    plt.colorbar(density_hist[3], ax=ax, label='Number of points per pixel')  
    ax.set_xlabel(p1_label)
    ax.set_ylabel(p2_label)
    ax.set_xlim(dat_min, dat_max)
    ax.set_ylim(dat_min, dat_max)

def get_stats_histogram(ds):
    
    #fig, ax = plt.subplots(figsize=(6, 4))
    hist_99_per = ds.quantile(q=.99)
    hist_mean = ds.mean()
    sub_title   = f", \n 99th per: {hist_99_per:.3f}, mean: {hist_mean:.3f}"

    return sub_title
    
def plot_histogram(sr_ds1, sr_ds2, p1_label, p2_label, ax):
    """
    """
    
    
    hist1  = ax.hist(sr_ds1.to_numpy().ravel(), 
                        100, 
                        histtype='bar', 
                        label=p1_label+get_stats_histogram(sr_ds1), 
                        alpha= 0.5)
    hist2  = ax.hist(sr_ds2.to_numpy().ravel(), 
                        100, 
                        histtype='bar', 
                        label=p2_label+get_stats_histogram(sr_ds2), 
                        alpha= 0.5)
    ax.legend(loc='upper right',  frameon=False)    

    ax.set_xlabel("Surface Reflectance")
    ax.set_ylabel("Pixel Count")
    ax.set_aspect('auto')
    
        
def plot_reflectance_difference(sr_ds1, sr_ds2, p1_label, p2_label, ax):
    """
    """
    #fig, ax = plt.subplots(figsize=(8, 6))
    
    def get_reflect_diff(dat1, dat2, abs_diff=False):
        """
            Calculate the difference in surface reflectance
            between two rasters
        """
        if len(dat1) == len(dat2):
            if abs_diff:
                diff = np.abs(dat1) - np.abs(dat2)
            else:
                diff = dat1 - dat2
        else:
            raise ValueError("data unequal size")

        return diff
    diff = get_reflect_diff(sr_ds1.squeeze(), 
                                     sr_ds2.squeeze())
    vmin = -0.05
    vmax = 0.05
    im = ax.imshow(diff, 
                    aspect="auto",
                    vmin=vmin, vmax=vmax,cmap="RdBu")
    plt.colorbar(im, ax=ax, cmap="RdBu", label= f"{p1_label} - {p2_label}")
    ax.set_xticklabels([]), ax.set_yticklabels([])
    ax.set_title("Surface Reflectance Difference")
    

def combine_plot(sr_ds1,sr_ds2,processing_step1,band_selector1,processing_step2,band_selector2,axes):
    p1_label = f"Product 1 {processing_step1} {band_selector1}"
    p2_label = f"Product 2 {processing_step2} {band_selector2}"


    index=0
    for ax in axes:
        # Plot scatter
        if index == 0:
            plot_histogram(sr_ds1, sr_ds2, p1_label, p2_label,ax)
        # Plot histogram
        if index == 1:
            plot_scatter(sr_ds1, sr_ds2, p1_label, p2_label,ax)
        # Plot reflectance difference
        if index == 2:
            plot_reflectance_difference(sr_ds1, sr_ds2, p1_label, p2_label,ax)

        index=index+1


# Time Series Visualization

In [1]:
def apply_SR_scale_factor(data):
    """
        Surface reflectance have scale factors and offset.
        Apply this to each pixel the data.
        Note: Collection 1 and Collection 2 have different 
              scale factors. Always get this from metadata
        See:
        https://www.usgs.gov/faqs/how-do-i-use-scale-factor-landsat-level-2-science-products#:~:text=Landsat%20Collection%202%20surface%20reflectance,offset%20of%20%2D0.2%20per%20pixel.
    """
    sf = data.attrs["scale_factor"] # scale factor
    ao = data.attrs["add_offset"]   # additional offset
    
    data.values = (data.values * sf) + ao    
    #return data

# Plot the bands
def plot_bands_time_series(ax, band_data, cbar_limits=None, sat_id=None, point=False,x=False,y=False):
        
        if cbar_limits is not None:
            im = ax.imshow(band_data, 
                           vmin = cbar_limits[0], 
                           vmax = cbar_limits[1],
                          cmap = 'RdBu'
                          )
            
        else:
            vmin = np.nanmin(band_data)
            vmax = np.nanmax(band_data)
            
            #im = ax.imshow(band_data,
             #             cmap = 'RdBu'
                           #vmin = vmin, 
                           #vmax = vmax
                  #  )
            
            im = band_data.plot(ax=ax)
                          
                
        if point is True:        
            ax.plot(x,y, 'bo', markersize=10)
        sub_title = sat_id
        ax.set_title(sub_title)
        ax.set_aspect('auto')


IndentationError: unexpected indent (87660770.py, line 34)