In [None]:
!pip install hmmlearn
!pip install -U kaleido




In [None]:
import pandas as pd
import numpy as np

index = pd.read_excel('ENSO.xlsx')
anomaly = pd.read_csv('WWVcleaned.csv')
print(index)
print(anomaly)

    Year  DJF  JFM  FMA  MAM  AMJ  MJJ  JJA  JAS  ASO  SON  OND  NDJ
0   1980  0.6  0.5  0.3  0.4  0.5  0.5  0.3  0.0 -0.1  0.0  0.1  0.0
1   1981 -0.3 -0.5 -0.5 -0.4 -0.3 -0.3 -0.3 -0.2 -0.2 -0.1 -0.2 -0.1
2   1982  0.0  0.1  0.2  0.5  0.7  0.7  0.8  1.1  1.6  2.0  2.2  2.2
3   1983  2.2  1.9  1.5  1.3  1.1  0.7  0.3 -0.1 -0.5 -0.8 -1.0 -0.9
4   1984 -0.6 -0.4 -0.3 -0.4 -0.5 -0.4 -0.3 -0.2 -0.2 -0.6 -0.9 -1.1
5   1985 -1.0 -0.8 -0.8 -0.8 -0.8 -0.6 -0.5 -0.5 -0.4 -0.3 -0.3 -0.4
6   1986 -0.5 -0.5 -0.3 -0.2 -0.1  0.0  0.2  0.4  0.7  0.9  1.1  1.2
7   1987  1.2  1.2  1.1  0.9  1.0  1.2  1.5  1.7  1.6  1.5  1.3  1.1
8   1988  0.8  0.5  0.1 -0.3 -0.9 -1.3 -1.3 -1.1 -1.2 -1.5 -1.8 -1.8
9   1989 -1.7 -1.4 -1.1 -0.8 -0.6 -0.4 -0.3 -0.3 -0.2 -0.2 -0.2 -0.1
10  1990  0.1  0.2  0.3  0.3  0.3  0.3  0.3  0.4  0.4  0.3  0.4  0.4
11  1991  0.4  0.3  0.2  0.3  0.5  0.6  0.7  0.6  0.6  0.8  1.2  1.5
12  1992  1.7  1.6  1.5  1.3  1.1  0.7  0.4  0.1 -0.1 -0.2 -0.3 -0.1
13  1993  0.1  0.3  0.5  0.7  0.7 

In [None]:

import pandas as pd
import numpy as np
from hmmlearn import hmm
from sklearn.metrics import accuracy_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import os
from google.colab import files

# Create output directory
output_dir = 'hmm_outputs'
os.makedirs(output_dir, exist_ok=True)


# Define periods
periods = ['DJF', 'JFM', 'FMA', 'MAM', 'AMJ', 'MJJ', 'JJA', 'JAS', 'ASO', 'SON', 'OND', 'NDJ']

# Correlation coefficients from your analysis
corr_coeffs = {
    'MJJ': -0.02436115975109819,
    'OND': 0.4120160038567676,
    'SON': 0.35330147648830007,
    'ASO': 0.2759602399103643,
    'JAS': 0.18994442388434107,
    'JJA': 0.13703542153643897
}

# Step 2: Define hidden states based on ONI thresholds
def classify_enso(oni):
    if oni >= 0.5:
        return 0  # El Niño
    elif oni <= -0.5:
        return 1  # La Niña
    else:
        return 2  # Neutral

# Step 3: Train HMM for each period
results = {}
state_names = {0: 'El Niño', 1: 'La Niña', 2: 'Neutral'}

for period in periods:
    print(f"\nProcessing period: {period}")

    # Prepare data
    data = pd.DataFrame({
        'Year': index['Year'],
        f'ONI_{period}': index[period],
        f'WWV_{period}': anomaly[period]
    }).dropna()

    # Classify states
    data['State'] = data[f'ONI_{period}'].apply(classify_enso)
    state_sequence = data['State'].values

    # Prepare observations
    observations = data[[f'ONI_{period}', f'WWV_{period}']].values

    # Initialize and train HMM
    n_states = 3
    model = hmm.GaussianHMM(n_components=n_states, covariance_type="full", n_iter=100, random_state=42)
    try:
        model.fit(observations)
    except Exception as e:
        print(f"Warning: HMM fitting failed for {period}: {e}")
        continue

    # Extract parameters
    transition_matrix = model.transmat_
    initial_probs = model.startprob_
    emission_means = model.means_
    emission_covars = model.covars_

    # Predict states
    predicted_states = model.predict(observations)
    data['Predicted_State'] = predicted_states
    data['Predicted_State_Name'] = data['Predicted_State'].map(state_names)

    # Evaluate performance
    accuracy = accuracy_score(data['State'], data['Predicted_State'])
    cm = confusion_matrix(data['State'], data['Predicted_State'], labels=[0, 1, 2])

    # Store results
    results[period] = {
        'data': data,
        'transition_matrix': transition_matrix,
        'initial_probs': initial_probs,
        'emission_means': emission_means,
        'emission_covars': emission_covars,
        'accuracy': accuracy,
        'confusion_matrix': cm
    }

    # Save parameters
    np.save(f'hmm_outputs/transition_matrix_{period}.npy', transition_matrix)
    np.save(f'hmm_outputs/initial_probs_{period}.npy', initial_probs)
    np.save(f'hmm_outputs/emission_means_{period}.npy', emission_means)
    np.save(f'hmm_outputs/emission_covars_{period}.npy', emission_covars)
    data.to_csv(f'hmm_outputs/predictions_{period}.csv', index=False)

    # Print results
    print(f"Transition Matrix for {period}:")
    print(transition_matrix)
    print(f"\nInitial Probabilities for {period}:")
    print(initial_probs)
    print(f"\nEmission Means for {period} (ONI, WWV):")
    print(emission_means)
    print(f"\nEmission Covariances for {period}:")
    for i in range(n_states):
        print(f"State {i}:")
        print(emission_covars[i])
    print(f"\nAccuracy for {period}: {accuracy:.4f}")
    print(f"Confusion Matrix for {period}:\n{cm}")

    # Plot predictions
    plt.figure(figsize=(12, 6))
    plt.plot(data['Year'], data[f'ONI_{period}'], label=f'ONI ({period})', color='blue')
    plt.plot(data['Year'], data[f'WWV_{period}'], label=f'WWV ({period})', color='red')
    for state in range(n_states):
        mask = data['Predicted_State'] == state
        plt.scatter(data['Year'][mask], data[f'ONI_{period}'][mask], label=f'Predicted {state_names[state]}', s=50)
    plt.xlabel('Year')
    plt.ylabel('Value')
    plt.title(f'HMM Predicted ENSO States for {period}')
    plt.legend()
    plt.savefig(f'hmm_outputs/enso_hmm_states_{period}.png')
    plt.close()

# Step 4: Detailed Analysis
print("\n=== Detailed Analysis ===")

# Summary of accuracies
accuracies = {period: results[period]['accuracy'] for period in results}
print("\nModel Accuracies by Period:")
for period, acc in accuracies.items():
    corr = corr_coeffs.get(period, 'N/A')
    print(f"{period}: Accuracy = {acc:.4f}, Correlation with NDJ WWV = {corr}")

# Identify best period
best_period = max(accuracies, key=accuracies.get)
print(f"\nBest Performing Period: {best_period} (Accuracy = {accuracies[best_period]:.4f})")

# Analyze confusion matrices
print("\nConfusion Matrix Analysis:")
for period in results:
    cm = results[period]['confusion_matrix']
    print(f"\n{period}:")
    print("Rows: Actual (El Niño, La Niña, Neutral)")
    print("Columns: Predicted (El Niño, La Niña, Neutral)")
    print(cm)
    # Calculate per-class precision
    precision = np.diag(cm) / np.sum(cm, axis=0)
    print(f"Precision (El Niño, La Niña, Neutral): {[f'{p:.4f}' if not np.isnan(p) else 'N/A' for p in precision]}")

# Correlation vs. Accuracy
print("\nCorrelation vs. Accuracy Analysis:")
for period in corr_coeffs:
    acc = accuracies.get(period, 'N/A')
    print(f"{period}: Correlation = {corr_coeffs[period]:.4f}, Accuracy = {acc if acc != 'N/A' else 'N/A'}")

# Save summary
with open('hmm_outputs/summary.txt', 'w') as f:
    f.write("=== HMM Summary ===\n")
    f.write("\nModel Accuracies by Period:\n")
    for period, acc in accuracies.items():
        corr = corr_coeffs.get(period, 'N/A')
        f.write(f"{period}: Accuracy = {acc:.4f}, Correlation with NDJ WWV = {corr}\n")
    f.write(f"\nBest Performing Period: {best_period} (Accuracy = {accuracies[best_period]:.4f})\n")
    f.write("\nConfusion Matrix Analysis:\n")
    for period in results:
        cm = results[period]['confusion_matrix']
        f.write(f"\n{period}:\n")
        f.write("Rows: Actual (El Niño, La Niña, Neutral)\n")
        f.write("Columns: Predicted (El Niño, La Niña, Neutral)\n")
        f.write(f"{cm}\n")
        precision = np.diag(cm) / np.sum(cm, axis=0)
        f.write(f"Precision (El Niño, La Niña, Neutral): {[f'{p:.4f}' if not np.isnan(p) else 'N/A' for p in precision]}\n")

# Download outputs (optional)
print("\nDownloading output files...")
for fname in os.listdir('hmm_outputs'):
    files.download(f'hmm_outputs/{fname}')

print("\nSummary saved to 'hmm_outputs/summary.txt'")
print("Model parameters, predictions, and plots saved in 'hmm_outputs/' directory")


Processing period: DJF
Transition Matrix for DJF:
[[2.26744154e-09 9.99999973e-01 2.44387261e-08]
 [6.82676952e-01 1.24073756e-05 3.17310641e-01]
 [1.58741991e-01 8.41258009e-01 1.30202873e-13]]

Initial Probabilities for DJF:
[1.00000000e+00 1.02702201e-37 1.01885117e-29]

Emission Means for DJF (ONI, WWV):
[[-0.64373172  0.24779321]
 [-0.109633    0.15579874]
 [ 1.68340366 -0.08271336]]

Emission Covariances for DJF:
State 0:
[[0.50879039 0.24021446]
 [0.24021446 0.24813871]]
State 1:
[[0.65004734 0.20579118]
 [0.20579118 1.12304139]]
State 2:
[[ 0.41306414 -0.19132983]
 [-0.19132983  0.47330649]]

Accuracy for DJF: 0.2326
Confusion Matrix for DJF:
[[1 7 6]
 [9 9 0]
 [6 5 0]]

Processing period: JFM
Transition Matrix for JFM:
[[3.28218730e-01 6.71781175e-01 9.45387286e-08]
 [1.04512581e-01 3.64167345e-01 5.31320074e-01]
 [6.15147365e-01 1.10948181e-09 3.84852633e-01]]

Initial Probabilities for JFM:
[5.39212453e-15 1.00000000e+00 2.37032811e-19]

