In [41]:
import json as json
from dataclasses import dataclass 
import polars as pl
from datetime import datetime, timedelta
from cpom.gias.gia import GIA
from cpom.gridding.gridareas import GridArea
import numpy as np 
import duckdb

In [42]:
griddir = "/home/willisc3/luna/CPOM/willisc3/EOCIS/altimetry/landice/gridded_altimetry/cs2/greenland_5km_cryotempo_li/compacted/"
# griddir = "/home/willisc3/luna/CPOM/maddalen/lancs_mass_balance/gris/cs2_gr/surfacefit_latest"
parquet_glob = f"{griddir}/**/*.parquet"

In [46]:
with open(griddir + 'grid_meta.json', "r") as f:
    json_data = json.load(f)

@dataclass
class Parser:
    griddir: str
    surface_fit_dir: str
    outdir: str
    epoch_length : int 
    epoch_start : int 
    epoch_end : int 
    gia_model : object

params = Parser("/home/willisc3/luna/CPOM/willisc3/EOCIS/altimetry/landice/gridded_altimetry/cs2/greenland_5km_cryotempo_li/compacted", 
                "/home/willisc3/luna/CPOM/willisc3/EOCIS/altimetry/landice/surface_fit/cs2/greenland_5km_cryotempo_li",
                "", 
                140, 
                "2010/07/20",
                "2025/07/18",
                "ice5g"
                )

parquet_glob = f"{params.surface_fit_dir}/x_part=*/**/*.parquet"

In [47]:
def get_min_max_time( epoch_time, parquet_glob):
    """
    Process command line options --mintime, --maxtime into datetime and seconds objects.
    If no options provided, get min and max from parquet grid files with duckdb.

    Args:
        mintime:
        maxtime:
        epoch_time:
        parquet_glob (str): Directionary glob for parquet files
    """

    epoch_time = datetime.fromisoformat(epoch_time)
    df = pl.scan_parquet(parquet_glob)
    min_max = df.select([
        pl.col("time").min().alias("min_time"),
        pl.col("time").max().alias("max_time")
    ]).collect()

    min_secs = min_max["min_time"][0]
    max_secs = min_max["max_time"][0]
    min_dt = epoch_time + timedelta(seconds=min_secs)
    max_dt = epoch_time + timedelta(seconds=max_secs)
    return min_dt, max_dt, min_secs, max_secs


In [48]:
from pathlib import Path 
import json 

json_metadata_path = Path(params.surface_fit_dir) / "surface_fit_meta.json"
try:
    with open(json_metadata_path, "r",) as f_meta:
        json_data = json.load(f_meta)
except OSError as exc:
    pass

json_data['standard_epoch'] = "1991-01-01T00:00:00"
epoch_time = datetime.fromisoformat(json_data['standard_epoch'])

def get_date(epoch_time, timedt=None):
    if timedt is not None:
        if "/" in timedt:
            time_dt = datetime.strptime(timedt, "%Y/%m/%d")
        elif "." in timedt:
            time_dt = datetime.strptime(timedt, "%Y.%m.%d")
        else:
            raise ValueError(
                f"Unrecognized date format: {timedt}, pass as YYYY/MM/DD or YYYY.MM.DD "
            )
        seconds = (time_dt - epoch_time).total_seconds()

    return time_dt, seconds

if params.epoch_start is not None and params.epoch_end is not None:
    epoch_start_dt, epoch_min_secs = get_date(epoch_time, params.epoch_start)
    epoch_end_dt, epoch_max_secs = get_date(epoch_time, params.epoch_end)

dtstart, dtend , _  , _ = get_min_max_time(epoch_time=json_data['standard_epoch'] , parquet_glob=parquet_glob)
epoch_length = float(params.epoch_length)
epoch_period = epoch_end_dt - epoch_start_dt
num_days_in_period = epoch_period.days
num_epochs = int(num_days_in_period / epoch_length)

gia = GIA('ice5g')
grid = GridArea("greenland", 5000) # TODO : get this from json_data
grid_x, grid_y = np.meshgrid(grid.cell_x_centres, grid.cell_y_centres)
uplift_grid = gia.interp_gia(grid.crs_bng, grid.crs_wgs, grid_x, grid_y)
y_idx, x_idx = np.indices(uplift_grid.shape)

uplift_grid = pl.LazyFrame({
    "x_bin": x_idx.ravel(),
    "y_bin": y_idx.ravel(),
    "uplift_value": uplift_grid.ravel()
})

