In [1]:

import os
import time
from contextlib import nullcontext

import psutil
import hydra
from hydra.utils import to_absolute_path
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
import torch
from torch.nn.parallel import DistributedDataParallel
from torch.utils.tensorboard import SummaryWriter
import nvtx
import wandb

from physicsnemo import Module
from physicsnemo.models.diffusion import UNet, EDMPrecondSuperResolution
from physicsnemo.distributed import DistributedManager
from physicsnemo.metrics.diffusion import RegressionLoss, ResidualLoss, RegressionLossCE
from physicsnemo.utils.patching import RandomPatching2D
from physicsnemo.launch.logging.wandb import initialize_wandb
from physicsnemo.launch.logging import PythonLogger, RankZeroLoggingWrapper
from physicsnemo.launch.utils import (
    load_checkpoint,
    save_checkpoint,
    get_checkpoint_dir,
)
from physicsnemo.experimental.metrics.diffusion import tEDMResidualLoss
from physicsnemo.experimental.models.diffusion.preconditioning import (
    tEDMPrecondSuperRes,
)

from datasets.dataset import init_train_valid_datasets_from_config, register_dataset
from helpers.train_helpers import (
    set_patch_shape,
    set_seed,
    configure_cuda_for_consistent_precision,
    compute_num_accumulation_rounds,
    handle_and_clip_gradients,
    is_time_for_periodic_task,
)


  from physicsnemo.experimental.metrics.diffusion import tEDMResidualLoss


In [26]:
import xarray as xr
import matplotlib.pyplot as plt
import numpy as np

# -----------------------------
# 1️⃣ Load datasets
# -----------------------------
ds_truth = xr.open_dataset("corrdiff_output.nc", group="truth", decode_times=True)
ds_pred = xr.open_dataset("corrdiff_output.nc", group="prediction", decode_times=True)
ds_input = xr.open_dataset("corrdiff_output.nc", group="input", decode_times=True)

# Choose variable (can switch to "pr")
var_name = "pr"
print(f"Analyzing variable: {var_name}")

truth = ds_truth[var_name]                # (time, y, x)
pred = ds_pred[var_name].isel(ensemble=0) # (time, y, x)
coarse = ds_input[var_name]               # (time, y, x)

# Align all datasets
truth, pred, coarse = xr.align(truth, pred, coarse)

# Compute bias
bias = pred - truth

# -----------------------------
# 2️⃣ Determine plotting parameters
# -----------------------------
tsteps = truth.time
n_time = len(tsteps)

# Common color scale for fields (not bias)
vmin = float(min(truth.min(), pred.min(), coarse.min()))
vmax = float(max(truth.max(), pred.max(), coarse.max()))

# Bias color scale symmetric around 0
bmax = float(np.nanmax(np.abs(bias)))
bmin = -bmax

# Make nice time labels
def format_time(t):
    t_str = str(t.values)

time_labels = [format_time(t) for t in tsteps]

# -----------------------------
# 3️⃣ Large Text Styling
# -----------------------------
plt.rcParams.update({
    "font.size": 30,        # base font size (3× larger)
    "axes.titlesize": 36,   # title font
    "axes.labelsize": 30,   # axis labels
    "xtick.labelsize": 26,  # tick labels
    "ytick.labelsize": 26,
    "legend.fontsize": 28,
})

# -----------------------------
# 4️⃣ Create panel plot
# -----------------------------
rows = ["Prediction", "Truth", "Input (Coarse)", "Bias (Pred − Truth)"]

fig, axes = plt.subplots(
    nrows=len(rows),
    ncols=n_time,
    figsize=(3 * n_time, 12),
    constrained_layout=True,
    sharex=True,
    sharey=True
)

for j, t in enumerate(tsteps):
    t_label = time_labels[j]
    p = pred.sel(time=t)
    t_ = truth.sel(time=t)
    c = coarse.sel(time=t)
    b = bias.sel(time=t)

    for i, (data, title, cmap, vmin_i, vmax_i) in enumerate([
        (p, "Prediction", "coolwarm", vmin, vmax),
        (t_, "Truth", "coolwarm", vmin, vmax),
        (c, "Input", "coolwarm", vmin, vmax),
        (b, "Bias", "RdBu_r", bmin, bmax),
    ]):
        ax = axes[i, j]
        im = ax.pcolormesh(
            data["x"], data["y"], data,
            shading="auto", cmap=cmap, vmin=vmin_i, vmax=vmax_i
        )
        if j == 0:
            ax.set_ylabel(title, fontsize=34, labelpad=20)
        if i == 0:
            ax.set_title(t_label, fontsize=34, pad=20)
        ax.set_xticks([])
        ax.set_yticks([])

