# HelioNetCDFDataset AWS / Local DataLoader Smoke Test

This notebook validates that:

1. `HelioNetCDFDataset` (from `helio_updated.py`) can be imported and instantiated.
2. A child dataset pattern (from `template_dataset.py`) can subclass the AWS wrapper (`helio_aws.py`).
3. A PyTorch `DataLoader` returns the expected `(sample_dict, metadata)` where:
   - `sample_dict` contains the HelioFM keys (`ts`, `time_delta_input`, `forecast`, `lead_time_delta`) and optionally latitude keys if enabled.
   - `metadata` contains file/timestamp context.

Because this environment cannot access your real S3 bucket, the notebook creates a **tiny local NetCDF toy dataset** and runs the same data path your training loop uses. In AWS, you will swap the index to `s3://...` URIs and the same class will stream from S3 (optionally with `simplecache`).

In [None]:
# If running in a clean environment, ensure dependencies are available.
import os
import sys
from pathlib import Path

import numpy as np
import pandas as pd
import xarray as xr
import torch
from torch.utils.data import DataLoader

print("Python:", sys.version)
print("Torch:", torch.__version__)
print("Pandas:", pd.__version__)
print("Xarray:", xr.__version__)

In [None]:
# Import helio_updated.py and helio_aws.py from local files.
# This avoids depending on your repo packaging for the smoke test.

import importlib.util

def import_from_path(module_name: str, path: str):
    spec = importlib.util.spec_from_file_location(module_name, path)
    module = importlib.util.module_from_spec(spec)
    assert spec and spec.loader
    spec.loader.exec_module(module)
    return module

HELIO_UPDATED_PATH = "/mnt/data/helio_updated.py"
HELIO_AWS_PATH = "/mnt/data/helio_aws.py"

helio_updated = import_from_path("helio_updated", HELIO_UPDATED_PATH)
helio_aws = import_from_path("helio_aws", HELIO_AWS_PATH)

HelioNetCDFDataset = helio_updated.HelioNetCDFDataset
HelioNetCDFDatasetAWS = helio_aws.HelioNetCDFDatasetAWS

print("Imported:", HelioNetCDFDataset, HelioNetCDFDatasetAWS)

In [None]:
# Create a small local NetCDF toy dataset.
# We make 3 files at times: t0, t0+60m, t0+120m to satisfy the dataset's required_timesteps logic.

toy_root = Path("/mnt/data/toy_nc")
toy_root.mkdir(parents=True, exist_ok=True)

t0 = pd.Timestamp("2013-01-01 00:00:00")
timesteps = [t0 + pd.Timedelta(minutes=m) for m in (0, 60, 120)]

channels = ["ch1", "ch2"]
H, W = 8, 8

paths = []
for ts in timesteps:
    arr1 = np.random.rand(H, W).astype(np.float32)
    arr2 = np.random.rand(H, W).astype(np.float32)
    ds = xr.Dataset(
        {
            "ch1": (("y", "x"), arr1),
            "ch2": (("y", "x"), arr2),
        }
    )
    fp = toy_root / f"{ts.strftime('%Y%m%d_%H%M')}.nc"
    # Use h5netcdf engine to match your reader (engine='h5netcdf').
    ds.to_netcdf(fp, engine="h5netcdf")
    paths.append(str(fp))

paths[:3], len(paths)

In [None]:
# Build a CSV index in the same shape as create_csv_index.py outputs:
# columns: timestep, path, present

index_df = pd.DataFrame(
    {
        "timestep": timesteps,
        "path": paths,
        "present": [1, 1, 1],
    }
)

index_csv = Path("/mnt/data/toy_index.csv")
index_df.to_csv(index_csv, index=False)
index_csv, index_df.head()

In [None]:
# Define a child dataset following the template pattern (template_dataset.py),
# but minimal: it adds a dummy downstream label under key 'forecast' (or 'ds_label').
#
# In your real use, you would import HelioNetCDFDatasetAWS from your package:
#   from Surya.surya.datasets.helio_aws import HelioNetCDFDatasetAWS