Emission Means for JFM (ONI, WWV):
[

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>


Summary saved to 'hmm_outputs/summary.txt'
Model parameters, predictions, and plots saved in 'hmm_outputs/' directory


In [None]:
import pandas as pd
import numpy as np
from hmmlearn import hmm
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import seaborn as sns
import os
from scipy.stats import multivariate_normal
from google.colab import files
import warnings
warnings.filterwarnings('ignore')

# Set random seed for reproducibility
np.random.seed(42)

# Create output directory
output_dir = 'hmm_outputs'
os.makedirs(output_dir, exist_ok=True)

# Define periods and correlation coefficients
periods = ['DJF', 'JFM', 'FMA', 'MAM', 'AMJ', 'MJJ', 'JJA', 'JAS', 'ASO', 'SON', 'OND', 'NDJ']
corr_coeffs = {
    'MJJ': -0.02436115975109819, 'OND': 0.4120160038567676, 'SON': 0.35330147648830007,
    'ASO': 0.2759602399103643, 'JAS': 0.18994442388434107, 'JJA': 0.13703542153643897
}
state_names = {0: 'El Niño', 1: 'La Niña', 2: 'Neutral'}

# 1. Classify ENSO States
def classify_enso(oni):
    """Classify ENSO state based on ONI thresholds."""
    if oni >= 0.5:
        return 0  # El Niño
    elif oni <= -0.5:
        return 1  # La Niña
    else:
        return 2  # Neutral

# 2. Train HMM
def train_hmm(observations, n_states=3, n_iter=100):
    """Train HMM with Gaussian emissions."""
    model = hmm.GaussianHMM(n_components=n_states, covariance_type='full', n_iter=n_iter, random_state=42)
    try:
        model.fit(observations)
        states = model.predict(observations)
        state_probs = model.predict_proba(observations)
        return model, states, state_probs
    except Exception as e:
        print(f"HMM fitting failed: {e}")
        return None, None, None

# 3. Analyze Transition Matrix
def analyze_transition_matrix(transmat, period):
    """Analyze stability and volatility of transition matrix."""
    stability = np.diag(transmat)  # Probability of staying in the same state
    volatility = 1 - stability  # Probability of switching states
    return {
        'stability': dict(zip(state_names.values(), stability)),
        'volatility': dict(zip(state_names.values(), volatility)),
        'most_likely_transitions': {
            state: state_names[np.argmax(transmat[i, :])]
            for i, state in enumerate(state_names.values())
        }
    }

# 4. Analyze Emission Properties
def analyze_emission_properties(means, covars):
    """Characterize statistical properties of emission distributions."""
    properties = {}
    for i, state in enumerate(state_names.values()):
        properties[state] = {
            'sst_mean': means[i, 0],
            'wwv_mean': means[i, 1],
            'sst_variance': covars[i][0, 0],
            'wwv_variance': covars[i][1, 1],
            'sst_wwv_covariance': covars[i][0, 1]
        }
    return properties

# 5. Change-Point Detection
def detect_change_points(states, years):
    """Detect regime shifts based on state transitions."""
    transitions = np.where(states[:-1] != states[1:])[0] + 1
    change_points = [(years[i], state_names[states[i-1]], state_names[states[i]]) for i in transitions]
    return change_points

# 6. Cluster Years by Emission Characteristics
def cluster_years(observations, states, years, n_clusters=3):
    """Group years with similar emission characteristics."""
    kmeans = KMeans(n_clusters=n_clusters, random_state=42)
    clusters = kmeans.fit_predict(observations)
    cluster_data = pd.DataFrame({
        'Year': years,
        'Cluster': clusters,
        'State': [state_names[s] for s in states]
    })
    return cluster_data

# 7. Visualization Functions
def plot_time_series(data, period, output_dir):
    """Plot ONI and WWV with predicted states."""
    plt.figure(figsize=(12, 6))
    plt.plot(data['Year'], data[f'ONI_{period}'], label=f'ONI ({period})', color='blue')
    plt.plot(data['Year'], data[f'WWV_{period}'], label=f'WWV ({period})', color='red')
    for state in range(3):
        mask = data['Predicted_State'] == state
        plt.scatter(data['Year'][mask], data[f'ONI_{period}'][mask], label=f'Predicted {state_names[state]}', s=50)
    plt.xlabel('Year')
    plt.ylabel('Value')
    plt.title(f'HMM Predicted ENSO States for {period}')
    plt.legend()
    plt.savefig(os.path.join(output_dir, f'enso_hmm_states_{period}.png'), dpi=300)
    plt.close()

def plot_posterior_probabilities(years, state_probs, period, output_dir):
    """Plot posterior probabilities of each state."""
    plt.figure(figsize=(12, 4))
    for i, state in enumerate(state_names.values()):
        plt.plot(years, state_probs[:, i], label=state)
    plt.title(f'{period} Posterior State Probabilities')
    plt.xlabel('Year')
    plt.ylabel('Probability')
    plt.legend()
    plt.savefig(os.path.join(output_dir, f'state_probs_{period}.png'), dpi=300)
    plt.close()

def plot_confusion_matrix(cm, period, output_dir):
    """Plot confusion matrix."""
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=state_names.values(), yticklabels=state_names.values())
    plt.title(f'{period} Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.savefig(os.path.join(output_dir, f'confusion_matrix_{period}.png'), dpi=300)
    plt.close()

def plot_emission_distributions(means, covars, period, output_dir):
    """Plot emission distributions for each state."""
    x = np.linspace(-3, 3, 100)
    y = np.linspace(-3, 3, 100)
    X, Y = np.meshgrid(x, y)
    pos = np.dstack((X, Y))

    plt.figure(figsize=(8, 6))
    for i, state in enumerate(state_names.values()):
        rv = multivariate_normal(mean=means[i], cov=covars[i])
        Z = rv.pdf(pos)
        plt.contour(X, Y, Z, levels=5, label=state)
    plt.title(f'{period} Emission Distributions')
    plt.xlabel('ONI (°C)')
    plt.ylabel('WWV Anomaly')
    plt.legend(list(state_names.values()))
    plt.savefig(os.path.join(output_dir, f'emission_dist_{period}.png'), dpi=300)
    plt.close()

# 8. Main Processing Function
def process_period(period, index, anomaly, output_dir):
    """Process HMM for a single period."""
    print(f"\nProcessing period: {period}")

    # Prepare data
    data = pd.DataFrame({
        'Year': index['Year'],
        f'ONI_{period}': index[period],
        f'WWV_{period}': anomaly[period]
    }).dropna()

    if len(data) < 10:
        print(f"Skipping {period} due to insufficient data")
        return None

    # Classify true states
    data['State'] = data[f'ONI_{period}'].apply(classify_enso)
    observations = data[[f'ONI_{period}', f'WWV_{period}']].values

    # Train HMM
    model, predicted_states, state_probs = train_hmm(observations)
    if model is None:
        return None

    data['Predicted_State'] = predicted_states
    data['Predicted_State_Name'] = data['Predicted_State'].map(state_names)

    # Evaluate
    accuracy = accuracy_score(data['State'], data['Predicted_State'])
    cm = confusion_matrix(data['State'], data['Predicted_State'], labels=[0, 1, 2])

    # Analyze transition matrix
    trans_analysis = analyze_transition_matrix(model.transmat_, period)

    # Analyze emission properties
    emission_analysis = analyze_emission_properties(model.means_, model.covars_)

    # Detect change points
    change_points = detect_change_points(predicted_states, data['Year'].values)

    # Cluster years
    cluster_data = cluster_years(observations, predicted_states, data['Year'].values)

    # Save outputs
    np.save(os.path.join(output_dir, f'transition_matrix_{period}.npy'), model.transmat_)
    np.save(os.path.join(output_dir, f'initial_probs_{period}.npy'), model.startprob_)
    np.save(os.path.join(output_dir, f'emission_means_{period}.npy'), model.means_)
    np.save(os.path.join(output_dir, f'emission_covars_{period}.npy'), model.covars_)
    data.to_csv(os.path.join(output_dir, f'predictions_{period}.csv'), index=False)
    cluster_data.to_csv(os.path.join(output_dir, f'clusters_{period}.csv'), index=False)

    # Visualizations
    plot_time_series(data, period, output_dir)
    plot_posterior_probabilities(data['Year'], state_probs, period, output_dir)
    plot_confusion_matrix(cm, period, output_dir)
    plot_emission_distributions(model.means_, model.covars_, period, output_dir)

    return {
        'data': data,
        'transition_matrix': model.transmat_,
        'initial_probs': model.startprob_,
        'emission_means': model.means_,
        'emission_covars': model.covars_,
        'accuracy': accuracy,
        'confusion_matrix': cm,
        'transition_analysis': trans_analysis,
        'emission_analysis': emission_analysis,
        'change_points': change_points,
        'cluster_data': cluster_data
    }

# 9. Main Execution
def main(index, anomaly):
    results = {}
    for period in periods:
        result = process_period(period, index, anomaly, output_dir)
        if result:
            results[period] = result

    # Detailed Analysis
    print("\n=== Detailed Analysis ===")

    # Accuracies and correlations
    accuracies = {p: r['accuracy'] for p, r in results.items()}
    print("\nModel Accuracies by Period:")
    for period, acc in accuracies.items():
        corr = corr_coeffs.get(period, 'N/A')
        print(f"{period}: Accuracy = {acc:.4f}, Correlation with NDJ WWV = {corr}")

    best_period = max(accuracies, key=accuracies.get)
    print(f"\nBest Performing Period: {best_period} (Accuracy = {accuracies[best_period]:.4f})")

    # Transition analysis
    print("\nTransition Matrix Analysis:")
    for period, res in results.items():
        print(f"\n{period}:")
        print(f"Stability: {res['transition_analysis']['stability']}")
        print(f"Volatility: {res['transition_analysis']['volatility']}")
        print(f"Most Likely Transitions: {res['transition_analysis']['most_likely_transitions']}")

    # Emission properties
    print("\nEmission Properties Analysis:")
    for period, res in results.items():
        print(f"\n{period}:")
        for state, props in res['emission_analysis'].items():
            print(f"{state}: SST Mean = {props['sst_mean']:.4f}, WWV Mean = {props['wwv_mean']:.4f}, "
                  f"SST Variance = {props['sst_variance']:.4f}, WWV Variance = {props['wwv_variance']:.4f}, "
                  f"Covariance = {props['sst_wwv_covariance']:.4f}")

    # Change points
    print("\nChange-Point Detection:")
    for period, res in results.items():
        print(f"\n{period}:")
        for year, from_state, to_state in res['change_points']:
            print(f"Year {year}: {from_state} -> {to_state}")

    # Clustering
    print("\nClustering Analysis:")
    for period, res in results.items():
        print(f"\n{period}:")
        print(res['cluster_data'].groupby('Cluster').agg({
            'Year': 'count',
            'State': lambda x: x.value_counts().to_dict()
        }).rename(columns={'Year': 'Count', 'State': 'State Distribution'}))

    # Save summary
    with open(os.path.join(output_dir, 'summary.txt'), 'w') as f:
        f.write("=== HMM Summary ===\n")
        f.write("\nModel Accuracies by Period:\n")
        for period, acc in accuracies.items():
            corr = corr_coeffs.get(period, 'N/A')
            f.write(f"{period}: Accuracy = {acc:.4f}, Correlation with NDJ WWV = {corr}\n")
        f.write(f"\nBest Performing Period: {best_period} (Accuracy = {accuracies[best_period]:.4f})\n")

        f.write("\nTransition Matrix Analysis:\n")
        for period, res in results.items():
            f.write(f"\n{period}:\n")
            f.write(f"Stability: {res['transition_analysis']['stability']}\n")
            f.write(f"Volatility: {res['transition_analysis']['volatility']}\n")
            f.write(f"Most Likely Transitions: {res['transition_analysis']['most_likely_transitions']}\n")

        f.write("\nEmission Properties Analysis:\n")
        for period, res in results.items():
            f.write(f"\n{period}:\n")
            for state, props in res['emission_analysis'].items():
                f.write(f"{state}: SST Mean = {props['sst_mean']:.4f}, WWV Mean = {props['wwv_mean']:.4f}, "
                      f"SST Variance = {props['sst_variance']:.4f}, WWV Variance = {props['wwv_variance']:.4f}, "
                      f"Covariance = {props['sst_wwv_covariance']:.4f}\n")

        f.write("\nChange-Point Detection:\n")
        for period, res in results.items():
            f.write(f"\n{period}:\n")
            for year, from_state, to_state in res['change_points']:
                f.write(f"Year {year}: {from_state} -> {to_state}\n")

        f.write("\nClustering Analysis:\n")
        for period, res in results.items():
            f.write(f"\n{period}:\n")
            f.write(str(res['cluster_data'].groupby('Cluster').agg({
                'Year': 'count',
                'State': lambda x: x.value_counts().to_dict()
            }).rename(columns={'Year': 'Count', 'State': 'State Distribution'})) + "\n")

    # Download outputs
    print("\nDownloading output files...")
    for fname in os.listdir(output_dir):
        files.download(os.path.join(output_dir, fname))

    print("\nSummary saved to 'hmm_outputs/summary.txt'")
    print("Model parameters, predictions, and plots saved in 'hmm_outputs/' directory")

if __name__ == '__main__':
    main(index, anomaly)


Processing period: DJF

Processing period: JFM

Processing period: FMA

Processing period: MAM

Processing period: AMJ

Processing period: MJJ

Processing period: JJA

Processing period: JAS

Processing period: ASO

Processing period: SON

Processing period: OND

Processing period: NDJ

=== Detailed Analysis ===

Model Accuracies by Period:
DJF: Accuracy = 0.2326, Correlation with NDJ WWV = N/A
JFM: Accuracy = 0.1364, Correlation with NDJ WWV = N/A
FMA: Accuracy = 0.1818, Correlation with NDJ WWV = N/A
MAM: Accuracy = 0.4884, Correlation with NDJ WWV = N/A
AMJ: Accuracy = 0.3023, Correlation with NDJ WWV = N/A
MJJ: Accuracy = 0.4651, Correlation with NDJ WWV = -0.02436115975109819
JJA: Accuracy = 0.4884, Correlation with NDJ WWV = 0.13703542153643897
JAS: Accuracy = 0.4419, Correlation with NDJ WWV = 0.18994442388434107
ASO: Accuracy = 0.5116, Correlation with NDJ WWV = 0.2759602399103643
SON: Accuracy = 0.6047, Correlation with NDJ WWV = 0.35330147648830007
OND: Accuracy = 0.0930, Co

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>


Summary saved to 'hmm_outputs/summary.txt'
Model parameters, predictions, and plots saved in 'hmm_outputs/' directory


In [None]:

import pandas as pd
import numpy as np
from hmmlearn import hmm
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import seaborn as sns
import os
from scipy.stats import multivariate_normal
from google.colab import files
import warnings
import pkg_resources


# Set random seed for reproducibility
np.random.seed(42)

# Create output directory
output_dir = 'hmm_outputs'
os.makedirs(output_dir, exist_ok=True)

# Define periods and correlation coefficients
periods = ['DJF', 'JFM', 'FMA', 'MAM', 'AMJ', 'MJJ', 'JJA', 'JAS', 'ASO', 'SON', 'OND', 'NDJ']
corr_coeffs = {
    'MJJ': -0.02436115975109819, 'OND': 0.4120160038567676, 'SON': 0.35330147648830007,
    'ASO': 0.2759602399103643, 'JAS': 0.18994442388434107, 'JJA': 0.13703542153643897
}
state_names = {0: 'El Niño', 1: 'La Niña', 2: 'Neutral'}

# 1. Classify ENSO States
def classify_enso(oni):
    """Classify ENSO state based on ONI thresholds."""
    if oni >= 0.5:
        return 0  # El Niño
    elif oni <= -0.5:
        return 1  # La Niña
    else:
        return 2  # Neutral

# 2. Align HMM States
def align_states(model, observations):
    """
    Reassign HMM states based on ONI means to match ENSO phases.
    Returns aligned model and updated state names.
    """
    if model.means_ is None:
        print("Warning: Model has no means_. Skipping alignment.")
        return model, state_names

    # Get ONI means (first column of means_)
    oni_means = model.means_[:, 0]

    # Define desired order: El Niño (highest ONI), La Niña (lowest), Neutral (middle)
    sorted_indices = np.argsort(oni_means)[::-1]  # Descending order
    new_order = np.zeros(3, dtype=int)
    new_order[0] = sorted_indices[0]  # El Niño: highest ONI
    new_order[1] = sorted_indices[2]  # La Niña: lowest ONI
    new_order[2] = sorted_indices[1]  # Neutral: middle ONI

    # Copy parameters to new model
    aligned_model = hmm.GaussianHMM(n_components=model.n_components, covariance_type='full')
    aligned_model.startprob_ = model.startprob_[new_order]
    aligned_model.transmat_ = model.transmat_[new_order][:, new_order]
    aligned_model.means_ = model.means_[new_order]
    aligned_model.covars_ = model.covars_[new_order]

    # Regularize covariances to prevent singularity
    for i in range(model.n_components):
        aligned_model.covars_[i] += np.eye(2) * 1e-6

    # Set n_features based on observations
    aligned_model.n_features = observations.shape[1]

    # Fit the aligned model to ensure internal consistency
    try:
        aligned_model.fit(observations)
    except Exception as e:
        print(f"Warning: Failed to fit aligned model: {e}")

    # Verify ONI means
    aligned_oni_means = aligned_model.means_[:, 0]
    print(f"Aligned ONI Means: El Niño={aligned_oni_means[0]:.4f}, La Niña={aligned_oni_means[1]:.4f}, Neutral={aligned_oni_means[2]:.4f}")

    return aligned_model, state_names

# 3. Train HMM
def train_hmm(observations, n_states=3, max_iter=200):
    """Train HMM with Gaussian emissions, ensuring convergence."""
    print(f"Training HMM with observations shape: {observations.shape}")
    if observations.shape[0] < n_states or observations.shape[1] != 2:
        print("Error: Invalid observations shape or insufficient data.")
        return None, None, None

    model = hmm.GaussianHMM(n_components=n_states, covariance_type='full', n_iter=max_iter, random_state=42)
    try:
        model.fit(observations)
        if not model.monitor_.converged:
            print("Warning: HMM did not converge. Consider increasing max_iter or checking data.")

        # Align states
        model, aligned_state_names = align_states(model, observations)

        # Predict with aligned model
        states = model.predict(observations)
        state_probs = model.predict_proba(observations)
        return model, states, state_probs
    except Exception as e:
        print(f"HMM fitting failed: {e}")
        return None, None, None

# 4. Analyze Transition Matrix
def analyze_transition_matrix(transmat, period):
    """Analyze stability and volatility of transition matrix."""
    stability = np.diag(transmat)
    volatility = 1 - stability
    return {
        'stability': dict(zip(state_names.values(), stability)),
        'volatility': dict(zip(state_names.values(), volatility)),
        'most_likely_transitions': {
            state: state_names[np.argmax(transmat[i, :])]
            for i, state in enumerate(state_names.values())
        }
    }

# 5. Analyze Emission Properties
def analyze_emission_properties(means, covars):
    """Characterize statistical properties of emission distributions."""
    properties = {}
    for i, state in enumerate(state_names.values()):
        properties[state] = {
            'sst_mean': means[i, 0],
            'wwv_mean': means[i, 1],
            'sst_variance': covars[i][0, 0],
            'wwv_variance': covars[i][1, 1],
            'sst_wwv_covariance': covars[i][0, 1]
        }
    return properties

# 6. Change-Point Detection
def detect_change_points(states, years):
    """Detect regime shifts based on state transitions."""
    transitions = np.where(states[:-1] != states[1:])[0] + 1
    change_points = [(years[i], state_names[states[i-1]], state_names[states[i]]) for i in transitions]
    return change_points

# 7. Cluster Years
def cluster_years(observations, states, years, n_clusters=3):
    """Group years with similar emission characteristics."""
    kmeans = KMeans(n_clusters=n_clusters, random_state=42)
    clusters = kmeans.fit_predict(observations)
    cluster_data = pd.DataFrame({
        'Year': years,
        'Cluster': clusters,
        'State': [state_names[s] for s in states]
    })
    return cluster_data

# 8. Visualization Functions
def plot_time_series(data, period, output_dir):
    """Plot ONI and WWV with true and predicted states."""
    fig, ax1 = plt.subplots(figsize=(14, 7))
    ax1.plot(data['Year'], data[f'ONI_{period}'], label=f'ONI ({period})', color='blue')
    ax1.set_xlabel('Year')
    ax1.set_ylabel('ONI (°C)', color='blue')
    ax1.tick_params(axis='y', labelcolor='blue')

    ax2 = ax1.twinx()
    ax2.plot(data['Year'], data[f'WWV_{period}'], label=f'WWV ({period})', color='red')
    ax2.set_ylabel('WWV Anomaly', color='red')
    ax2.tick_params(axis='y', labelcolor='red')

    # Plot true and predicted states
    for state in range(3):
        mask_true = data['State'] == state
        mask_pred = data['Predicted_State'] == state
        ax1.scatter(data['Year'][mask_true], data[f'ONI_{period}'][mask_true],
                   label=f'True {state_names[state]}', marker='o', s=50, alpha=0.6)
        ax1.scatter(data['Year'][mask_pred], data[f'ONI_{period}'][mask_pred],
                   label=f'Predicted {state_names[state]}', marker='x', s=50, alpha=0.6)

    plt.title(f'HMM Predicted vs True ENSO States for {period}')
    fig.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05), ncol=5)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f'enso_hmm_states_{period}.png'), dpi=300, bbox_inches='tight')
    plt.close()

