In [None]:
import xarray as xa
import xcdat as xc
import numpy as np
from matplotlib import pyplot as plt
import cartopy
from matplotlib import colors as mcolors
from matplotlib import colorbar
from matplotlib import cm

In [None]:
ace_forced = xa.open_dataset('/pscratch/sd/d/duan0000/PMP/demo_output/precip_variability/ACE2-PCMDI/PS_pr.day_regrid.180x90_Aday.amip.nc')
ace_unforced = xa.open_dataset('/pscratch/sd/d/duan0000/PMP/demo_output/precip_variability/ACE2-PCMDI/PS_pr.day_regrid.180x90_Aday.amip_unforced.nc')
ngcm_precip_forced = xa.open_dataset('/pscratch/sd/d/duan0000/PMP/demo_output/precip_variability/NeuralGCM-precip/PS_pr.day_regrid.180x90_Aday.amip.nc')
ngcm_precip_unforced = xa.open_dataset('/pscratch/sd/d/duan0000/PMP/demo_output/precip_variability/NeuralGCM-precip/PS_pr.day_regrid.180x90_Aday.amip_unforced.nc')
ngcm_evap_forced = xa.open_dataset('/pscratch/sd/d/duan0000/PMP/demo_output/precip_variability/NeuralGCM-evap/PS_pr.day_regrid.180x90_Aday.amip.nc')
ngcm_evap_unforced = xa.open_dataset('/pscratch/sd/d/duan0000/PMP/demo_output/precip_variability/NeuralGCM-evap/PS_pr.day_regrid.180x90_Aday.amip_unforced.nc')

In [None]:
gpcp_forced = xa.open_dataset('/pscratch/sd/d/duan0000/PMP/demo_output/precip_variability/GPCP-1-3/PS_pr.day_regrid.180x90_GPCP-1-3.gn.nc')
gpcp_unforced = xa.open_dataset('/pscratch/sd/d/duan0000/PMP/demo_output/precip_variability/GPCP-1-3/PS_pr.day_regrid.180x90_GPCP-1-3.gn_unforced.nc')

In [None]:
frcs = ["forced", "unforced"]

frqs = [
    "annual",
    "semi-annual",
    "diurnal",
    "interannual",
    "seasonal-annual",
    "sub-seasonal",
    "synoptic",
    "sub-daily",
]
frqtlt = [
    "Annual",
    "Semi-annual",
    "Diurnal",
    "Interannual",
    "Seasonal",
    "Sub-seasonal",
    "Synoptic",
    "Sub-daily",
]

In [None]:
doms = ["Total_50S50N", "Total_30N50N", "Total_30S30N", "Total_50S30S"]
domtlts = ["GLOBAL", "NHEX", "TROPICS", "SHEX"]
vars = ["power", "sig95"]
yl = [0.0001, 50]

In [None]:
forced_list = [gpcp_forced, ace_forced, ngcm_precip_forced, ngcm_evap_forced]
unforced_list = [gpcp_unforced, ace_unforced, ngcm_precip_unforced, ngcm_evap_unforced]
colors = ['black', 'tab:blue', 'tab:red', 'tab:green', ]
names = ['GPCP-1-3', 'ACE2', 'NGCM-precip', 'NGCM-evap',]

prdday = np.array([1095, 365, 180, 90, 20, 5, # 1, # 0.5
                  ])
