# Woolsey Fire linear regressions

This notebook calculates linear regressions between data related to the Woolsey Fire. The notebook produces a large number of plots, so figures are saved to disk instead of displayed inside the notebook.

In [None]:
import itertools
import json
import os
import re
import warnings

import earthpy as et
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import numpy_groupies as npg
import rioxarray as rxr
import seaborn as sns
import xarray as xr

from ea_drought_burn.config import CRS
from ea_drought_burn.utils import (
    aggregate,
    create_figure,
    create_sampling_mask,
    open_raster,
    plot_bands,
    plot_rgb,
    plot_regression,
    reproject_match
)




# Set whether to plot prefire or postfire data. Used to set the boundary of
# the study area and the response variables plotted.
DATA_TO_PLOT = "postfire"




# Set default plotting parameters
plt.rc("figure.constrained_layout", use=True, h_pad=12/72, w_pad=12/72)

# Set working directory to the earthpy data directory
os.chdir(os.path.join(et.io.HOME, "earth-analytics", "data", "woolsey-fire"))

In [None]:
def mask_burned(xda, burned):
    """Calculates array to mask pixels based on if they burned"""
    if burned.lower() == "burned":
        return xr.where(xda >= 100, True, False).values
    return xr.where(xda < 100, True, False).values
        
    
def mask_community(xda, name):
    """Calculates array to mask everything except the given community"""
    names = {
        "All communities": None,
        "Annual grass": 1,
        "Chaparral": 2,
        "Coastal sage scrub": 3,
        "Oak woodland": 4,
        "Riparian": 5,
        "Substrate": 6
    }
    if names[name] is None:
        return xr.where(xda == names[name], True, True).values
    return xr.where(xda == names[name], True, False).values


def mask_aspect(xda, aspect):
    """Calculates array to mask everything except the given aspect"""
    try:
        i = ["E", "S", "W"].index(aspect)
    except ValueError:
        return xr.where((xda > 315) | (xda <= 45), True, False)
    else:
        mindeg =  + 90 * i
        maxdeg = mindeg + 90
        return xr.where((xda > mindeg) & (xda <= maxdeg), True, False)

    
def mask_slope(xda, steepness="flat"):
    """Calculates array to mask everything except the given steepness"""
    # FIXME: Bin these better
    steepnesses = {
        "flat": (0, 10),
        "shallow": (10, 30),
        "steep": (30, 90)
    }
    minslope, maxslope = steepnesses[steepness.lower()]
    return xr.where((xda > minslope) & (xda <= maxslope), True, False)


def slugify(val):
    return re.sub(" +", "_", val.lower())

In [None]:
# Test DATA_TO_PLOT
if DATA_TO_PLOT not in {"prefire", "postfire"}:
    raise ValueError("DATA_TO_PLOT must be either prefire or postfire")

# Load data

In [None]:
# Load the Woolsey Fire perimeter
woolsey_fire = gpd.read_file(os.path.join("shapefiles",
                                          "nifc_woolsey_perimeter",
                                          "2018-CAVNC-091023.shp")).to_crs(CRS)
crop_bound = woolsey_fire.geometry

# Pre-fire plots use the fire envelope, post-fire plots use the fire scar
if DATA_TO_PLOT == "prefire":
    crop_bound = woolsey_fire.envelope.geometry
elif DATA_TO_PLOT == "postfire":
    crop_bound = woolsey_fire.geometry

In [None]:
# Open SMM stack
smm_stack = open_raster(
    os.path.join("aviris-climate-vegetation", "SMMDroughtstack.dat"),
    crs=CRS,
    crop_bound=crop_bound,
)
smm_stack = smm_stack.where(smm_stack >= -1e38)

# Calculate fraction dead from fraction alive
fal = smm_stack[1:5]
minfal, maxfal = np.nanpercentile(fal, (1, 99))
fal = fal.where((fal > minfal) & (fal < maxfal))
fdd = 1 - fal

# Calculate dFAL as pre- minus post-fire FAL. Positive values indicate loss
# of living vegetation. Align dFAL with the second year of the range for each
# (for example, align 2013-2014 to 2014). The first array is therefore empty. 
dfal = []
dfal.append(fal[0].where(~np.isfinite(fal[0])))  # first band is all NaN
dfal.append(fal[0] - fal[1])
dfal.append(fal[1] - fal[2])
dfal.append(fal[2] - fal[3])
for i, band in enumerate(dfal):
    band["band"] = i + 1
dfal = xr.concat(dfal, dim="band")

