In [1]:
# ==========================================================
# FINAL OUT-OF-SAMPLE TEST SCRIPT FOR 2D ZONAL-MEAN MODELS
# ==========================================================
# This script contains all necessary code and corrected paths.
# Please use this to replace your entire out-of-sample notebook.

import os
import gc
import numpy as np
import tensorflow as tf
from sklearn.metrics import r2_score
from scipy.stats import pearsonr
from keras.models import load_model
import xarray as xr
import scipy.io as sio
import matplotlib.pyplot as plt
from tensorflow.keras import backend as K

# --- 1. Configuration: Set ONE Correct Path for Model Results ---
print("--- Setting up for Zonal-Mean Out-of-Sample Test ---")

# For CESM1:
path_model_dir = '/ocean/projects/ees250004p/ezhu3/data/CESM1/trained_model_2D_changedLRandKS'
# For CESM2:
#path_model_dir = '/ocean/projects/ees250004p/ezhu3/data/CESM2/trained_model_2D_changedLRandKS'

# Define variable names for your test files
In_name = "TS"
Out_name = "TOA_anom"

# Define paths to your TWO separate 4xCO2 test files
# # For CESM1:
file_4xCO2_input = "/ocean/projects/ees250004p/ezhu3/data/CESM1/test/test.4xCO2.ANN.new.nc"
file_4xCO2_output = "/ocean/projects/ees250004p/ezhu3/data/CESM1/test/test.4xCO2.zmean.ANN.new.nc"

# For CESM2:
#file_4xCO2_input = "/ocean/projects/ees250004p/ezhu3/data/CESM2/test/test.CESM2-4xCO2.ANN.nc"
#file_4xCO2_output = "/ocean/projects/ees250004p/ezhu3/data/CESM2/test/test.CESM2-4xCO2.zmean.ANN.nc"

# --- 2. Load Normalization Data from the Training Run ---
print("Loading normalization data...")
normalization_path = os.path.join(path_model_dir, 'Normalization_zonal.mat')
normalization = sio.loadmat(normalization_path)
X_mean = normalization['X_mean']
X_std = normalization['X_std']
y_mean = normalization['y_mean']
y_std = normalization['y_std']
print("✅ Normalization data loaded successfully.")
print(X_mean)
print(X_std)
print(y_mean)
print(y_std)
# --- 3. Load and Preprocess 4xCO2 Test Data ---
print("\nLoading and preprocessing 4xCO2 test data...")
ds_4xCO2_X = xr.open_dataset(file_4xCO2_input)
ds_4xCO2_y = xr.open_dataset(file_4xCO2_output)

# Extract variables from the correct files
TS_4xCO2_raw = ds_4xCO2_X[In_name]
TOA_4xCO2_truth = ds_4xCO2_y[Out_name].values
lat = ds_4xCO2_X['lat'].values
time_4xCO2 = ds_4xCO2_X['year'].values if 'year' in ds_4xCO2_X else ds_4xCO2_X['time'].values

# Normalize inputs correctly by adding the channel dimension first
TS_4xCO2_norm = (TS_4xCO2_raw.values[..., np.newaxis] - X_mean) / X_std
print("✅ Test data preprocessed.")

# --- 4. Prediction Loop ---
print("\n--- Running Ensemble Predictions for 4xCO2 ---")
n_folds = 5
predictions_from_folds = []

for fold_no in range(1, n_folds + 1):
    K.clear_session(); gc.collect()
    
    # Use the corrected, direct path to load the model
    model_path = os.path.join(path_model_dir, f'model_fold{fold_no}_ens1.h5')
    print(f"    Loading and predicting with model: {model_path}")
    
    model = load_model(model_path)
    
    pred_4xco2_norm = model.predict(TS_4xCO2_norm)
    pred_4xco2_unnorm = pred_4xco2_norm * y_std + y_mean
    predictions_from_folds.append(pred_4xco2_unnorm)

# Average predictions across the folds
Model_pred_4xco2 = np.mean(np.stack(predictions_from_folds), axis=0)
print("\n✅ Prediction complete.")

2025-08-06 02:57:31.315361: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-08-06 02:57:31.421680: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-08-06 02:57:31.454552: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