epochs = pl.LazyFrame({
"epoch_lo_dt": pl.datetime_range(
    start=epoch_start_dt, 
    end=epoch_end_dt,
    interval=timedelta(days=epoch_length), 
    closed = "left",
    eager=True
)
}).with_columns([
(pl.col('epoch_lo_dt') + pl.duration(days=epoch_length)).alias("epoch_hi_dt"),
pl.arange(0, pl.len()).alias("epoch_number")
]).filter(
pl.col("epoch_lo_dt") < dtend,
pl.col("epoch_hi_dt") > dtstart
)

surface_fit_grid = pl.scan_parquet(parquet_glob).with_columns(
    (pl.lit(epoch_time) + pl.duration(seconds=pl.col("time"))).alias("time_dt")
)

conn = duckdb.connect()
conn.register('surface_fit_grid', surface_fit_grid)
conn.register('epochs', epochs)

surface_fit_grid_with_epoch = conn.execute(f"""
    SELECT 
        s.*, 
        e.*,
    FROM surface_fit_grid s
    LEFT JOIN epochs e
        ON s.time_dt BETWEEN e.epoch_lo_dt AND e.epoch_hi_dt
    WHERE e.epoch_number IS NOT NULL and 
    -- Filter to ensure time_dt is within the epoch range for each x_bin, y_bin, epoch_number
    s.time_dt >= e.epoch_lo_dt and
    s.time_dt < e.epoch_hi_dt
""").pl()

if params.gia_model is not None:
    surface_fit_grid_gia_corrected = pl.LazyFrame(surface_fit_grid_with_epoch).with_columns([
        pl.col("time_years").min().over(["x_bin", "y_bin", "epoch_number"]).alias("min_time"), # Group by x_bin, y_bin, epoch_number to get the min time for each bin
    ]).join(
        uplift_grid, 
        on=["x_bin", "y_bin"], 
        how="left"
    ).select(
        pl.col("x_bin"),
        pl.col("y_bin"),
        pl.col("epoch_number"),
        pl.col("epoch_lo_dt"),
        pl.col("epoch_hi_dt"),
        pl.col("dh"),
        pl.col("time"),
        pl.col("time_years"),
        pl.col("time_dt"),
        pl.col("min_time"),
        pl.col("uplift_value")
    ).with_columns(
        (pl.col("time_years") - pl.col("min_time")).alias("time_delta_years")
    ).with_columns(
        (pl.col("dh") - pl.col("time_delta_years")* pl.col("uplift_value")).alias("dh") # Correct the dh by the uplift value
    ).drop("time_delta_years", "min_time" , "uplift_value", "time_years", "time").collect()


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

In [49]:
epoch_end = dtend
epoch_end

datetime.datetime(2024, 7, 30, 14, 2, 30, 200238)

In [50]:
epochs.collect()

epoch_lo_dt,epoch_hi_dt,epoch_number
datetime[μs],datetime[μs],i64
2010-07-20 00:00:00,2010-12-07 00:00:00,0
2010-12-07 00:00:00,2011-04-26 00:00:00,1
2011-04-26 00:00:00,2011-09-13 00:00:00,2
2011-09-13 00:00:00,2012-01-31 00:00:00,3
2012-01-31 00:00:00,2012-06-19 00:00:00,4
…,…,…
2022-10-25 00:00:00,2023-03-14 00:00:00,32
2023-03-14 00:00:00,2023-08-01 00:00:00,33
2023-08-01 00:00:00,2023-12-19 00:00:00,34
2023-12-19 00:00:00,2024-05-07 00:00:00,35


In [51]:
df = surface_fit_grid_gia_corrected

# y_bin_np = df["y_bin"].to_numpy()
# x_bin_np = df["x_bin"].to_numpy()
# epoch_np = df["epoch_number"].to_numpy()
# dh_np = df["dh"].to_numpy() 
# keys = np.core.records.fromarrays([y_bin_np, x_bin_np, epoch_np],
#                                   names="y_bin,x_bin,epoch")


In [16]:
result = conn.execute("""
    SELECT COUNT(*) as unique_xy_bin
    FROM (SELECT DISTINCT x_bin, y_bin FROM df)
""").fetchdf()
print(result)

   unique_xy_bin
0          63708


In [11]:
count = df.select(
    pl.col(["x_bin", "y_bin"]).unique().count()
)
print(count)

shape: (1, 2)
┌───────┬───────┐
│ x_bin ┆ y_bin │
│ ---   ┆ ---   │
│ u32   ┆ u32   │
╞═══════╪═══════╡
│ 295   ┆ 528   │
└───────┴───────┘


In [77]:
from astropy.stats import sigma_clipped_stats

def clipped_stat(col):
    return sigma_clipped_stats(col)[0]  # mean

grouped = df.group_by(["y_bin", "x_bin", "epoch_number"]).agg([
    pl.col("dh").map_elements(lambda x: sigma_clipped_stats(x)[0]).alias("mean_dh")
])