frq3hr = 1.0 / (prdday * 8.0)
frqday = 1.0 / (prdday * 1.0)
for dom in doms:
    fig = plt.figure(figsize=(12, 4))
    if dom=='Total_50S50N':
        lat1, lat2 = -50, 50
        dname = "Global"
    elif dom=='Total_30N50N':
        lat1, lat2 = 30, 50
        dname = "NHEX"
    elif dom=='Total_30S30N':
        lat1, lat2 = -30, 30
        dname = "Tropics"
    elif dom=='Total_50S30S':
        lat1, lat2 = -50, -30
        dname = "SHEX"
    for fc in range(2):
        ax = fig.add_subplot(1, 2, fc+1)
        if fc==0:
            data_list = forced_list
            ptlt = 'Total'
        else:
            data_list = unforced_list
            ptlt = 'Anomaly'
        ax2 = ax.twiny()
        for i in range(4):
            data = data_list[i]
            freqs = np.array(data['freqs'])
            power = data["power"].sel(lat=slice(lat1, lat2))
            weights = np.cos(np.deg2rad(power.lat))
            power_ave = power.weighted(weights).mean(dim=['lat', 'lon'])
            am = np.array(power_ave)
            ax.loglog(freqs, am,lw=1.5, ls='solid', color=colors[i], label=names[i])
            ax.tick_params(axis="x", labelsize=9)
            ax.tick_params(axis="y", labelsize=9)
            ax2.loglog(freqs, [0] * len(freqs), lw=1, ls='solid')
            ax.text(
                0.001,
                0.001*2**i+0.00001,
                names[i],
                color=colors[i],
                fontsize=11,
                ha="left",
            )
        ax2.set_xticks(frqday)
        ax2.set_xticklabels(
            [
                "3yr",
                "1yr",
                "180dy",
                "90dy",
                "20dy",
                "5dy",
                # "1dy",
                # "12hr",
            ],
            fontweight="light",
            rotation="vertical",
            va="top",
        )
        ax2.tick_params(
            axis="x",
            which="both",
            direction="in",
            top=True,
            bottom=False,
            labeltop=True,
            labelbottom=False,
            pad=-5,
            labelsize=11,
        )
        
        ax2.tick_params(axis="x", which="minor", top=False)
        ax.set_xlim(freqs[1], freqs[-1])
        ax.set_ylim(yl[0], yl[1])
        ax2.set_xlim(freqs[1], freqs[-1])
        if fc==1: #unforced
            tco='k'
            ax2.text(frqday[0], 1.1,'Interannual',color=tco,fontsize=11,ha='center')
            ax2.text(frqday[2], 1.1,'Seasonal',color=tco,fontsize=11,ha='center')
            ax2.text(frqday[3], 1.1,'Sub-seasonal',color=tco,fontsize=11,ha='left')
            ax2.text(frqday[5], 1.1,'Synoptic',color=tco,fontsize=11,ha='center')
            # ax2.text(frqday[7], 1.1,'Sub-daily',color=tco,fontsize=7,ha='center')
        ax.set_title(ptlt, loc="left")
        ax.set_xlabel("Frequency [cycle day$^{-1}$]", fontsize=11)
        ax.set_ylabel("Power [mm$^{2}$ day$^{-2}$]", fontsize=11)
    
        ax.axvline(x=frqday[1], c="gray", lw=0.5, ls="dotted")
        ax.axvline(x=frqday[3], c="gray", lw=0.5, ls="dotted")
        ax.axvline(x=frqday[4], c="gray", lw=0.5, ls="dotted")
        # ax.axvline(x=frqday[6], c="gray", lw=0.5, ls="dotted")
        
    plt.tight_layout()  
    fig.suptitle(dname, fontsize=14, y=1.05, weight='bold')
    plt.savefig(f"demo_output/{dname}_frequency_power.png", bbox_inches='tight', dpi=150)
    plt.show()

In [None]:
def prdday_to_frqidx(prdday, frequency, ntd=1):
    """
    Find frequency index from input period
    Input
    - prdday: period (day)
    - frequency: frequency
    - ntd: number of time steps per day (daily: 1, 3-hourly: 8)
    Output
    - idx: frequency index
    """
    frq = 1.0 / (float(prdday) * ntd)
    idx = (np.abs(frequency - frq)).argmin()
    return int(idx)

In [None]:
frqs_forced = ['annual', 'semi-annual', ]
model_names = ['GPCP-1-3', 'ACE2', 'NGCM-precip', 'NGCM-evap',]
data_list = forced_list

In [None]:
ntd = 1
for count, frq in enumerate(frqs_forced):
    print(frq)
    fig = plt.figure(figsize=(12, 6))
    data_list = forced_list
    for i in range(4):
        model = model_names[i]
        data = data_list[i]
        frequency = data['freqs']
        am = data['power'].sel(lat=slice(-50, 50))
        if frq == "semi-annual":  # 180day=<pr=<183day
            idx2 = prdday_to_frqidx(180, frequency, ntd)
            idx1 = prdday_to_frqidx(183, frequency, ntd)
            amfm = np.amax(am[idx1 : idx2 + 1])
            amfm_map = am.isel(frequency=slice(idx1, idx2 + 1)).max(dim="frequency")
            # results.append(amfm)
            # maps[frq] = amfm_map
            if i==0:
                print(idx2, idx1)
        elif frq == "annual":  # 360day=<pr=<366day
            idx2 = prdday_to_frqidx(360, frequency, ntd)
            idx1 = prdday_to_frqidx(366, frequency, ntd)
            amfm = np.amax(am[idx1 : idx2 + 1])
            amfm_map = am.isel(frequency=slice(idx1, idx2 + 1)).max(dim="frequency")
            # results.append(amfm)
            # maps[frq] = amfm_map
            if i==0:
                print(idx2, idx1)
        ax = fig.add_subplot(2, 2, i+1, projection=cartopy.crs.Robinson(central_longitude=180))
        if i==0: # GPCP
            weights = np.cos(np.deg2rad(amfm_map.lat))
            ave = amfm_map.weighted(weights).mean(dim=['lat', 'lon'])
            normalized_map = amfm_map/ave
            v_max = normalized_map.max().data
            print(v_max)
            norm = mcolors.Normalize(vmin=0, vmax=v_max)
            levels = np.linspace(0, v_max, 20)
        else:
            normalized_map = amfm_map/ave
        normalized_map_mean = normalized_map.weighted(weights).mean(dim=['lat', 'lon'])
        # print('normalized ratio: ', normalized_map.weighted(weights).mean(dim=['lat', 'lon']).data)
        # ax.pcolormesh(amfm_map.lon, amfm_map.lat, normalized_map, transform=cartopy.crs.PlateCarree(), cmap='Spectral_r', norm=norm)
        ax.contourf(amfm_map.lon, amfm_map.lat, normalized_map, transform=cartopy.crs.PlateCarree(), cmap='Spectral_r', norm=norm, levels=levels, extend='max')
        ax.set_global()
        ax.coastlines()
        ax.text(
            0.98, 0.99, f"({normalized_map_mean:.2f})",          
            transform=ax.transAxes,                
            ha='right', va='top',                  
            fontsize=10, color='k'
        )
        ax.set_title(model)
        # print(amfm)
    cbar_ax = fig.add_axes([0.92, 0.25, 0.015, 0.5])  # [left, bottom, width, height]
    cbar = fig.colorbar(cm.ScalarMappable(norm=norm, cmap='Spectral_r'), cax=cbar_ax, extend='max')
    cbar.set_label('Normalized variability', fontsize=10)
    fig.savefig(f'demo_output/{frq}.png', bbox_inches='tight', dpi=150)

