In [None]:
import ee
import folium
import numpy as np
import time
import random
import geemap
import geopandas as gpd

import warnings
warnings.filterwarnings('ignore')

import os
os.chdir("../../")

import sys
sys.path.insert(0, "src")

from etl import *
from viz import *

In [None]:
ee.Authenticate()

In [None]:
ee.Initialize()

In [None]:
# Load bbox+fire geometries and push to EE as FeatureCollection
bbox_df = gpd.read_file("data/unburned/bbox.shp")
fireBounds_df = gpd.read_file("data/burned/fireBounds.shp")

bbox_EE = geemap.gdf_to_ee(bbox_df, geodesic=True)
fireBounds_EE = geemap.gdf_to_ee(fireBounds_df, geodesic=True)

In [None]:
# can check if data is successfully pushed to EE
# print(bbox_EE.size().getInfo(),
#       fireBounds_EE.size().getInfo())

In [None]:
firePts, bboxPts = ee.List([]), ee.List([])

for fireName, size in fireBounds_df[["FIRE_NAME", "GIS_ACRES"]].values:
    if size < 40000:
        gridScale = 80
    elif size < 90000:
        gridScale = 100
    elif size < 150000:
        gridScale = 110
    else:
        gridScale = 125

    firePts = firePts.add(genSamplePoints(collection=fireBounds_EE,
                                          fireName=fireName,
                                          gridScale=gridScale-15,
                                          pointScale=1/8,
                                          seed=random.randint(0, 1e6)))

    bboxPts = bboxPts.add(genSamplePoints(collection=bbox_EE,
                                          fireName=fireName,
                                          gridScale=gridScale+10,
                                          pointScale=1/8,
                                          seed=random.randint(0, 1e6)))    
    
# firePts = ee.FeatureCollection(firePts)#.flatten()
# bboxPts = ee.FeatureCollection(bboxPts)#.flatten()

In [None]:
# Extract points from EE
startTime = time.time()

firePts_df = formatToGPD(list(fireBounds_df["FIRE_NAME"]), firePts.getInfo())   # 167k points
bboxPts_df = formatToGPD(list(fireBounds_df["FIRE_NAME"]), bboxPts.getInfo())   # 165k points

print("Point Sampling Runtime: {} minutes".format(np.round((time.time()-startTime)/60, 3)))

In [None]:
bounds_df = gpd.read_file("data/bounds/bounds.shp")

In [None]:
# Loads images and reduces over sample points in EE

startTime = time.time()
imageLst = ee.List([])
fireSampleData, bboxSampleData = [], []

for fireName, preFireDate, postFireDate, geometry in bounds_df[["FIRE_NAME", "pre-date", "post-date", "geometry"]].values:
    t1 = time.time()
    
#     firePts.filter(ee.Filter.eq("FIRE_NAME", fireName))
#     bboxPts.filter(ee.Filter.eq("FIRE_NAME", fireName))

    points = ee.List([geemap.gdf_to_ee(firePts_df[firePts_df["FIRE_NAME"]==fireName]),
                      geemap.gdf_to_ee(bboxPts_df[bboxPts_df["FIRE_NAME"]==fireName])])
    
    # Converts shapely polygon to EE rectangle     
    geometry = ee.Geometry.Rectangle(list(geometry.bounds))
  
    # Loads pre+post fire Landsat 8 images
    preFireImage = mosaicByDate(ee.ImageCollection("LANDSAT/LC08/C02/T1_L2"
                                 ).filterBounds(geometry
                                 ).filterDate(preFireDate,
                                              ee.Date(preFireDate).advance(1, "day")))

    postFireImage = mosaicByDate(ee.ImageCollection("LANDSAT/LC08/C02/T1_L2"
                                  ).filterBounds(geometry
                                  ).filterDate(postFireDate,
                                               ee.Date(postFireDate).advance(1, "day")))
    
    preFireImage, postFireImage = ee.Image(preFireImage.get(0)), ee.Image(postFireImage.get(0))       
    combined = prepImage(preFireImage, postFireImage, fireName, geometry, postFireDate)

    imageLst = imageLst.add(combined)

    # apply reducer and save results
    reducedPts = points.map(lambda x: pointReducer(image=combined,
                                                   collection=x,
                                                   scale=30,
                                                   reducer=ee.Reducer.mean()))
    # Pulls sample data out of EE
    lst_1, lst_2 = reducedPts.getInfo()
    
    fireSampleData.append(lst_1)
    bboxSampleData.append(lst_2)
    print("{} Runtime: {} minutes".format(fireName, np.round((time.time()-t1)/60, 3)))
    
print("Total Runtime: {} minutes".format(np.round((time.time()-startTime)/60, 3)))

In [None]:
# band names from combined image
keys = ["FIRE_NAME"] + ee.Image(imageLst.get(0)).bandNames().getInfo()

In [None]:
# Save sample data as csv
saveSampleData(data=fireSampleData,
               keys=keys,
               geometry=firePts_df["geometry"],
               path="data/burned/postFireData.csv")

saveSampleData(data=bboxSampleData,
               keys=keys,
               geometry=bboxPts_df["geometry"],
               path="data/unburned/postFireData.csv")

In [None]:
# df_1 = pd.read_csv("data/burned/postFireData.csv")
# df_2 = pd.read_csv("data/unburned/postFireData.csv")

In [None]:
# post-fire Landsat 8 image, NLCD landcover, and thresholded burn severity as image layers

burnPalette = ["706c1e", "4e9d5c", "fff70b", "ff641b", "a41fd6"]
landCoverPalette = ["A2D6F2", "FF7F68", "258914", "FFF100", "7CD860", "B99B56"]

for name, date, geometry in bounds_df[["FIRE_NAME", "post-date", "geometry"]].values:
    geometry = ee.Geometry.Rectangle(list(geometry.bounds))
    center = geometry.centroid().getInfo()["coordinates"][::-1]

    fireImage = ee.Image(imageLst.filter(ee.Filter.eq("FIRE_NAME", name)).get(0)
                 ).clip(geometry)

    m = folium.Map(location=center, zoom_start=11.25)

    m.add_ee_layer(fireImage,
                   {"bands": ["SR_B7", "SR_B5", "SR_B3"], 
                    "gamma": [1.1, 1.1, 1],
                    "min": 1000, "max": 25000},
                    "Post Fire {}".format(date))

    m.add_ee_layer(fireImage, 
                  {"bands": ["landCover"],
                   "min": 1, "max": 6,
                   "palette": landCoverPalette},
                   "Land Cover")
    
    m.add_ee_layer(fireImage, 
                  {"bands": ["burnSeverity"],
                   "min": 1, "max": 5,
                   "palette": burnPalette},
                   "Burn Severity")
    
    m.add_child(folium.LayerControl())
    print(name)
    display(m)    
    print("\n \n")