# Calculate fraction dead since each of the four years in the dataset. For
# each pixel, calculate the minmum fraction dead between the current year
# and the end of the dataset, then subtract the fraction dead from the
# preceding year.
dead_since = []
dead_since.append(fdd[:3].min(axis=0))
dead_since.append(fdd[1:3].min(axis=0) - fdd[0])
dead_since.append(fdd[2:3].min(axis=0) - fdd[1])
dead_since.append(fdd[3:].min(axis=0) - fdd[2])
for i, band in enumerate(dead_since):
    band["band"] = i + 1
dead_since = xr.concat(dead_since, dim="band")

In [None]:
# Read dNBR from MTBS 
path = os.path.join("mtbs-burn-severity",
                    "ca3424011870020181108",
                    "ca3424011870020181108_20171215_20181215_dnbr.tif")
dnbr = open_raster(path, crs=CRS)
dnbr = dnbr.where(dnbr != -9999, np.nan)

dnbr = reproject_match(dnbr, smm_stack[0])

In [None]:
# Read slope data
path = os.path.join("usgs-terrain", "usgs_2016_18_merge_rs15m_fix_slope.tif")
slope = open_raster(path)
slope = reproject_match(slope, smm_stack[0])

# Read aspect data
path = os.path.join("usgs-terrain", "usgs_2016_18_merge_rs15m_fix_aspect.tif")
aspect = open_raster(path)
aspect = reproject_match(aspect, smm_stack[0])

# Calculate folded aspect
folded_aspect = np.absolute(180 - np.absolute(aspect - 225))

In [None]:
# Create sampling mask
sampling_mask = create_sampling_mask(
    smm_stack[0],
    counts={"training": 7000, "validation": 3000},
    seed=20210421
)

In [None]:
# Read PRISM grid
prism_grid = open_raster(
    os.path.join("masks", "prism_grid.tif"),
    crs=CRS,
    crop_bound=crop_bound,
    masked=False
)

In [None]:
# Create lookup for full datasets
datasets = {
    # Vegetation community
    "Community": smm_stack[0],
    # Burn severity
    "dNBR": dnbr,
    # Fraction alive
    "FAL": smm_stack[1:5],
    "dFAL": dfal,
    "Dead Since": dead_since,
    # Topography
    "Aspect": aspect,
    "Folded Aspect": folded_aspect,
    "Slope": slope,
    # Drought climate
    "Days Precipitation": smm_stack[5:9],
    "Max VPD": smm_stack[9:13],
    "Min Temperature": smm_stack[13:17],
    "Heat Days Over 95": smm_stack[17:21],
    "Cumulative Precipitation": smm_stack[21:25]
}

# Calculate subsets center around a cardinal direction
for cdir, offset in [
    ("North", 0),
    ("East", 90),
    ("South", 180),
    ("West", 270),
]:
    aspect_rad = xr.apply_ufunc(np.deg2rad, datasets["Aspect"] - offset)
    datasets[f"{cdir}ness"] = xr.apply_ufunc(np.cos, aspect_rad.copy())

In [None]:
# Verify that all data has the same shape
shapes = {
    "prism_grid": prism_grid.shape[-2:],
    "sampling_mask": sampling_mask.shape[-2:],
}
shapes.update({k: v.shape[-2:] for k, v in datasets.items()})
if len(set([tuple(s) for s in shapes.values()])) != 1:
    display(shapes)
    raise ValueError("Shapes do not match")

# Verify that all data has the same bounds
bounds = {
    "prism_grid": prism_grid.rio.bounds(),
    "sampling_mask": sampling_mask.rio.bounds(),
}
bounds.update({k: v.rio.bounds() for k, v in datasets.items()})
if len(set([s for s in bounds.values()])) != 1:
    display(bounds)
    raise ValueError("Bounds do not match")

In [None]:
# Create training subset
training = {k: v.where(sampling_mask[0]) for k, v in datasets.items()}

# Regressions

In [None]:
# This box creates A LOT of plots, so don't show them
%matplotlib agg

