In [None]:
from pathlib import Path
import numpy as np
import pandas as pd
import yaml
import matplotlib.pyplot as plt
from poseidon.data.sources.swot.shard_builder import build_shards
from poseidon.data.shards import reshard_random_train
import torch


# Build SWOT Shards from Granules
This notebook demonstrates how to convert downloaded SWOT granules into per-cycle, per-pass shards using the functional shard builder utilities.

## Load Sharding Configuration
We reuse `configs/data/swot_example.yaml`, which now includes shard-specific paths and options.

In [2]:
config_path = Path("../configs/data/swot_example.yaml")
raw_cfg = yaml.safe_load(config_path.read_text())

bbox = raw_cfg.get("bbox")
if bbox is None and all(k in raw_cfg for k in ("min_lon", "min_lat", "max_lon", "max_lat")):
    bbox = [raw_cfg["min_lon"], raw_cfg["min_lat"], raw_cfg["max_lon"], raw_cfg["max_lat"]]

shard_params = {
    "granule_dir": raw_cfg["granule_dir"],
    "watermask_dir": raw_cfg["watermask_dir"],
    "shard_outdir": raw_cfg["shard_outdir"],
    "downsampling_factor": raw_cfg.get("downsampling_factor", 30),
}

summary_df = pd.DataFrame(
    list({
        "bbox": bbox,
        **shard_params,
    }.items()),
    columns=["key", "value"],
)
summary_df

Unnamed: 0,key,value
0,bbox,"[-99.0, 17.0, -79.0, 31.0]"
1,granule_dir,../data/example/granules
2,watermask_dir,../data/example/watermask
3,shard_outdir,../data/example/whole_shards
4,downsampling_factor,30


## Verify Inputs
The shard builder expects downloaded SWOT NetCDF granules and matching watermask tiles. Ensure the directories referenced above contain the files produced in the download notebook.

In [3]:
granule_dir = Path(raw_cfg["granule_dir"]).resolve()
watermask_dir = Path(raw_cfg["watermask_dir"]).resolve()
shard_outdir = Path(raw_cfg["shard_outdir"]).resolve()

if not granule_dir.exists():
    raise FileNotFoundError(f"Granule directory not found: {granule_dir}")
if not watermask_dir.exists():
    raise FileNotFoundError(f"Watermask directory not found: {watermask_dir}")

netcdf_files = sorted(granule_dir.glob(raw_cfg.get("granule_glob", "*.nc")))
if not netcdf_files:
    raise FileNotFoundError("No NetCDF files found. Run the download workflow first.")

pd.DataFrame(
    {
        "granules": [len(netcdf_files)],
        "watermask_tiles": [len(list(watermask_dir.glob("*.tif")))],
        "granule_sample": [netcdf_files[0].name],
    }
)

Unnamed: 0,granules,watermask_tiles,granule_sample
0,11,9,SWOT_L2_LR_SSH_Expert_026_231_20250101T013336_...


## Build Shards
Call `build_shards` with the configured bounding box and directories. Use `limit_granules` to constrain runtime during exploration.

In [4]:
limit_granules = 8                                        
bbox_list = bbox if bbox is not None else [raw_cfg["min_lon"], raw_cfg["min_lat"], raw_cfg["max_lon"], raw_cfg["max_lat"]]
bbox_map = {
    "min_lon": float(bbox_list[0]),
    "min_lat": float(bbox_list[1]),
    "max_lon": float(bbox_list[2]),
    "max_lat": float(bbox_list[3]),
}

downsampling = int(raw_cfg.get("downsampling_factor", 30))
shard_outdir.mkdir(parents=True, exist_ok=True)

selected = netcdf_files[:limit_granules] if limit_granules else netcdf_files
written = build_shards(
    [str(p) for p in selected],
    bbox=bbox_map,
    out_dir=shard_outdir,
    watermask_dir=watermask_dir,
    downsampling_factor=downsampling,
)

summary = pd.Series(
    {
        "granules_processed": len(selected),
        "shards_written": len(written),
        "output_dir": str(shard_outdir),
    },
    name="shard_build",
)

display(summary)

pd.DataFrame({"shard": [Path(p).name for p in written]})

  multiarray.copyto(a, fill_value, casting='unsafe')
  multiarray.copyto(a, fill_value, casting='unsafe')
  multiarray.copyto(a, fill_value, casting='unsafe')
  multiarray.copyto(a, fill_value, casting='unsafe')
  multiarray.copyto(a, fill_value, casting='unsafe')
  multiarray.copyto(a, fill_value, casting='unsafe')
  multiarray.copyto(a, fill_value, casting='unsafe')
  multiarray.copyto(a, fill_value, casting='unsafe')
  multiarray.copyto(a, fill_value, casting='unsafe')
  multiarray.copyto(a, fill_value, casting='unsafe')
  multiarray.copyto(a, fill_value, casting='unsafe')