In [78]:
grouped

y_bin,x_bin,epoch_number,mean_dh
i64,i64,i64,f64
209,201,2,0.250418
412,141,25,0.018222
559,225,23,-2.539529
237,169,32,0.265979
397,219,19,0.009936
…,…,…,…
428,284,1,0.162138
240,320,26,-10.519561
265,163,22,0.924707
137,173,4,0.256676


In [76]:

grouped = df.group_by(["y_bin", "x_bin","epoch_number"]).agg([
    pl.col("dh").alias("dh_array"),
    # pl.col("time_dt").alias("time_list")
])
np_structured = grouped.to_numpy()
for row in np_structured:
    this_stats = sigma_clipped_stats( np.array(row[3]))

KeyboardInterrupt: 

In [None]:
rows = grouped.rows(named=True)


In [74]:
rows

[{'y_bin': 494,
  'x_bin': 205,
  'epoch_number': 20,
  'dh_array': [0.05319062328545806,
   -0.22508518151437737,
   -0.14560237304845935,
   -0.03846243780635935,
   0.18862830127195893,
   -0.30142513346458155,
   0.36042311940762795,
   0.2108578340747605,
   0.01648428018824109,
   0.2561886917393309,
   -0.13047945151289003,
   -0.12966364502347494,
   0.21984167706568464,
   -0.15333141132390127,
   -0.6130133231015075,
   0.039757479837083665,
   0.39866941830459857,
   0.5993576554296068,
   0.23573724224531434,
   0.30021842789296765,
   -0.14872220384667753,
   0.4258260149292823,
   -0.3835801319221446,
   0.593462798448114,
   -0.06645457389732283,
   -0.44520062919068354,
   0.5581370683645129,
   -0.3220785371517847,
   0.011410805273013839,
   -0.297171404744562,
   0.2964000161780416,
   0.31271398865708433,
   0.039913084123192236,
   -0.46330049250756583,
   -0.2155555009585365,
   0.48010697340708514,
   0.4480489554441938,
   0.0451018504233676,
   -0.0619694875941

In [75]:
counter = 0 
for row in rows:
    counter += 1
    if counter % 100000 == 0:
        print(f"Processed {counter} rows")

    dh = np.array(row["dh_array"])
    this_stats = sigma_clipped_stats(dh)

Processed 100000 rows
Processed 200000 rows
Processed 300000 rows
Processed 400000 rows
Processed 500000 rows


KeyboardInterrupt: 

In [None]:
from astropy.stats import sigma_clipped_stats  # Robust stats
import time

i = 0 



Processed 4.724591736216595 of 2116585 grid cells in 1752833050.9256344 
Processed 9.44918347243319 of 2116585 grid cells in 1752833077.787436 
Processed 14.173775208649783 of 2116585 grid cells in 1752833104.637666 
Processed 18.89836694486638 of 2116585 grid cells in 1752833131.3435025 
Processed 23.62295868108297 of 2116585 grid cells in 1752833158.237999 


KeyboardInterrupt: 

In [83]:
df = surface_fit_grid_gia_corrected
x_bins = df['x_bin'].unique()
y_bins = df['y_bin'].unique()
epoch_bins = df['epoch_number'].unique()


In [84]:
test = pl.LazyFrame(df)

In [None]:

results = []

for row in grouped.iter_rows(named=True):
    r = row["y_bin"]
    c = row["x_bin"]
    mean, _, std = sigma_clipped_stats(row["dh_list"])
    results.append({
        "y_bin": r,
        "x_bin": c,
        "epoch_number": row["epoch_number"],
        "dh_ave": mean,
        "input_dh_dens" : len(row["dh_list"]),
        "input_dh_stddev": std

    })


In [79]:
import pandas as pd
from astropy.stats import sigma_clipped_stats
from duckdb.typing import VARCHAR, FLOAT, INTEGER

df = surface_fit_grid_gia_corrected
con = duckdb.connect()
con.register("input_table", df)

def sigma_clipped_mean_std(arr):
    arr = np.array(arr)
    mean, median, std = sigma_clipped_stats(arr)
    return mean, std, len(arr)

con.create_function("sigma_clipped_mean_std", 
                    sigma_clipped_mean_std,
                    return_type=duckdb.struct_type({"mean": FLOAT, "std": FLOAT, "count": INTEGER}))

result = con.execute("""
SELECT
  x_bin, y_bin, epoch_number,
  sigma_clipped_mean_std(list(dh)) as stats
FROM input_table
GROUP BY x_bin, y_bin, epoch_number
""")


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

In [None]:
result.pl()