In [None]:
frqs_unforced = ['synoptic', 'sub-seasonal', 'seasonal-annual', 'interannual']
for count, frq in enumerate(frqs_unforced):
    print(frq)
    fig = plt.figure(figsize=(12, 6))
    data_list = forced_list
    for i in range(4):
        model = model_names[i]
        data = data_list[i]
        frequency = data['freqs']
        am = data['power'].sel(lat=slice(-50, 50))
        if frq == "synoptic":  # 1day=<pr<20day
            idx2 = prdday_to_frqidx(1, frequency, ntd)
            idx1 = prdday_to_frqidx(20, frequency, ntd)
            amfm = np.nanmean(am[idx1 + 1 : idx2 + 1])
            amfm_map = am.isel(frequency=slice(idx1, idx2 + 1)).mean(dim="frequency")
            # maps[frq] = amfm_map
        elif frq == "sub-seasonal":  # 20day=<pr<90day
            idx2 = prdday_to_frqidx(20, frequency, ntd)
            idx1 = prdday_to_frqidx(90, frequency, ntd)
            amfm = np.nanmean(am[idx1 + 1 : idx2 + 1])
            amfm_map = am.isel(frequency=slice(idx1, idx2 + 1)).mean(dim="frequency")
            # maps[frq] = amfm_map
        elif frq == "seasonal-annual":  # 90day=<pr<365day
            idx2 = prdday_to_frqidx(90, frequency, ntd)
            idx1 = prdday_to_frqidx(365, frequency, ntd)
            amfm = np.nanmean(am[idx1 + 1 : idx2 + 1])
            amfm_map = am.isel(frequency=slice(idx1, idx2 + 1)).mean(dim="frequency")
            # maps[frq] = amfm_map
        elif frq == "interannual":  # 365day=<pr
            idx2 = prdday_to_frqidx(365, frequency, ntd)
            amfm = np.nanmean(am[: idx2 + 1])
            amfm_map = am.isel(frequency=slice(0, idx2 + 1)).mean(dim="frequency")
            # maps[frq] = amfm_map
        ax = fig.add_subplot(2, 2, i+1, projection=cartopy.crs.Robinson(central_longitude=180))
        if i==0: # GPCP
            weights = np.cos(np.deg2rad(amfm_map.lat))
            ave = amfm_map.weighted(weights).mean(dim=['lat', 'lon'])
            normalized_map = amfm_map/ave
            v_max = normalized_map.max().data
            print(v_max)
            norm = mcolors.Normalize(vmin=0, vmax=v_max)
            levels = np.linspace(0, v_max, 20)
        else:
            normalized_map = amfm_map/ave
        normalized_map_mean = normalized_map.weighted(weights).mean(dim=['lat', 'lon'])
        # print('normalized ratio: ', normalized_map.weighted(weights).mean(dim=['lat', 'lon']).data)
        # ax.pcolormesh(amfm_map.lon, amfm_map.lat, normalized_map, transform=cartopy.crs.PlateCarree(), cmap='Spectral_r', norm=norm)
        ax.contourf(amfm_map.lon, amfm_map.lat, normalized_map, transform=cartopy.crs.PlateCarree(), cmap='Spectral_r', norm=norm, levels=levels, extend='max')
        ax.set_global()
        ax.coastlines()
        ax.text(
            0.98, 0.99, f"({normalized_map_mean:.2f})",          
            transform=ax.transAxes,                
            ha='right', va='top',                  
            fontsize=10, color='k'
        )
        ax.set_title(model)
        # print(amfm)
    cbar_ax = fig.add_axes([0.92, 0.25, 0.015, 0.5])  # [left, bottom, width, height]
    cbar = fig.colorbar(cm.ScalarMappable(norm=norm, cmap='Spectral_r'), cax=cbar_ax, extend='max')
    cbar.set_label('Normalized variability', fontsize=10)
    fig.savefig(f'demo_output/{frq}.png', bbox_inches='tight', dpi=150)

In [None]:
np.linspace(0, 5.56, 20)

In [None]:
prdday_to_frqidx(360, frequency, ntd)