# -----------------------------
# 5️⃣ Colorbars
# -----------------------------
cb_fields = fig.colorbar(
    im, ax=axes[0:3, :].ravel().tolist(),
    orientation="horizontal", fraction=0.03, pad=0.05
)
cb_fields.set_label(f"{var_name} units", fontsize=34)
cb_fields.ax.tick_params(labelsize=28)

cb_bias = fig.colorbar(
    im, ax=axes[-1, :].ravel().tolist(),
    orientation="horizontal", fraction=0.03, pad=0.1
)
cb_bias.set_label("Bias (Pred − Truth)", fontsize=34)
cb_bias.ax.tick_params(labelsize=28)

# -----------------------------
# 6️⃣ Title and show
# -----------------------------
plt.suptitle(
    f"{var_name.upper()} — CorrDiff Output (Prediction / Truth / Input / Bias)",
    fontsize=40, y=1.05
)

plt.savefig(f"{var_name}_panel.svg", format="svg", dpi=300)
print(f"Saved panel figure as {var_name}_panel.svg")
plt.close()

Analyzing variable: pr
Saved panel figure as pr_panel.svg


In [24]:
import numpy as np
import pandas as pd
import xarray as xr

# -----------------------------
# Load datasets
# -----------------------------
file = "corrdiff_output.nc"
var_name = "tas"  # or "pr"

ds_truth = xr.open_dataset(file, group="truth", decode_times=True)
ds_pred = xr.open_dataset(file, group="prediction", decode_times=True)

truth = ds_truth[var_name]                # (time, y, x)
pred = ds_pred[var_name].isel(ensemble=0) # (time, y, x)
truth, pred = xr.align(truth, pred)

# -----------------------------
# Compute Bias (%) and MAPE (%)
# -----------------------------
diff = pred - truth

# Percent bias = mean(diff) / mean(truth) * 100
bias_pct = (diff.mean(dim=("x", "y")) / truth.mean(dim=("x", "y"))) * 100

# MAPE = mean(|diff| / |truth|) * 100
mape_pct = (np.abs(diff) / np.abs(truth)).mean(dim=("x", "y")) * 100

# Correlation
corr = xr.corr(pred, truth, dim=("x", "y"))

# -----------------------------
# Build DataFrame
# -----------------------------
stats_df = pd.DataFrame({
    "time": truth.time.values,
    "Bias (%)": bias_pct.values,
    "MAPE (%)": mape_pct.values,
    "Corr": corr.values
})

# Convert CFTime to pandas datetime if possible
try:
    stats_df["time"] = pd.to_datetime(stats_df["time"])
except Exception:
    pass

# -----------------------------
# Display and save
# -----------------------------
pd.set_option("display.precision", 3)
print(stats_df.iloc[:,1:])

stats_df.to_csv(f"corrdiff_stats_{var_name}_mape.csv", index=False)


    Bias (%)  MAPE (%)   Corr
0     -0.002     0.168  0.998
1     -0.016     0.164  0.997
2     -0.010     0.135  0.998
3     -0.014     0.119  0.994
4     -0.016     0.129  0.997
5     -0.035     0.165  0.995
6     -0.033     0.125  0.994
7     -0.034     0.135  0.996
8      0.010     0.139  0.998
9     -0.039     0.165  0.996
10    -0.030     0.153  0.996
11    -0.053     0.233  0.998


In [21]:
ds_truth

In [2]:
def checkpoint_list(path, suffix=".mdlus"):
    """Helper function to return sorted list, in ascending order, of checkpoints in a path"""
    checkpoints = []
    for file in os.listdir(path):
        if file.endswith(suffix):
            # Split the filename and extract the index
            try:
                index = int(file.split(".")[-2])
                checkpoints.append((index, file))
            except ValueError:
                continue

    # Sort by index and return filenames
    checkpoints.sort(key=lambda x: x[0])
    return [file for _, file in checkpoints]


