In [22]:
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import geopandas as gpd
import glob
from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
from shapely.geometry import Point
import xarray as xr

import warnings
warnings.filterwarnings('ignore')

In [27]:
def get_spi_dataset(acc_period: str = 1, years: list = [2020]):
    data_root_folder = '/data1/drought_dataset/spi/'
    spi_folder = os.path.join(data_root_folder, f'spi{acc_period}')
    spi_paths = []

    for year in years:
        spi_paths.extend(sorted(glob.glob(
            f'{data_root_folder}spi{acc_period}/SPI{acc_period}_gamma_global_era5_moda_ref1991to2020_{year}*.nc')))

    return xr.open_mfdataset(spi_paths, chunks={'time': "auto"}, concat_dim="time", combine='nested', parallel=False)


def get_spei_dataset(acc_period: str = 1, years: list = [2020]):
    data_root_folder = '/data1/drought_dataset/spei/'
    spi_folder = os.path.join(data_root_folder, f'spi{acc_period}')
    spi_paths = []

    for year in years:
        spi_paths.extend(sorted(glob.glob(
            f'{data_root_folder}spei{acc_period}/SPEI{acc_period}_genlogistic_global_era5_moda_ref1991to2020_{year}*.nc')))

    return xr.open_mfdataset(spi_paths, chunks={'time': "auto"}, concat_dim="time", combine='nested', parallel=False)


def mask_invalid_values(ds, index, value=-9999):
    ds[index] = ds[index].where(ds[index] != value, np.nan)
    return ds

def get_spei_significance_dataset(index='SPEI1', year=2020):
    data_root_folder='/data1/drought_dataset/spei/'
    quality_paths = []
    for month in range(1, 13):
        month_str = f'{month:02d}'
        quality_paths.append(f'{data_root_folder}{index.lower()}/parameter/{index}_significance_global_era5_moda_{year}{month_str}_ref1991to2020.nc')
    return xr.open_mfdataset(quality_paths, concat_dim="time", combine='nested', parallel=False)


def get_spi_significance_dataset(index='SPI1', year=2020):
    data_root_folder='/data1/drought_dataset/spi/'
    quality_paths = []
    for month in range(1, 13):
        month_str = f'{month:02d}'
        quality_paths.append(f'{data_root_folder}{index.lower()}/parameter/{index}_significance_global_era5_moda_{year}{month_str}_ref1991to2020.nc')
    return xr.open_mfdataset(quality_paths, concat_dim="time", combine='nested', parallel=False)


def create_drought_dataset(years: list):
    # spi1 = get_spi_dataset(acc_period=1, years=years)
    # spi3 = get_spi_dataset(acc_period=3, years=years)
    # spi6 = get_spi_dataset(acc_period=6, years=years)
    # spi12 = get_spi_dataset(acc_period=12, years=years)
    # spi24 = get_spi_dataset(acc_period=24, years=years)
    # spi48 = get_spi_dataset(acc_period=48, years=years)
    
    # spei1 = get_spei_dataset(acc_period=1, years=years)
    # spei3 = get_spei_dataset(acc_period=3, years=years)
    spei6 = get_spei_dataset(acc_period=6, years=years)
    # spei12 = get_spei_dataset(acc_period=12, years=years)
    # spei24 = get_spei_dataset(acc_period=24, years=years)
    # spei48 = get_spei_dataset(acc_period=48, years=years)
    
    # spei_significance = get_spei_significance_dataset(year=2020)
    # spi_significance = get_spi_significance_dataset(year=2020)
    
    drought_dataset = xr.Dataset()

    for key, ds in {
        # 'SPI1': spi1,
        # 'SPI3': spi3,
        # 'SPI6': spi6,
        # 'SPI12': spi12,
        # 'SPI24': spi24,
        # 'SPI48': spi48,
        # 'SPEI1': spei1,
        # 'SPEI3': spei3,
        'SPEI6': spei6,
        # 'SPEI12': spei12,
        # 'SPEI24': spei24,
        # 'SPEI48': spei48,
        # 'SPEI_significance': spei_significance,
        # 'SPI_significance': spi_significance
    }.items():
        for var in ds.data_vars:
            drought_dataset[f"{key}"] = ds[var]
    
    return drought_dataset

In [28]:
drought_dataset = create_drought_dataset(years = [y for y in range(1991, 2024)])

In [29]:
def classify_drought(spei_value):
    if spei_value < -2:
        return 5  # Extreme Drought
    elif -2 <= spei_value < -1.5:
        return 4  # Severe Drought
    elif -1.5 <= spei_value < -1:
        return 3  # Moderate Drought
    elif -1 <= spei_value < -0.5:
        return 2  # Mild Drought
    elif -0.5 <= spei_value <= 0.5:
        return 1  # Normal
    else:
        return 0  # Wet Conditions

In [30]:
from pathlib import Path
import shutil

index = 'SPEI6'

