In [1]:
import xarray as xr

def preprocess_tnlwrfcs(filepath):
    ds = xr.open_dataset(filepath)
    var = ds['avg_tnlwrfcs']
    
    # Ensure time is renamed for consistency
    if 'valid_time' in var.dims:
        var = var.rename({'valid_time': 'time'})
    
    # Daily mean (assuming 6-hourly data, so 4 steps/day)
    var_daily = var.resample(time='1D').mean()

    # Climatology (mean across all days)
    climatology = var_daily.mean(dim='time')

    # Daily anomalies
    anomalies = var_daily - climatology
    anomalies.name = 'avg_tnlwrfcs'

    anomalies.to_netcdf('tnlwrfcs_anomalies_2001_2015.nc')
    print("✅ Daily anomalies saved to 'tnlwrfcs_anomalies_2001_2015.nc'")

if __name__ == "__main__":
    preprocess_tnlwrfcs('OLR_data.nc')


✅ Daily anomalies saved to 'tnlwrfcs_anomalies_2001_2015.nc'


In [3]:
import xarray as xr
ds = xr.open_dataset('tnlwrfcs_anomalies_2001_2015.nc')
print(ds)
print(ds['avg_tnlwrfcs'].isel(time=0))  # Check a sample day
print(ds['avg_tnlwrfcs'].mean().item())  # Should not be nan


<xarray.Dataset> Size: 340MB
Dimensions:       (latitude: 129, longitude: 121, time: 5449)
Coordinates:
  * latitude      (latitude) float64 1kB 38.0 37.75 37.5 37.25 ... 6.5 6.25 6.0
  * longitude     (longitude) float64 968B 68.0 68.25 68.5 ... 97.5 97.75 98.0
  * time          (time) datetime64[ns] 44kB 2001-01-01 ... 2015-12-02
    number        int64 8B ...
Data variables:
    avg_tnlwrfcs  (time, latitude, longitude) float32 340MB ...
<xarray.DataArray 'avg_tnlwrfcs' (latitude: 129, longitude: 121)> Size: 62kB
[15609 values with dtype=float32]
Coordinates:
  * latitude   (latitude) float64 1kB 38.0 37.75 37.5 37.25 ... 6.5 6.25 6.0
  * longitude  (longitude) float64 968B 68.0 68.25 68.5 ... 97.5 97.75 98.0
    time       datetime64[ns] 8B 2001-01-01
    number     int64 8B ...
-5.4464104323415086e-05


In [6]:
import xarray as xr
import numpy as np
from eofs.xarray import Eof
from sklearn.cluster import KMeans

def run_eof_clustering():
    data = xr.open_dataset('tnlwrfcs_anomalies_2001_2015.nc')['avg_tnlwrfcs']

    if 'valid_time' in data.dims:
        data = data.rename({'valid_time': 'time'})

    # Drop time steps that are fully NaN
    data = data.dropna(dim='time', how='all')

    # Optional: Print remaining data size
    print("✅ Cleaned data shape:", data.shape)
    print("Remaining NaNs:", np.isnan(data.values).sum())

    solver = Eof(data)
    pcs = solver.pcs(npcs=7, pcscaling=1)

    np.save('pcs_tnlwrfcs.npy', pcs.values)
    variance = solver.varianceFraction().values[:7]
    np.save('explained_variance_tnlwrfcs.npy', variance)

    pcs_np = pcs.values
    pcs_norm = (pcs_np - pcs_np.mean(axis=0)) / pcs_np.std(axis=0)

    kmeans = KMeans(n_clusters=4, n_init=10, random_state=42).fit(pcs_norm)
    labels = kmeans.labels_

    regimes_ds = xr.Dataset({'regime': (['time'], labels)}, coords={'time': data['time']})
    regimes_ds.to_netcdf('eof_weather_regimes_tnlwrfcs.nc')

    print("✅ EOF clustering for tnlwrfcs completed and saved.")

if __name__ == "__main__":
    run_eof_clustering()


✅ Cleaned data shape: (360, 129, 121)
Remaining NaNs: 0
✅ EOF clustering for tnlwrfcs completed and saved.


In [7]:
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns
import cartopy.crs as ccrs
import os
import pandas as pd
from scipy.interpolate import interp1d

def ensure_time(ds):
    if 'valid_time' in ds.dims:
        ds = ds.rename({'valid_time': 'time'})
    return ds