# Define safe CUDA profiler tools that fallback to no-ops when CUDA is not available
def cuda_profiler():
    if torch.cuda.is_available():
        return torch.cuda.profiler.profile()
    else:
        return nullcontext()


def cuda_profiler_start():
    if torch.cuda.is_available():
        torch.cuda.profiler.start()


def cuda_profiler_stop():
    if torch.cuda.is_available():
        torch.cuda.profiler.stop()


def profiler_emit_nvtx():
    if torch.cuda.is_available():
        return torch.autograd.profiler.emit_nvtx()
    else:
        return nullcontext()



In [2]:
from omegaconf import OmegaConf
from hydra import initialize, compose
initialize(config_path="conf", version_base="1.2")
cfg = compose(config_name="config_training_era5_regression")

OmegaConf.resolve(cfg)
dataset_cfg = OmegaConf.to_container(cfg.dataset)  # TODO needs better handling




In [16]:
import xarray as xr
xr.open_dataset("/beegfs/muduchuru/data/mswx/pr/2024366.nc")

In [7]:
DistributedManager.initialize()
dist = DistributedManager()

  warn(


In [4]:

if hasattr(cfg, "validation"):
    validation = True
    validation_dataset_cfg = OmegaConf.to_container(cfg.validation)
else:
    validation = False
    validation_dataset_cfg = None
fp_optimizations = cfg.training.perf.fp_optimizations
songunet_checkpoint_level = cfg.training.perf.songunet_checkpoint_level
fp16 = fp_optimizations == "fp16"
enable_amp = fp_optimizations.startswith("amp")
amp_dtype = torch.float16 if (fp_optimizations == "amp-fp16") else torch.bfloat16
# logger.info(f"Saving the outputs in {os.getcwd()}")
checkpoint_dir = get_checkpoint_dir(
    str(cfg.training.io.get("checkpoint_dir", ".")), cfg.model.name
)
if cfg.training.hp.batch_size_per_gpu == "auto":
    cfg.training.hp.batch_size_per_gpu = (
        cfg.training.hp.total_batch_size // dist.world_size
    )

# Load the current number of images for resuming
try:
    cur_nimg = load_checkpoint(
        path=checkpoint_dir,
    )
except Exception:
    cur_nimg = 0


NameError: name 'dist' is not defined

In [13]:
data_loader_kwargs = {
    "pin_memory": True,
    "num_workers": cfg.training.perf.dataloader_workers,
    "prefetch_factor": 2 if cfg.training.perf.dataloader_workers > 0 else None,
}
(
    dataset,
    dataset_iterator,
    validation_dataset,
    validation_dataset_iterator,
) = init_train_valid_datasets_from_config(
    dataset_cfg,
    data_loader_kwargs,
    batch_size=16,
    seed=0,
    validation_dataset_cfg=validation_dataset_cfg,
    validation=False,
    sampler_start_idx=0,
)

In [14]:
validation_dataset.time()

AttributeError: 'NoneType' object has no attribute 'time'

In [12]:
import xarray as xr
data_path = "./data/hrrr_mini/hrrr_mini_train.nc"  # <- change this
ds_inv = xr.open_dataset(data_path, group="invariant")

In [36]:
coords = xr.open_dataset(data_path).coords

In [15]:
import xarray as xr
import matplotlib.pyplot as plt
import numpy as np

# === USER SETTINGS ===
data_path = "./data/hrrr_mini/hrrr_mini_train.nc"  # <- change this

# --- Load input and output groups directly ---
with xr.open_dataset(data_path, group="input") as ds_in, xr.open_dataset(data_path, group="output") as ds_out:
    print("Input variables:", list(ds_in.keys()))
    print("Output variables:", list(ds_out.keys()))

    # Pick a representative variable (first one)
    var_in = ["t2m"]
    var_out = ["2t"]

    # Select a single timestep (if time dimension exists)
    if "time" in ds_in.dims:
        da_in = ds_in[var_in].isel(time=0)
    else:
        da_in = ds_in[var_in]

    if "time" in ds_out.dims:
        da_out = ds_out[var_out].isel(time=0)
    else:
        da_out = ds_out[var_out]

    # Extract coordinate variables (names may differ depending on file)
    # Try to detect them automatically
    x_in = ds_in.coords.get("x") or ds_in.coords.get("lon") or np.arange(da_in.shape[-1])
    y_in = ds_in.coords.get("y") or ds_in.coords.get("lat") or np.arange(da_in.shape[-2])

    x_out = ds_out.coords.get("x") or ds_out.coords.get("lon") or np.arange(da_out.shape[-1])
    y_out = ds_out.coords.get("y") or ds_out.coords.get("lat") or np.arange(da_out.shape[-2])

    print(f"Input grid shape:  {da_in.shape}, x: {len(x_in)}, y: {len(y_in)}")
    print(f"Output grid shape: {da_out.shape}, x: {len(x_out)}, y: {len(y_out)}")

# --- Plot to visually compare ---
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Plot low-res input
im0 = axes[0].imshow(da_in[0].values, origin="lower", extent=[x_in[0], x_in[-1], y_in[0], y_in[-1]])
axes[0].set_title(f"Input ({var_in})")
plt.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04)

