In [None]:
import glob, re, os
import matplotlib as plt
import numpy as np
import xarray as xr
import pandas as pd
import threading
import OpenVisus as ov
from datetime import datetime, timedelta
from tqdm import tqdm
from dask import delayed
from concurrent.futures import ThreadPoolExecutor, as_completed





In [None]:
# !pip install dask_jobqueue

In [None]:
import dask 
from dask_jobqueue import PBSCluster
from dask.distributed import Client
from dask.distributed import performance_report

In [None]:
######## File paths ################
lustre_scratch    = "/glade/work/dpanta"
era5_surface_data = "/gdex/data/special_projects/harshah/ARCO/e5.oper.an.sfc"

In [None]:
BASE_YEAR = 1940  

def _is_leap(y: int) -> bool:
    return (y % 4 == 0) and (y % 100 != 0 or y % 400 == 0)

def _leap_days_since_base(year: int) -> int:
    return sum(_is_leap(y) for y in range(BASE_YEAR, year))

def hour_index_from_iso_big(iso: str) -> int:
    s = iso.strip().replace("-", " ")
    parts = s.split()
    if len(parts) == 2:
        date, hh = parts
        ts = pd.to_datetime(f"{date} {hh}:00:00", utc=True)
    elif len(parts) == 1:
        ts = pd.to_datetime(parts[0] + " 00:00:00", utc=True)
    else:
        ts = pd.to_datetime(iso.replace("-", " ", 3).replace(" ", "T", 1), utc=True)

    y = ts.year
    year_start = pd.Timestamp(year=y, month=1, day=1, tz="UTC")
    hours_since_year_start = int((ts - year_start) / pd.Timedelta(hours=1))
    leap_days = _leap_days_since_base(y)
    return y*365*24 + leap_days*24 + hours_since_year_start + 1


In [None]:
cluster = PBSCluster(
        job_name = 'dask-osdf-25',
        cores = 1,
        memory = '4GiB',
        processes = 1,
        local_directory = lustre_scratch + '/dask/spill',
        log_directory = lustre_scratch + '/dask/logs/',
        resource_spec = 'select=1:ncpus=1:mem=4GB',
        queue = 'casper',
        account='P43713000',
        walltime = '3:00:00',
        #interface = 'ib0'
        interface = 'ext'
    )

In [None]:
client = Client(cluster)
# Scale the cluster and display cluster dashboard URL
n_workers =5
cluster.scale(n_workers)
client.wait_for_workers(n_workers = n_workers)
cluster

In [None]:
BASE_YEAR = 1940
def _is_leap(y: int) -> bool:
    return (y % 4 == 0) and (y % 100 != 0 or y % 400 == 0)
def _leap_days_since_base(year: int) -> int:
    return sum(_is_leap(y) for y in range(BASE_YEAR, year))
def hour_index_from_iso_big(iso: str) -> int:
    s = iso.strip().replace("-", " ")
    parts = s.split()
    if len(parts) == 2:
        date, hh = parts
        ts = pd.to_datetime(f"{date} {hh}:00:00", utc=True)
    elif len(parts) == 1:
        ts = pd.to_datetime(parts[0] + " 00:00:00", utc=True)
    else:
        ts = pd.to_datetime(iso.replace("-", " ", 3).replace(" ", "T", 1), utc=True)
    y = ts.year
    year_start = pd.Timestamp(year=y, month=1, day=1, tz="UTC")
    hours_since_year_start = int((ts - year_start) / pd.Timedelta(hours=1))
    leap_days = _leap_days_since_base(y)
    return y*365*24 + leap_days*24 + hours_since_year_start + 1




In [None]:
IDX_PATH = "/glade/work/dpanta/era5/idx/2T/era5_sfc_2T_zip.idx"

H, W = 721, 1440
lat = np.linspace(90.0, -90.0, H, dtype=np.float32)
W1D = np.cos(np.deg2rad(lat)).astype(np.float32)
FULL_DEN = float(W1D.sum() * W)

_DB = None
_ACCESS = None

def _get_db_and_access():
    global _DB, _ACCESS
    if _DB is None:
        _DB = ov.LoadDataset(IDX_PATH)
        try:
            _ACCESS = _DB.createAccess()
        except Exception:
            _ACCESS = None
    return _DB, _ACCESS