granules_processed                                                    8
shards_written                                                        6
output_dir            /Users/mako3626/newfrontiers/poseidon/data/exa...
Name: shard_build, dtype: object

Unnamed: 0,shard
0,shard_c026_p231_SWOT_L2_LR_SSH_Expert_026_231_...
1,shard_c026_p244_SWOT_L2_LR_SSH_Expert_026_244_...
2,shard_c026_p259_SWOT_L2_LR_SSH_Expert_026_259_...
3,shard_c026_p272_SWOT_L2_LR_SSH_Expert_026_272_...
4,shard_c026_p300_SWOT_L2_LR_SSH_Expert_026_300_...
5,shard_c026_p328_SWOT_L2_LR_SSH_Expert_026_328_...


## Inspect a Shard
Load one of the generated `.npz` files to confirm schema and value ranges.

In [5]:
if not written:
    raise RuntimeError("No shards were produced; inspect previous step for issues.")

shard_path = Path(written[0])
with np.load(shard_path) as shard:
    shard_keys = sorted(shard.files)
    shard_shapes = {key: shard[key].shape for key in shard_keys}
    shard_dtypes = {key: str(shard[key].dtype) for key in shard_keys}

summary_df = pd.DataFrame(
    {
        "shape": shard_shapes,
        "dtype": shard_dtypes,
    }
)

print(f"Inspecting shard: {shard_path.name}")
summary_df

Inspecting shard: shard_c026_p231_SWOT_L2_LR_SSH_Expert_026_231_20250101T013336_20250101T022504_PIC2_01.npz


Unnamed: 0,shape,dtype
cycle,"(39076,)",int16
lat,"(39076,)",float32
lon,"(39076,)",float32
pas,"(39076,)",int16
t,"(39076,)",float32
y,"(39076,)",float32


## Reshard Training Batches
Convert the consolidated training shards into batched `.pt` files stored alongside the example dataset. Adjust the ratios or batch size as needed for your experiment.

In [6]:
pt_outdir = shard_outdir.parent / "batch_shards"
pt_outdir.mkdir(parents=True, exist_ok=True)

reshard_result = reshard_random_train(
    src_dir=shard_outdir,
    out_dir=pt_outdir,
    seed=0,
    batch_size=8192,
    batches_per_file=8,
    val_ratio=0.05,
    test_ratio=0.05,
 )

split_counts = {f"{split}_groups": len(groups) for split, groups in reshard_result.split_groups.items()}

reshard_summary = pd.Series(
    {
        "samples_loaded": reshard_result.samples_loaded,
        "batches_written": reshard_result.batches_written,
        "batch_size": reshard_result.batch_size,
        "dropped_samples": reshard_result.dropped_samples,
        "files_written": len(reshard_result.written),
        "missing_shards": reshard_result.missing_shards,
        "output_dir": str(pt_outdir.resolve()),
        **split_counts,
    },
    name="reshard",
)

reshard_summary

samples_loaded                                                131767
batches_written                                                   16
batch_size                                                      8192
dropped_samples                                                  695
files_written                                                      2
missing_shards                                                     0
output_dir         /Users/mako3626/newfrontiers/poseidon/data/exa...
train_groups                                                       4
val_groups                                                         1
test_groups                                                        1
Name: reshard, dtype: object

In [7]:
split_rows = []
for split, groups in reshard_result.split_groups.items():
    for cycle, pas in groups:
        split_rows.append({"split": split, "cycle": cycle, "pass": pas})
split_overview = pd.DataFrame(split_rows) if split_rows else pd.DataFrame(columns=["split", "cycle", "pass"])
split_overview

Unnamed: 0,split,cycle,pass
0,train,26,231
1,train,26,244
2,train,26,259
3,train,26,300
4,val,26,328
5,test,26,272


In [8]:

if reshard_result.written:
    batch_info = []
    for path in reshard_result.written[:2]:
        data = torch.load(path)
        batch_info.append(
            {
                "file": Path(path).name,
                "X_shape": tuple(data["X"].shape),
                "Y_shape": tuple(data["Y"].shape),
            }
        )
    pd.DataFrame(batch_info)
else:
    raise RuntimeError("Resharding did not produce any batch shards.")

Issue : Fix cycle and pass being transformed into full length to minimize I/O operations