class ToyChildDataset(HelioNetCDFDatasetAWS):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Dummy downstream label per sample index
        self._labels = np.linspace(0, 1, num=len(self.valid_indices), dtype=np.float32)

    def __getitem__(self, idx: int):
        base_dict, metadata = super().__getitem__(idx)
        # Add a downstream label key similar to template_dataset.py
        base_dict["ds_label"] = np.array(self._labels[idx]).astype(np.float32)
        return base_dict, metadata

print("ToyChildDataset defined.")

In [None]:
# Instantiate dataset and dataloader.
#
# Settings chosen to require exactly 3 timesteps per sample:
#   - n_input_timestamps = 1
#   - time_delta_input_minutes = [0]
#   - rollout_steps = 1  => needs 2 target steps (rollout_steps + 1)
#   - time_delta_target_minutes = 60 => target at +60m and +120m
#
# Resulting shapes (for C=2, H=W=8, n_input=1, rollout_steps=1):
#   ts:        (C, T=1, H, W)
#   forecast:  (C, L=2, H, W)

dataset = ToyChildDataset(
    index_path=str(index_csv),
    time_delta_input_minutes=[0],
    time_delta_target_minutes=60,
    n_input_timestamps=1,
    rollout_steps=1,
    scalers=None,
    num_mask_aia_channels=0,
    drop_hmi_probability=0.0,
    use_latitude_in_learned_flow=False,   # keep latitudes disabled for this smoke test
    channels=channels,
    phase="val",
    # AWS/S3 knobs are accepted (not used here since we are local):
    s3_use_simplecache=True,
    s3_cache_dir="/tmp/helio_s3_cache",
)

loader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=0)

print("len(dataset):", len(dataset))

In [None]:
# Fetch one batch and validate keys / shapes.

batch = next(iter(loader))
sample_dict, metadata = batch

print("Sample keys:", sorted(sample_dict.keys()))
for k, v in sample_dict.items():
    if torch.is_tensor(v):
        print(f"{k:15s} tensor shape={tuple(v.shape)} dtype={v.dtype}")
    else:
        # DataLoader collates numpy scalars into arrays/objects depending on type
        print(f"{k:15s} type={type(v)}")

print("\nMetadata type:", type(metadata))
# metadata is typically a dict of lists/strings; print a couple fields if present:
if isinstance(metadata, dict):
    print("Metadata keys:", list(metadata.keys())[:10])

In [None]:
# Assertions: expected dictionary contract.
# NOTE: latitude keys are only present when use_latitude_in_learned_flow=True.

expected_base_keys = {"ts", "time_delta_input", "forecast", "lead_time_delta"}
assert expected_base_keys.issubset(sample_dict.keys()), f"Missing expected keys: {expected_base_keys - set(sample_dict.keys())}"

# Child key
assert "ds_label" in sample_dict, "Missing child-added key 'ds_label'"

# Shape checks (batch_size=2)
C = len(channels)
assert sample_dict["ts"].shape[1] == C, "Expected channel dimension at axis=1 after batching (B, C, T, H, W)"
assert sample_dict["forecast"].shape[1] == C, "Expected channel dimension at axis=1 after batching (B, C, L, H, W)"

print("All assertions passed.")

## Running against real S3 (AWS environment)

In AWS (EC2/EKS), you would point `index_path` to your S3 index CSV (or a local copy of it) where each `path` is an S3 URI, for example:

`s3://nasa-surya-bench/2013/01/20130130_1924.nc`

Then instantiate with the same class:

```python
dataset = HelioNetCDFDatasetAWS(
    index_path="surya_aws_s3_val.csv",
    time_delta_input_minutes=[0, 12, 24, 36, 48, 60],
    time_delta_target_minutes=12,
    n_input_timestamps=6,
    rollout_steps=4,
    scalers=scalers,
    channels=[...],
    phase="val",
    s3_use_simplecache=True,
    s3_cache_dir="/tmp/helio_s3_cache",  # ideally fast local disk
)
```

Operational guidance:
- Run in the same AWS region as the bucket.
- Prefer instance roles (IAM) over static keys.
- Keep `s3_use_simplecache=True` unless you have a strong reason to disable it.