In [None]:
import logging
import logging.config
import os
import pickle
import warnings
from copy import deepcopy

import cf_units
import ipdb
import iris
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import statsmodels.api as sm
from IPython.display import HTML
from joblib import Memory, Parallel, delayed
from matplotlib import animation, rc
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split
from tqdm import tqdm, tqdm_notebook

import wildfires.utils as utils
from wildfires.analysis.plotting import (
    cube_plotting,
    get_cubes_vmin_vmax,
    map_model_output,
    partial_dependence_plot,
)
from wildfires.analysis.processing import log_map, log_modulus, map_name, vif
from wildfires.data.cube_aggregation import *
from wildfires.data.cube_aggregation import Datasets, prepare_selection
from wildfires.data.datasets import *
from wildfires.data.datasets import DATA_DIR, GSMaP_dry_day_period, data_map_plot
from wildfires.logging_config import LOGGING
from wildfires.utils import Time, TqdmContext
from wildfires.utils import land_mask as get_land_mask
from wildfires.utils import match_shape, polygon_mask

logger = logging.getLogger(__name__)
logging.config.dictConfig(LOGGING)

# tqdm_notebook does not work for some reason
warnings.filterwarnings("ignore", ".*Collapsing a non-contiguous coordinate.*")

In [None]:
normal_size = 9.0
normal_coast_linewidth = 0.5
dpi = 600

In [None]:
selection = Datasets((ERA5_DryDayPeriod(), GSMaP_dry_day_period()))
selection.show("pretty")

min_time, max_time, times_df = dataset_times(selection.datasets)
print(times_df.to_string(index=False))
# print(times_df.to_latex(index=False))

monthly_datasets, mean_datasets, climatology_datasets = prepare_selection(selection)

# Get land mask.
land_mask = ~get_land_mask(n_lon=1440)

# Define a latitude mask which ignores data beyond 60 degrees, as the precipitation data does not extend to those latitudes.
lat_mask = ~polygon_mask([(180, -60), (-180, -60), (-180, 60), (180, 60), (180, -60)])

In [None]:
# Make a deep copy so that the original cubes are preserved.
masked_mean_datasets = mean_datasets.copy(deep=True)

# Apply the masks.
for cube in masked_mean_datasets.cubes:
    cube.data.mask |= match_shape(land_mask, cube.shape) | match_shape(
        lat_mask, cube.shape
    )

mpl.rcParams["figure.figsize"] = (11, 6)

vmin, vmax = get_cubes_vmin_vmax(masked_mean_datasets.cubes, (10, 90))

for dataset in masked_mean_datasets:
    # cube_plotting(dataset.cubes[0],log=True, title=f"log Mean Dry Day Period ({dataset.name})")
    cube_plotting(
        dataset.cube,
        log=False,
        title=f"Mean Dry Day Period ({dataset.name})",
        vmin=vmin,
        vmax=vmax,
    )

In [None]:
# Make a deep copy so that the original cubes are preserved.
masked_monthly_datasets = monthly_datasets.copy(deep=True)

# Apply the masks.
for cube in masked_monthly_datasets.cubes:
    cube.data.mask |= match_shape(land_mask, cube.shape) | match_shape(
        lat_mask, cube.shape
    )

mpl.rcParams["figure.figsize"] = (11, 6)
era5 = "ERA5_DryDayPeriod"
gsmap = "GSMaP_dry_day_period"

era5_cube = masked_monthly_datasets.select_datasets(era5, inplace=False).dataset.cube
gsmap_cube = masked_monthly_datasets.select_datasets(gsmap, inplace=False).dataset.cube

gsmap_cube.remove_coord("time")
gsmap_cube.add_dim_coord(era5_cube.coord("time"), 0)

relative_monthly_differences = 100 * (gsmap_cube - era5_cube) / gsmap_cube
relative_monthly_differences.units = cf_units.Unit("%")

mean_relative_monthly_difference = relative_monthly_differences.collapsed(
    "time", iris.analysis.MEAN
)
vmin, vmax = get_cubes_vmin_vmax(
    iris.cube.CubeList([mean_relative_monthly_difference]), (10, 90)
)
cube_plotting(
    mean_relative_monthly_difference,
    log=False,
    title=f"Mean Relative Monthly Differences (GSMaP - ERA5)",
    vmin=vmin,
    vmax=vmax,
)

min_relative_monthly_difference = dummy_lat_lon_cube(
    np.take_along_axis(
        relative_monthly_differences.data,
        np.expand_dims(
            np.argmin(np.abs(relative_monthly_differences.data), axis=0), axis=0
        ),
        axis=0,
    ),
    units=cf_units.Unit("%"),
)[0]
vmin, vmax = get_cubes_vmin_vmax(
    iris.cube.CubeList([min_relative_monthly_difference]), (0, 100)
)
cube_plotting(
    min_relative_monthly_difference,
    log=False,
    title=f"(Best-Case) Non-local Min Relative Monthly Differences (GSMaP - ERA5)",
    vmin=vmin,
    vmax=vmax,
)

max_relative_monthly_difference = dummy_lat_lon_cube(
    np.take_along_axis(
        relative_monthly_differences.data,
        np.expand_dims(
            np.argmax(np.abs(relative_monthly_differences.data), axis=0), axis=0
        ),
        axis=0,
    ),
    units=cf_units.Unit("%"),
)[0]
vmin, vmax = get_cubes_vmin_vmax(
    iris.cube.CubeList([max_relative_monthly_difference]), (0, 100)
)
_ = cube_plotting(
    max_relative_monthly_difference,
    log=False,
    title=f"(Worst-Case) Non-local Max Relative Monthly Differences (GSMaP - ERA5)",
    vmin=vmin,
    vmax=vmax,
)

min_relative_monthly_difference = relative_monthly_differences[
    np.argmax(np.linalg.norm(relative_monthly_differences.data, axis=(1, 2)))
]
vmin, vmax = get_cubes_vmin_vmax(
    iris.cube.CubeList([min_relative_monthly_difference]), (5, 95)
)
cube_plotting(
    min_relative_monthly_difference,
    log=False,
    title=f"Min Relative Monthly Difference Single Time (GSMaP - ERA5)",
    vmin=vmin,
    vmax=vmax,
)

max_relative_monthly_difference = relative_monthly_differences[
    np.argmin(np.linalg.norm(relative_monthly_differences.data, axis=(1, 2)))
]
vmin, vmax = get_cubes_vmin_vmax(
    iris.cube.CubeList([max_relative_monthly_difference]), (5, 95)
)
_ = cube_plotting(
    max_relative_monthly_difference,
    log=False,
    title=f"Max Relative Monthly Difference Single Time (GSMaP - ERA5)",
    vmin=vmin,
    vmax=vmax,
)