def plot_regressions(
    xdata,
    ydata,
    subsets=None,
    agg_to=None,
    xagg=np.nanmean,
    yagg=np.nanmean,
    colors=None,
    titles=None,
    outdir=None,
    use_one_figure=True,
    **kwargs
):
    """Plots a set of regressions"""
    
    # Set maplotlib params
    plt.rc("font", size=24)
    plt.rc("axes", labelsize=24)
    plt.rc("xtick", labelsize=24)
    plt.rc("ytick", labelsize=24)
    plt.rc("legend", fontsize=24)
    plt.rc("figure.constrained_layout", use=True, h_pad=12/72, w_pad=12/72)
    
    subset_names = list(subset.keys())
    subset_data = list(subsets.values())

    # Default is that each key gets its own column
    for data in (xdata, ydata):
        if isinstance(data, dict):
            n_cols = len(data.keys())
            break
    else:
        n_cols = 1
    
    # Combine all plots into a single figure
    if use_one_figure:
        
        # Get the number of options in each subset
        n_items = [len(s[2]) for s in subset_data]
        
        # If only one key, set number of columns based on items in first subset
        if n_cols == 1:
            try:
                n_cols = n_items[0]
                n_items = n_items[1:]
            except (IndexError, TypeError, ValueError):
                pass
            
        #print(2, n_cols)
        
        # Set number of rows based on total number of subsets
        try:
            n_rows = np.product(n_items)
            # If only one column, plot everything in one row instead
            if n_cols == 1:
                n_cols = n_rows
                n_rows = 1
        except (IndexError, TypeError, ValueError):
            n_rows = 1
            
        #print(3, n_rows, n_cols)
        
        fig, axes = create_figure(n_rows, n_cols)
    
    # Otherwise, calculate the total number of plots needed to convey each
    # plot from the primary subset
    else:
        try:
            n_rows = np.product([len(s[2]) for s in subset_data[1:]])
        except ValueError:
            n_rows = 1
            
    # Set up the figure based on the first entry
    if subsets:
        func, xda, args = subset_data[0]
    else:
        func, xda, args = None, None, [None]
        
    # Colors default to black if none supplied
    if colors is None:
        colors = ["black"] * len(args)
    
    # Titles default to "Untitled" if none supplied
    if titles is None:
        titles = ["Untitled"] * len(args)

    for arg, color, title in zip(args, colors, titles):
        
        if not use_one_figure:
            fig, axes = create_figure(n_rows, n_cols, title)
        
        # Create mask from primary subset if defined
        if xda is not None:
            mask = func(xda, arg) if func else xda.copy()
        else:
            mask = None
        
        try:
            combinations = itertools.product(*[s[2] for s in subset_data[1:]])
        except (TypeError, ValueError):
            combinations = [None]

        # Walk through each subset if subsets defined
        for args in combinations:

            # Combine masks for each subset into a single mask
            subset_mask = None
            subtitle = ""
            if mask is not None:
                subset_mask = mask.copy()
                filters = []
                for (func_, xda_, _), arg in zip(subset_data[1:], args):
                    subset_mask *= func_(xda_, arg) if func_ else xda_
                    filters.append(arg)
                subtitle = ', '.join([str(s) for s in filters])

            # Either xdata, ydata, or both can be dicts. If only one is a
            # dict, plot each key against the other array. If both are dicts,
            # their keys must match.
            keys = []
            for data in (xdata, ydata):
                if isinstance(data, dict):
                    keys.extend(data.keys())
            
            for key in sorted(set(keys if keys else [None])):
                
                # Read x and y data
                if isinstance(xdata, dict):
                    x = xdata[key].copy()
                else:
                    x = xdata.copy()
                    
                if isinstance(ydata, dict):
                    y = ydata[key].copy()
                else:
                    y = ydata.copy()
                    
                # Mask any non-finite value in subset, x, or y
                if subset_mask is not None:
                    x = x.where(subset_mask)
                    y = y.where(subset_mask)
                
                # Calculate blocks aligned with a reference grid
                if agg_to is not None:
                    try:
                        x = aggregate(x, agg_to, func=xagg)
                        y = aggregate(y, agg_to, func=yagg)
                    except (IndexError, ValueError):
                        raise ValueError(
                            f"Aggregation failed: agg_to={agg_to.shape},"
                            f" x={x.shape}, xlabel={xlabel},"
                            f" y={y.shape}, ylabel={ylabel}"
                        )
                
                # Convert to 1D arrays
                x = np.ravel(x)
                y = np.ravel(y)
                
                # Limit to values that are finite in both arrays
                xy_mask = x * y
                x = x[np.isfinite(xy_mask)]
                y = y[np.isfinite(xy_mask)]
                
                # Set axis title
                ax_title = f"{ylabel} vs. {xlabel}"
                if key != xlabel or title:
                    vals = [
                        "Agg" if agg_to is not None else "",
                        key if key != xlabel else "",
                        title
                    ]
                    val = " - ".join([str(s) for s in vals if s])
                    if val:
                        ax_title += "\n" + val
                if subtitle:
                    ax_title += f"\n({subtitle})"
                ax_title = ax_title.replace("\n\n", "\n").strip()

                # Plot regression
                plot_regression(x, y, axes.pop(0),
                                color=color,
                                title=ax_title,
                                **kwargs)

        # Save plot to file if directory supplied and axes exhuasted
        if outdir and not axes:
            
            try:
                os.makedirs(outdir)
            except OSError:
                pass
            
            # Construct filename based on arguments
            parts = [
                "".join(subset_names) if subset_names else "all",
                "agg" if agg_to is not None else "pt",
                ylabel,
                "vs",
                xlabel
            ]
            if not use_one_figure:
                parts.append(title)
            stem = slugify(' '.join(parts)).strip("_")
            path = os.path.join(outdir, f"{stem}.png")
            plt.savefig(path)
            
            # Remove figure instance
            plt.close(fig)
    
    if axes:
        raise ValueError(f"{ylabel} vs {xlabel}")
    
    # Reset matplotlib plotting params
    plt.rcParams.update(plt.rcParamsDefault)


    