def plot_posterior_probabilities(years, state_probs, period, output_dir):
    """Plot posterior probabilities of each state."""
    plt.figure(figsize=(12, 4))
    for i, state in enumerate(state_names.values()):
        plt.plot(years, state_probs[:, i], label=state)
    plt.title(f'{period} Posterior State Probabilities')
    plt.xlabel('Year')
    plt.ylabel('Probability')
    plt.legend()
    plt.savefig(os.path.join(output_dir, f'state_probs_{period}.png'), dpi=300)
    plt.close()

def plot_confusion_matrix(cm, period, output_dir):
    """Plot confusion matrix."""
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=state_names.values(), yticklabels=state_names.values())
    plt.title(f'{period} Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.savefig(os.path.join(output_dir, f'confusion_matrix_{period}.png'), dpi=300)
    plt.close()

def plot_emission_distributions(means, covars, period, output_dir):
    """Plot emission distributions for each state."""
    x = np.linspace(-3, 3, 100)
    y = np.linspace(-3, 3, 100)
    X, Y = np.meshgrid(x, y)
    pos = np.dstack((X, Y))

    plt.figure(figsize=(8, 6))
    for i, state in enumerate(state_names.values()):
        try:
            rv = multivariate_normal(mean=means[i], cov=covars[i])
            Z = rv.pdf(pos)
            plt.contour(X, Y, Z, levels=5, label=state)
        except Exception as e:
            print(f"Warning: Failed to plot emission distribution for {state} in {period}: {e}")
    plt.title(f'{period} Emission Distributions')
    plt.xlabel('ONI (°C)')
    plt.ylabel('WWV Anomaly')
    plt.legend(list(state_names.values()), loc='upper right')
    plt.savefig(os.path.join(output_dir, f'emission_dist_{period}.png'), dpi=300)
    plt.close()

# 9. Process Period
def process_period(period, index, anomaly, output_dir):
    """Process HMM for a single period."""
    print(f"\nProcessing period: {period}")

    # Prepare data
    try:
        data = pd.DataFrame({
            'Year': index['Year'],
            f'ONI_{period}': index[period],
            f'WWV_{period}': anomaly[period]
        }).dropna()
    except KeyError as e:
        print(f"Error: Missing column for {period}: {e}")
        return None

    if len(data) < 10:
        print(f"Skipping {period} due to insufficient data")
        return None

    print(f"Data shape for {period}: {data.shape}")

    # Classify true states
    data['State'] = data[f'ONI_{period}'].apply(classify_enso)
    observations = data[[f'ONI_{period}', f'WWV_{period}']].values

    # Train HMM
    model, predicted_states, state_probs = train_hmm(observations)
    if model is None:
        return None

    data['Predicted_State'] = predicted_states
    data['Predicted_State_Name'] = data['Predicted_State'].map(state_names)

    # Evaluate
    accuracy = accuracy_score(data['State'], data['Predicted_State'])
    cm = confusion_matrix(data['State'], data['Predicted_State'], labels=[0, 1, 2])

    # Analyze transition matrix
    trans_analysis = analyze_transition_matrix(model.transmat_, period)

    # Analyze emission properties
    emission_analysis = analyze_emission_properties(model.means_, model.covars_)

    # Detect change points
    change_points = detect_change_points(predicted_states, data['Year'].values)

    # Cluster years
    cluster_data = cluster_years(observations, predicted_states, data['Year'].values)

    # Save outputs
    np.save(os.path.join(output_dir, f'transition_matrix_{period}.npy'), model.transmat_)
    np.save(os.path.join(output_dir, f'initial_probs_{period}.npy'), model.startprob_)
    np.save(os.path.join(output_dir, f'emission_means_{period}.npy'), model.means_)
    np.save(os.path.join(output_dir, f'emission_covars_{period}.npy'), model.covars_)
    data.to_csv(os.path.join(output_dir, f'predictions_{period}.csv'), index=False)
    cluster_data.to_csv(os.path.join(output_dir, f'clusters_{period}.csv'), index=False)

    # Visualizations
    plot_time_series(data, period, output_dir)
    plot_posterior_probabilities(data['Year'], state_probs, period, output_dir)
    plot_confusion_matrix(cm, period, output_dir)
    plot_emission_distributions(model.means_, model.covars_, period, output_dir)

    return {
        'data': data,
        'transition_matrix': model.transmat_,
        'initial_probs': model.startprob_,
        'emission_means': model.means_,
        'emission_covars': model.covars_,
        'accuracy': accuracy,
        'confusion_matrix': cm,
        'transition_analysis': trans_analysis,
        'emission_analysis': emission_analysis,
        'change_points': change_points,
        'cluster_data': cluster_data
    }

# 10. Main Execution
def main(index, anomaly):
    results = {}
    for period in periods:
        result = process_period(period, index, anomaly, output_dir)
        if result:
            results[period] = result

    # Detailed Analysis
    print("\n=== Detailed Analysis ===")

    # Accuracies and correlations
    accuracies = {p: r['accuracy'] for p, r in results.items()}
    print("\nModel Accuracies by Period:")
    for period, acc in accuracies.items():
        corr = corr_coeffs.get(period, 'N/A')
        print(f"{period}: Accuracy = {acc:.4f}, Correlation with NDJ WWV = {corr}")

    best_period = max(accuracies, key=accuracies.get) if accuracies else 'None'
    if best_period != 'None':
        print(f"\nBest Performing Period: {best_period} (Accuracy = {accuracies[best_period]:.4f})")

    # Transition analysis
    print("\nTransition Matrix Analysis:")
    for period, res in results.items():
        print(f"\n{period}:")
        print(f"Stability: {res['transition_analysis']['stability']}")
        print(f"Volatility: {res['transition_analysis']['volatility']}")
        print(f"Most Likely Transitions: {res['transition_analysis']['most_likely_transitions']}")

    # Emission properties
    print("\nEmission Properties Analysis:")
    for period, res in results.items():
        print(f"\n{period}:")
        for state, props in res['emission_analysis'].items():
            print(f"{state}: SST Mean = {props['sst_mean']:.4f}, WWV Mean = {props['wwv_mean']:.4f}, "
                  f"SST Variance = {props['sst_variance']:.4f}, WWV Variance = {props['wwv_variance']:.4f}, "
                  f"Covariance = {props['sst_wwv_covariance']:.4f}")

    # Change points
    print("\nChange-Point Detection:")
    for period, res in results.items():
        print(f"\n{period}:")
        for year, from_state, to_state in res['change_points']:
            print(f"Year {year}: {from_state} -> {to_state}")

    # Clustering
    print("\nClustering Analysis:")
    for period, res in results.items():
        print(f"\n{period}:")
        print(res['cluster_data'].groupby('Cluster').agg({
            'Year': 'count',
            'State': lambda x: x.value_counts().to_dict()
        }).rename(columns={'Year': 'Count', 'State': 'State Distribution'}))

    # Save summary
    with open(os.path.join(output_dir, 'summary.txt'), 'w') as f:
        f.write("=== HMM Summary ===\n")
        f.write("\nModel Accuracies by Period:\n")
        for period, acc in accuracies.items():
            corr = corr_coeffs.get(period, 'N/A')
            f.write(f"{period}: Accuracy = {acc:.4f}, Correlation with NDJ WWV = {corr}\n")
        if best_period != 'None':
            f.write(f"\nBest Performing Period: {best_period} (Accuracy = {accuracies[best_period]:.4f})\n")

        f.write("\nTransition Matrix Analysis:\n")
        for period, res in results.items():
            f.write(f"\n{period}:\n")
            f.write(f"Stability: {res['transition_analysis']['stability']}\n")
            f.write(f"Volatility: {res['transition_analysis']['volatility']}\n")
            f.write(f"Most Likely Transitions: {res['transition_analysis']['most_likely_transitions']}\n")

        f.write("\nEmission Properties Analysis:\n")
        for period, res in results.items():
            f.write(f"\n{period}:\n")
            for state, props in res['emission_analysis'].items():
                f.write(f"{state}: SST Mean = {props['sst_mean']:.4f}, WWV Mean = {props['wwv_mean']:.4f}, "
                      f"SST Variance = {props['sst_variance']:.4f}, WWV Variance = {props['wwv_variance']:.4f}, "
                      f"Covariance = {props['sst_wwv_covariance']:.4f}\n")

        f.write("\nChange-Point Detection:\n")
        for period, res in results.items():
            f.write(f"\n{period}:\n")
            for year, from_state, to_state in res['change_points']:
                f.write(f"Year {year}: {from_state} -> {to_state}\n")

        f.write("\nClustering Analysis:\n")
        for period, res in results.items():
            f.write(f"\n{period}:\n")
            f.write(str(res['cluster_data'].groupby('Cluster').agg({
                'Year': 'count',
                'State': lambda x: x.value_counts().to_dict()
            }).rename(columns={'Year': 'Count', 'State': 'State Distribution'})) + "\n")

    # Download outputs
    print("\nDownloading output files...")
    for fname in os.listdir(output_dir):
        files.download(os.path.join(output_dir, fname))

    print("\nSummary saved to 'hmm_outputs/summary.txt'")
    print("Model parameters, predictions, and plots saved in 'hmm_outputs/' directory")
import pandas as pd
import numpy as np
from hmmlearn import hmm
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import seaborn as sns
import os
from scipy.stats import multivariate_normal
from google.colab import files
import warnings
import pkg_resources

# Suppress warnings for cleaner output, but log critical ones
warnings.filterwarnings('ignore', category=RuntimeWarning)

# Check hmmlearn version
hmmlearn_version = pkg_resources.get_distribution("hmmlearn").version
print(f"Using hmmlearn version: {hmmlearn_version}")
if hmmlearn_version < '0.2.8':
    print("Warning: hmmlearn version is older than 0.2.8. Consider upgrading: !pip install hmmlearn --upgrade")

# Set random seed for reproducibility
np.random.seed(42)

# Create output directory
output_dir = 'hmm_outputs'
os.makedirs(output_dir, exist_ok=True)

# Define periods and correlation coefficients
periods = ['DJF', 'JFM', 'FMA', 'MAM', 'AMJ', 'MJJ', 'JJA', 'JAS', 'ASO', 'SON', 'OND', 'NDJ']
corr_coeffs = {
    'MJJ': -0.02436115975109819, 'OND': 0.4120160038567676, 'SON': 0.35330147648830007,
    'ASO': 0.2759602399103643, 'JAS': 0.18994442388434107, 'JJA': 0.13703542153643897
}
state_names = {0: 'El Niño', 1: 'La Niña', 2: 'Neutral'}

# ... (rest of the code remains unchanged)

Using hmmlearn version: 0.3.3


In [None]:
import pandas as pd
import numpy as np
from hmmlearn import hmm
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import seaborn as sns
import os
from scipy.stats import multivariate_normal
from google.colab import files
import warnings
import pkg_resources
import traceback

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore', category=RuntimeWarning)

# Check hmmlearn version
hmmlearn_version = pkg_resources.get_distribution("hmmlearn").version
print(f"Using hmmlearn version: {hmmlearn_version}")
if hmmlearn_version < '0.2.8':
    print("Warning: hmmlearn version is older than 0.2.8. Consider upgrading: !pip install hmmlearn --upgrade")

# Set random seed for reproducibility
np.random.seed(42)

# Create output directory
output_dir = 'hmm_outputs'
os.makedirs(output_dir, exist_ok=True)

# Define periods and correlation coefficients
periods = ['DJF', 'JFM', 'FMA', 'MAM', 'AMJ', 'MJJ', 'JJA', 'JAS', 'ASO', 'SON', 'OND', 'NDJ']
corr_coeffs = {
    'MJJ': -0.02436115975109819, 'OND': 0.4120160038567676, 'SON': 0.35330147648830007,
    'ASO': 0.2759602399103643, 'JAS': 0.18994442388434107, 'JJA': 0.13703542153643897
}
state_names = {0: 'El Niño', 1: 'La Niña', 2: 'Neutral'}

# 1. Classify ENSO States
def classify_enso(oni):
    """Classify ENSO state based on ONI thresholds."""
    if oni >= 0.5:
        return 0  # El Niño
    elif oni <= -0.5:
        return 1  # La Niña
    else:
        return 2  # Neutral

# 2. Align HMM States
def align_states(model, observations):
    """Reassign HMM states based on ONI means to match ENSO phases."""
    print("Aligning states...")
    if model.means_ is None:
        print("Warning: Model has no means_. Skipping alignment.")
        return model, state_names

    # Get ONI means (first column of means_)
    oni_means = model.means_[:, 0]

    # Define desired order: El Niño (highest ONI), La Niña (lowest), Neutral (middle)
    sorted_indices = np.argsort(oni_means)[::-1]  # Descending order
    new_order = np.zeros(3, dtype=int)
    new_order[0] = sorted_indices[0]  # El Niño: highest ONI
    new_order[1] = sorted_indices[2]  # La Niña: lowest ONI
    new_order[2] = sorted_indices[1]  # Neutral: middle ONI

    # Create new model and copy parameters
    aligned_model = hmm.GaussianHMM(n_components=model.n_components, covariance_type='full', n_iter=100, random_state=42)
    aligned_model.startprob_ = model.startprob_[new_order]
    aligned_model.transmat_ = model.transmat_[new_order][:, new_order]
    aligned_model.means_ = model.means_[new_order]
    aligned_model.covars_ = model.covars_[new_order]

    # Regularize covariances to prevent singularity
    for i in range(model.n_components):
        aligned_model.covars_[i] += np.eye(2) * 1e-6

    # Re-fit to ensure internal consistency
    try:
        aligned_model.fit(observations)
        print("Aligned model fitted successfully.")
    except Exception as e:
        print(f"Warning: Failed to fit aligned model: {e}")
        return model, state_names

    # Verify ONI means
    aligned_oni_means = aligned_model.means_[:, 0]
    print(f"Aligned ONI Means: El Niño={aligned_oni_means[0]:.4f}, La Niña={aligned_oni_means[1]:.4f}, Neutral={aligned_oni_means[2]:.4f}")

    return aligned_model, state_names

# 3. Train HMM
def train_hmm(observations, n_states=3, n_iter=100):
    """Train HMM with Gaussian emissions."""
    print(f"Training HMM with observations shape: {observations.shape}")
    if observations.shape[0] < n_states or observations.shape[1] != 2:
        print("Error: Invalid observations shape or insufficient data.")
        return None, None, None

    model = hmm.GaussianHMM(n_components=n_states, covariance_type='full', n_iter=n_iter, random_state=42)
    try:
        model.fit(observations)
        if not model.monitor_.converged:
            print("Warning: HMM did not converge. Consider increasing n_iter or checking data.")

        # Align states
        model, aligned_state_names = align_states(model, observations)

        # Predict with aligned model
        states = model.predict(observations)
        state_probs = model.predict_proba(observations)
        return model, states, state_probs
    except Exception as e:
        print(f"HMM fitting failed: {e}")
        return None, None, None

# 4. Analyze Transition Matrix
def analyze_transition_matrix(transmat, period):
    """Analyze stability and volatility of transition matrix."""
    stability = np.diag(transmat)
    volatility = 1 - stability
    return {
        'stability': dict(zip(state_names.values(), stability)),
        'volatility': dict(zip(state_names.values(), volatility)),
        'most_likely_transitions': {
            state: state_names[np.argmax(transmat[i, :])]
            for i, state in enumerate(state_names.values())
        }
    }

# 5. Analyze Emission Properties
def analyze_emission_properties(means, covars):
    """Characterize statistical properties of emission distributions."""
    properties = {}
    for i, state in enumerate(state_names.values()):
        properties[state] = {
            'sst_mean': means[i, 0],
            'wwv_mean': means[i, 1],
            'sst_variance': covars[i][0, 0],
            'wwv_variance': covars[i][1, 1],
            'sst_wwv_covariance': covars[i][0, 1]
        }
    return properties

# 6. Change-Point Detection
def detect_change_points(states, years):
    """Detect regime shifts based on state transitions."""
    transitions = np.where(states[:-1] != states[1:])[0] + 1
    change_points = [(years[i], state_names[states[i-1]], state_names[states[i]]) for i in transitions]
    return change_points

# 7. Cluster Years
def cluster_years(observations, states, years, n_clusters=3):
    """Group years with similar emission characteristics."""
    kmeans = KMeans(n_clusters=n_clusters, random_state=42)
    clusters = kmeans.fit_predict(observations)
    cluster_data = pd.DataFrame({
        'Year': years,
        'Cluster': clusters,
        'State': [state_names[s] for s in states]
    })
    return cluster_data

# 8. Visualization Functions
def plot_time_series(data, period, output_dir):
    """Plot ONI and WWV with true and predicted states."""
    fig, ax1 = plt.subplots(figsize=(14, 7))
    ax1.plot(data['Year'], data[f'ONI_{period}'], label=f'ONI ({period})', color='blue')
    ax1.set_xlabel('Year')
    ax1.set_ylabel('ONI (°C)', color='blue')
    ax1.tick_params(axis='y', labelcolor='blue')

    ax2 = ax1.twinx()
    ax2.plot(data['Year'], data[f'WWV_{period}'], label=f'WWV ({period})', color='red')
    ax2.set_ylabel('WWV Anomaly', color='red')
    ax2.tick_params(axis='y', labelcolor='red')

    for state in range(3):
        mask_true = data['State'] == state
        mask_pred = data['Predicted_State'] == state
        ax1.scatter(data['Year'][mask_true], data[f'ONI_{period}'][mask_true],
                   label=f'True {state_names[state]}', marker='o', s=50, alpha=0.6)
        ax1.scatter(data['Year'][mask_pred], data[f'ONI_{period}'][mask_pred],
                   label=f'Predicted {state_names[state]}', marker='x', s=50, alpha=0.6)

    plt.title(f'HMM Predicted vs True ENSO States for {period}')
    fig.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05), ncol=5)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f'enso_hmm_states_{period}.png'), dpi=300, bbox_inches='tight')
    plt.close()