def plot_regime_frequency(ds):
    labels, counts = np.unique(ds['regime'].values, return_counts=True)
    plt.figure(figsize=(7, 5))
    sns.barplot(x=labels, y=counts, palette="deep")
    plt.title("TNLWRFCS Regime Frequency (2001–2015)", fontsize=14)
    plt.xlabel("Regime", fontsize=12)
    plt.ylabel("Days", fontsize=12)
    plt.tight_layout()
    plt.savefig('plots_tnlwrfcs/regime_frequency.png', dpi=300)
    plt.close()

def plot_spatial_composites(ds):
    for r in np.unique(ds['regime'].values):
        with xr.open_dataset('tnlwrfcs_anomalies_2001_2015.nc') as anomalies:
            anomalies = ensure_time(anomalies)
            z = anomalies['avg_tnlwrfcs'].dropna(dim='time', how='all')
            z_regime = z.where(ds['regime'] == r).mean(dim='time')

            plt.figure(figsize=(8, 6))
            ax = plt.axes(projection=ccrs.PlateCarree())
            z_regime.plot.contourf(ax=ax, transform=ccrs.PlateCarree(),
                                   levels=np.arange(-20, 20.1, 2),
                                   cmap='RdBu_r', extend='both', add_colorbar=True)
            ax.coastlines()
            ax.set_title(f"Regime {r}: Mean TNLWRFCS Anomaly", fontsize=14)
            plt.tight_layout()
            plt.savefig(f'plots_tnlwrfcs/regime_{r}_composite.png', dpi=300)
            plt.close()

def plot_pc_timeseries(pcs, nao_index, enso_df):
    with xr.open_dataset('tnlwrfcs_anomalies_2001_2015.nc') as anomalies:
        anomalies = ensure_time(anomalies).dropna(dim='time', how='all')
        times = anomalies['time'].values

    explained_variance = np.load('explained_variance_tnlwrfcs.npy')
    pc1_var = explained_variance[0] * 100

    enso_df['date'] = pd.to_datetime(enso_df[['Year', 'Month']].assign(day=15))
    enso_df = enso_df.set_index('date')
    enso_df_filtered = enso_df.loc["2001-01-01":"2015-12-31"]

    plt.figure(figsize=(10, 5))
    plt.plot(times, pcs[:, 0], label='PC1 (EOF1)', linewidth=1.5)
    plt.plot(nao_index['time'].values, nao_index['nao'].values, label='NAO Index', linewidth=1.5)
    plt.plot(enso_df_filtered.index, enso_df_filtered['Anomaly'], label='ENSO Index (2001–2015)', linewidth=1.5)
    plt.legend()
    plt.title(f"PC1 vs NAO and ENSO Index (2001–2015)\nPC1 Explained Variance = {pc1_var:.2f}%", fontsize=14)
    plt.xlabel("Time", fontsize=12)
    plt.ylabel("Index Value", fontsize=12)
    plt.tight_layout()
    plt.savefig('plots_tnlwrfcs/pc1_vs_enso_nao.png', dpi=300)
    plt.close()

def plot_seasonal_cycle(ds):
    ds = ensure_time(ds)
    times = ds['time'].to_index()
    months = times.month
    regimes = ds['regime'].values
    plt.figure(figsize=(10, 5))
    for r in np.unique(regimes):
        monthly_counts = [np.sum((months == m) & (regimes == r)) for m in range(1, 13)]
        plt.plot(range(1, 13), monthly_counts, label=f'Regime {r}')
    plt.xticks(range(1, 13), ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
                              'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'])
    plt.legend()
    plt.title("Seasonal Cycle of Regime Occurrence", fontsize=14)
    plt.xlabel("Month", fontsize=12)
    plt.ylabel("Days", fontsize=12)
    plt.tight_layout()
    plt.savefig('plots_tnlwrfcs/seasonal_cycle.png', dpi=300)
    plt.close()