# Set names and colors of vegetation communities
communities = [
    "All communities",
    "Annual grass",
    "Chaparral",
    "Coastal sage scrub",
    "Oak woodland",
    "Riparian",
    "Substrate"
]

colors = [
    "black",
    "tab:green",
    "tab:olive",
    "tab:cyan",
    "tab:red",
    "tab:blue",
    "lightgray"
]

# Define explanatory and response values to evaluate as (x, y)
xy = {
    # Vegetation mortality variables
    #"FAL": ["dNBR"],
    #"dFAL": ["dNBR"],
    #"Dead Since": ["dNBR"],
    # Climate variables
    #"Days Precipitation": ["dNBR", "dFAL", "FAL"],
    #"Max VPD": ["dNBR", "dFAL", "FAL"],
    #"Min Temperature": ["dNBR", "dFAL", "FAL"],
    #"Heat Days Over 95": ["dNBR", "dFAL", "FAL"],
    #"Cumulative Precipitation": ["dNBR", "dFAL", "FAL"],
    # Topography variables
    "Folded Aspect": ["dNBR", "dFAL", "FAL"],
    "Slope": ["dNBR", "dFAL", "FAL"]   
}

# Define parameters to use for plotting
params = {
    "xydata": [],
    "agg": [True, False],
}

# Select y based on if we're plotting pre- or post-fire data
for x, ys in xy.items():
    for y in ys:
        if (
            DATA_TO_PLOT == "prefire" and y == "dNBR"
            or DATA_TO_PLOT == "postfire" and y != "dNBR"
        ):
            continue
        params["xydata"].append((x, y))
        
# Calculate all possible combinations of params
param_sets = list(itertools.product(*params.values()))

# Define subsets that will be used to segment the data. Each subset will be
# used for each set of parameters.
subsets = {
    "c": (mask_community, datasets["Community"].copy(), communities[:-1]),
    "b": (mask_burned, dnbr, ["Burned", "Unburned"]),
    "s": (mask_slope, slope, ["Flat", "Shallow", "Steep"]),
    "a": (mask_aspect, aspect, ["N", "S"])
}

# Calculate all possible combinations of subsets
keys = list(subsets)
subset_sets = [""]
for r in range(1, len(keys) + 1):
    subset_sets.extend(itertools.combinations(keys, r))

for xydata, agg in param_sets:

    # Set data source (all data if aggregating, training otherwise)
    lookup = datasets if agg else training
    
    # Get x and y data
    xlabel, ylabel = xydata
    
    # Convert banded arrays to year-keyed dicts
    xdata = lookup[xlabel].copy() 
    if len(xdata.shape) == 3:
        xdata = dict(zip(range(2013, 2017), xdata))

    ydata = lookup[ylabel].copy()
    if len(ydata.shape) == 3:
        ydata = dict(zip(range(2013, 2017), ydata))

    # Iterate through each subset
    for subset_keys in subset_sets:
        
        print(f"Processing {ylabel} vs {xlabel} ({''.join(subset_keys)})...")

        # Check for subsets
        if subset_keys:         
            subset = {k: subsets[k] for k in subset_keys}
            titles = subset[subset_keys[0]][2]
        else:
            subset = {}
            titles = [""]

        # Skip subsets that use the same data as x
        if (
            xlabel in ("Aspect", "Northness", "Eastness") and "a" in subset
            or xlabel == "Slope" and "s" in subset
        ):
            continue
        
        # Name output directories using the response variable
        outdir = os.path.join(*["outputs", "plots", DATA_TO_PLOT, ylabel])

        # Style axes based on labels
        kwargs = {}
        if xlabel == "FAL":
            kwargs["xlim"] = (0, 1)
        if ylabel == "FAL":
            kwargs["ylim"] = (0, 1)
        if "ness" in xlabel:
            kwargs["xlim"] = (-1, 1)

        plot_regressions(
            xdata,
            ydata,
            subsets=subset,
            agg_to=prism_grid if agg else None,
            xagg=np.nanmean if xlabel != "dFAL" else np.nansum,
            yagg=np.nanmean if ylabel != "dFAL" else np.nansum,
            colors=colors if "c" in subset_keys else None,
            titles=titles,
            outdir=outdir,
            xlabel=xlabel,
            ylabel=ylabel,
            **kwargs
        )