def plot_posterior_probabilities(years, state_probs, period, output_dir):
    """Plot posterior probabilities of each state."""
    plt.figure(figsize=(12, 4))
    for i, state in enumerate(state_names.values()):
        plt.plot(years, state_probs[:, i], label=state)
    plt.title(f'{period} Posterior State Probabilities')
    plt.xlabel('Year')
    plt.ylabel('Probability')
    plt.legend()
    plt.savefig(os.path.join(output_dir, f'state_probs_{period}.png'), dpi=300)
    plt.close()

def plot_confusion_matrix(cm, period, output_dir):
    """Plot confusion matrix."""
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=state_names.values(), yticklabels=state_names.values())
    plt.title(f'{period} Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.savefig(os.path.join(output_dir, f'confusion_matrix_{period}.png'), dpi=300)
    plt.close()

def plot_emission_distributions(means, covars, period, output_dir):
    """Plot emission distributions for each state."""
    x = np.linspace(-3, 3, 100)
    y = np.linspace(-3, 3, 100)
    X, Y = np.meshgrid(x, y)
    pos = np.dstack((X, Y))

    plt.figure(figsize=(8, 6))
    for i, state in enumerate(state_names.values()):
        try:
            rv = multivariate_normal(mean=means[i], cov=covars[i])
            Z = rv.pdf(pos)
            plt.contour(X, Y, Z, levels=5, label=state)
        except Exception as e:
            print(f"Warning: Failed to plot emission distribution for {state} in {period}: {e}")
    plt.title(f'{period} Emission Distributions')
    plt.xlabel('ONI (°C)')
    plt.ylabel('WWV Anomaly')
    plt.legend(list(state_names.values()), loc='upper right')
    plt.savefig(os.path.join(output_dir, f'emission_dist_{period}.png'), dpi=300)
    plt.close()

# 9. Process Period
def process_period(period, index, anomaly, output_dir):
    """Process HMM for a single period."""
    print(f"\nProcessing period: {period}")

    # Prepare data
    try:
        data = pd.DataFrame({
            'Year': index['Year'],
            f'ONI_{period}': index[period],
            f'WWV_{period}': anomaly[period]
        }).dropna()
        print(f"Data shape for {period}: {data.shape}")
    except KeyError as e:
        print(f"Error: Missing column for {period}: {e}")
        return None
    except Exception as e:
        print(f"Error preparing data for {period}: {e}")
        return None

    if len(data) < 10:
        print(f"Skipping {period} due to insufficient data")
        return None

    # Classify true states
    data['State'] = data[f'ONI_{period}'].apply(classify_enso)
    observations = data[[f'ONI_{period}', f'WWV_{period}']].values
    print(f"Observations shape for {period}: {observations.shape}")

    # Train HMM
    model, predicted_states, state_probs = train_hmm(observations)
    if model is None:
        print(f"Skipping {period} due to HMM fitting failure")
        return None

    data['Predicted_State'] = predicted_states
    data['Predicted_State_Name'] = data['Predicted_State'].map(state_names)

    # Evaluate
    accuracy = accuracy_score(data['State'], data['Predicted_State'])
    cm = confusion_matrix(data['State'], data['Predicted_State'], labels=[0, 1, 2])
    print(f"Accuracy for {period}: {accuracy:.4f}")

    # Analyze transition matrix
    trans_analysis = analyze_transition_matrix(model.transmat_, period)

    # Analyze emission properties
    emission_analysis = analyze_emission_properties(model.means_, model.covars_)

    # Detect change points
    change_points = detect_change_points(predicted_states, data['Year'].values)

    # Cluster years
    cluster_data = cluster_years(observations, predicted_states, data['Year'].values)

    # Save outputs
    np.save(os.path.join(output_dir, f'transition_matrix_{period}.npy'), model.transmat_)
    np.save(os.path.join(output_dir, f'initial_probs_{period}.npy'), model.startprob_)
    np.save(os.path.join(output_dir, f'emission_means_{period}.npy'), model.means_)
    np.save(os.path.join(output_dir, f'emission_covars_{period}.npy'), model.covars_)
    data.to_csv(os.path.join(output_dir, f'predictions_{period}.csv'), index=False)
    cluster_data.to_csv(os.path.join(output_dir, f'clusters_{period}.csv'), index=False)

    # Visualizations
    plot_time_series(data, period, output_dir)
    plot_posterior_probabilities(data['Year'], state_probs, period, output_dir)
    plot_confusion_matrix(cm, period, output_dir)
    plot_emission_distributions(model.means_, model.covars_, period, output_dir)

    return {
        'data': data,
        'transition_matrix': model.transmat_,
        'initial_probs': model.startprob_,
        'emission_means': model.means_,
        'emission_covars': model.covars_,
        'accuracy': accuracy,
        'confusion_matrix': cm,
        'transition_analysis': trans_analysis,
        'emission_analysis': emission_analysis,
        'change_points': change_points,
        'cluster_data': cluster_data
    }

# 10. Main Execution
def main(index, anomaly):
    print("Entering main function")
    print("Index columns:", list(index.columns))
    print("Anomaly columns:", list(anomaly.columns))
    print("Index shape:", index.shape)
    print("Anomaly shape:", anomaly.shape)
    print("Missing values in index:\n", index.isna().sum())
    print("Missing values in anomaly:\n", anomaly.isna().sum())

    results = {}
    for period in periods:
        try:
            result = process_period(period, index, anomaly, output_dir)
            if result:
                results[period] = result
        except Exception as e:
            print(f"Error processing period {period}: {e}")
            print(traceback.format_exc())

    # Detailed Analysis
    print("\n=== Detailed Analysis ===")

    # Accuracies and correlations
    accuracies = {p: r['accuracy'] for p, r in results.items()}
    print("\nModel Accuracies by Period:")
    for period, acc in accuracies.items():
        corr = corr_coeffs.get(period, 'N/A')
        print(f"{period}: Accuracy = {acc:.4f}, Correlation with NDJ WWV = {corr}")

    best_period = max(accuracies, key=accuracies.get) if accuracies else 'None'
    if best_period != 'None':
        print(f"\nBest Performing Period: {best_period} (Accuracy = {accuracies[best_period]:.4f})")

    # Transition analysis
    print("\nTransition Matrix Analysis:")
    for period, res in results.items():
        print(f"\n{period}:")
        print(f"Stability: {res['transition_analysis']['stability']}")
        print(f"Volatility: {res['transition_analysis']['volatility']}")
        print(f"Most Likely Transitions: {res['transition_analysis']['most_likely_transitions']}")

    # Emission properties
    print("\nEmission Properties Analysis:")
    for period, res in results.items():
        print(f"\n{period}:")
        for state, props in res['emission_analysis'].items():
            print(f"{state}: SST Mean = {props['sst_mean']:.4f}, WWV Mean = {props['wwv_mean']:.4f}, "
                  f"SST Variance = {props['sst_variance']:.4f}, WWV Variance = {props['wwv_variance']:.4f}, "
                  f"Covariance = {props['sst_wwv_covariance']:.4f}")

    # Change points
    print("\nChange-Point Detection:")
    for period, res in results.items():
        print(f"\n{period}:")
        for year, from_state, to_state in res['change_points']:
            print(f"Year {year}: {from_state} -> {to_state}")

    # Clustering
    print("\nClustering Analysis:")
    for period, res in results.items():
        print(f"\n{period}:")
        print(res['cluster_data'].groupby('Cluster').agg({
            'Year': 'count',
            'State': lambda x: x.value_counts().to_dict()
        }).rename(columns={'Year': 'Count', 'State': 'State Distribution'}))

    # Confusion matrices
    print("\nConfusion Matrix Analysis:")
    for period in results:
        cm = results[period]['confusion_matrix']
        print(f"\n{period}:")
        print("Rows: Actual (El Niño, La Niña, Neutral)")
        print("Columns: Predicted (El Niño, La Niña, Neutral)")
        print(cm)
        precision = np.diag(cm) / np.sum(cm, axis=0)
        print(f"Precision (El Niño, La Niña, Neutral): {[f'{p:.4f}' if not np.isnan(p) else 'N/A' for p in precision]}")

    # Correlation vs. accuracy
    print("\nCorrelation vs. Accuracy Analysis:")
    for period in corr_coeffs:
        acc = accuracies.get(period, 'N/A')
        print(f"{period}: Correlation = {corr_coeffs[period]:.4f}, Accuracy = {acc if acc != 'N/A' else 'N/A'}")

    # Save summary
    with open(os.path.join(output_dir, 'summary.txt'), 'w') as f:
        f.write("=== HMM Summary ===\n")
        f.write("\nModel Accuracies by Period:\n")
        for period, acc in accuracies.items():
            corr = corr_coeffs.get(period, 'N/A')
            f.write(f"{period}: Accuracy = {acc:.4f}, Correlation with NDJ WWV = {corr}\n")
        if best_period != 'None':
            f.write(f"\nBest Performing Period: {best_period} (Accuracy = {accuracies[best_period]:.4f})\n")

        f.write("\nTransition Matrix Analysis:\n")
        for period, res in results.items():
            f.write(f"\n{period}:\n")
            f.write(f"Stability: {res['transition_analysis']['stability']}\n")
            f.write(f"Volatility: {res['transition_analysis']['volatility']}\n")
            f.write(f"Most Likely Transitions: {res['transition_analysis']['most_likely_transitions']}\n")

        f.write("\nEmission Properties Analysis:\n")
        for period, res in results.items():
            f.write(f"\n{period}:\n")
            for state, props in res['emission_analysis'].items():
                f.write(f"{state}: SST Mean = {props['sst_mean']:.4f}, WWV Mean = {props['wwv_mean']:.4f}, "
                      f"SST Variance = {props['sst_variance']:.4f}, WWV Variance = {props['wwv_variance']:.4f}, "
                      f"Covariance = {props['sst_wwv_covariance']:.4f}\n")

        f.write("\nChange-Point Detection:\n")
        for period, res in results.items():
            f.write(f"\n{period}:\n")
            for year, from_state, to_state in res['change_points']:
                f.write(f"Year {year}: {from_state} -> {to_state}\n")

        f.write("\nClustering Analysis:\n")
        for period, res in results.items():
            f.write(f"\n{period}:\n")
            f.write(str(res['cluster_data'].groupby('Cluster').agg({
                'Year': 'count',
                'State': lambda x: x.value_counts().to_dict()
            }).rename(columns={'Year': 'Count', 'State': 'State Distribution'})) + "\n")

        f.write("\nConfusion Matrix Analysis:\n")
        for period in results:
            cm = results[period]['confusion_matrix']
            f.write(f"\n{period}:\n")
            f.write("Rows: Actual (El Niño, La Niña, Neutral)\n")
            f.write("Columns: Predicted (El Niño, La Niña, Neutral)\n")
            f.write(f"{cm}\n")
            precision = np.diag(cm) / np.sum(cm, axis=0)
            f.write(f"Precision (El Niño, La Niña, Neutral): {[f'{p:.4f}' if not np.isnan(p) else 'N/A' for p in precision]}\n")

    # Download outputs
    print("\nDownloading output files...")
    for fname in os.listdir(output_dir):
        files.download(os.path.join(output_dir, fname))

    print("\nSummary saved to 'hmm_outputs/summary.txt'")
    print("Model parameters, predictions, and plots saved in 'hmm_outputs/' directory")

if __name__ == '__main__':
    try:
        main(index, anomaly)
    except Exception as e:
        print(f"Error in main execution: {e}")
        print(traceback.format_exc())

Using hmmlearn version: 0.3.3
Entering main function
Index columns: ['Year', 'DJF', 'JFM', 'FMA', 'MAM', 'AMJ', 'MJJ', 'JJA', 'JAS', 'ASO', 'SON', 'OND', 'NDJ']
Anomaly columns: ['Year', 'DJF', 'JFM', 'FMA', 'MAM', 'AMJ', 'MJJ', 'JJA', 'JAS', 'ASO', 'SON', 'OND', 'NDJ']
Index shape: (44, 13)
Anomaly shape: (44, 13)
Missing values in index:
 Year    0
DJF     0
JFM     0
FMA     0
MAM     0
AMJ     1
MJJ     1
JJA     1
JAS     1
ASO     1
SON     1
OND     1
NDJ     1
dtype: int64
Missing values in anomaly:
 Year    0
DJF     1
JFM     0
FMA     0
MAM     1
AMJ     1
MJJ     1
JJA     1
JAS     1
ASO     1
SON     1
OND     1
NDJ     1
dtype: int64

Processing period: DJF
Data shape for DJF: (43, 3)
Observations shape for DJF: (43, 2)
Training HMM with observations shape: (43, 2)
Aligning states...
HMM fitting failed: 'GaussianHMM' object has no attribute 'n_features'
Skipping DJF due to HMM fitting failure

Processing period: JFM
Data shape for JFM: (44, 3)
Observations shape for JFM:

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>


Summary saved to 'hmm_outputs/summary.txt'
Model parameters, predictions, and plots saved in 'hmm_outputs/' directory


In [None]:

import pandas as pd
import numpy as np
from hmmlearn import hmm
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import seaborn as sns
import os
from scipy.stats import multivariate_normal
from google.colab import files
import warnings
import pkg_resources
import traceback

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore', category=RuntimeWarning)

# Check hmmlearn version
hmmlearn_version = pkg_resources.get_distribution("hmmlearn").version
print(f"Using hmmlearn version: {hmmlearn_version}")
if hmmlearn_version < '0.2.8':
    print("Warning: hmmlearn version is older than 0.2.8. Consider upgrading: !pip install hmmlearn --upgrade")

# Set random seed for reproducibility
np.random.seed(42)

# Create output directory
output_dir = 'hmm_outputs'
os.makedirs(output_dir, exist_ok=True)

# Define periods and correlation coefficients
periods = ['DJF', 'JFM', 'FMA', 'MAM', 'AMJ', 'MJJ', 'JJA', 'JAS', 'ASO', 'SON', 'OND', 'NDJ']
corr_coeffs = {
    'MJJ': -0.02436115975109819, 'OND': 0.4120160038567676, 'SON': 0.35330147648830007,
    'ASO': 0.2759602399103643, 'JAS': 0.18994442388434107, 'JJA': 0.13703542153643897
}
state_names = {0: 'El Niño', 1: 'La Niña', 2: 'Neutral'}

# 1. Classify ENSO States
def classify_enso(oni):
    """Classify ENSO state based on ONI thresholds."""
    if oni >= 0.5:
        return 0  # El Niño
    elif oni <= -0.5:
        return 1  # La Niña
    else:
        return 2  # Neutral

# 2. Align HMM States
def align_states(model, observations):
    """Reassign HMM states based on ONI means without creating a new model."""
    print("Aligning states...")
    if model.means_ is None:
        print("Warning: Model has no means_. Skipping alignment.")
        return model, state_names, np.arange(3)  # No permutation

    # Get ONI means (first column of means_)
    oni_means = model.means_[:, 0]
    print(f"Original ONI Means: {oni_means}")

    # Define desired order: El Niño (highest ONI), La Niña (lowest), Neutral (middle)
    sorted_indices = np.argsort(oni_means)[::-1]  # Descending order
    new_order = np.zeros(3, dtype=int)
    new_order[0] = sorted_indices[0]  # El Niño: highest ONI
    new_order[1] = sorted_indices[2]  # La Niña: lowest ONI
    new_order[2] = sorted_indices[1]  # Neutral: middle ONI

    # Create permutation mapping
    permutation = np.zeros(3, dtype=int)
    for i, idx in enumerate(new_order):
        permutation[idx] = i

    # Permute model parameters
    model.startprob_ = model.startprob_[new_order]
    model.transmat_ = model.transmat_[new_order][:, new_order]
    model.means_ = model.means_[new_order]
    model.covars_ = model.covars_[new_order]

    # Regularize covariances to prevent singularity
    for i in range(model.n_components):
        model.covars_[i] += np.eye(2) * 1e-6

    # Verify ONI means
    aligned_oni_means = model.means_[:, 0]
    print(f"Aligned ONI Means: El Niño={aligned_oni_means[0]:.4f}, La Niña={aligned_oni_means[1]:.4f}, Neutral={aligned_oni_means[2]:.4f}")

    return model, state_names, permutation

# 3. Train HMM
def train_hmm(observations, n_states=3, n_iter=100):
    """Train HMM with Gaussian emissions."""
    print(f"Training HMM with observations shape: {observations.shape}")
    if observations.shape[0] < n_states or observations.shape[1] != 2:
        print("Error: Invalid observations shape or insufficient data.")
        return None, None, None

    model = hmm.GaussianHMM(n_components=n_states, covariance_type='full', n_iter=n_iter, random_state=42)
    try:
        model.fit(observations)
        if not model.monitor_.converged:
            print("Warning: HMM did not converge. Consider increasing n_iter or checking data.")

        # Align states
        model, aligned_state_names, permutation = align_states(model, observations)

        # Predict with aligned model and permute states
        states = model.predict(observations)
        state_probs = model.predict_proba(observations)
        # Apply permutation to predicted states and probabilities
        aligned_states = np.array([permutation[state] for state in states])
        aligned_state_probs = state_probs[:, new_order]

        return model, aligned_states, aligned_state_probs
    except Exception as e:
        print(f"HMM fitting failed: {e}")
        return None, None, None

# 4. Analyze Transition Matrix
def analyze_transition_matrix(transmat, period):
    """Analyze stability and volatility of transition matrix."""
    stability = np.diag(transmat)
    volatility = 1 - stability
    return {
        'stability': dict(zip(state_names.values(), stability)),
        'volatility': dict(zip(state_names.values(), volatility)),
        'most_likely_transitions': {
            state: state_names[np.argmax(transmat[i, :])]
            for i, state in enumerate(state_names.values())
        }
    }

# 5. Analyze Emission Properties
def analyze_emission_properties(means, covars):
    """Characterize statistical properties of emission distributions."""
    properties = {}
    for i, state in enumerate(state_names.values()):
        properties[state] = {
            'sst_mean': means[i, 0],
            'wwv_mean': means[i, 1],
            'sst_variance': covars[i][0, 0],
            'wwv_variance': covars[i][1, 1],
            'sst_wwv_covariance': covars[i][0, 1]
        }
    return properties

# 6. Change-Point Detection
def detect_change_points(states, years):
    """Detect regime shifts based on state transitions."""
    transitions = np.where(states[:-1] != states[1:])[0] + 1
    change_points = [(years[i], state_names[states[i-1]], state_names[states[i]]) for i in transitions]
    return change_points

# 7. Cluster Years
def cluster_years(observations, states, years, n_clusters=3):
    """Group years with similar emission characteristics."""
    kmeans = KMeans(n_clusters=n_clusters, random_state=42)
    clusters = kmeans.fit_predict(observations)
    cluster_data = pd.DataFrame({
        'Year': years,
        'Cluster': clusters,
        'State': [state_names[s] for s in states]
    })
    return cluster_data

# 8. Visualization Functions
def plot_time_series(data, period, output_dir):
    """Plot ONI and WWV with true and predicted states."""
    fig, ax1 = plt.subplots(figsize=(14, 7))
    ax1.plot(data['Year'], data[f'ONI_{period}'], label=f'ONI ({period})', color='blue')
    ax1.set_xlabel('Year')
    ax1.set_ylabel('ONI (°C)', color='blue')
    ax1.tick_params(axis='y', labelcolor='blue')

    ax2 = ax1.twinx()
    ax2.plot(data['Year'], data[f'WWV_{period}'], label=f'WWV ({period})', color='red')
    ax2.set_ylabel('WWV Anomaly', color='red')
    ax2.tick_params(axis='y', labelcolor='red')

    for state in range(3):
        mask_true = data['State'] == state
        mask_pred = data['Predicted_State'] == state
        ax1.scatter(data['Year'][mask_true], data[f'ONI_{period}'][mask_true],
                   label=f'True {state_names[state]}', marker='o', s=50, alpha=0.6)
        ax1.scatter(data['Year'][mask_pred], data[f'ONI_{period}'][mask_pred],
                   label=f'Predicted {state_names[state]}', marker='x', s=50, alpha=0.6)

    plt.title(f'HMM Predicted vs True ENSO States for {period}')
    fig.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05), ncol=5)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f'enso_hmm_states_{period}.png'), dpi=300, bbox_inches='tight')
    plt.close()

