In [None]:
import warnings
from copy import copy, deepcopy
from functools import partial, reduce
from itertools import islice
from pprint import pprint

import cartopy.crs as ccrs
import iris
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import statsmodels.api as sm
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split
from tqdm import tqdm

from wildfires.analysis.analysis import get_no_fire_mask
from wildfires.analysis.plotting import cube_plotting
from wildfires.data.cube_aggregation import Datasets, prepare_selection
from wildfires.data.datasets import (
    MCD64CMQ_C6,
    VODCA,
    AvitabileThurnerAGB,
    CarvalhaisGPP,
    CCI_BurnedArea_MERIS_4_1,
    CCI_BurnedArea_MODIS_5_1,
    GFEDv4,
    GFEDv4s,
    GlobFluo_SIF,
    MOD15A2H_LAI_fPAR,
    regions_GFED,
)
from wildfires.logging_config import enable_logging
from wildfires.utils import get_masked_array, get_ncpus, get_unmasked
from wildfires.utils import land_mask as get_land_mask
from wildfires.utils import match_shape, polygon_mask

enable_logging("jupyter")
warnings.filterwarnings("ignore", ".*Collapsing a non-contiguous coordinate.*")
warnings.filterwarnings("ignore", ".*DEFAULT_SPHERICAL_EARTH_RADIUS*")
warnings.filterwarnings("ignore", message=r".*spherical earth.*")

In [None]:
fire_datasets = Datasets(
    map(
        lambda fire_dataset: fire_dataset(),
        (GFEDv4s, GFEDv4, CCI_BurnedArea_MODIS_5_1, MCD64CMQ_C6,),
    )
).select_variables(
    ["CCI MODIS BA", "GFED4 BA", "GFED4s BA", "MCD64CMQ BA",]
) + Datasets(
    CCI_BurnedArea_MERIS_4_1()
).select_variables(
    "CCI MERIS BA"
)

monthly = prepare_selection(fire_datasets, which="monthly")
pprint(list(monthly))

bio_datasets = Datasets(
    map(
        lambda dataset: dataset(),
        (CarvalhaisGPP, AvitabileThurnerAGB, VODCA, GlobFluo_SIF, MOD15A2H_LAI_fPAR,),
    )
).select_variables(["AGB Tree", "SIF", "FAPAR", "LAI", "VOD Ku-band",])
pprint(bio_datasets)
pprint(bio_datasets.pretty_variable_names)

bio_monthly = prepare_selection(bio_datasets, which="monthly")

observed = CCI_BurnedArea_MERIS_4_1().cubes.extract_strict(
    iris.Constraint("fraction of observed area")
)
observed_mask = (
    np.mean(
        observed.data.reshape(
            2,
            int(observed.data.shape[0] / 2),
            observed.data.shape[1],
            observed.data.shape[2],
        ),
        axis=0,
    )
    < 0.8
)

In [None]:
land_mask = ~get_land_mask()
no_fire_mask = get_no_fire_mask()

regions = regions_GFED()

monthly.homogenise_masks()
bio_monthly.homogenise_masks()

for cube in monthly.cubes:
    cube.data.mask |= reduce(
        np.logical_or,
        map(
            partial(match_shape, target_shape=cube.shape),
            (land_mask, no_fire_mask, observed_mask),
        ),
    )

for cube in bio_monthly.cubes:
    cube.data.mask |= reduce(
        np.logical_or,
        map(partial(match_shape, target_shape=cube.shape), (land_mask, no_fire_mask)),
    )

climatologies = iris.cube.CubeList()
bio_climatologies = iris.cube.CubeList()

for cube in monthly.cubes:
    if not cube.coords("month_number"):
        iris.coord_categorisation.add_month_number(cube, "time")
    climatologies.append(cube.aggregated_by("month_number", iris.analysis.MEAN))

for cube in bio_monthly.cubes:
    if not cube.coords("month_number"):
        iris.coord_categorisation.add_month_number(cube, "time")
    bio_climatologies.append(cube.aggregated_by("month_number", iris.analysis.MEAN))

region_climatologies = dict()
region_bio_climatologies = dict()

for region_index in tqdm(range(1, 15)):
    region_mask = regions.data != region_index

    region_bio_climatologies[region_index] = iris.cube.CubeList()
    region_climatologies[region_index] = iris.cube.CubeList()

    for cube in climatologies:
        cube.data.mask |= region_mask
        region_climatologies[region_index].append(cube)

    for cube in bio_climatologies:
        cube.data.mask |= region_mask
        region_bio_climatologies[region_index].append(cube)

## Seasonality

In [None]:
for region_index in tqdm(range(1, 15)):
    region_name = regions.attributes["regions"][region_index]

    fire_region_climatology = region_climatologies[region_index]
    bio_region_climatology = region_bio_climatologies[region_index]

    fig = plt.figure(figsize=(11, 6))
    # axes = (plt.subplot(1, 2, 1), plt.subplot(1, 2, 2, projection=ccrs.Robinson()))
    axes = (plt.subplot(1, 2, 1), plt.subplot(1, 2, 2))
    for cube, name in zip(
        deepcopy(fire_region_climatology), [c.name() for c in fire_region_climatology],
    ):
        axes[0].plot(
            range(1, 13),
            cube.collapsed(
                ("latitude", "longitude"),
                iris.analysis.MEAN,
                weights=iris.analysis.cartography.area_weights(cube),
            ).data,
            label=name[:-3],
        )
    for cube, name in zip(
        deepcopy(bio_region_climatology), [c.name() for c in bio_region_climatology]
    ):
        for coord_name in ("latitude", "longitude"):
            if not cube.coord(coord_name).has_bounds():
                cube.coord(coord_name).guess_bounds()
        tseries = cube.collapsed(
            ("latitude", "longitude"),
            iris.analysis.MEAN,
            weights=iris.analysis.cartography.area_weights(cube),
        ).data
        tseries /= np.max(tseries)
        axes[1].plot(range(1, 13), tseries, label=name)

    axes[0].legend(loc="best")
    axes[0].set_ylabel("Average Burned Area Fraction (1)")
    axes[0].set_xlabel("Month")
    axes[0].set_yscale("log")

    axes[1].set_xlabel("Month")
    axes[1].set_ylabel("Vegetation Proxy (1)")
    axes[1].legend(loc="best")

    fig.suptitle(region_name + " Climatology")

    # fig.suptitle('Northern Hemisphere South America')
    plt.tight_layout()
    fig.subplots_adjust(top=0.85)
    plt.savefig(
        os.path.join(
            os.path.expanduser("~"),
            "tmp",
            f"{region_name.replace(' ', '_')}_meris_mask_plots.png",
        ),
        bbox_inches="tight",
        dpi=800,
    )

    # vis_cube = deepcopy(mean.cubes[0])
    # cube_plotting(vis_cube, ax=axes[1], log=True, title=region_name)