dest_folder = Path().resolve() / 'animation' / index
if os.path.exists(dest_folder):
    shutil.rmtree(dest_folder)

os.makedirs(dest_folder, exist_ok=True)

In [35]:
months = ["January", "February", "March", "April", "May", "June","July", "August", "September", "October", "November", "December"]
years = [y for y in range(1991, 2024)]
map_paths = []
for target_year in years:
    print(target_year)
    for target_month in range(1, 13):
        spei_data = drought_dataset[index].sel(time=f'{target_year}-{target_month:2d}', method='nearest').squeeze()
        classified_spei = xr.apply_ufunc(classify_drought, spei_data, vectorize=True, dask='parallelized')
        
        cmap = ListedColormap(['#5AB1A7', '#B6E3DC', '#EFDDAF', '#CEA053', '#995D12', '#543005'])
        levels = [-0.5, 0.5, 1.5, 2.5, 3.5, 4.5, 5.5]
        
        # Plotting the classified SPEI data on a global map
        plt.figure(figsize=(15, 10))
        ax = plt.axes(projection=ccrs.PlateCarree())
        
        # Plot the classified SPEI data using the custom colormap and levels
        im = classified_spei.plot(ax=ax, transform=ccrs.PlateCarree(), cmap=cmap, levels=levels, 
                             extend='neither',
                             cbar_kwargs={'label': 'Drought Category', 
                                          'ticks': [0, 1, 2, 3, 4, 5],
                                          # 'tick_labels': ['Wet Conditions', 'Normal', 'Mild Drought', 'Moderate Drought', 'Severe Drought', 'Extreme Drought'],
                                          'format': '%d',
                                          'shrink': 0.5
                                         }, 
                             zorder=1)
        
        # Add coastlines, borders, and ocean masking
        ax.coastlines(zorder=3)
        ax.add_feature(cfeature.BORDERS, zorder=4)
        ax.add_feature(cfeature.OCEAN, color='white', zorder=2)
        
        # Add latitude and longitude grid lines
        gl = ax.gridlines(draw_labels=True, linewidth=0.5, color='gray', alpha=0.7, linestyle='--', zorder=5)
        gl.top_labels = False  # Disable top labels
        gl.right_labels = False  # Disable right labels
        gl.xlocator = plt.FixedLocator(range(-180, 181, 30))  # Set longitude grid line intervals
        gl.ylocator = plt.FixedLocator(range(-90, 91, 15))  # Set latitude grid line intervals
        gl.xlabel_style = {'size': 10, 'color': 'black'}
        gl.ylabel_style = {'size': 10, 'color': 'black'}
        
        im.colorbar.set_ticklabels(['Wet Conditions', 'Normal', 'Mild Drought', 'Moderate Drought', 'Severe Drought', 'Extreme Drought'])
        # ax.set_title(f'Global Drought Conditions ({index}) for {months[target_month-1]} {target_year}')
        ax.set_title(f'Global Drought Conditions ({index}) {target_year}')

        map_path = os.path.join(dest_folder, f'{target_year}-{months[target_month-1]}.png')
        # im.save(dest_path, dpi=300, bbox_inches='tight')
        plt.savefig(map_path, dpi=100, bbox_inches='tight')
        
        # Close the plot to avoid memory overload
        plt.close()
        map_paths.append(map_path)
        # plt.show()

1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023


In [None]:
from IPython.display import HTML
import matplotlib.animation as animation
from PIL import Image
import matplotlib as mpl

# Set a larger embed limit, e.g., 30 MB
mpl.rcParams['animation.embed_limit'] = 100  # in MB


months = ["January", "February", "March", "April", "May", "June","July", "August", "September", "October", "November", "December"]
years = [y for y in range(1991, 2024)]
index = 'SPEI6'

dest_folder = Path().resolve() / 'animation' / index

map_paths = []
for target_year in years:
    for target_month in range(1, 13):
        map_path = os.path.join(dest_folder, f'{target_year}-{months[target_month-1]}.png')
        map_paths.append(map_path)
# Open all images as PIL objects

images = [Image.open(map_) for map_ in map_paths]
print('Images loaded...')

fig, ax = plt.subplots(figsize=(16, 10))
ims = []
for idx, image in enumerate(images):
    im = ax.imshow(image, animated=True)
    
    # Remove ticks and borders
    ax.set_xticks([])  # Remove x-axis ticks
    ax.set_yticks([])  # Remove y-axis ticks
    ax.spines['top'].set_visible(False)  # Remove top border
    ax.spines['bottom'].set_visible(False)  # Remove bottom border
    ax.spines['left'].set_visible(False)  # Remove left border
    ax.spines['right'].set_visible(False)  # Remove right border
    
    ims.append([im])
    

# Create the animation
ani = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=1000)
ani.save('animation-22.gif', writer='ffmpeg', dpi=100)
plt.close(fig)

# HTML(ani.to_jshtml())

Images loaded...
