In [450]:
import xarray as xr
import rioxarray
import numpy as np
import matplotlib.pyplot as plt
import rasterio as rio
import geopandas as gpd
from glob import glob
import os
from matplotlib.colors import ListedColormap, BoundaryNorm
from utils import draw_legend
from config import Config
import re
from datetime import datetime
import sys
from rasterio.enums import Resampling
import pandas as pd

In [451]:
import warnings
warnings.filterwarnings('ignore')
config = Config()
dir_separator = config.dir_sep

In [452]:
#25993 landsat has many dnbr remaining, not necessary but if needed can re-run it
fire_id = '9844'
satellite = 'Sentinel'
image_dir = f'/home/aramakrishnan/Documents/Firedpy/Fire_events/{fire_id}/{satellite}/'
if satellite == 'Sentinel':
    file_address = glob(os.path.join(image_dir, '*.SAFE'))
else:
    file_address = glob(os.path.join(image_dir, '*_T[12]'))
shp_filepath = '/home/aramakrishnan/Documents/Firedpy/Fire_events/selected_events.shp'

In [453]:
file_address

['/home/aramakrishnan/Documents/Firedpy/Fire_events/9844/Sentinel/S2A_MSIL2A_20210130T151701_N0500_R125_T18NZJ_20230608T061259.SAFE',
 '/home/aramakrishnan/Documents/Firedpy/Fire_events/9844/Sentinel/S2A_MSIL2A_20210301T151701_N0500_R125_T19NBE_20230514T042315.SAFE',
 '/home/aramakrishnan/Documents/Firedpy/Fire_events/9844/Sentinel/S2B_MSIL2A_20210115T151709_N0500_R125_T19NBD_20230611T085846.SAFE',
 '/home/aramakrishnan/Documents/Firedpy/Fire_events/9844/Sentinel/S2B_MSIL2A_20210115T151709_N0500_R125_T18NZK_20230611T085846.SAFE',
 '/home/aramakrishnan/Documents/Firedpy/Fire_events/9844/Sentinel/S2B_MSIL2A_20210105T151709_N0500_R125_T19NBE_20230323T055111.SAFE',
 '/home/aramakrishnan/Documents/Firedpy/Fire_events/9844/Sentinel/S2A_MSIL2A_20201211T151701_N9999_R125_T18NZK_20230317T111349.SAFE',
 '/home/aramakrishnan/Documents/Firedpy/Fire_events/9844/Sentinel/S2B_MSIL2A_20210105T151709_N0500_R125_T18NZJ_20230323T055111.SAFE',
 '/home/aramakrishnan/Documents/Firedpy/Fire_events/9844/Senti

In [454]:
# Code to delete unnecessary files and folders
# Directory path where files are located

def clean_up_data_folder():
    directory = f'/home/aramakrishnan/Documents/Firedpy/Fire_events/{fire_id}/{satellite}'
    
    # Regular expression patterns to match file names
    if satellite == 'Sentinel':
        keep_file_patterns = [
            r'^.*_B12_20m\.jp2$',
            r'^.*_B8A_20m\.jp2$',
            r'^.*_B0[2348]_10m\.jp2$',
            r'^.*Sentinel_footprints.geojson',
            r'^.*MSK_CLASSI_B00.jp2$'
        ]
    else:
        keep_file_patterns = [
            r'^.*_B5\.TIF$',
            r'^.*_B7\.TIF$',
            r'^.*_B4\.TIF$',
            r'^.*_B3\.TIF$',
            r'^.*_B2\.TIF$',
        ]
    
    # Compile the regular expression patterns
    patterns = [re.compile(pattern) for pattern in keep_file_patterns]
    
    # Iteratively search for and process files in subdirectories
    for root, _, files in os.walk(directory):
        for filename in files:
            file_path = os.path.join(root, filename)
            # Check if the file name matches any of the regular expression patterns
            if any(pattern.match(filename) for pattern in patterns):
                continue
            else:
                try:
                    # Delete the file
                    os.remove(file_path)
                    print(f"Deleted: {file_path}")
                except Exception as e:
                    print(f"Error deleting {file_path}: {e}")
        
    # Delete empty subdirectories
    for root, dirs, files in os.walk(directory):
        for dir in dirs:
            dir_path = os.path.join(root, dir)
            # Check if the directory is empty
            if not os.listdir(dir_path):
                try:
                    # Delete the empty directory
                    os.rmdir(dir_path)
                    print(f"Deleted empty directory: {dir_path}")
                except Exception as e:
                    print(f"Error deleting directory {dir_path}: {e}")

In [455]:
for i in range(10):
    clean_up_data_folder()

In [456]:
fire_events = gpd.read_file(shp_filepath)
fire_event_ids = list(fire_events['id'])
fire_event_ids

['9844',
 '11862',
 '12887',
 '202',
 '280',
 '5072',
 '5532',
 '7123',
 '25993',
 '47422',
 '2282',
 '2614',
 '4413',
 '5587',
 '5615',
 '5950',
 '2819',
 '5447',
 '7792',
 '8676']

In [457]:
# #TO DELETE ALL UNNECESSARY FILES IN ONE GO
# for fid in fire_event_ids:
#    fire_id=fid
#    for sat in ['Sentinel', 'Landsat']:
#        satellite=sat
#        for i in range(10):
#            try:
#                clean_up_data_folder()
#            except:
#               continue

In [458]:
fire_events[fire_events['fid_1'] == fire_id]

Unnamed: 0,fid_1,id,ig_date,ig_day,ig_month,ig_year,last_date,event_dur,tot_pix,tot_ar_km2,...,lc_mode,lc_name,lc_desc,lc_type,tot_perim,layer,path,main_clim,Koppen_c,geometry
0,9844,9844,2020-12-18,353,12,2020,2021-03-03,76,2418,519.044672,...,255,Unclassified,Has not received a map label because of missin...,IGBP global vegetation classification scheme,465205.967394,h10_LC_selected,E:\Projects\UCB\FiredPy\firedpy\proj\outputs\s...,Tropical,Am,"MULTIPOLYGON (((-72.03003 3.38749, -72.03420 3..."


In [459]:
def get_fid_dates(fire_id):
    # Fetch start and end dates for the fire event
    fires = gpd.read_file(shp_filepath)
    if not str(fire_id) in fires['id'].values:
     print("invalid fire ID provided! Make sure the ID exists in the fire events shapefile")
     sys.exit()
    fire_event = fires[fires['id'] == str(fire_id)]
    start_date = datetime.strptime(fire_event['ig_date'].values[0], "%Y-%m-%d")
    end_date = datetime.strptime(fire_event['last_date'].values[0], "%Y-%m-%d")
    return start_date, end_date

In [460]:
start_date, end_date = get_fid_dates(fire_id)

In [461]:
start_date

datetime.datetime(2020, 12, 18, 0, 0)

In [462]:
end_date

datetime.datetime(2021, 3, 3, 0, 0)

In [463]:
def extract_date_from_path_sentinel(file_path):
    ''' Extracts the acquisition datetime from the Sentinel file path. '''
    
    date_str = file_path.split('/')[-1].split('_')[1]
    return datetime.strptime(date_str, "%Y%m%dT%H%M%S")

In [464]:
def extract_date_from_path_landsat(file_path):
    ''' Extracts the acquisition datetime from the Landsat file path. '''
    
    date_str = file_path.split('/')[-1].split('_')[3]
    return datetime.strptime(date_str, "%Y%m%d")

In [465]:
def filter_paths_by_date(start_date, end_date, file_paths):
    ''' Filter image paths by datetime into pre-fire and post-fire. '''
    
    pre_fire_paths = []
    post_fire_paths = []
    in_fire_paths = []

    for path in file_paths:
        if satellite == 'Sentinel':
            date_in_path = extract_date_from_path_sentinel(path)
        else:
            date_in_path = extract_date_from_path_landsat(path)

        if date_in_path < start_date:
            pre_fire_paths.append(path)
        elif date_in_path > end_date:
            post_fire_paths.append(path)
        else:
            in_fire_paths.append(path)

    return pre_fire_paths, post_fire_paths, in_fire_paths

In [466]:
def extract_file_path_components_sentinel(path1, path2):
    ''' Extracts file name components from Sentinel file path. '''
    
    satellite_mission = re.search(r'Sentinel/(.*?)_', path1).group(1)
    processing_level = re.search(r'GRANULE/(.*?)_', path1).group(1).split('_')[0]
    location_id = re.search(r'R\d+m/(.*?)_', path1).group(1)
    date1 = re.search(rf'/{location_id}_(.*?)_B', path1).group(1)
    date2 = re.search(rf'/{location_id}_(.*?)_B', path2).group(1)
    components = (satellite_mission, processing_level, location_id, date1, date2)

    return components

In [467]:
def extract_file_path_components_landsat(path1, path2):
    ''' Extracts file name components from Landsat file path. '''
    
    satellite_mission = re.search(r'Landsat/(.*?)_', path1).group(1)
    processing_level = path1.split('/')[-1].split('_')[1]
    location_id = path1.split('/')[-1].split('_')[2]
    date1 = path1.split('/')[-1].split('_')[3]
    date2 = path2.split('/')[-1].split('_')[3]
    components = (satellite_mission, processing_level, location_id, date1, date2)

    return components

In [468]:
def make_file_paths(file_path_components, resolution='10m'):
    ''' Creates dNBR and mask output file paths. '''
    
    satellite_mission = file_path_components[0]
    process_level = file_path_components[1]
    location_id = file_path_components[2]

    img1_acquired_date = file_path_components[3]
    img2_acquired_date = file_path_components[4]
    
    dNBR_path = f'{satellite_mission}_{process_level}_{location_id}_{fire_id}_{img1_acquired_date}_{img2_acquired_date}_{resolution}_DNBR.TIF'
    mask_path = f'{satellite_mission}_{process_level}_{location_id}_{fire_id}_{img1_acquired_date}_{img2_acquired_date}_{resolution}_MASK.TIF'
    
    file_paths = [dNBR_path, mask_path]

    return file_paths

In [469]:
def make_rgb_path(file_path, satellite, resolution='10m'):
    ''' Creates RGB output file path. '''
    
    if satellite == 'Sentinel':
        satellite_mission = re.search(r'Sentinel/(.*?)_', file_path).group(1)
        process_level = re.search(r'GRANULE/(.*?)_', file_path).group(1).split('_')[0]
        location_id = re.search(r'R\d+m/(.*?)_', file_path).group(1)
        acquired_date = re.search(rf'/{location_id}_(.*?)_B', file_path).group(1)
    
    else:
        satellite_mission = re.search(r'Landsat/(.*?)_', file_path).group(1)
        process_level = file_path.split('/')[-1].split('_')[1]
        location_id = file_path.split('/')[-1].split('_')[2]
        acquired_date = file_path.split('/')[-1].split('_')[3]
    
    rgb_path = f'{satellite_mission}_{process_level}_{location_id}_{fire_id}_{acquired_date}_{resolution}_RGB.TIF'
    
    return rgb_path

In [470]:
image_DIRs = []

if satellite == 'Sentinel':
    for file in file_address:
        granule_address = glob(os.path.join(file, "GRANULE/*"))[0]
        try:
            image_address = glob(os.path.join(granule_address, "IMG_DATA/R20m/*B12_20m.jp2"))[0]
        except:
            continue
        image_DIRs.append(image_address)

else:
    for file in file_address:
        try:
            image_address = glob(os.path.join(file, "*_B5.TIF"))[0]
        except:
            continue
        image_DIRs.append(image_address)

In [471]:
image_DIRs

['/home/aramakrishnan/Documents/Firedpy/Fire_events/9844/Sentinel/S2A_MSIL2A_20210130T151701_N0500_R125_T18NZJ_20230608T061259.SAFE/GRANULE/L2A_T18NZJ_A029294_20210130T151704/IMG_DATA/R20m/T18NZJ_20210130T151701_B12_20m.jp2',
 '/home/aramakrishnan/Documents/Firedpy/Fire_events/9844/Sentinel/S2A_MSIL2A_20210301T151701_N0500_R125_T19NBE_20230514T042315.SAFE/GRANULE/L2A_T19NBE_A029723_20210301T152109/IMG_DATA/R20m/T19NBE_20210301T151701_B12_20m.jp2',
 '/home/aramakrishnan/Documents/Firedpy/Fire_events/9844/Sentinel/S2B_MSIL2A_20210115T151709_N0500_R125_T19NBD_20230611T085846.SAFE/GRANULE/L2A_T19NBD_A020171_20210115T151703/IMG_DATA/R20m/T19NBD_20210115T151709_B12_20m.jp2',
 '/home/aramakrishnan/Documents/Firedpy/Fire_events/9844/Sentinel/S2B_MSIL2A_20210115T151709_N0500_R125_T18NZK_20230611T085846.SAFE/GRANULE/L2A_T18NZK_A020171_20210115T151703/IMG_DATA/R20m/T18NZK_20210115T151709_B12_20m.jp2',
 '/home/aramakrishnan/Documents/Firedpy/Fire_events/9844/Sentinel/S2B_MSIL2A_20210105T151709_N05

In [472]:
pre_fire_paths, post_fire_paths, in_fire_paths = filter_paths_by_date(start_date, end_date, image_DIRs)

In [473]:
pre_fire_paths

['/home/aramakrishnan/Documents/Firedpy/Fire_events/9844/Sentinel/S2A_MSIL2A_20201211T151701_N9999_R125_T18NZK_20230317T111349.SAFE/GRANULE/L2A_T18NZK_A028579_20201211T151700/IMG_DATA/R20m/T18NZK_20201211T151701_B12_20m.jp2',
 '/home/aramakrishnan/Documents/Firedpy/Fire_events/9844/Sentinel/S2A_MSIL2A_20201211T151701_N9999_R125_T19NBE_20230317T111437.SAFE/GRANULE/L2A_T19NBE_A028579_20201211T151700/IMG_DATA/R20m/T19NBE_20201211T151701_B12_20m.jp2']

In [474]:
post_fire_paths

['/home/aramakrishnan/Documents/Firedpy/Fire_events/9844/Sentinel/S2A_MSIL2A_20210321T151701_N9999_R125_T19NBD_20230331T190239.SAFE/GRANULE/L2A_T19NBD_A030009_20210321T151703/IMG_DATA/R20m/T19NBD_20210321T151701_B12_20m.jp2',
 '/home/aramakrishnan/Documents/Firedpy/Fire_events/9844/Sentinel/S2A_MSIL2A_20210321T151701_N9999_R125_T18NZJ_20230331T190221.SAFE/GRANULE/L2A_T18NZJ_A030009_20210321T151703/IMG_DATA/R20m/T18NZJ_20210321T151701_B12_20m.jp2',
 '/home/aramakrishnan/Documents/Firedpy/Fire_events/9844/Sentinel/S2A_MSIL2A_20210321T151701_N0500_R125_T18NZJ_20230606T151200.SAFE/GRANULE/L2A_T18NZJ_A030009_20210321T151703/IMG_DATA/R20m/T18NZJ_20210321T151701_B12_20m.jp2',
 '/home/aramakrishnan/Documents/Firedpy/Fire_events/9844/Sentinel/S2A_MSIL2A_20210321T151701_N0500_R125_T19NBD_20230606T151200.SAFE/GRANULE/L2A_T19NBD_A030009_20210321T151703/IMG_DATA/R20m/T19NBD_20210321T151701_B12_20m.jp2']

In [475]:
print(f'No. of pre-fire images: {len(pre_fire_paths)}')
print(f'No. of post-fire images: {len(post_fire_paths)}')
print(f'No. of during-fire images: {len(in_fire_paths)}')

No. of pre-fire images: 2
No. of post-fire images: 4
No. of during-fire images: 21


In [476]:
labels = ["pre_fire", "post_fire"]
nbr_raws = dict.fromkeys(labels, None)

In [477]:
fire_events = gpd.read_file(shp_filepath)

In [478]:
fire_events.crs

<Geographic 2D CRS: EPSG:4326>
Name: WGS 84
Axis Info [ellipsoidal]:
- Lat[north]: Geodetic latitude (degree)
- Lon[east]: Geodetic longitude (degree)
Area of Use:
- name: World.
- bounds: (-180.0, -90.0, 180.0, 90.0)
Datum: World Geodetic System 1984 ensemble
- Ellipsoid: WGS 84
- Prime Meridian: Greenwich

In [479]:
mask_geom = fire_events[fire_events['id']==fire_id]['geometry'].values

In [480]:
MASK_IMAGES = False

In [481]:
def resample_band(band, upscale_factor = 3):
    """
    Resamples the given raster band using bilinear interpolation.
    Written to resample landsat images from the original 30m resolution to the 
    10m resolution desired for dNBR.

    This function takes a raster band and a scaling factor, and resamples the band using bilinear interpolation.
    It returns the resampled band with the new dimensions calculated based on the provided scaling factor.
    The primary use of this function is to increase the resolution of a band, making it suitable for merging
    with other bands of higher resolution.

    Parameters:
        band (xarray.DataArray): The raster band to be resampled.
        upscale_factor (int, optional): The factor by which the band should be upscaled. Default is 3.

    Returns:
        xarray.DataArray: The resampled band with new dimensions.
    """
    new_width = int(band.rio.width * upscale_factor)
    new_height = int(band.rio.height * upscale_factor)
    #new_width, new_height = 10980, 10980
    
    band_upsampled = band.rio.reproject(
        band.rio.crs,
        shape=(new_height, new_width),
        resampling=Resampling.bilinear)
    
    return band_upsampled

In [482]:
def read_landsat_image(image_path, shift_dn):
    """
    Reads NIR and SWIR bands of a Landsat image, optionally resamples them, and loads them into memory.

    This function reads the NIR (band 5) and SWIR (band 7) bands of a Landsat image from the specified directory.
    It optionally resamples the bands to increase their resolution. The bands are then concatenated into a single
    xarray.DataArray object. If the shift_dn parameter is set to True, the function subtracts 1000 from the pixel values.
    The resulting multi-band image is returned for further processing.

    Parameters:
        image_dir (str): The directory where the Landsat image bands are located.
        ext (str): The file extension of the image bands (e.g., 'TIF', 'tif').
        resample (bool): Whether to resample the bands to increase their resolution.
        shift_dn (bool): Whether to subtract 1000 from the pixel values.

    Returns:
        xarray.DataArray: The multi-band image with the NIR and SWIR bands.
    """
    bands = []
    image_path = image_path[:image_path.find('_B')+2]
    
    if MASK_IMAGES:
        nir = rioxarray.open_rasterio(f"{image_path}5.TIF").rio.clip(mask_geom, fire_events.crs)
        swir = rioxarray.open_rasterio(f"{image_path}7.TIF").rio.clip(mask_geom, fire_events.crs)
    else:
        nir = rioxarray.open_rasterio(f"{image_path}5.TIF")
        swir = rioxarray.open_rasterio(f"{image_path}7.TIF")
    
    nir = resample_band(nir)
    swir = resample_band(swir)
    
    print(nir.shape)
    print(swir.shape)
    
    bands.append(nir)
    bands.append(swir)
    
    print('Resampled NIR shape: ', nir.shape)
    print("Resampled SWIR shape: ", swir.shape)

    del nir, swir
        
    
    image = xr.concat(bands, dim="band")
    image = image.assign_coords(dict(band=["nir", "swir"]))
    if shift_dn: #TODO: Lookup what this shift is for
        image = image - 1000
    
    print("Image SR loaded\n")

    return image

In [483]:
def read_image(image_path, shift_dn, is_sr=True):
    '''
    Loads the NIR and SWIR bands of Sentinel-2 images into memory.

    This function reads NIR and SWIR bands from Sentinel-2 images, resamples the SWIR band to
    match the resolution of the NIR band, and combines them into an xarray dataset. It also
    performs some preprocessing steps, such as clipping the images with a mask, applying a shift,
    and handling surface reflectance (SR) images.

    Parameters:
    -----------
    shift_dn : bool
        A flag indicating whether to apply a shift to the pixel values.
    is_sr : bool
        A flag indicating whether the image is surface reflectance (SR) or not.
    '''
    
    bands = []    
    
    swir_image_path = image_path[:image_path.find('_B')+2]
    nir_image_path = swir_image_path.replace('R20', 'R10')
    
    if MASK_IMAGES:
        nir = rioxarray.open_rasterio(f"{nir_image_path}08_10m.jp2").rio.clip(mask_geom, fire_events.crs)
        swir = rioxarray.open_rasterio(f"{swir_image_path}12_20m.jp2").rio.clip(mask_geom, fire_events.crs)
    else:
        nir = rioxarray.open_rasterio(f"{nir_image_path}08_10m.jp2")
        swir = rioxarray.open_rasterio(f"{swir_image_path}12_20m.jp2")
    
    if is_sr:
        nir = nir.where(lambda x: x > 0, other=np.nan)
        swir = swir.where(lambda x: x > 0, other=np.nan)
    bands.append(nir)

    swir_resampled = swir.rio.reproject_match(nir)
    bands.append(swir_resampled)
    print('NIR shape: ', nir.shape)
    print("Resampled SWIR shape: ", swir_resampled.shape)

    del nir, swir, swir_resampled
    
    image = xr.concat(bands, dim="band")
    image = image.assign_coords(dict(band=["nir", "swir"]))
    if shift_dn: #TODO: Lookup what this shift is for
        image = image - 1000
    
    print(f"Image SR loaded\n")

    return image

In [484]:
def open_RGB(image_path, satellite, shift_dn=False):
    '''
    Loads the red, green, and blue bands of a Sentinel-2 image to create an RGB image.

    This function reads the red (B4), green (B3), and blue (B2) bands from a Sentinel-2 image,
    combines them into an xarray dataset, and applies a shift to the pixel values if specified.
    The function can also clip the images with a mask if the global variable MASK_IMAGES is set to True.

    Parameters:
    -----------
    shift_dn : bool
        A flag indicating whether to apply a shift to the pixel values.
    '''

    if satellite == 'Sentinel':
        image_path = image_path[:image_path.find('_B')+2]
        image_path = image_path.replace('R20', 'R10')

        if MASK_IMAGES:
            rgb = [rioxarray.open_rasterio(f"{image_path}0{x}_10m.jp2").rio.clip(mask_geom, fire_events.crs) for x in [4,3,2]]
        else:
            rgb = [rioxarray.open_rasterio(f"{image_path}0{x}_10m.jp2") for x in [4,3,2]]
    
    else:
        image_path = image_path[:image_path.find('_B')+2]

        if MASK_IMAGES:
            rgb = [rioxarray.open_rasterio(f"{image_path}{x}.TIF").rio.clip(mask_geom, fire_events.crs) for x in [4,3,2]]
        else:
            rgb = [rioxarray.open_rasterio(f"{image_path}{x}.TIF") for x in [4,3,2]]
    
    rgb = xr.concat(rgb, dim="band")
    rgb = rgb.assign_coords(dict(band=["r","g","b"]))

    if shift_dn:
        rgb = rgb - 1000

    print('RGB loaded')
    
    return rgb

In [485]:
def get_NBRs(image_sr):
    '''
    Calculates the Normalized Burn Ratio (NBR) for the Surface Reflectance (SR) Sentinel-2 images.

    This function computes the NBR, which is used to identify burned areas in satellite images.
    The NBR is calculated using the Near-Infrared (NIR) and Short-Wave Infrared (SWIR) bands of the images.
    
    The formula for NBR is given by: 
        NBR = (NIR - SWIR) / (NIR + SWIR) = (B08 - B12) / (B08 + B12)
    '''
    NBR_SR = (image_sr.sel(band="nir") - image_sr.sel(band="swir"))/(image_sr.sel(band="nir") + image_sr.sel(band="swir"))
    
    return NBR_SR

In [486]:
def plot_dNBR_mask(dNBR, mask):
    '''Plots dNBR and mask images in a single plot. '''

    colors = ['darkorange', 'lightyellow', 'grey']
    severity_labels = ['Burned', 'Unburned', 'Out of range']
    class_bins = [0.5, 1.5, 2.5, 3.5]

    cmap = ListedColormap(colors)
    norm = BoundaryNorm(class_bins, len(colors))

    #cmap='RdBu_r' for dNBR if needed
    fig, axes = plt.subplots(1,2,figsize=(14,6))
    dNBR.plot.imshow(ax=axes[0])
    axes[0].set_title("dNBR")
    axes[0].set_xlabel('x coordinate of projection (metre)')
    axes[0].set_ylabel('y coordinate of projection (metre)')

    plot = mask.plot.imshow(cmap=cmap, norm=norm, add_colorbar=False, ax=axes[1])
    draw_legend(plot, titles=severity_labels)
    axes[1].set_title("Mask")
    axes[1].set_xlabel('x coordinate of projection (metre)')
    axes[1].set_ylabel('y coordinate of projection (metre)')
    plt.show()

In [487]:
def plot_pre_post_dNBR(pre_fire, post_fire, dNBR):
    ''' Plots pre-fire, post-fire and dNBR images in a single plot. '''
    
    fig, axes = plt.subplots(1,3,figsize=(26,8))
    pre_fire.plot.imshow(ax=axes[0])
    axes[0].set_title("NBR pre fire")
    axes[0].set_xlabel('x coordinate of projection (metre)')
    axes[0].set_ylabel('y coordinate of projection (metre)')
    post_fire.plot.imshow(ax=axes[1])
    axes[1].set_title("NBR post fire")
    axes[1].set_xlabel('x coordinate of projection (metre)')
    axes[1].set_ylabel('y coordinate of projection (metre)')
    dNBR.plot.imshow(ax=axes[2])
    axes[2].set_title("dNBR (pre-post) fire")
    axes[2].set_xlabel('x coordinate of projection (metre)')
    axes[2].set_ylabel('y coordinate of projection (metre)')
    plt.show()

In [488]:
def plot_mask(mask):
    ''' Plots the dNBR mask with 3 classes. '''
    
    colors = ['firebrick', 'wheat', 'grey']
    severity_labels = ['Burned', 'Unburned', 'Out of range']
    class_bins = [0.5, 1.5, 2.5, 3.5]

    cmap = ListedColormap(colors)
    norm = BoundaryNorm(class_bins, len(colors))

    fig2, ax2 = plt.subplots(figsize=(7, 8))
    plot = mask.plot.imshow(cmap=cmap, norm=norm, add_colorbar=False, ax=ax2)
    draw_legend(plot, titles=severity_labels)
    ax2.set(title="Segmentation Mask")
    ax2.set_xlabel('x coordinate of projection (metre)')
    ax2.set_ylabel('y coordinate of projection (metre)')
    plt.show()

In [489]:
def plot_rgb_pre_post(pre_fire, post_fire):

    fig, axes = plt.subplots(1, 2, figsize=(26, 8))
    (pre_fire/2000).clip(0,1).plot.imshow(ax=axes[0])
    axes[0].set(title='pre-fire RGB')
    axes[0].set_xlabel('x coordinate of projection (metre)')
    axes[0].set_ylabel('y coordinate of projection (metre)')

    (post_fire/2000).clip(0,1).plot.imshow(ax=axes[1])
    axes[1].set(title='post-fire RGB')
    axes[1].set_xlabel('x coordinate of projection (metre)')
    axes[1].set_ylabel('y coordinate of projection (metre)')
    plt.show()

In [490]:
def plot_rgb(rgb, title):
    ''' Plots the rgb image. '''
    title = f'{title} RGB image'

    fig, ax = plt.subplots(figsize=(7, 8))
    (rgb/2000).clip(0,1).plot.imshow(ax=ax)
    ax.set(title=title)
    ax.set_xlabel('x coordinate of projection (metre)')
    ax.set_ylabel('y coordinate of projection (metre)')
    plt.show()

In [491]:
def apply_cloud_mask(img, img_path, rgb=False):
    ''' Applies a cloud mask. '''
    common_path = re.search(r'(.*?)/IMG_DATA/', img_path).group(1)
    cloud_path = common_path + '/QI_DATA/MSK_CLASSI_B00.jp2'

    cloud_mask = rioxarray.open_rasterio(cloud_path)
    
    if rgb==False:
        img = img.transpose('y', 'x')
    cloud_mask = cloud_mask.transpose('band', 'y', 'x')
    ds_cloud_resized = cloud_mask.rio.reproject_match(img)

    cloud_free_nbr = xr.where((ds_cloud_resized.sel(band=1) == 0) & (ds_cloud_resized.sel(band=2) == 0) & (ds_cloud_resized.sel(band=3) == 0), img, np.nan)
    
    return cloud_free_nbr

In [492]:
def get_dNBR_and_mask(pre_fire_path, post_fire_path, cloud_mask=True, plot_figs=False):
    ''' Loads NBR images, calculates dNBR and mask and optionally displays plots. '''

    #print('Loading pre-fire NBR...')
    if satellite == 'Sentinel':
        image_sr1 = read_image(pre_fire_path, False, True)
    else:
        image_sr1 = read_landsat_image(pre_fire_path, False)
    nbr_raws['pre_fire'] = get_NBRs(image_sr1)

    #print('Loading post-fire NBR...')
    if satellite == 'Sentinel':
        image_sr2 = read_image(post_fire_path, False, True)
    else:
        image_sr2 = read_landsat_image(post_fire_path, False)
    nbr_raws['post_fire'] = get_NBRs(image_sr2)

    if cloud_mask:
        print('Applying cloud mask...')
        nbr_raws['pre_fire'] = apply_cloud_mask(nbr_raws['pre_fire'], pre_fire_path)
        nbr_raws['post_fire'] = apply_cloud_mask(nbr_raws['post_fire'], post_fire_path)
    
    #print('Calculating dNBR...')
    dNBR_post_raw = nbr_raws['pre_fire'] - nbr_raws['post_fire']
    dNBR_post_raw = dNBR_post_raw.copy(deep=True, data=np.clip(dNBR_post_raw, -1., 1.))

    
    #print('Calculating mask...')
    # Define threshold
    burned_threshold = 0.1

    # Get mask values
    dNBR_post_raw_class_ma = dNBR_post_raw.copy(deep=True, data=xr.where(dNBR_post_raw >= burned_threshold, 1., xr.where(dNBR_post_raw < burned_threshold, 2., 3.)))

    if plot_figs:
        print('Plotting images...')
        plot_pre_post_dNBR(nbr_raws['pre_fire'], nbr_raws['post_fire'], dNBR_post_raw)
        plot_mask(dNBR_post_raw_class_ma)

    return dNBR_post_raw, dNBR_post_raw_class_ma

In [493]:
def download_dNBR_scenes(path_list1, path_list2, dNBR_dir, mask_dir, cloud_mask=False, plot_figs=False, plot_RGB=False):
    n_outputs = 0
    for path1 in path_list1:
        if satellite == 'Sentinel':
            location_id1 = path1.rsplit('/', 1)[-1].split('_', 1)[0]
        else:
            location_id1 = path1.split('/')[-1].split('_')[2]

        for path2 in path_list2:
            if satellite == 'Sentinel':
                location_id2 = path2.rsplit('/', 1)[-1].split('_', 1)[0]
            else:
                location_id2 = path2.split('/')[-1].split('_')[2]

            if location_id1 == location_id2:
                if satellite == 'Sentinel':
                    path_components = extract_file_path_components_sentinel(path1, path2)
                else:
                    path_components = extract_file_path_components_landsat(path1, path2)

                file_paths = make_file_paths(path_components)
                dNBR_path, mask_path = file_paths[0], file_paths[1]

                rgb_path1 = make_rgb_path(path1, satellite)
                rgb_path2 = make_rgb_path(path2, satellite)
                
                if dNBR_path in os.listdir(dNBR_dir) and mask_path in os.listdir(mask_dir):
                    #print('dNBR scene already exists in directory! Skipping...')
                    continue
                dNBR, mask = get_dNBR_and_mask(path1, path2, cloud_mask=cloud_mask, plot_figs=plot_figs)
                #dNBR = dNBR.transpose('band', 'y', 'x')
                #mask = mask.transpose('band', 'y', 'x')

                dNBR = dNBR.astype('float32')
                mask = mask.astype('uint8')
                dNBR.rio.to_raster(os.path.join(dNBR_dir, dNBR_path))
                mask.rio.to_raster(os.path.join(mask_dir, mask_path))
                n_outputs += 1

                if plot_RGB:
                    rgb1 = open_RGB(path1)
                    rgb2 = open_RGB(path2)
                    rgb1 = rgb1.transpose('band', 'y', 'x')
                    rgb2 = rgb2.transpose('band', 'y', 'x')
                    plot_rgb_pre_post(rgb1, rgb2)
    
    return n_outputs

In [494]:
def download_RGB_scenes(path_list, dir, satellite, cloud_mask=False):
    for img_path in path_list:
        rgb_path = make_rgb_path(img_path, satellite=satellite)
        if rgb_path in os.listdir(dir):
                    #print('RGB scene already exists in directory! Skipping...')
                    continue
        try:
            rgb = open_RGB(img_path, satellite)
        except:
            continue
        rgb = rgb.transpose('band', 'y', 'x')
        if cloud_mask:
            rgb = apply_cloud_mask(rgb, img_path, rgb=True)
            rgb = rgb.transpose('band', 'y', 'x')
        rgb.rio.to_raster(os.path.join(dir, rgb_path))

In [495]:
# Folders to store output files

# dNBR_dir = '/Bhaltos/ASHWATH/Dataset/dNBR'
# mask_dir = '/Bhaltos/ASHWATH/Dataset/Masks'
# rgb_dir = '/Bhaltos/ASHWATH/Dataset/RGB'

dNBR_dir = f'/Bhaltos/ASHWATH/Dataset/dNBR'
mask_dir = f'/Bhaltos/ASHWATH/Dataset/Masks'
rgb_dir = f'/Bhaltos/ASHWATH/Dataset/RGB'

In [496]:
n_outputs = 0
get_dNBR_mask = True
get_RGB = False
plot_RGB = False
plot_figs = False
cloud_mask = False
br = 0

# Download RGB scenes
if get_RGB:
    download_RGB_scenes(pre_fire_paths, rgb_dir, satellite, cloud_mask)
    download_RGB_scenes(post_fire_paths, rgb_dir, satellite, cloud_mask)
    download_RGB_scenes(in_fire_paths, rgb_dir, satellite, cloud_mask)

# Download dNBR scenes and segmentation masks, optionally plotting them and/or their respective RGB scenes
if get_dNBR_mask:
    n_outputs += download_dNBR_scenes(pre_fire_paths, post_fire_paths, dNBR_dir, mask_dir, cloud_mask, plot_figs, plot_RGB)
    print('pre-fire with post-fire done.')
    n_outputs += download_dNBR_scenes(pre_fire_paths, in_fire_paths, dNBR_dir, mask_dir, cloud_mask, plot_figs, plot_RGB)
    print('pre-fire with in-fire done.')
    n_outputs += download_dNBR_scenes(in_fire_paths, post_fire_paths, dNBR_dir, mask_dir, cloud_mask, plot_figs, plot_RGB)
    print('in-fire with post-fire done.')
    print('Images saved.')

print(f'\nNumber of dNBR scenes: {n_outputs}')

pre-fire with post-fire done.
pre-fire with in-fire done.
in-fire with post-fire done.
Images saved.

Number of dNBR scenes: 0


In [497]:
# Code to fetch features of the fire event from the geojson file
#gpath ='/Users/ashwath/Documents/Firedpy/Fire_events/8676/Sentinel/Sentinel_footprints.geojson'
#
#with open(gpath) as f:
#    gj = geojson.load(f)
#features = gj['features'][0]
#features

In [498]:
# Code to generate list of uuids / product ids for the fire event
#j=0
#for i in gj['features']:
#    j+=1
#    print(i['properties']['uuid'])
#print(f'J len: {j}')

In [499]:
# # EXAMPLE of reading and plotting output files
# read_img_path = '/Bhaltos/ASHWATH/Dataset/dNBR/S2B_L2A_T12SXJ_7792_20200713T175909_20201006T180231_10m_DNBR.TIF'
# read_mask_path = '/Bhaltos/ASHWATH/Dataset/Masks/S2B_L2A_T12SXJ_7792_20200713T175909_20201006T180231_10m_MASK.TIF'
# #
# img = rioxarray.open_rasterio(read_img_path).squeeze()
# ma = rioxarray.open_rasterio(read_mask_path).squeeze()
# #
# print(img.shape)
# print(ma.shape)
# #
# plot_dNBR_mask(img, ma)

In [500]:
#read_rgb_dir = '/Users/ashwath/Documents/Firedpy/Dataset/RGB'
#for rgb_path in os.listdir(read_rgb_dir):
#    if rgb_path[-4:] != '.TIF':
#        continue
#    rgb = rioxarray.open_rasterio(os.path.join(read_rgb_dir, rgb_path))
#    plot_rgb(rgb, 'fire scene')