--- Setting up for Zonal-Mean Out-of-Sample Test ---
Loading normalization data...
✅ Normalization data loaded successfully.
[[[220.6603 ]
  [220.86879]
  [220.88463]
  ...
  [220.85564]
  [220.6545 ]
  [220.89789]]

 [[221.35854]
  [221.40932]
  [221.33986]
  ...
  [221.18336]
  [221.2454 ]
  [221.24274]]

 [[221.73988]
  [221.76971]
  [221.81885]
  ...
  [221.57878]
  [221.65894]
  [221.67375]]

 ...

 [[250.342  ]
  [250.36276]
  [250.38281]
  ...
  [250.28914]
  [250.30632]
  [250.32236]]

 [[250.1685 ]
  [250.17897]
  [250.18936]
  ...
  [250.13574]
  [250.14691]
  [250.15811]]

 [[250.01157]
  [250.01212]
  [250.01262]
  ...
  [250.0096 ]
  [250.0104 ]
  [250.01102]]]
[[[0.53236264]
  [0.5272247 ]
  [0.52714396]
  ...
  [0.52715623]
  [0.53207755]
  [0.52702093]]

 [[0.5789421 ]
  [0.57852644]
  [0.58601373]
  ...
  [0.57942045]
  [0.58237875]
  [0.58337593]]

 [[0.6280233 ]
  [0.61705726]
  [0.6132673 ]
  ...
  [0.62475264]
  [0.6240839 ]
  [0.61795306]]

 ...

 [[1.0660723 ]
  

2025-08-06 02:57:34.666864: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-08-06 02:57:35.099741: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1616] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 469 MB memory:  -> device: 0, name: Tesla V100-SXM2-32GB, pci bus id: 0000:8a:00.0, compute capability: 7.0
2025-08-06 02:57:45.907360: W tensorflow/core/common_runtime/bfc_allocator.cc:479] Allocator (GPU_0_bfc) ran out of memory trying to allocate 216.00MiB (rounded to 226492416)requested by op model/conv2d/Conv2D
If the cause is memory fragmentation maybe the environment variable 'TF_GPU_ALLOCATOR=cuda_malloc_async' will improve the situation. 
Current allocation

ResourceExhaustedError: Graph execution error:

