In [2]:
import torch
import numpy as np
import pandas as pd
from sklearn.metrics import r2_score
from train import apply_inverse_zca_whitening_4d_torch
from unet import UNet
from utils import *
from torch.utils.data import TensorDataset, DataLoader

setup_random_seeds(42)
device = get_device()

base_path = "gs://leap-persistent/YueWang/SSH/data"
storage_opts = {"token": "cloud", "asynchronous": False}

g = 9.81
dx_map = {54: 1500.0, 80: 1500.0, 108: 1500.0}  # adjust if grid spacing differs
dy_map = {54: 1500.0, 80: 1500.0, 108: 1500.0}
f_cor = -8.6e-5

# ============================================================================
# CONFIG: patch sizes and checkpoint paths
# ============================================================================
configs = [
    {"size": 108, "checkpoint": "/home/jovyan/GRL_ssh/checkpoints/ps108.pth",
     "train": "train_108_sst.zarr", "test": "test_108_sst.zarr", "zca": "zca_108.zarr"},
    {"size": 80, "checkpoint": "/home/jovyan/GRL_ssh/checkpoints/sst_ssh.pth",
     "train": "train_80_sst.zarr", "test": "test_80_sst.zarr", "zca": "zca_80.zarr"},
    {"size": 54, "checkpoint": "/home/jovyan/GRL_ssh/checkpoints/ps54.pth",
     "train": "train_54_sst.zarr", "test": "test_54_sst.zarr", "zca": "zca_54.zarr"},
]

# ============================================================================
# HELPER FUNCTIONS
# ============================================================================
def geostrophic_vel(field_2d, dx, dy):
    dη_dy = np.gradient(field_2d, dy, axis=0, edge_order=2)
    dη_dx = np.gradient(field_2d, dx, axis=1, edge_order=2)
    u = -g / f_cor * dη_dy
    v =  g / f_cor * dη_dx
    return u, v

def r2_corr_arrays(truth, pred):
    """Per-sample R² and correlation with mean, p05, p95."""
    r2_vals, c_vals = [], []
    for t, p in zip(truth, pred):
        m = np.isfinite(t) & np.isfinite(p)
        if m.sum() < 2:
            r2_vals.append(np.nan); c_vals.append(np.nan)
        else:
            r2_vals.append(r2_score(t[m], p[m]))
            c_vals.append(np.corrcoef(t[m], p[m])[0, 1])
    r2_vals, c_vals = np.array(r2_vals), np.array(c_vals)
    return (
        (np.nanmean(r2_vals), np.nanpercentile(r2_vals, 5), np.nanpercentile(r2_vals, 95)),
        (np.nanmean(c_vals), np.nanpercentile(c_vals, 5), np.nanpercentile(c_vals, 95))
    )

# ============================================================================
# EVALUATE EACH PATCH SIZE
# ============================================================================
all_records_r2 = []
all_records_corr = []