# Plot high-res output
im1 = axes[1].imshow(da_out[0].values, origin="lower", extent=[x_out[0], x_out[-1], y_out[0], y_out[-1]])
axes[1].set_title(f"Output ({var_out})")
plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)

plt.tight_layout()
plt.show()

# --- Overlay check (optional) ---
plt.figure(figsize=(6, 6))
plt.contour(y_out, x_out, da_out[0].values, levels=8, colors='r', linewidths=1.0)
plt.contour(y_in, x_in, da_in[0].values, levels=8, colors='k', linewidths=1.0)
plt.title("Overlay of Input (black) and Output (red) Grids")
plt.xlabel("X or Lon")
plt.ylabel("Y or Lat")
plt.show()


Input variables: ['u10m', 'v10m', 't2m', 'tcwv', 'sp', 'msl', 'u1000', 'u850', 'u500', 'u250', 'v1000', 'v850', 'v500', 'v250', 'z1000', 'z850', 'z500', 'z250', 't1000', 't850', 't500', 't250', 'q1000', 'q850', 'q500', 'q250']
Output variables: ['2t', '10u', '10v', 'tp']


AttributeError: 'Dataset' object has no attribute 'shape'

In [3]:
import glob
import xarray as xr
import pandas as pd

tru = xr.open_dataset('corrdiff_output.nc',group='truth')
pred = xr.open_dataset('corrdiff_output.nc',group='prediction')
input = xr.open_dataset('corrdiff_output.nc',group='input')

In [114]:
ds_day['t500']

In [None]:
import xarray as xr
import pandas as pd
import numpy as np
import os
import glob
from concurrent.futures import ProcessPoolExecutor, as_completed

data_path = "/data01/FDS/muduchuru/ERA5/europe/ERA5"

era5_surface_channels = [
    "pr",
    "tas",
    "tasmax",
    "tasmin",
]
era5_isobaric_channels = [
    "vo850",
    "vo700",
    "vo500",
    "vo200",
    "t850",
    "t700",
    "t500",
    "t200",
    "r850",
    "r700",
    "r500",
    "r200",
    "q850",
    "q700",
    "q500",
    "q200",
]

In [7]:
import os
import glob
import re
import xarray as xr
import pandas as pd


def extract_dates(root="/beegfs/muduchuru/data/era5"):
    """Extract all available dates for each variable and return their intersection."""
    pattern = re.compile(r"_(\d{4}-\d{2}-\d{2})_")

    all_dates_per_var = []
    for sub in [os.path.join(root, "sf", "*"), os.path.join(root, "pl", "*")]:
        for var_dir in glob.glob(sub):
            files = glob.glob(os.path.join(var_dir, "*.nc"))
            dates = set()
            for f in files:
                m = pattern.search(os.path.basename(f))
                if m:
                    dates.add(m.group(1))
            if dates:
                all_dates_per_var.append(dates)

    common_dates = sorted(set.intersection(*all_dates_per_var))
    print(f"✅ Found {len(common_dates)} common dates across all variables.")
    return common_dates


