In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import warnings

import iris
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import HTML
from joblib import Memory, Parallel, delayed
from matplotlib import animation, rc
from tqdm import tqdm, tqdm_notebook

from wildfires.analysis.plotting import cube_plotting
from wildfires.data.cube_aggregation import get_ncpus
from wildfires.data.datasets import DATA_DIR, GSMaP_dry_day_period, data_map_plot
from wildfires.utils import Time, TqdmContext
from wildfires.utils import land_mask as get_land_mask

# tqdm_notebook does not work for some reason

ncpus = get_ncpus()
warnings.filterwarnings("ignore", ".*Collapsing a non-contiguous coordinate.*")
memory = Memory(DATA_DIR)

In [None]:
dry_day_period = GSMaP_dry_day_period()
print(dry_day_period[0])

In [None]:
mpl.rcParams["figure.figsize"] = (10, 4)
for analysis_method, title in zip(
    (iris.analysis.MEAN, iris.analysis.MIN, iris.analysis.MAX),
    ("Log Mean Dry Day Period", "Log Min Dry Day Period", "Log Max Dry Day Period"),
):
    cube_plotting(
        dry_day_period[0].collapsed("time", analysis_method), log=True, title=title
    )

In [None]:
# Collect selected cubes, then regrid them.
new_dataset = dry_day_period.copy()
new_dataset.cubes = iris.cube.CubeList()
for i, single_time_cube in enumerate(dry_day_period[0].slices_over("time")):
    if i % 1 == 0:
        new_dataset.cubes.append(single_time_cube)

print("Starting regridding")
new_dataset.regrid()

# Define and use relevant masks.
land_mask = ~get_land_mask()

# Define a latitude mask which ignores data beyond 60 degrees, as the precipitation data does not extend to those latitudes.
lats = new_dataset[0].coord("latitude").points
lons = new_dataset[0].coord("longitude").points

latitude_grid = np.meshgrid(lats, lons, indexing="ij")[0]
lat_mask = np.abs(latitude_grid) > 60

# Apply the masks.
for single_time_cube in new_dataset:
    single_time_cube.data.mask |= land_mask | lat_mask

vmin = min(
    cube.collapsed(["longitude", "latitude", "time"], iris.analysis.MIN).data
    for cube in new_dataset
)
vmax = max(
    cube.collapsed(["longitude", "latitude", "time"], iris.analysis.MAX).data
    for cube in new_dataset
)

In [None]:
warnings.filterwarnings("ignore", ".*divide by zero.*")
warnings.filterwarnings("ignore", ".*invalid value encountered.*")

mpl.rcParams["figure.figsize"] = (14, 9)

# Define a latitude mask which ignores data beyond 60 degrees, as the precipitation data does not extend to those latitudes.
lat_bounds = new_dataset[0].coord("latitude").contiguous_bounds()
lon_bounds = new_dataset[0].coord("longitude").contiguous_bounds()
n_lat = len(lat_bounds) - 1
n_lon = len(lon_bounds) - 1


@memory.cache
def get_js_animation(N_frames=len(new_dataset)):
    fig, ax, mesh, cb, suptitle_text = cube_plotting(
        new_dataset[0],
        log=True,
        animation_output=True,
        title="",
        vmin=0,
        vmax=np.log(vmax),
    )
    title_text = ax.text(
        0.5, 1.08, "bla", transform=ax.transAxes, ha="center", fontsize=15
    )
    plt.close()  # Prevent display of (duplicate) static figure due to %matplotlib inline

    # N_frames = len(new_dataset)
    # N_frames = 4
    interval = 1000 / 12  # One second per year.

    def init():
        mesh.set_array(np.zeros(n_lat * n_lon))
        title_text.set_text("")
        return (mesh,)

    with TqdmContext(unit=" plots", desc="Plotting", total=N_frames) as t:

        def animate(i):
            single_time_cube = new_dataset[i]
            single_time_cube.data.mask |= land_mask | lat_mask
            _ = cube_plotting(
                single_time_cube,
                log=True,
                ax=ax,
                mesh=mesh,
                animation_output=False,
                new_colorbar=False,
                title="",
            )
            title_text.set_text(
                # Ignore the time, which flip-flops between 00:00:00 and 12:00:00.
                "Dry Day Period "
                + str(single_time_cube.coord("time").cell(0).point)[:10]
            )
            t.update()

            return (mesh, title_text)

        anim = animation.FuncAnimation(
            fig, animate, init_func=init, frames=N_frames, interval=interval, blit=True
        )

        js_output = anim.to_jshtml()
        return js_output


HTML(get_js_animation())