for cfg in configs:
    ps = cfg["size"]
    print(f"\n{'='*60}")
    print(f"Evaluating patch size {ps}×{ps}")
    print(f"{'='*60}")
    
    # Load data
    train_ds = open_zarr(f"{base_path}/{cfg['train']}", storage_opts)
    test_ds = open_zarr(f"{base_path}/{cfg['test']}", storage_opts)
    zca_ds = open_zarr(f"{base_path}/{cfg['zca']}", storage_opts)
    
    Vt = torch.from_numpy(zca_ds.ubm_Vt.values).float().to(device)
    scale = torch.from_numpy(zca_ds.ubm_scale.values).float().to(device)
    mean = torch.from_numpy(zca_ds.ubm_mean.values).float().to(device)
    
    # Normalization from training set
    x_train_ssh = torch.from_numpy(train_ds.ssh.values).float().unsqueeze(1).to(device)
    x_train_sst = torch.from_numpy(train_ds.sst.values).float().unsqueeze(1).to(device)
    x_train = torch.cat([x_train_ssh, x_train_sst], dim=1)
    _, min_vals, max_vals = min_max_normalize(x_train)
    del x_train, x_train_ssh, x_train_sst  # free memory
    
    # Test data
    x_test_ssh = torch.from_numpy(test_ds.ssh.values).float().unsqueeze(1).to(device)
    x_test_sst = torch.from_numpy(test_ds.sst.values).float().unsqueeze(1).to(device)
    x_test = torch.cat([x_test_ssh, x_test_sst], dim=1)
    x_test_norm, _, _ = min_max_normalize(x_test, min_vals, max_vals)
    
    y_test_phys = torch.from_numpy(test_ds.ubm.values).float().unsqueeze(1).to(device)
    y_test_zca = torch.from_numpy(test_ds.zca_ubm.values).float().unsqueeze(1).to(device)
    y_test = torch.cat([y_test_phys, y_test_zca], dim=1)
    
    test_loader = DataLoader(TensorDataset(x_test_norm, y_test), batch_size=128, shuffle=False)
    
    # Load model
    model = UNet(in_channels=2, out_channels=2, initial_features=32, depth=4).to(device)
    ckpt = torch.load(cfg["checkpoint"], map_location=device)
    model.load_state_dict(ckpt['model_state_dict'])
    model.eval()
    
    # Run inference
    ubm_preds, ubm_trues, ssh_originals = [], [], []
    
    with torch.no_grad():
        for i, (bx, by) in enumerate(test_loader):
            bs = i * test_loader.batch_size
            be = min(bs + test_loader.batch_size, len(test_loader.dataset))
            idx = list(range(bs, be))
            
            bx = bx.to(device)
            outputs = model(bx)
            
            mu_zca = outputs[:, 0:1, ...]
            ubm_pred = apply_inverse_zca_whitening_4d_torch(mu_zca, Vt, scale, mean)
            
            ubm_preds.append(ubm_pred.squeeze(1).cpu().numpy())
            ubm_trues.append(by[:, 0, ...].cpu().numpy())
            ssh_originals.append(x_test_ssh[idx].squeeze(1).cpu().numpy())
    
    ubm_pred_all = np.concatenate(ubm_preds, axis=0)
    ubm_true_all = np.concatenate(ubm_trues, axis=0)
    ssh_all = np.concatenate(ssh_originals, axis=0)
    
    bm_pred_all = ssh_all - ubm_pred_all
    bm_true_all = ssh_all - ubm_true_all
    
    # Clean mask
    clean = ~np.isnan(ubm_true_all).reshape(ubm_true_all.shape[0], -1).any(axis=1)
    
    # Flatten clean samples
    def fc(arr):
        return arr[clean].reshape(clean.sum(), -1)
    
    # Metrics
    dx_val, dy_val = dx_map[ps], dy_map[ps]
    
    (ubm_r2, ubm_r2_5, ubm_r2_95), (ubm_c, ubm_c_5, ubm_c_95) = r2_corr_arrays(fc(ubm_true_all), fc(ubm_pred_all))
    (bm_r2, bm_r2_5, bm_r2_95), (bm_c, bm_c_5, bm_c_95) = r2_corr_arrays(fc(bm_true_all), fc(bm_pred_all))
    
    # Geostrophic velocities
    u_true = np.array([geostrophic_vel(f, dx_val, dy_val)[0] for f in bm_true_all])
    v_true = np.array([geostrophic_vel(f, dx_val, dy_val)[1] for f in bm_true_all])
    u_pred = np.array([geostrophic_vel(f, dx_val, dy_val)[0] for f in bm_pred_all])
    v_pred = np.array([geostrophic_vel(f, dx_val, dy_val)[1] for f in bm_pred_all])
    
    (u_r2, u_r2_5, u_r2_95), (u_c, u_c_5, u_c_95) = r2_corr_arrays(fc(u_true), fc(u_pred))
    (v_r2, v_r2_5, v_r2_95), (v_c, v_c_5, v_c_95) = r2_corr_arrays(fc(v_true), fc(v_pred))
    
    avg_r2 = np.nanmean([ubm_r2, bm_r2, u_r2, v_r2])
    avg_c = np.nanmean([ubm_c, bm_c, u_c, v_c])
    
    all_records_r2.append({
        "size": f"{ps}×{ps}",
        "UBM": f"{ubm_r2:.2f} ({ubm_r2_5:.2f}, {ubm_r2_95:.2f})",
        "BM": f"{bm_r2:.2f} ({bm_r2_5:.2f}, {bm_r2_95:.2f})",
        "U": f"{u_r2:.2f} ({u_r2_5:.2f}, {u_r2_95:.2f})",
        "V": f"{v_r2:.2f} ({v_r2_5:.2f}, {v_r2_95:.2f})",
        "Avg": f"{avg_r2:.2f}"
    })
    
    all_records_corr.append({
        "size": f"{ps}×{ps}",
        "UBM": f"{ubm_c:.2f} ({ubm_c_5:.2f}, {ubm_c_95:.2f})",
        "BM": f"{bm_c:.2f} ({bm_c_5:.2f}, {bm_c_95:.2f})",
        "U": f"{u_c:.2f} ({u_c_5:.2f}, {u_c_95:.2f})",
        "V": f"{v_c:.2f} ({v_c_5:.2f}, {v_c_95:.2f})",
        "Avg": f"{avg_c:.2f}"
    })
    
    # Cleanup
    del model, x_test_ssh, x_test_sst, x_test, x_test_norm, y_test
    torch.cuda.empty_cache()
    
    print(f"  UBM R²={ubm_r2:.3f}, BM R²={bm_r2:.3f}, U R²={u_r2:.3f}, V R²={v_r2:.3f}")