def combine_era5_channels(date_str, root="/beegfs/muduchuru/data/era5"):
    """Combine all ERA5 variables for a given date into one dataset with a single variable 'image'."""
    subdirs = [os.path.join(root, "sf", "*"), os.path.join(root, "pl", "*")]
    files = []
    for sub in subdirs:
        files.extend(glob.glob(os.path.join(sub, f"*_{date_str}_*.nc")))

    if not files:
        raise FileNotFoundError(f"No ERA5 files found for {date_str}")

    datasets = []
    channel_names = []

    for f in sorted(files):
        try:
            ds = xr.open_dataset(f)
            var = list(ds.data_vars.keys())[0]
            da = ds[var].squeeze(drop=True)

            # Drop plev if present (we already encode it in the var name)
            if "plev" in da.coords:
                da = da.drop_vars("plev")

            da = da.expand_dims("channel").assign_coords(channel=[var])
            datasets.append(da)
            channel_names.append(var)
        except Exception as e:
            print(f"⚠️ Skipping {f}: {e}")

    # Combine into single 3D array
    image = xr.concat(datasets, dim="channel", coords="minimal", compat="override")
    image = image.sortby("lat", ascending=False)
    image.attrs["variables"] = ",".join(channel_names)

    # Add time coordinate (scalar)
    time_val = pd.to_datetime(date_str)
    image = image.expand_dims("time").assign_coords(time=[time_val])

    # Wrap into Dataset with one variable: 'image'
    ds_out = xr.Dataset({"image": image})
    ds_out["channel"] = image["channel"]

    print(f"✅ Combined {len(channel_names)} vars for {date_str}")
    return ds_out


def combine_all_common_dates(root="/beegfs/muduchuru/data/era5", out_dir=None):
    """Loop through all common dates and save combined files as E5pl00_1D_<date>_<nvar>var.nc"""
    if out_dir is None:
        out_dir = os.path.join(root, "combined")
    os.makedirs(out_dir, exist_ok=True)

    common_dates = extract_dates(root)

    for date_str in common_dates:
        try:
            out_filename = f"E5pl00_1D_{date_str}_{20}var.nc"
            out_path = os.path.join(out_dir, out_filename)
            if not os.path.exists(out_path):    
                ds_out = combine_era5_channels(date_str, root)
            
                ds_out.to_netcdf(out_path)
                print(f"💾 Saved {out_path}")
            else:
                print(f"⏭️ Skipped (already exists): {out_path}")
        except Exception as e:
            print(f"❌ Skipping {date_str}: {e}")


# === Run ===
combine_all_common_dates("/beegfs/muduchuru/data/era5")


✅ Found 9497 common dates across all variables.
⏭️ Skipped (already exists): /beegfs/muduchuru/data/era5/combined/E5pl00_1D_1999-01-01_20var.nc
⏭️ Skipped (already exists): /beegfs/muduchuru/data/era5/combined/E5pl00_1D_1999-01-02_20var.nc
⏭️ Skipped (already exists): /beegfs/muduchuru/data/era5/combined/E5pl00_1D_1999-01-03_20var.nc
⏭️ Skipped (already exists): /beegfs/muduchuru/data/era5/combined/E5pl00_1D_1999-01-04_20var.nc
⏭️ Skipped (already exists): /beegfs/muduchuru/data/era5/combined/E5pl00_1D_1999-01-05_20var.nc
⏭️ Skipped (already exists): /beegfs/muduchuru/data/era5/combined/E5pl00_1D_1999-01-06_20var.nc
⏭️ Skipped (already exists): /beegfs/muduchuru/data/era5/combined/E5pl00_1D_1999-01-07_20var.nc
⏭️ Skipped (already exists): /beegfs/muduchuru/data/era5/combined/E5pl00_1D_1999-01-08_20var.nc
⏭️ Skipped (already exists): /beegfs/muduchuru/data/era5/combined/E5pl00_1D_1999-01-09_20var.nc
⏭️ Skipped (already exists): /beegfs/muduchuru/data/era5/combined/E5pl00_1D_1999-01-10_2

In [2]:
import xarray as xr
ds = xr.open_dataset("/beegfs/muduchuru/data/era5/combined/E5pl00_1D_2009-04-17_20var.nc")

module 'pycparser' has no attribute 'c_ast'
  external_backend_entrypoints = backends_dict_from_pkg(entrypoints_unique)


In [3]:
ds