In [None]:
from pathlib import Path
import yaml
import torch
import pandas as pd
import lightning.pytorch as pl
from IPython.display import display
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger
from poseidon.training import (


Steps: load the experiment YAML, inspect the data module, trace tensor shapes through the model, and run a short training/evaluation cycle with checkpointing.

In [1]:


    create_datamodule_from_config,
    LitRegressor,
    SetEpochOnIterable,
)


In [2]:
config_path = Path("../configs/experiments/example_rect_mlp.yaml").resolve()
cfg = yaml.safe_load(config_path.read_text())
experiment_name = cfg["experiment"]["name"]
checkpoint_dir = Path(cfg["experiment"]["checkpoints_dir"]).resolve()
checkpoint_dir.mkdir(parents=True, exist_ok=True)

torch.manual_seed(int(cfg["data"].get("seed", 0)))

print(f"Loaded config: {config_path}")
print(f"Checkpoint directory: {checkpoint_dir}")

Loaded config: /Users/mako3626/newfrontiers/poseidon/configs/experiments/example_rect_mlp.yaml
Checkpoint directory: /Users/mako3626/newfrontiers/poseidon/notebooks/experiments/example_rect_mlp/checkpoints


In [3]:
datamodule = create_datamodule_from_config(cfg)
datamodule.setup("fit")

train_loader = datamodule.train_dataloader()
val_loader = datamodule.val_dataloader()
test_loader = datamodule.test_dataloader()

summary = {
    "train_batches": len(datamodule.train_ds) if hasattr(datamodule.train_ds, "__len__") else "iterable",
    "val_samples": len(datamodule.val_ds),
    "test_samples": len(datamodule.test_ds),
}
pd.Series(summary)

[ShardedDataModule] Train=22  Val=39076  Test=21350  Shards(npz)=6  time_stats=train  train_mode=pt-batch


train_batches       22
val_samples      39076
test_samples     21350
dtype: int64

In [4]:
first_train = next(iter(train_loader))
first_val = next(iter(val_loader))
first_test = next(iter(test_loader))

def describe_batch(batch):
    return pd.DataFrame(
        {
            key: {"shape": tuple(value.shape), "dtype": str(value.dtype)}
            for key, value in batch.items()
        }
    ).T

train_desc = describe_batch(first_train)
val_desc = describe_batch(first_val)
test_desc = describe_batch(first_test)

print("Train batch overview:")
display(train_desc)
print("Validation batch overview:")
display(val_desc)
print("Test batch overview:")
display(test_desc)



Train batch overview:


Unnamed: 0,shape,dtype
lat,"(8192,)",torch.float32
lon,"(8192,)",torch.float32
t,"(8192,)",torch.float32
y,"(8192,)",torch.float32


Validation batch overview:


Unnamed: 0,shape,dtype
lat,"(64,)",torch.float32
lon,"(64,)",torch.float32
t,"(64,)",torch.float32
y,"(64,)",torch.float32


Test batch overview:


Unnamed: 0,shape,dtype
lat,"(64,)",torch.float32
lon,"(64,)",torch.float32
t,"(64,)",torch.float32
y,"(64,)",torch.float32


In [5]:
lit_module = LitRegressor(cfg, datamodule.stats)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lit_module.to(device)

lat = first_train["lat"].to(device).float()
lon = first_train["lon"].to(device).float()
t = first_train.get("t")
if t is not None:
    t = t.to(device).float()
y = first_train["y"].to(device).float()

lit_module.eval()
with torch.no_grad():
    pe_out = lit_module.model.pe(lat, lon, t)
    net_in = pe_out
    net_out = lit_module.model.net(net_in)
    model_out = lit_module.model(lat, lon, t)
    if isinstance(model_out, tuple):
        model_out = model_out[0]

shape_report = pd.DataFrame(
    [
        {"stage": "inputs", "shape": tuple(lat.shape), "notes": "lat/lon/t vectors"},
        {"stage": "positional_encoding", "shape": tuple(pe_out.shape), "notes": "rect_baseline output"},
        {"stage": "net_input", "shape": tuple(net_in.shape), "notes": "features passed into MLP"},
        {"stage": "net_output", "shape": tuple(net_out.shape), "notes": "raw MLP prediction"},
        {"stage": "model_output", "shape": tuple(model_out.shape), "notes": "model forward() result"},
    ]
)
lit_module.train()
shape_report

Unnamed: 0,stage,shape,notes
0,inputs,"(8192,)",lat/lon/t vectors
1,positional_encoding,"(8192, 9)",rect_baseline output
2,net_input,"(8192, 9)",features passed into MLP
3,net_output,"(8192,)",raw MLP prediction
4,model_output,"(8192, 1)",model forward() result


In [6]:
trainer_cfg = cfg.get("trainer", {})
experiment_cfg = cfg.get("experiment", {})
output_root = Path(experiment_cfg.get("output_root", checkpoint_dir.parent)).resolve()