def weighted_global_mean_fast(a: np.ndarray) -> float:
    a = np.asarray(a, dtype=np.float32)
    if np.isfinite(a).all():
        row_sum = a.sum(axis=1, dtype=np.float32)
        return float(np.dot(W1D, row_sum) / FULL_DEN)
    row_sum = np.nansum(a, axis=1).astype(np.float32)
    row_cnt = np.sum(np.isfinite(a), axis=1, dtype=np.int32)
    den = float(np.dot(W1D, row_cnt.astype(np.float32)))
    return float(np.dot(W1D, row_sum) / den) if den > 0 else np.nan

def gmst_hours_block_threaded(t0: int, n_hours: int, max_threads: int = 4) -> list[float]:
    db, access = _get_db_and_access()

    def read_reduce(t: int) -> float:
        if access is not None:
            a = db.read(time=t, access=access)
        else:
            a = db.read(time=t)
        return weighted_global_mean_fast(a)

    out = [None] * n_hours
    with ThreadPoolExecutor(max_workers=max_threads) as ex:
        futs = {ex.submit(read_reduce, t0 + k): k for k in range(n_hours)}
        for f in as_completed(futs):
            k = futs[f]
            out[k] = f.result()
    return out

YEAR = 1949
start_idx = hour_index_from_iso_big(f"{YEAR}-01-01 00")
end_idx   = hour_index_from_iso_big(f"{YEAR+1}-01-01 00")
total_hours = end_idx - start_idx

BLOCK_HOURS = 24 * 30  
n_blocks = (total_hours + BLOCK_HOURS - 1) // BLOCK_HOURS

tasks = []
for b in range(n_blocks):
    b_start = start_idx + b * BLOCK_HOURS
    b_n = min(BLOCK_HOURS, total_hours - b * BLOCK_HOURS)
    tasks.append(delayed(gmst_hours_block_threaded)(b_start, b_n, max_threads=4))


In [None]:
# %%timeit -n 2 -r 3 
blocks = dask.compute(*tasks)
gmst_vals = np.fromiter((v for block in blocks for v in block), dtype=np.float32, count=total_hours)
print("Annual GMST", YEAR, ":", float(np.nanmean(gmst_vals)))

In [None]:
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
import numpy as np, dask
from dask import delayed
import OpenVisus as ov

IDX_PATH = "/glade/work/dpanta/era5/idx/2T/era5_sfc_2T_zip.idx"
QUALITY  = -2  # half-res

_DB = None
_ACCESS = None        
_W1D = None
_FULL_DEN = None

def _get_db_and_access():
    """Match your API; returns a DB and a (unused) process-wide access handle."""
    global _DB, _ACCESS
    if _DB is None:
        _DB = ov.LoadDataset(IDX_PATH)
        try:
            _ACCESS = _DB.createAccess()
        except Exception:
            _ACCESS = None
    return _DB, _ACCESS

def _init_weights_for_quality(example_time: int):
    """Read one QUALITY frame to discover (H,W) and build matching weights/den."""
    global _W1D, _FULL_DEN
    if _W1D is not None and _FULL_DEN is not None:
        return
    db, acc = _get_db_and_access()
    a0 = db.read(time=example_time, quality=QUALITY, access=acc) if acc is not None else db.read(time=example_time, quality=QUALITY)
    H, W = a0.shape[-2], a0.shape[-1]
    lat = np.linspace(90.0, -90.0, H, dtype=np.float32)
    _W1D = np.cos(np.deg2rad(lat)).astype(np.float32)
    _FULL_DEN = float(_W1D.sum() * W)

def weighted_global_mean_fast(a: np.ndarray) -> float:
    a = np.asarray(a, dtype=np.float32)
    if np.isfinite(a).all():
        row_sum = a.sum(axis=1, dtype=np.float32)
        return float(np.dot(_W1D, row_sum) / _FULL_DEN)
    row_sum = np.nansum(a, axis=1).astype(np.float32)
    row_cnt = np.sum(np.isfinite(a), axis=1, dtype=np.int32).astype(np.float32)
    den = float(np.dot(_W1D, row_cnt))
    return float(np.dot(_W1D, row_sum) / den) if den > 0 else np.nan

