### Note that the two datasets measure inherently different things!

WWLLN measure ground strikes explicitly, while CAPE x Precip is a proxy for (cloud) lighting.

In [None]:
import logging
import os
import re
import sys
import warnings
from collections import namedtuple
from functools import reduce
from itertools import combinations
from operator import mul

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
from loguru import logger as loguru_logger
from matplotlib.patches import Rectangle
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import cross_val_score, train_test_split
from tqdm import tqdm

import wildfires.analysis
from alepython import ale_plot
from alepython.ale import _second_order_ale_quant
from wildfires.analysis import *
from wildfires.dask_cx1 import get_client
from wildfires.data import *
from wildfires.logging_config import enable_logging
from wildfires.qstat import get_ncpus
from wildfires.utils import *

loguru_logger.enable("alepython")
loguru_logger.remove()
loguru_logger.add(sys.stderr, level="WARNING")

logger = logging.getLogger(__name__)

enable_logging("jupyter")

warnings.filterwarnings("ignore", ".*Collapsing a non-contiguous coordinate.*")
warnings.filterwarnings("ignore", ".*DEFAULT_SPHERICAL_EARTH_RADIUS*")
warnings.filterwarnings("ignore", ".*guessing contiguous bounds*")

normal_coast_linewidth = 0.5
mpl.rc("figure", figsize=(14, 6))
mpl.rc("font", size=9.0)

figure_saver = FigureSaver(
    directories=os.path.join("~", "tmp", "analysis_wwlln_vs_cape_precip"), debug=True
)
memory = get_memory("analysis_wwlln_vs_cape_precip", verbose=1)

# Load WWLLN and CAPExPRECIP Data
### Compare at their native resolutions

In [None]:
wwlln = WWLLN()
cape_precip = ERA5_CAPEPrecip()

In [None]:
datasets = Datasets([wwlln, cape_precip])
dataset_times(datasets, lat_lon=True)[2]

In [None]:
_ = cube_plotting(wwlln.cube, log=True)

In [None]:
_ = cube_plotting(cape_precip.cube, log=True)

## Scale to the same grid and compute correlations overall and over land

In [None]:
monthly, mean, climatology = prepare_selection(datasets)

## Regridded Mean Datasets

In [None]:
dataset_times(monthly, lat_lon=True)[2]

### Mean maps

In [None]:
for cube in mean.cubes:
    cube_plotting(cube, log=True)

### Standard deviation maps

In [None]:
for cube in monthly.cubes:
    cube_plotting(
        cube.collapsed("time", iris.analysis.STD_DEV),
        log=True,
        title=f"STD: {cube.name()}",
    )

### Correlations

In [None]:
from functools import reduce

for selection in (monthly, climatology):
    selection.homogenise_masks()
    overall_mask = reduce(np.logical_or, [cube.data.mask for cube in selection.cubes])
    selection.apply_masks(overall_mask)

### Monthly correlations

In [None]:
corr_mat = np.corrcoef(*[get_unmasked(cube.data) for cube in monthly.cubes])
assert corr_mat.shape[0] == 2, "Expect only 2 variables."
print("Monthly, all, corr:", corr_mat[0, 1])

In [None]:
land_mask = ~get_land_mask()
monthly_land = monthly.copy(deep=True)
monthly_land.apply_masks(land_mask)
corr_mat = np.corrcoef(*[get_unmasked(cube.data) for cube in monthly_land.cubes])
assert corr_mat.shape[0] == 2, "Expect only 2 variables."
print("Monthly, land, corr:", corr_mat[0, 1])

In [None]:
plt.figure()
arrs = [get_unmasked(cube.data) for cube in monthly_land.cubes]
names = list(monthly_land.pretty_variable_names)
plt.hexbin(*arrs, bins="log")
plt.xlabel(names[0])
_ = plt.ylabel(names[1])

### Climatological correlations

In [None]:
corr_mat = np.corrcoef(*[get_unmasked(cube.data) for cube in climatology.cubes])
assert corr_mat.shape[0] == 2, "Expect only 2 variables."
print("Monthly, all, corr:", corr_mat[0, 1])

In [None]:
land_mask = ~get_land_mask()
climatology_land = climatology.copy(deep=True)
climatology_land.apply_masks(land_mask)
corr_mat = np.corrcoef(*[get_unmasked(cube.data) for cube in climatology_land.cubes])
assert corr_mat.shape[0] == 2, "Expect only 2 variables."
print("Monthly, land, corr:", corr_mat[0, 1])

In [None]:
plt.figure()
arrs = [get_unmasked(cube.data) for cube in climatology_land.cubes]
names = list(climatology_land.pretty_variable_names)
plt.hexbin(*arrs, bins="log")
plt.xlabel(names[0])
_ = plt.ylabel(names[1])