Detected at node 'model/conv2d/Conv2D' defined at (most recent call last):
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/runpy.py", line 194, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/site-packages/ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/site-packages/traitlets/config/application.py", line 992, in launch_instance
      app.start()
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/site-packages/ipykernel/kernelapp.py", line 711, in start
      self.io_loop.start()
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/site-packages/tornado/platform/asyncio.py", line 215, in start
      self.asyncio_loop.run_forever()
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/asyncio/base_events.py", line 570, in run_forever
      self._run_once()
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/asyncio/base_events.py", line 1859, in _run_once
      handle._run()
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/asyncio/events.py", line 81, in _run
      self._context.run(self._callback, *self._args)
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 510, in dispatch_queue
      await self.process_one()
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 499, in process_one
      await dispatch(*args)
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 406, in dispatch_shell
      await result
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 729, in execute_request
      reply_content = await reply_content
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/site-packages/ipykernel/ipkernel.py", line 411, in do_execute
      res = shell.run_cell(
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/site-packages/ipykernel/zmqshell.py", line 531, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2961, in run_cell
      result = self._run_cell(
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3016, in _run_cell
      result = runner(coro)
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3221, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3400, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3460, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/var/tmp/ipykernel_42437/175387373.py", line 82, in <module>
      pred_4xco2_norm = model.predict(TS_4xCO2_norm)
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/site-packages/keras/engine/training.py", line 2253, in predict
      tmp_batch_outputs = self.predict_function(iterator)
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/site-packages/keras/engine/training.py", line 2041, in predict_function
      return step_function(self, iterator)
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/site-packages/keras/engine/training.py", line 2027, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/site-packages/keras/engine/training.py", line 2015, in run_step
      outputs = model.predict_step(data)
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/site-packages/keras/engine/training.py", line 1983, in predict_step
      return self(x, training=False)
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/site-packages/keras/engine/training.py", line 557, in __call__
      return super().__call__(*args, **kwargs)
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/site-packages/keras/engine/functional.py", line 510, in call
      return self._run_internal_graph(inputs, training=training, mask=mask)
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/site-packages/keras/engine/functional.py", line 667, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/site-packages/keras/layers/convolutional/base_conv.py", line 283, in call
      outputs = self.convolution_op(inputs, self.kernel)
    File "/jet/home/ezhu3/.conda/envs/tf210/lib/python3.8/site-packages/keras/layers/convolutional/base_conv.py", line 255, in convolution_op
      return tf.nn.convolution(
Node: 'model/conv2d/Conv2D'
OOM when allocating tensor with shape[32,32,192,288] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
	 [[{{node model/conv2d/Conv2D}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.
 [Op:__inference_predict_function_576]

In [None]:
# ==========================================================
# Out-of-Sample Analysis and Visualization (Corrected Plotting Axis)
# ==========================================================
print("\n--- Starting Out-of-Sample Analysis ---")

# --- Task 1: Calculate Overall Pattern Correlation ---
print("\n    Calculating Overall Pattern Correlation...")
truth_flat = TOA_4xCO2_truth.flatten()
pred_flat = Model_pred_4xco2.flatten()
pattern_r, _ = pearsonr(truth_flat, pred_flat)
print(f"✅ Overall Pattern Correlation (r) = {pattern_r:.4f}")

# --- Task 2: Plot R-squared as a Function of Latitude ---
print("\n    Calculating and plotting R-squared per latitude...")
r2_by_latitude = [r2_score(TOA_4xCO2_truth[:, i], Model_pred_4xco2[:, i]) for i in range(len(lat))]

plt.figure(figsize=(10, 6))
plt.plot(lat, r2_by_latitude, marker='o', linestyle='-')
plt.title('Out-of-Sample Performance (R²) by Latitude - 4xCO2', fontsize=16)
plt.xlabel('Latitude', fontsize=12)
plt.ylabel('R-squared Score', fontsize=12)
plt.grid(True, linestyle='--'); plt.ylim(0, 1)
plt.show()

# --- Task 3: Plot Truth vs. Prediction as a 2D Contour Map ---
print("\n    Plotting Truth vs. Prediction as contour maps...")

# This creates a simple numerical axis [0, 1, 2, ...] for plotting
time_axis_for_plot = np.arange(TOA_4xCO2_truth.shape[0])

fig, axes = plt.subplots(1, 2, figsize=(18, 6), sharey=True)
vmax = np.percentile(np.abs(TOA_4xCO2_truth), 98)
vmin = -vmax

axes[0].set_title('Ground Truth TOA Zonal Mean', fontsize=16)
# Use the new simple time axis for plotting
cf1 = axes[0].contourf(time_axis_for_plot, lat, TOA_4xCO2_truth.T, levels=20, cmap='RdBu_r', vmin=vmin, vmax=vmax)
axes[0].set_xlabel('Time (Model Years)', fontsize=12)
axes[0].set_ylabel('Latitude', fontsize=12)

axes[1].set_title('Predicted TOA Zonal Mean', fontsize=16)
# Use the new simple time axis for plotting here as well
cf2 = axes[1].contourf(time_axis_for_plot, lat, Model_pred_4xco2.T, levels=20, cmap='RdBu_r', vmin=vmin, vmax=vmax)
axes[1].set_xlabel('Time (Model Years)', fontsize=12)

fig.colorbar(cf1, ax=axes.ravel().tolist(), shrink=0.8, label='TOA Anomaly (W/m²)')
fig.suptitle("Out-of-Sample Results - 4xCO2 Scenario", fontsize=18, fontweight='bold')
plt.show()

In [None]:
# =========================================================
# Task 4: Calculate and Compare Global Mean Time Series
# =========================================================
print("\n--- Calculating and Comparing Weighted Global Means ---")

# --- Step 1: Calculate Latitude Weights ---
# To get a true global mean, we must weight each latitude by the cosine
# of its angle to account for the smaller grid cell areas near the poles.
lat_radians = np.deg2rad(lat)
weights = np.cos(lat_radians)
# Ensure weights have the correct shape for broadcasting during the average
weights = weights[np.newaxis, :]

# --- Step 2: Calculate Weighted Average for Truth and Prediction ---
# We average over the latitude axis (axis=1) to get a single global
# mean value for each time step.
global_mean_truth = np.average(TOA_4xCO2_truth, axis=1, weights=weights.flatten())
global_mean_pred = np.average(Model_pred_4xco2, axis=1, weights=weights.flatten())

print("✅ Weighted global means calculated.")

# --- Step 3: Print and Compare the Overall Mean Values ---
# This gives a single number summary of the entire time series
print(f"    Overall Mean of Ground Truth: {np.mean(global_mean_truth):.4f} W/m²")
print(f"    Overall Mean of Prediction:   {np.mean(global_mean_pred):.4f} W/m²")

# --- Step 4: Plot the Global Mean Time Series for Comparison ---
print("\n    Plotting global mean time series comparison...")
plt.figure(figsize=(12, 6))
time_axis_for_plot = np.arange(global_mean_truth.shape[0])

plt.plot(time_axis_for_plot, global_mean_truth, label='Ground Truth', color='black', linewidth=2)
plt.plot(time_axis_for_plot, global_mean_pred, label='Model Prediction', color='red', linestyle='--')

plt.title('Out-of-Sample: Global Mean TOA Anomaly Time Series', fontsize=16)
plt.xlabel('Time (Model Years)', fontsize=12)
plt.ylabel('Global Mean Anomaly (W/m²)', fontsize=12)
plt.legend()
plt.grid(True, linestyle=':')
plt.show()

In [None]:
# =========================================================
# Task 4: Calculate and Compare Global Mean Time Series
# =========================================================
print("\n--- Calculating and Comparing Weighted Global Means ---")

# --- Step 1: Calculate Latitude Weights ---
# To get a true global mean, we must weight each latitude by the cosine
# of its angle to account for the smaller grid cell areas near the poles.
lat_radians = np.deg2rad(lat)
weights = np.cos(lat_radians)
# Ensure weights have the correct shape for broadcasting during the average
weights = weights[np.newaxis, :]

# --- Step 2: Calculate Weighted Average for Truth and Prediction ---
# We average over the latitude axis (axis=1) to get a single global
# mean value for each time step.
global_mean_truth = np.average(TOA_4xCO2_truth, axis=1, weights=weights.flatten())
global_mean_pred = np.average(Model_pred_4xco2, axis=1, weights=weights.flatten())

print("✅ Weighted global means calculated.")

# --- Step 3: Print and Compare the Overall Mean Values ---
# This gives a single number summary of the entire time series
print(f"    Overall Mean of Ground Truth: {np.mean(global_mean_truth):.4f} W/m²")
print(f"    Overall Mean of Prediction:   {np.mean(global_mean_pred):.4f} W/m²")

# --- Step 4: Plot the Global Mean Time Series for Comparison ---
print("\n    Plotting global mean time series comparison...")
plt.figure(figsize=(12, 6))
time_axis_for_plot = np.arange(global_mean_truth.shape[0])

plt.plot(time_axis_for_plot, global_mean_truth, label='Ground Truth', color='black', linewidth=2)
plt.plot(time_axis_for_plot, global_mean_pred, label='Model Prediction', color='red', linestyle='--')

plt.title('Out-of-Sample: Global Mean TOA Anomaly Time Series', fontsize=16)
plt.xlabel('Time (Model Years)', fontsize=12)
plt.ylabel('Global Mean Anomaly (W/m²)', fontsize=12)
plt.legend()
plt.grid(True, linestyle=':')
plt.show()

In [None]:
# =================================================================
# Task 5: Calculate and Plot Climate Feedback Parameter (λ)
# =================================================================
from scipy.stats import linregress

print("\n--- Calculating and Plotting Climate Feedback Parameter (λ) ---")

# --- Step 1: Calculate the X-axis data (Area-Weighted Global Mean TS Anomaly) ---
# We already have the latitude weights from the previous task.
# First, take the mean across the longitude axis of the raw input data.
TS_zonal_mean_truth = np.mean(TS_4xCO2_raw.values, axis=2)

# Now, calculate the latitude-weighted average to get the global mean.
global_mean_TS_truth = np.average(TS_zonal_mean_truth, axis=1, weights=weights.flatten())

# --- Step 2: Set up the side-by-side plots ---
fig, axes = plt.subplots(1, 2, figsize=(20, 8), sharey=True)
fig.suptitle('Climate Feedback Parameter (λ) - 4xCO2 Out-of-Sample', fontsize=20, weight='bold')

# --- Step 3: Plot for the Ground Truth ---
ax1 = axes[0]
# Perform linear regression to find the slope (lambda)
slope_truth, intercept_truth, r_value_truth, _, _ = linregress(global_mean_TS_truth, global_mean_truth)
lambda_truth = slope_truth
r2_truth = r_value_truth**2

# Scatter plot
ax1.scatter(global_mean_TS_truth, global_mean_truth, alpha=0.6, label='Yearly Data (Ground Truth)')
# Best-fit line
fit_line_truth = slope_truth * global_mean_TS_truth + intercept_truth
ax1.plot(global_mean_TS_truth, fit_line_truth, color='red', linestyle='--', label='Linear Best Fit')

# Add text box with results
text_truth = (f'λ = {lambda_truth:.3f} W/m²/K\n'
              f'$R^2$ = {r2_truth:.3f}')
ax1.text(0.05, 0.95, text_truth, transform=ax1.transAxes, fontsize=14,
         verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

ax1.set_title('Ground Truth Data', fontsize=16)
ax1.set_xlabel('Area-Weighted Global Mean TS Anomaly (K)', fontsize=12)
ax1.set_ylabel('Global Mean TOA Anomaly (W/m²)', fontsize=12)
ax1.grid(True, linestyle=':')
ax1.legend()


# --- Step 4: Plot for the Model Prediction ---
ax2 = axes[1]
# Perform linear regression
slope_pred, intercept_pred, r_value_pred, _, _ = linregress(global_mean_TS_truth, global_mean_pred)
lambda_pred = slope_pred
r2_pred = r_value_pred**2

# Scatter plot
ax2.scatter(global_mean_TS_truth, global_mean_pred, alpha=0.6, label='Yearly Data (Model Prediction)')
# Best-fit line
fit_line_pred = slope_pred * global_mean_TS_truth + intercept_pred
ax2.plot(global_mean_TS_truth, fit_line_pred, color='red', linestyle='--', label='Linear Best Fit')

# Add text box with results
text_pred = (f'λ = {lambda_pred:.3f} W/m²/K\n'
             f'$R^2$ = {r2_pred:.3f}')
ax2.text(0.05, 0.95, text_pred, transform=ax2.transAxes, fontsize=14,
         verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

ax2.set_title('Model Prediction', fontsize=16)
ax2.set_xlabel('Area-Weighted Global Mean TS Anomaly (K)', fontsize=12)
# No Y-label needed as it's shared with ax1
ax2.grid(True, linestyle=':')
ax2.legend()

plt.tight_layout(rect=[0, 0, 1, 0.96]) # Adjust layout to make room for suptitle
plt.show()

In [None]:
# ==========================================================
# NEW ANALYSIS: Global Mean of Zonal-Mean Prediction vs. Truth
# ==========================================================
print("\n--- Starting new analysis: Comparing the area-weighted global mean ---")

# --- 1. Helper Function for Area-Weighted Mean ---
# This function correctly calculates the global mean from zonal-mean [time, lat] data.
def calculate_area_weighted_global_mean_zonal(data_2d, lat_coords):
    """
    Calculates the area-weighted global mean from a [time, lat] array.
    """
    # The weights for each latitude band are proportional to the cosine of the latitude.
    weights = np.cos(np.deg2rad(lat_coords))
    # np.average calculates the weighted average over the latitude axis (axis=1).
    global_mean_timeseries = np.average(data_2d, axis=1, weights=weights)
    return global_mean_timeseries

# --- 2. Calculate the Global Mean Time Series ---
# This assumes 'TOA_4xCO2_truth' and 'Model_pred_4xco2' are in memory from the previous cell.
# It also assumes 'lat' (your latitude coordinate array) is in memory.

print("    Calculating global mean for both truth and prediction...")
global_mean_truth = calculate_area_weighted_global_mean_zonal(TOA_4xCO2_truth, lat)
global_mean_pred = calculate_area_weighted_global_mean_zonal(Model_pred_4xco2, lat)

# --- 3. Plot the Comparison Graph ---
print("    Plotting the global mean comparison graph...")

# Calculate the R-squared score for the global mean time series
r2_global_mean = r2_score(global_mean_truth, global_mean_pred)

# === THIS IS THE FIX ===
# Instead of using the complex time objects from the file, we create a simple numerical axis.
# This will be an array like [0, 1, 2, ...] that has the correct length for plotting.
time_axis_for_plot = np.arange(len(global_mean_truth))
# ======================

# Create the plot
fig, ax = plt.subplots(figsize=(12, 6))

# Use the new, simple time axis for plotting
ax.plot(time_axis_for_plot, global_mean_truth, label="Truth (Global Mean)", color="k", linewidth=2.5)
ax.plot(time_axis_for_plot, global_mean_pred, label="Prediction (Global Mean)", color="C3", linestyle='--')

# Add R-squared annotation
ax.text(0.02, 0.95, f"$R^2$ = {r2_global_mean:.3f}", transform=ax.transAxes,
        fontsize=16, bbox=dict(facecolor="white", edgecolor="black", alpha=0.7))

# Style the plot
ax.set_xlabel("Time (Model Years)", fontsize=16) # Label updated for clarity
ax.set_ylabel("Global Mean TOA Anomaly (W/m²)", fontsize=16)
ax.grid(True, linestyle="--", alpha=0.5)
ax.tick_params(axis='both', labelsize=14)
ax.set_title("Model Performance on Global Mean (from Zonal-Mean Prediction)", fontsize=18, pad=15)
ax.legend(fontsize=14, loc="best")

plt.tight_layout()
plt.show()