# ---- thread-local access so parallel reads don’t clash ----
_tlocal = threading.local()
def _get_thread_access(db):
    acc = getattr(_tlocal, "access", None)
    if acc is None:
        try:
            acc = db.createAccess()
        except Exception:
            acc = None
        _tlocal.access = acc
    return acc

def gmst_hours_block_threaded(t0: int, n_hours: int, max_threads: int = 4) -> list[float]:
    db, _ = _get_db_and_access()

    def read_reduce(t: int) -> float:
        acc = _get_thread_access(db)
        a = db.read(time=t, quality=QUALITY, access=acc) if acc is not None else db.read(time=t, quality=QUALITY)
        return weighted_global_mean_fast(a)

    out = [None] * n_hours
    with ThreadPoolExecutor(max_workers=max_threads) as ex:
        futs = {ex.submit(read_reduce, t0 + k): k for k in range(n_hours)}
        for f in as_completed(futs):
            out[futs[f]] = f.result()
    return out

YEAR = 1949
start_idx = hour_index_from_iso_big(f"{YEAR}-01-01 00")
end_idx   = hour_index_from_iso_big(f"{YEAR+1}-01-01 00")
total_hours = end_idx - start_idx

_init_weights_for_quality(start_idx)

BLOCK_HOURS = 24 * 60 
n_blocks = (total_hours + BLOCK_HOURS - 1) // BLOCK_HOURS

tasks = []
for b in range(n_blocks):
    b_start = start_idx + b * BLOCK_HOURS
    b_n = min(BLOCK_HOURS, total_hours - b * BLOCK_HOURS)
    tasks.append(delayed(gmst_hours_block_threaded)(b_start, b_n, max_threads=4))  # 4–6 threads works well

blocks = dask.compute(*tasks)
gmst_vals = np.fromiter((v for block in blocks for v in block), dtype=np.float32, count=total_hours)
print("Annual GMST", YEAR, ":", float(np.nanmean(gmst_vals)))


## GMST functions

## Load data and compute GMST

In [None]:
# import numpy as np
# import pandas as pd
# import dask
# from dask import delayed
# import OpenVisus as ov

# IDX_PATH = "/glade/work/dpanta/era5/idx/2T/era5_sfc_2T_zip.idx"
# ov.LoadDataset(IDX_PATH)
# H, W = 721, 1440

# lat = np.linspace(90.0, -90.0, H, dtype=np.float64)
# W1D = np.cos(np.deg2rad(lat)).astype(np.float64)
# W2D = np.repeat(W1D[:, None], W, axis=1)  # shape (H, W)

# _DB = None
# def _get_db():
#     global _DB
#     if _DB is None:
#         _DB = ov.LoadDataset(IDX_PATH)
#     return _DB

# def weighted_global_mean(arr2d: np.ndarray) -> float:
#     a = np.asarray(arr2d, dtype=np.float64)
#     bad = (~np.isfinite(a))
#     if bad.any():
#         a = a.copy()
#         a[bad] = np.nan
#     valid = np.isfinite(a)
#     num = np.nansum(a[valid] * W2D[valid])
#     den = np.sum(W2D[valid])
#     return float(num / den) if den > 0 else np.nan

# def gmst_hours_block(t_start_inclusive: int, n_hours: int) -> list[float]:
#     db = _get_db()
#     out = []
#     for k in range(n_hours):
#         a = db.read(time=t_start_inclusive + k)
#         out.append(weighted_global_mean(a))
#     return out

# YEAR = 1945
# start_idx = hour_index_from_iso_big(f"{YEAR}-01-01 00")
# end_idx   = hour_index_from_iso_big(f"{YEAR+1}-01-01 00")
# total_hours = end_idx - start_idx

# time_index_dt = pd.date_range(f"{YEAR}-01-01 00:00:00", periods=total_hours, freq="H", tz="UTC")

# tasks = [delayed(gmst_hours_block)(start_idx + d*24, 24) for d in range(total_hours // 24)]
# blocks = dask.compute(*tasks)
# gmst_vals = np.fromiter((v for block in blocks for v in block), dtype=np.float64, count=total_hours)

# print("Annual GMST 1950:", np.nanmean(gmst_vals))


In [None]:
client.close()