def plot_posterior_probabilities(years, state_probs, period, output_dir):
    """Plot posterior probabilities of each state."""
    plt.figure(figsize=(12, 4))
    for i, state in enumerate(state_names.values()):
        plt.plot(years, state_probs[:, i], label=state)
    plt.title(f'{period} Posterior State Probabilities')
    plt.xlabel('Year')
    plt.ylabel('Probability')
    plt.legend()
    plt.savefig(os.path.join(output_dir, f'state_probs_{period}.png'), dpi=300)
    plt.close()

def plot_confusion_matrix(cm, period, output_dir):
    """Plot confusion matrix."""
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=state_names.values(), yticklabels=state_names.values())
    plt.title(f'{period} Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.savefig(os.path.join(output_dir, f'confusion_matrix_{period}.png'), dpi=300)
    plt.close()

def plot_emission_distributions(means, covars, period, output_dir):
    """Plot emission distributions for each state."""
    x = np.linspace(-3, 3, 100)
    y = np.linspace(-3, 3, 100)
    X, Y = np.meshgrid(x, y)
    pos = np.dstack((X, Y))

    plt.figure(figsize=(8, 6))
    for i, state in enumerate(state_names.values()):
        try:
            rv = multivariate_normal(mean=means[i], cov=covars[i])
            Z = rv.pdf(pos)
            plt.contour(X, Y, Z, levels=5, label=state)
        except Exception as e:
            print(f"Warning: Failed to plot emission distribution for {state} in {period}: {e}")
    plt.title(f'{period} Emission Distributions')
    plt.xlabel('ONI (°C)')
    plt.ylabel('WWV Anomaly')
    plt.legend(list(state_names.values()), loc='upper right')
    plt.savefig(os.path.join(output_dir, f'emission_dist_{period}.png'), dpi=300)
    plt.close()

# 9. Process Period
def process_period(period, index, anomaly, output_dir):
    """Process HMM for a single period."""
    print(f"\nProcessing period: {period}")

    # Prepare data
    try:
        data = pd.DataFrame({
            'Year': index['Year'],
            f'ONI_{period}': index[period],
            f'WWV_{period}': anomaly[period]
        }).dropna()
        print(f"Data shape for {period}: {data.shape}")
    except KeyError as e:
        print(f"Error: Missing column for {period}: {e}")
        return None
    except Exception as e:
        print(f"Error preparing data for {period}: {e}")
        return None

    if len(data) < 10:
        print(f"Skipping {period} due to insufficient data")
        return None

    # Classify true states
    data['State'] = data[f'ONI_{period}'].apply(classify_enso)
    observations = data[[f'ONI_{period}', f'WWV_{period}']].values
    print(f"Observations shape for {period}: {observations.shape}")

    # Train HMM
    model, predicted_states, state_probs = train_hmm(observations)
    if model is None:
        print(f"Skipping {period} due to HMM fitting failure")
        return None

    data['Predicted_State'] = predicted_states
    data['Predicted_State_Name'] = data['Predicted_State'].map(state_names)

    # Evaluate
    accuracy = accuracy_score(data['State'], data['Predicted_State'])
    cm = confusion_matrix(data['State'], data['Predicted_State'], labels=[0, 1, 2])
    print(f"Accuracy for {period}: {accuracy:.4f}")

    # Analyze transition matrix
    trans_analysis = analyze_transition_matrix(model.transmat_, period)

    # Analyze emission properties
    emission_analysis = analyze_emission_properties(model.means_, model.covars_)

    # Detect change points
    change_points = detect_change_points(predicted_states, data['Year'].values)

    # Cluster years
    cluster_data = cluster_years(observations, predicted_states, data['Year'].values)

    # Save outputs
    np.save(os.path.join(output_dir, f'transition_matrix_{period}.npy'), model.transmat_)
    np.save(os.path.join(output_dir, f'initial_probs_{period}.npy'), model.startprob_)
    np.save(os.path.join(output_dir, f'emission_means_{period}.npy'), model.means_)
    np.save(os.path.join(output_dir, f'emission_covars_{period}.npy'), model.covars_)
    data.to_csv(os.path.join(output_dir, f'predictions_{period}.csv'), index=False)
    cluster_data.to_csv(os.path.join(output_dir, f'clusters_{period}.csv'), index=False)

    # Visualizations
    plot_time_series(data, period, output_dir)
    plot_posterior_probabilities(data['Year'], state_probs, period, output_dir)
    plot_confusion_matrix(cm, period, output_dir)
    plot_emission_distributions(model.means_, model.covars_, period, output_dir)

    return {
        'data': data,
        'transition_matrix': model.transmat_,
        'initial_probs': model.startprob_,
        'emission_means': model.means_,
        'emission_covars': model.covars_,
        'accuracy': accuracy,
        'confusion_matrix': cm,
        'transition_analysis': trans_analysis,
        'emission_analysis': emission_analysis,
        'change_points': change_points,
        'cluster_data': cluster_data
    }

# 10. Main Execution
def main(index, anomaly):
    print("Entering main function")
    print("Index columns:", list(index.columns))
    print("Anomaly columns:", list(anomaly.columns))
    print("Index shape:", index.shape)
    print("Anomaly shape:", anomaly.shape)
    print("Missing values in index:\n", index.isna().sum())
    print("Missing values in anomaly:\n", anomaly.isna().sum())

    # Remove 'Year' from anomaly if present
    if 'Year' in anomaly.columns:
        anomaly = anomaly.drop(columns=['Year'])
        print("Removed 'Year' from anomaly. New columns:", list(anomaly.columns))

    # Impute missing values
    index = index.fillna(0)
    anomaly = anomaly.fillna(0)
    print("Imputed missing values. New missing values in index:\n", index.isna().sum())
    print("New missing values in anomaly:\n", anomaly.isna().sum())

    results = {}
    for period in periods:
        try:
            result = process_period(period, index, anomaly, output_dir)
            if result:
                results[period] = result
        except Exception as e:
            print(f"Error processing period {period}: {e}")
            print(traceback.format_exc())

    # Detailed Analysis
    print("\n=== Detailed Analysis ===")

    # Accuracies and correlations
    accuracies = {p: r['accuracy'] for p, r in results.items()}
    print("\nModel Accuracies by Period:")
    for period, acc in accuracies.items():
        corr = corr_coeffs.get(period, 'N/A')
        print(f"{period}: Accuracy = {acc:.4f}, Correlation with NDJ WWV = {corr}")

    best_period = max(accuracies, key=accuracies.get) if accuracies else 'None'
    if best_period != 'None':
        print(f"\nBest Performing Period: {best_period} (Accuracy = {accuracies[best_period]:.4f})")

    # Transition analysis
    print("\nTransition Matrix Analysis:")
    for period, res in results.items():
        print(f"\n{period}:")
        print(f"Stability: {res['transition_analysis']['stability']}")
        print(f"Volatility: {res['transition_analysis']['volatility']}")
        print(f"Most Likely Transitions: {res['transition_analysis']['most_likely_transitions']}")

    # Emission properties
    print("\nEmission Properties Analysis:")
    for period, res in results.items():
        print(f"\n{period}:")
        for state, props in res['emission_analysis'].items():
            print(f"{state}: SST Mean = {props['sst_mean']:.4f}, WWV Mean = {props['wwv_mean']:.4f}, "
                  f"SST Variance = {props['sst_variance']:.4f}, WWV Variance = {props['wwv_variance']:.4f}, "
                  f"Covariance = {props['sst_wwv_covariance']:.4f}")

    # Change points
    print("\nChange-Point Detection:")
    for period, res in results.items():
        print(f"\n{period}:")
        for year, from_state, to_state in res['change_points']:
            print(f"Year {year}: {from_state} -> {to_state}")

    # Clustering
    print("\nClustering Analysis:")
    for period, res in results.items():
        print(f"\n{period}:")
        print(res['cluster_data'].groupby('Cluster').agg({
            'Year': 'count',
            'State': lambda x: x.value_counts().to_dict()
        }).rename(columns={'Year': 'Count', 'State': 'State Distribution'}))

    # Confusion matrices
    print("\nConfusion Matrix Analysis:")
    for period in results:
        cm = results[period]['confusion_matrix']
        print(f"\n{period}:")
        print("Rows: Actual (El Niño, La Niña, Neutral)")
        print("Columns: Predicted (El Niño, La Niña, Neutral)")
        print(cm)
        precision = np.diag(cm) / np.sum(cm, axis=0)
        print(f"Precision (El Niño, La Niña, Neutral): {[f'{p:.4f}' if not np.isnan(p) else 'N/A' for p in precision]}")

    # Correlation vs. accuracy
    print("\nCorrelation vs. Accuracy Analysis:")
    for period in corr_coeffs:
        acc = accuracies.get(period, 'N/A')
        print(f"{period}: Correlation = {corr_coeffs[period]:.4f}, Accuracy = {acc if acc != 'N/A' else 'N/A'}")

    # Save summary
    with open(os.path.join(output_dir, 'summary.txt'), 'w') as f:
        f.write("=== HMM Summary ===\n")
        f.write("\nModel Accuracies by Period:\n")
        for period, acc in accuracies.items():
            corr = corr_coeffs.get(period, 'N/A')
            f.write(f"{period}: Accuracy = {acc:.4f}, Correlation with NDJ WWV = {corr}\n")
        if best_period != 'None':
            f.write(f"\nBest Performing Period: {best_period} (Accuracy = {accuracies[best_period]:.4f})\n")

        f.write("\nTransition Matrix Analysis:\n")
        for period, res in results.items():
            f.write(f"\n{period}:\n")
            f.write(f"Stability: {res['transition_analysis']['stability']}\n")
            f.write(f"Volatility: {res['transition_analysis']['volatility']}\n")
            f.write(f"Most Likely Transitions: {res['transition_analysis']['most_likely_transitions']}\n")

        f.write("\nEmission Properties Analysis:\n")
        for period, res in results.items():
            f.write(f"\n{period}:\n")
            for state, props in res['emission_analysis'].items():
                f.write(f"{state}: SST Mean = {props['sst_mean']:.4f}, WWV Mean = {props['wwv_mean']:.4f}, "
                      f"SST Variance = {props['sst_variance']:.4f}, WWV Variance = {props['wwv_variance']:.4f}, "
                      f"Covariance = {props['sst_wwv_covariance']:.4f}\n")

        f.write("\nChange-Point Detection:\n")
        for period, res in results.items():
            f.write(f"\n{period}:\n")
            for year, from_state, to_state in res['change_points']:
                f.write(f"Year {year}: {from_state} -> {to_state}\n")

        f.write("\nClustering Analysis:\n")
        for period, res in results.items():
            f.write(f"\n{period}:\n")
            f.write(str(res['cluster_data'].groupby('Cluster').agg({
                'Year': 'count',
                'State': lambda x: x.value_counts().to_dict()
            }).rename(columns={'Year': 'Count', 'State': 'State Distribution'})) + "\n")

        f.write("\nConfusion Matrix Analysis:\n")
        for period in results:
            cm = results[period]['confusion_matrix']
            f.write(f"\n{period}:\n")
            f.write("Rows: Actual (El Niño, La Niña, Neutral)\n")
            f.write("Columns: Predicted (El Niño, La Niña, Neutral)\n")
            f.write(f"{cm}\n")
            precision = np.diag(cm) / np.sum(cm, axis=0)
            f.write(f"Precision (El Niño, La Niña, Neutral): {[f'{p:.4f}' if not np.isnan(p) else 'N/A' for p in precision]}\n")

    # Download outputs
    print("\nDownloading output files...")
    for fname in os.listdir(output_dir):
        files.download(os.path.join(output_dir, fname))

    print("\nSummary saved to 'hmm_outputs/summary.txt'")
    print("Model parameters, predictions, and plots saved in 'hmm_outputs/' directory")

if __name__ == '__main__':
    try:
        main(index, anomaly)
    except Exception as e:
        print(f"Error in main execution: {e}")
        print(traceback.format_exc())


Using hmmlearn version: 0.3.3
Entering main function
Index columns: ['Year', 'DJF', 'JFM', 'FMA', 'MAM', 'AMJ', 'MJJ', 'JJA', 'JAS', 'ASO', 'SON', 'OND', 'NDJ']
Anomaly columns: ['Year', 'DJF', 'JFM', 'FMA', 'MAM', 'AMJ', 'MJJ', 'JJA', 'JAS', 'ASO', 'SON', 'OND', 'NDJ']
Index shape: (44, 13)
Anomaly shape: (44, 13)
Missing values in index:
 Year    0
DJF     0
JFM     0
FMA     0
MAM     0
AMJ     1
MJJ     1
JJA     1
JAS     1
ASO     1
SON     1
OND     1
NDJ     1
dtype: int64
Missing values in anomaly:
 Year    0
DJF     1
JFM     0
FMA     0
MAM     1
AMJ     1
MJJ     1
JJA     1
JAS     1
ASO     1
SON     1
OND     1
NDJ     1
dtype: int64
Removed 'Year' from anomaly. New columns: ['DJF', 'JFM', 'FMA', 'MAM', 'AMJ', 'MJJ', 'JJA', 'JAS', 'ASO', 'SON', 'OND', 'NDJ']
Imputed missing values. New missing values in index:
 Year    0
DJF     0
JFM     0
FMA     0
MAM     0
AMJ     0
MJJ     0
JJA     0
JAS     0
ASO     0
SON     0
OND     0
NDJ     0
dtype: int64
New missing values 

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>


Summary saved to 'hmm_outputs/summary.txt'
Model parameters, predictions, and plots saved in 'hmm_outputs/' directory


In [None]:
import pandas as pd
import numpy as np
from hmmlearn import hmm
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.cluster import KMeans
from scipy.stats import multivariate_normal, mode
import matplotlib.pyplot as plt
import seaborn as sns
import os
import warnings
warnings.filterwarnings('ignore')

# Constants
periods = ['DJF', 'JFM', 'FMA', 'MAM', 'AMJ', 'MJJ', 'JJA', 'JAS', 'ASO', 'SON', 'OND', 'NDJ']
corr_coeffs = {
    'MJJ': -0.02436115975109819, 'OND': 0.4120160038567676, 'SON': 0.35330147648830007,
    'ASO': 0.2759602399103643, 'JAS': 0.18994442388434107, 'JJA': 0.13703542153643897
}
state_names = {0: 'El Niño', 1: 'La Niña', 2: 'Neutral'}
output_dir = 'hmm_outputs'
os.makedirs(output_dir, exist_ok=True)

# Helper Functions
def classify_enso(oni):
    if oni >= 0.5:
        return 0  # El Niño
    elif oni <= -0.5:
        return 1  # La Niña
    else:
        return 2  # Neutral

def align_states(true_states, predicted_states, n_states=3):
    state_map = {}
    new_predicted = np.copy(predicted_states)
    for i in range(n_states):
        mask = predicted_states == i
        if np.any(mask):
            majority_label = mode(true_states[mask], keepdims=True)[0][0]
            state_map[i] = majority_label
        else:
            state_map[i] = 2
    for i in range(n_states):
        new_predicted[predicted_states == i] = state_map[i]
    return new_predicted, state_map

def train_hmm(observations, n_states=3, n_iter=100):
    model = hmm.GaussianHMM(n_components=n_states, covariance_type='full', n_iter=n_iter, random_state=42)
    model.fit(observations)
    states = model.predict(observations)
    probs = model.predict_proba(observations)
    return model, states, probs

def process_period(period, index, anomaly):
    print(f"\nProcessing period: {period}")
    data = pd.DataFrame({
        'Year': index['Year'],
        f'ONI_{period}': index[period],
        f'WWV_{period}': anomaly[period]
    }).dropna()

    if len(data) < 10:
        print(f"Skipping {period} due to insufficient data")
        return None

    data['State'] = data[f'ONI_{period}'].apply(classify_enso)
    observations = data[[f'ONI_{period}', f'WWV_{period}']].values
    model, raw_states, state_probs = train_hmm(observations)
    aligned_states, state_map = align_states(data['State'].values, raw_states)

    data['Predicted_State'] = aligned_states
    data['Predicted_State_Name'] = data['Predicted_State'].map(state_names)

    acc = accuracy_score(data['State'], aligned_states)
    cm = confusion_matrix(data['State'], aligned_states, labels=[0, 1, 2])

    np.save(f'{output_dir}/transition_matrix_{period}.npy', model.transmat_)
    np.save(f'{output_dir}/initial_probs_{period}.npy', model.startprob_)
    np.save(f'{output_dir}/emission_means_{period}.npy', model.means_)
    np.save(f'{output_dir}/emission_covars_{period}.npy', model.covars_)
    data.to_csv(f'{output_dir}/predictions_{period}.csv', index=False)

    return {
        'accuracy': acc,
        'confusion_matrix': cm,
        'transition_matrix': model.transmat_,
        'initial_probs': model.startprob_,
        'emission_means': model.means_,
        'emission_covars': model.covars_,
        'state_map': state_map
    }

def main(index, anomaly):
    results = {}
    for period in periods:
        result = process_period(period, index, anomaly)
        if result:
            results[period] = result

    print("\n=== Summary ===")
    for period in results:
        print(f"\n{period}:")
        print(f"Accuracy: {results[period]['accuracy']:.4f}")
        print("Confusion Matrix:")
        print(results[period]['confusion_matrix'])
        print("Transition Matrix:")
        print(results[period]['transition_matrix'])
        print("Emission Means:")
        print(results[period]['emission_means'])
        print("Emission Covariances:")
        print(results[period]['emission_covars'])
        print("State Map (HMM -> ENSO):")
        print(results[period]['state_map'])

main(index, anomaly)



Processing period: DJF

Processing period: JFM

Processing period: FMA

Processing period: MAM

Processing period: AMJ

Processing period: MJJ

Processing period: JJA

Processing period: JAS

Processing period: ASO

Processing period: SON

Processing period: OND

Processing period: NDJ

=== Summary ===

DJF:
Accuracy: 0.5581
Confusion Matrix:
[[ 6  8  0]
 [ 0 18  0]
 [ 0 11  0]]
Transition Matrix:
[[2.26744154e-09 9.99999973e-01 2.44387261e-08]
 [6.82676952e-01 1.24073756e-05 3.17310641e-01]
 [1.58741991e-01 8.41258009e-01 1.30202873e-13]]
Emission Means:
[[-0.64373172  0.24779321]
 [-0.109633    0.15579874]
 [ 1.68340366 -0.08271336]]
Emission Covariances:
[[[ 0.50879039  0.24021446]
  [ 0.24021446  0.24813871]]

 [[ 0.65004734  0.20579118]
  [ 0.20579118  1.12304139]]

 [[ 0.41306414 -0.19132983]
  [-0.19132983  0.47330649]]]
State Map (HMM -> ENSO):
{0: np.int64(1), 1: np.int64(1), 2: np.int64(0)}

JFM:
Accuracy: 0.6136
Confusion Matrix:
[[14  0  0]
 [ 4 13  0]
 [10  3  0]]
Transit

In [None]:
import pandas as pd
import numpy as np
from hmmlearn import hmm
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.cluster import KMeans
from scipy.stats import multivariate_normal, mode
import matplotlib.pyplot as plt
import seaborn as sns
import os
import warnings
warnings.filterwarnings('ignore')

# Constants
periods = ['DJF', 'JFM', 'FMA', 'MAM', 'AMJ', 'MJJ', 'JJA', 'JAS', 'ASO', 'SON', 'OND', 'NDJ']
corr_coeffs = {
    'MJJ': -0.02436115975109819, 'OND': 0.4120160038567676, 'SON': 0.35330147648830007,
    'ASO': 0.2759602399103643, 'JAS': 0.18994442388434107, 'JJA': 0.13703542153643897
}
state_names = {0: 'El Niño', 1: 'La Niña', 2: 'Neutral'}
output_dir = 'hmm_outputs_v2'
os.makedirs(output_dir, exist_ok=True)

# Helper Functions
def classify_enso(oni):
    if oni >= 0.5:
        return 0
    elif oni <= -0.5:
        return 1
    else:
        return 2

def align_states(true_states, predicted_states, n_states=3):
    state_map = {}
    new_predicted = np.copy(predicted_states)
    for i in range(n_states):
        mask = predicted_states == i
        if np.any(mask):
            majority_label = mode(true_states[mask], keepdims=True)[0][0]
            state_map[i] = majority_label
        else:
            state_map[i] = 2
    for i in range(n_states):
        new_predicted[predicted_states == i] = state_map[i]
    return new_predicted, state_map

def train_hmm(observations, n_states=3, n_iter=100):
    model = hmm.GaussianHMM(n_components=n_states, covariance_type='full', n_iter=n_iter, random_state=42)
    model.fit(observations)
    states = model.predict(observations)
    probs = model.predict_proba(observations)
    return model, states, probs

def detect_change_points(states, years):
    transitions = np.where(states[:-1] != states[1:])[0] + 1
    change_points = [(years[i], state_names[states[i-1]], state_names[states[i]]) for i in transitions]
    return change_points

def cluster_years(observations, states, years, n_clusters=3):
    kmeans = KMeans(n_clusters=n_clusters, random_state=42)
    clusters = kmeans.fit_predict(observations)
    return pd.DataFrame({
        'Year': years,
        'Cluster': clusters,
        'State': [state_names[s] for s in states]
    })

def plot_confusion_matrix(cm, period):
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=state_names.values(), yticklabels=state_names.values())
    plt.title(f'{period} Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.savefig(f'{output_dir}/confusion_matrix_{period}.png')
    plt.close()

def plot_emission_distributions(means, covars, period):
    x = np.linspace(-3, 3, 100)
    y = np.linspace(-3, 3, 100)
    X, Y = np.meshgrid(x, y)
    pos = np.dstack((X, Y))
    plt.figure(figsize=(8, 6))
    for i, state in enumerate(state_names.values()):
        rv = multivariate_normal(mean=means[i], cov=covars[i])
        Z = rv.pdf(pos)
        plt.contour(X, Y, Z, levels=5, label=state)
    plt.title(f'{period} Emission Distributions')
    plt.xlabel('ONI (°C)')
    plt.ylabel('WWV Anomaly')
    plt.legend(list(state_names.values()))
    plt.savefig(f'{output_dir}/emission_distributions_{period}.png')
    plt.close()

def plot_predictions(data, period):
    plt.figure(figsize=(12, 6))
    plt.plot(data['Year'], data[f'ONI_{period}'], label='ONI', color='blue')
    plt.plot(data['Year'], data[f'WWV_{period}'], label='WWV', color='red')
    for state in range(3):
        mask = data['Predicted_State'] == state
        plt.scatter(data['Year'][mask], data[f'ONI_{period}'][mask], label=f'Predicted {state_names[state]}', s=50)
    plt.xlabel('Year')
    plt.ylabel('Value')
    plt.title(f'Predicted ENSO States - {period}')
    plt.legend()
    plt.savefig(f'{output_dir}/predicted_states_{period}.png')
    plt.close()

def plot_state_probabilities(years, state_probs, period):
    plt.figure(figsize=(12, 4))
    for i, state in enumerate(state_names.values()):
        plt.plot(years, state_probs[:, i], label=state)
    plt.title(f'{period} Posterior State Probabilities')
    plt.xlabel('Year')
    plt.ylabel('Probability')
    plt.legend()
    plt.savefig(f'{output_dir}/state_probs_{period}.png')
    plt.close()

def process_period(period, index, anomaly):
    print(f"\nProcessing period: {period}")
    data = pd.DataFrame({
        'Year': index['Year'],
        f'ONI_{period}': index[period],
        f'WWV_{period}': anomaly[period]
    }).dropna()

    if len(data) < 10:
        print(f"Skipping {period} due to insufficient data")
        return None

    data['State'] = data[f'ONI_{period}'].apply(classify_enso)
    observations = data[[f'ONI_{period}', f'WWV_{period}']].values
    model, raw_states, state_probs = train_hmm(observations)
    aligned_states, state_map = align_states(data['State'].values, raw_states)

    data['Predicted_State'] = aligned_states
    data['Predicted_State_Name'] = data['Predicted_State'].map(state_names)

    acc = accuracy_score(data['State'], aligned_states)
    cm = confusion_matrix(data['State'], aligned_states, labels=[0, 1, 2])
    change_points = detect_change_points(aligned_states, data['Year'].values)
    cluster_data = cluster_years(observations, aligned_states, data['Year'].values)

    np.save(f'{output_dir}/transition_matrix_{period}.npy', model.transmat_)
    np.save(f'{output_dir}/initial_probs_{period}.npy', model.startprob_)
    np.save(f'{output_dir}/emission_means_{period}.npy', model.means_)
    np.save(f'{output_dir}/emission_covars_{period}.npy', model.covars_)
    data.to_csv(f'{output_dir}/predictions_{period}.csv', index=False)
    cluster_data.to_csv(f'{output_dir}/clusters_{period}.csv', index=False)

    plot_confusion_matrix(cm, period)
    plot_emission_distributions(model.means_, model.covars_, period)
    plot_predictions(data, period)
    plot_state_probabilities(data['Year'], state_probs, period)

    return {
        'accuracy': acc,
        'confusion_matrix': cm,
        'transition_matrix': model.transmat_,
        'initial_probs': model.startprob_,
        'emission_means': model.means_,
        'emission_covars': model.covars_,
        'state_map': state_map,
        'change_points': change_points,
        'cluster_data': cluster_data
    }

def main(index, anomaly):
    results = {}
    for period in periods:
        result = process_period(period, index, anomaly)
        if result:
            results[period] = result

    print("\n=== Summary ===")
    for period in results:
        print(f"\n{period}:")
        print(f"Accuracy: {results[period]['accuracy']:.4f}")
        print("Confusion Matrix:")
        print(results[period]['confusion_matrix'])
        print("Transition Matrix:")
        print(results[period]['transition_matrix'])
        print("Emission Means:")
        print(results[period]['emission_means'])
        print("Emission Covariances:")
        print(results[period]['emission_covars'])
        print("State Map (HMM -> ENSO):")
        print(results[period]['state_map'])
        print("Change Points:")
        for cp in results[period]['change_points']:
            print(f"Year {cp[0]}: {cp[1]} → {cp[2]}")
        print("Cluster Summary:")
        print(results[period]['cluster_data'].groupby('Cluster').agg({
            'Year': 'count',
            'State': lambda x: x.value_counts().to_dict()
        }).rename(columns={'Year': 'Count', 'State': 'State Distribution'}))

main(index, anomaly)


Processing period: DJF

Processing period: JFM

Processing period: FMA

Processing period: MAM

Processing period: AMJ

Processing period: MJJ

Processing period: JJA

Processing period: JAS

Processing period: ASO

Processing period: SON

Processing period: OND

Processing period: NDJ

=== Summary ===

DJF:
Accuracy: 0.5581
Confusion Matrix:
[[ 6  8  0]
 [ 0 18  0]
 [ 0 11  0]]
Transition Matrix:
[[2.26744154e-09 9.99999973e-01 2.44387261e-08]
 [6.82676952e-01 1.24073756e-05 3.17310641e-01]
 [1.58741991e-01 8.41258009e-01 1.30202873e-13]]
Emission Means:
[[-0.64373172  0.24779321]
 [-0.109633    0.15579874]
 [ 1.68340366 -0.08271336]]
Emission Covariances:
[[[ 0.50879039  0.24021446]
  [ 0.24021446  0.24813871]]

 [[ 0.65004734  0.20579118]
  [ 0.20579118  1.12304139]]

 [[ 0.41306414 -0.19132983]
  [-0.19132983  0.47330649]]]
State Map (HMM -> ENSO):
{0: np.int64(1), 1: np.int64(1), 2: np.int64(0)}
Change Points:
Year 1983: La Niña → El Niño
Year 1984: El Niño → La Niña
Year 1987: L