# ============================================================================
# PRINT RESULTS
# ============================================================================
print("\n=== R² Table ===")
df_r2 = pd.DataFrame(all_records_r2).set_index("size")
print(df_r2)

print("\n=== Correlation Table ===")
df_corr = pd.DataFrame(all_records_corr).set_index("size")
print(df_corr)

# ============================================================================
# GENERATE LATEX
# ============================================================================
def to_latex_table(df, caption, table_num, metric_name):
    lines = []
    lines.append(r"\begin{table}")
    lines.append(f"\\settablenum{{{table_num}}}")
    lines.append(f"\\caption{{{caption}}}")
    lines.append(r"\centering")
    lines.append(r"\begin{tabular}{l c c c c c}")
    lines.append(r"\hline")
    lines.append(f"\\textbf{{Mesh Size}} & \\textbf{{UBM (5th, 95th)}} & \\textbf{{BM (5th, 95th)}} & \\textbf{{U (5th, 95th)}} & \\textbf{{V (5th, 95th)}} & \\textbf{{Avg. {metric_name}}} \\\\")
    lines.append(r"\hline")
    for size, row in df.iterrows():
        lines.append(f"{size} & {row['UBM']} & {row['BM']} & {row['U']} & {row['V']} & {row['Avg']} \\\\")
    lines.append(r"\hline")
    lines.append(r"\end{tabular}")
    lines.append(r"\end{table}")
    return "\n".join(lines)

print("\n" + "="*60)
print("LATEX OUTPUT")
print("="*60)

print(to_latex_table(df_r2, 
    r"R$^2$ Performance Comparison of ZCA+SST Model at Different Mesh Resolutions for UBM, BM, and Geostrophic Velocity Prediction. 5th and 95th percentiles are shown in parentheses.",
    "S3", r"R$^2$"))

print()

print(to_latex_table(df_corr,
    r"Correlation Performance Comparison of ZCA+SST Model at Different Mesh Resolutions for UBM, BM, and Geostrophic Velocity Prediction. 5th and 95th percentiles are shown in parentheses.",
    "S4", "corr"))

Using device: cuda

Evaluating patch size 108×108


  return cls(fs=fs, path=path, read_only=read_only, allowed_exceptions=allowed_exceptions)
  return cls(fs=fs, path=path, read_only=read_only, allowed_exceptions=allowed_exceptions)
  return cls(fs=fs, path=path, read_only=read_only, allowed_exceptions=allowed_exceptions)


  UBM R²=0.021, BM R²=0.981, U R²=0.862, V R²=0.857

Evaluating patch size 80×80


  return cls(fs=fs, path=path, read_only=read_only, allowed_exceptions=allowed_exceptions)
  return cls(fs=fs, path=path, read_only=read_only, allowed_exceptions=allowed_exceptions)
  return cls(fs=fs, path=path, read_only=read_only, allowed_exceptions=allowed_exceptions)


  UBM R²=0.010, BM R²=0.969, U R²=0.848, V R²=0.849

Evaluating patch size 54×54


  return cls(fs=fs, path=path, read_only=read_only, allowed_exceptions=allowed_exceptions)
  return cls(fs=fs, path=path, read_only=read_only, allowed_exceptions=allowed_exceptions)
  return cls(fs=fs, path=path, read_only=read_only, allowed_exceptions=allowed_exceptions)


  UBM R²=-0.196, BM R²=0.931, U R²=0.809, V R²=0.805

=== R² Table ===
                         UBM                 BM                  U  \
size                                                                 
108×108   0.02 (-0.97, 0.69)  0.98 (0.92, 1.00)  0.86 (0.63, 0.98)   
80×80     0.01 (-0.94, 0.73)  0.97 (0.87, 1.00)  0.85 (0.57, 0.97)   
54×54    -0.20 (-1.76, 0.73)  0.93 (0.71, 1.00)  0.81 (0.45, 0.97)   

                         V   Avg  
size                              
108×108  0.86 (0.54, 0.98)  0.68  
80×80    0.85 (0.58, 0.98)  0.67  
54×54    0.80 (0.41, 0.97)  0.59  

=== Correlation Table ===
                        UBM                 BM                  U  \
size                                                                
108×108   0.47 (0.06, 0.86)  0.99 (0.97, 1.00)  0.93 (0.81, 0.99)   
80×80     0.48 (0.01, 0.87)  0.99 (0.94, 1.00)  0.92 (0.78, 0.99)   
54×54    0.50 (-0.05, 0.90)  0.98 (0.90, 1.00)  0.91 (0.72, 0.99)   

                         V   A