checkpoint_cb = ModelCheckpoint(
    dirpath=str(checkpoint_dir),
    filename=f"{experiment_name}-{{epoch:02d}}-{{val_loss:.4f}}",
    monitor="val_loss",
    mode="min",
    save_top_k=1,
)
callbacks = [SetEpochOnIterable(), checkpoint_cb, LearningRateMonitor(logging_interval="epoch")]

logger = CSVLogger(save_dir=str(output_root), name="lightning", version=experiment_name)

trainer_kwargs = {
    "max_epochs": int(trainer_cfg.get("max_epochs", 1)),
    "log_every_n_steps": int(trainer_cfg.get("log_every_n_steps", 1)),
    "gradient_clip_val": float(trainer_cfg.get("gradient_clip_val", 0.0)),
    "precision": trainer_cfg.get("precision", 32),
    "accelerator": trainer_cfg.get("accelerator", "auto"),
    "devices": trainer_cfg.get("devices", "auto"),
    "limit_train_batches": trainer_cfg.get("limit_train_batches"),
    "limit_val_batches": trainer_cfg.get("limit_val_batches"),
    "limit_test_batches": trainer_cfg.get("limit_test_batches"),
    "default_root_dir": str(Path(trainer_cfg.get("default_root_dir", output_root)).resolve()),
}
trainer_kwargs = {k: v for k, v in trainer_kwargs.items() if v is not None}

pl.seed_everything(int(cfg["data"].get("seed", 0)), workers=True)
trainer = pl.Trainer(**trainer_kwargs, callbacks=callbacks, logger=logger)


Seed set to 0
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(limit_test_batches=1)` was configured so 1 batch will be used.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(limit_test_batches=1)` was configured so 1 batch will be used.


In [8]:
fit_result = trainer.fit(lit_module, datamodule=datamodule)

best_ckpt = checkpoint_cb.best_model_path or "<no checkpoint saved>"
metrics_csv = Path(logger.log_dir) / "metrics.csv"

print(f"Best checkpoint candidate: {best_ckpt}")
print(f"Metrics logged to: {metrics_csv}")

/Users/mako3626/newfrontiers/newfrontiers/lib/python3.13/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:751: Checkpoint directory /Users/mako3626/newfrontiers/poseidon/notebooks/experiments/example_rect_mlp/checkpoints exists and is not empty.

  | Name  | Type        | Params | Mode 
----------------------------------------------
0 | model | LocEncModel | 34.4 K | train
----------------------------------------------
34.4 K    Trainable params
0         Non-trainable params
34.4 K    Total params
0.138     Total estimated model params size (MB)
11        Modules in train mode
0         Modules in eval mode


[ShardedDataModule] Train=22  Val=39076  Test=21350  Shards(npz)=6  time_stats=train  train_mode=pt-batch


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/mako3626/newfrontiers/newfrontiers/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.
/Users/mako3626/newfrontiers/newfrontiers/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.
/Users/mako3626/newfrontiers/newfrontiers/lib/python3.13/site-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data

Best checkpoint candidate: /Users/mako3626/newfrontiers/poseidon/notebooks/experiments/example_rect_mlp/checkpoints/example_rect_mlp-epoch=03-val_loss=6117.7256.ckpt
Metrics logged to: /Users/mako3626/newfrontiers/poseidon/notebooks/experiments/example_rect_mlp/lightning/example_rect_mlp/metrics.csv


In [None]:
ckpt_path = "best" if checkpoint_cb.best_model_path else None
test_results = trainer.test(ckpt_path=ckpt_path, datamodule=datamodule)

print(f"Evaluated checkpoint: {checkpoint_cb.best_model_path or '<last>'}")

Restoring states from the checkpoint path at /Users/mako3626/newfrontiers/poseidon/notebooks/experiments/example_rect_mlp/checkpoints/example_rect_mlp-epoch=03-val_loss=6117.7256.ckpt
Loaded model weights from the checkpoint at /Users/mako3626/newfrontiers/poseidon/notebooks/experiments/example_rect_mlp/checkpoints/example_rect_mlp-epoch=03-val_loss=6117.7256.ckpt
Loaded model weights from the checkpoint at /Users/mako3626/newfrontiers/poseidon/notebooks/experiments/example_rect_mlp/checkpoints/example_rect_mlp-epoch=03-val_loss=6117.7256.ckpt


[ShardedDataModule] Train=22  Val=39076  Test=21350  Shards(npz)=6  time_stats=train  train_mode=pt-batch


/Users/mako3626/newfrontiers/newfrontiers/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss              79860.984375
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Evaluated checkpoint: /Users/mako3626/newfrontiers/poseidon/notebooks/experiments/example_rect_mlp/checkpoints/example_rect_mlp-epoch=03-val_loss=6117.7256.ckpt
Evaluated checkpoint: /Users/mako3626/newfrontiers/poseidon/notebooks/experiments/example_rect_mlp/checkpoints/example_rect_mlp-epoch=03-val_loss=6117.7256.ckpt


Unnamed: 0,test_loss
0,79860.984375