def plot_pc_index_correlation(pcs, index_array, label, file_name, time_values=None, index_time=None):
    if time_values is not None and index_array.shape[0] != pcs.shape[0]:
        df = pd.DataFrame(pcs, columns=[f"PC{i+1}" for i in range(pcs.shape[1])])
        df['date'] = pd.to_datetime(time_values)
        df = df.set_index('date')
        pcs_monthly = df.resample('M').mean()

        index_df = pd.Series(index_array, index=index_time)
        index_df = index_df.loc[pcs_monthly.index]

        correlations = [np.corrcoef(pcs_monthly.iloc[:, i], index_df)[0, 1] for i in range(pcs.shape[1])]
    else:
        correlations = [np.corrcoef(pcs[:, i], index_array)[0, 1] for i in range(pcs.shape[1])]

    plt.figure(figsize=(8, 5))
    sns.barplot(x=[f'PC{i+1}' for i in range(7)], y=correlations, palette="deep")
    plt.ylim(-1, 1)
    plt.title(f"Correlation between EOF PCs and {label} (2001–2015)", fontsize=14)
    plt.ylabel("Pearson Correlation", fontsize=12)
    plt.xlabel("Principal Components", fontsize=12)
    plt.tight_layout()
    plt.savefig(f'plots_tnlwrfcs/{file_name}.png', dpi=300)
    plt.close()

def plot_spatial_composites_new(ds):
    pcs = np.load('pcs_tnlwrfcs.npy')
    explained_variance = np.load('explained_variance_tnlwrfcs.npy')
    labels = ds['regime'].values
    explained_variance = explained_variance / explained_variance.sum()

    regime_variance_map = {}
    for r in np.unique(labels):
        regime_pcs = pcs[labels == r]
        pc_var = np.var(regime_pcs, axis=0)
        pc_var_normalized = pc_var / pc_var.sum()
        regime_variance = np.sum(pc_var_normalized * explained_variance)
        regime_variance_map[r] = regime_variance

    for r in np.unique(ds['regime'].values):
        with xr.open_dataset('tnlwrfcs_anomalies_2001_2015.nc') as anomalies:
            anomalies = ensure_time(anomalies).dropna(dim='time', how='all')
            z = anomalies['avg_tnlwrfcs']
            z_regime = z.where(ds['regime'] == r).mean(dim='time')

            plt.figure(figsize=(8, 6))
            ax = plt.axes(projection=ccrs.PlateCarree())
            z_regime.plot.contourf(ax=ax, transform=ccrs.PlateCarree(),
                                   levels=np.arange(-20, 20.1, 2),
                                   cmap='RdBu_r', extend='both', add_colorbar=True)
            ax.coastlines()
            variance_str = f"{regime_variance_map[r] * 100:.2f}%"
            ax.set_title(f"Regime {r}: Mean TNLWRFCS Anomaly\n(Weighted EOF Variance ≈ {variance_str})", fontsize=13)
            plt.tight_layout()
            plt.savefig(f'plots_tnlwrfcs_rand/regime_{r}_composite.png', dpi=300)
            plt.close()

def main():
    os.makedirs("plots_tnlwrfcs", exist_ok=True)
    os.makedirs("plots_tnlwrfcs_rand", exist_ok=True)

    with xr.open_dataset('eof_weather_regimes_tnlwrfcs.nc') as ds:
        ds = ensure_time(ds)
        pcs = np.load('pcs_tnlwrfcs.npy')

        with xr.open_dataset('nao_index.nc') as nao_index:
            enso_df = pd.read_csv('Enso_Monthwise_Index.csv')
            enso_df['date'] = pd.to_datetime(enso_df[['Year', 'Month']].assign(day=1)) + pd.offsets.MonthEnd(0)
            enso_df = enso_df.set_index('date').loc["2001-01-31":"2015-12-31"]
            enso_series = enso_df['Anomaly'].values

            anomalies = xr.open_dataset('tnlwrfcs_anomalies_2001_2015.nc')
            anomalies = ensure_time(anomalies).dropna(dim='time', how='all')
            time_values = anomalies['time'].values

            f_nao = interp1d(nao_index['time'].values.astype(np.int64),
                             nao_index['nao'].values, kind='linear', fill_value="extrapolate")
            nao_interp = f_nao(time_values.astype(np.int64))

            plot_regime_frequency(ds)
            plot_spatial_composites(ds)
            plot_pc_timeseries(pcs, nao_index, enso_df)
            plot_pc_index_correlation(pcs, nao_interp, "NAO Index", "pc_nao_correlation")
            plot_pc_index_correlation(pcs, enso_series, "ENSO Index", "pc_enso_correlation",
                                      time_values=time_values, index_time=enso_df.index)
            plot_seasonal_cycle(ds)
            plot_spatial_composites_new(ds)

if __name__ == "__main__":
    main()



Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(x=labels, y=counts, palette="deep")

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(x=[f'PC{i+1}' for i in range(7)], y=correlations, palette="deep")
  pcs_monthly = df.resample('M').mean()

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(x=[f'PC{i+1}' for i in range(7)], y=correlations, palette="deep")
