In [1]:
def plot_SR(ds,title_plot,band=None,rgb=False,multiband=False):
    '''
    Function to plot RGB or a certain band for visualization
    
    
    Example: 
    '''
     
    if multiband is False:
    
        if rgb is True:
            fig, ax = plt.subplots()
            ds[['R','G','B']].squeeze().to_array().plot.imshow(ax=ax, 
                                                                      robust=True)
            
            ax.set_title(title_plot)
            ax.set_axis_off()        
            
            
        else:
            if band is None:
                ds_plot = ds[0,:,:]
            else:
                ds_plot = ds[band][0,:,:]
                
            fig, ax = plt.subplots()
            vmin = np.nanmin(ds_plot)
            vmax = np.nanmax(ds_plot)
            ax.imshow(ds_plot,aspect='auto',
                          vmin=vmin,
                          vmax=vmax)
            
            ax.set_title(title_plot)
            ax.set_axis_off() 

    else:
        fig, axes = plt.subplots(nrows=2,ncols=4,figsize=(17,7))
        axes=axes.flatten()
        i = 0
        for band in list(ds.data_vars):
            ax = axes[i]
            vmin = np.nanmin(ds[band])
            vmax = np.nanmax(ds[band])
            img = ax.imshow(ds[band][0,:,:],
                      aspect='auto',
                      label='reflectance',
                      vmin=vmin,
                      vmax=vmax,
                     cmap='RdBu'
                     )
            cb = plt.colorbar(img)
            cb.set_label('Reflectance')
            
            ax.set_title('Band: '+band)
            ax.set_axis_off() 
            ax.set_aspect('auto')
            i=i+1
        


def plot_mask(ds,mask_array,granule_id,title,plot_rgb=False):

    if rgb is False:
        vmin = np.nanmin(mask_array.squeeze())
        vmax = np.nanmax(mask_array.squeeze())
        
        fig, ax = plt.subplots()
        im = ax.imshow(mask_array.squeeze(),
                       cmap=plt.cm.get_cmap('tab20b', 2),
                      vmin=vmin,vmax=vmax
                      )
        ax.set_axis_off()
        ax.set_title(granule_id +'\n'+ title)

    else:
        mask_array_copy = mask_array.copy()
        mask_array_copy = mask_array_copy.values
        mask_array_copy[~np.isnan(mask_array_copy)] = 1
        
        mask_array_copy_2 = mask_array.copy()
        mask_array_copy_2_values = mask_array_copy_2.values
        mask_array_copy_2_values[np.isnan(mask_array_copy_2_values)] = 1111
        mask_array_copy_2_values[mask_array_copy_2_values!=1111] = np.nan
        vmin = np.nanmin(mask_array_copy_2_values.squeeze())
        vmax = np.nanmax(mask_array_copy_2_values.squeeze())
        
    
        ds_copy = ds.copy()
            
        for item in ['R','G','B']:
            ds_copy[item] = ds_copy[item]*mask_array_copy
            
        fig, axes = plt.subplots(ncols=2,figsize=(12,5))

        ds_list = [ds,ds_copy]
        title_plot = ['before masking','after masking']
        index=0
        for ax in axes:
            ds_list[index][['R','G','B']].squeeze().to_array().plot.imshow(ax=ax, 
                                                                      robust=True)

            if index == 1:
                mask_array_copy_2.squeeze().plot.imshow(ax=ax, robust=True, add_colorbar=False)

            ax.set_title(title_plot[index]+' '+title)
            ax.set_axis_off()     
            index=index+1
        

def plot_histogram(ds, 
                   band = None,
                   num_bins=200, 
                   sat_id = None, 
                   show_stat = True, 
                   histtype='bar', 
                   alpha=1,
                   multiband=False
                   ):
    '''
    Function to plot histogram of SR for individual bands
    
    '''
    if multiband is False:
    
        fig, ax = plt.subplots()
        if band is None:
            ds_plot = ds
            long_name = ''
        else:
            ds_plot = ds[band]
            long_name = f"{ds_plot.attrs['long_name']}"
        
        label = f"{sat_id}, "+long_name

            
        hist  = ax.hist(ds_plot.to_numpy().ravel(), 
                            num_bins, 
                            histtype=histtype, 
                            label=label, 
                            alpha= alpha)
        
        hist_99_per = ds_plot.quantile(q=.99)
        hist_mean = ds_plot.mean()
        sub_title   = f"{sat_id}, "+long_name+f", \n 99th per: {hist_99_per:.3f}, mean: {hist_mean:.3f}"
        if show_stat:
            ax.set_title(sub_title)
            
        ax.set_xlabel("Surface Reflectance")
        ax.set_ylabel("Pixel Count")
        ax.set_aspect('auto')
        
    else:
        fig, axes = plt.subplots(2,4, figsize=(17,7))
        axes = axes.flatten()
        for i, band in enumerate(ds.data_vars):
            ax = axes[i]
            hist_99_per = ds[band].quantile(q=.99)
            hist_mean = ds[band].mean()
            sub_title = f"{band}, 99th per: {hist_99_per:1.3f}, mean: {hist_mean:1.3f}"
                
            ax.hist(ds[band].to_numpy().ravel(), 
                                    num_bins, 
                                    histtype='step', 
                                    label=sat_id+', '+band, 
                                    alpha= alpha)
                
                
            ax.legend()
            ax.set_xlabel("Surface Reflectance")
            ax.set_ylabel("Pixel Count")
            ax.set_aspect('auto')
            ax.set_title(sub_title)
        fig.tight_layout()
        

    
    
    
    
    


