# ASTR-113: Real Data Loading Smoke Test

This notebook validates the real-data collection and loading pipeline:
- Collect validated detections into a `TrainingDataset`
- Persist `TrainingSample` rows
- Load the dataset and create train/val/test splits

Run this before integrating with `notebooks/training/model_training.ipynb`. 


In [None]:
# Setup imports and environment
import os, sys
from pathlib import Path

project_root = Path.cwd().parent.parent
sys.path.insert(0, str(project_root))

print(f"Project root added: {project_root}")


In [None]:
API_BASE = os.environ.get("ASTRID_API_BASE", "http://127.0.0.1:8000")
API_BASE


In [None]:
from src.core.constants import TRAINING_PIPELINE_API_KEY

global AUTH_HEADERS
AUTH_HEADERS = {
    "X-API-Key": TRAINING_PIPELINE_API_KEY,
    "Content-Type": "application/json",
}

In [None]:
# Collect a small dataset via API
import requests
from datetime import datetime

params = {
    "survey_ids": ["hst"],
    "start": "2024-01-01T00:00:00",
    "end": "2024-12-31T23:59:59",
    "confidence_threshold": 0.7,
    "max_samples": 50,
    "name": "smoketest_hst_2024"
}

r = requests.post(f"{API_BASE}/training/datasets/collect", json=params, headers=AUTH_HEADERS, timeout=60)
r.raise_for_status()
resp = r.json()
resp


In [None]:
# Verify dataset is listed
r = requests.get(f"{API_BASE}/training/datasets", headers=AUTH_HEADERS)
r.raise_for_status()
datasets = r.json()["data"]

print(f"Datasets: {len(datasets)}")
# Show the most recent one
sorted(datasets, key=lambda d: d.get("created_at", ""), reverse=True)[:3]


In [None]:
# Real data scaffold: fetch SkyView cutout and upload to R2\nimport os, json, tempfile, uuid\nfrom pathlib import Path\n\nfrom src.adapters.external.skyview import SkyViewClient\nfrom src.adapters.external.r2 import R2StorageClient\nfrom src.adapters.imaging.fits_io import FITSProcessor\nfrom src.adapters.imaging.utils import to_display_image\nfrom src.core.constants import MLFLOW_BUCKET_NAME, CLOUDFLARE_R2_BUCKET_NAME\n\n# Config\nbucket = MLFLOW_BUCKET_NAME or "astrid-models"\nmanifest_path = Path.cwd() / "real_data_manifest.json"\n\n# Minimal fetch + upload (DSS2 Red around a bright field)\nra, dec = 180.0, 45.0\nclient = SkyViewClient(timeout=60)\nproc = FITSProcessor()\nr2 = R2StorageClient()\n\n# Get displayable image via hips2fits/SkyView helper\nimg, info = client.fetch_reference_image(ra, dec, size_pixels=512, fov_deg=0.05, survey="DSS2 Red")\nprint("Source:", info.get("source"), "error:", info.get("error"))\n\n# Save to a temp FITS and upload\ntmpdir = Path(tempfile.gettempdir())\nout = tmpdir / f"skyview_{uuid.uuid4().hex[:8]}.fits"\n# Convert display image to FITS primary array for this scaffold\nimport numpy as np\nif img is None:\n    raise RuntimeError("No image returned from SkyView/HiPS2FITS")\nproc.save_fits(np.asarray(img), str(out))\n\nobject_key = f"references/skyview/DSS2_Red/{ra:.4f}_{dec:.4f}.fits"\nres = await r2.upload_file(local_path=str(out), object_key=object_key, bucket=bucket, metadata={"ra": str(ra), "dec": str(dec), "survey": "DSS2 Red"})\nprint("Uploaded:", res.get("url"))\n\n# Write simple manifest for training pipeline consumption\nentry = {"bucket": bucket, "key": object_key, "ra": ra, "dec": dec, "survey": "DSS2 Red"}\nprev = []\nif manifest_path.exists():\n    try:\n        prev = json.loads(manifest_path.read_text())\n    except Exception:\n        prev = []\nprev.append(entry)\nmanifest_path.write_text(json.dumps(prev, indent=2))\nprint("Manifest updated:", manifest_path)


In [None]:
# Batch: build a tiny grid of positions and upload multiple cutouts\nimport numpy as np, time\nfrom typing import List, Dict\n\nbase_ra, base_dec = 180.0, 45.0\noffsets = np.linspace(-0.05, 0.05, 3)  # 3x3 grid\npositions = [(base_ra + dx, base_dec + dy) for dx in offsets for dy in offsets]\n\nadded: List[Dict] = []\nfor (ra, dec) in positions:\n    try:\n        img, info = client.fetch_reference_image(ra, dec, size_pixels=384, fov_deg=0.04, survey="DSS2 Red")\n        if img is None:\n            print("skip (no image)", ra, dec, info.get("error"))\n            continue\n        tmp = tmpdir / f"skyview_{uuid.uuid4().hex[:8]}.fits"\n        proc.save_fits(np.asarray(img), str(tmp))\n        key = f"references/skyview/DSS2_Red/{ra:.4f}_{dec:.4f}.fits"\n        res = await r2.upload_file(local_path=str(tmp), object_key=key, bucket=bucket, metadata={"ra": str(ra), "dec": str(dec), "survey": "DSS2 Red"})\n        entry = {"bucket": bucket, "key": key, "ra": ra, "dec": dec, "survey": "DSS2 Red"}\n        added.append(entry)\n        print("+", res.get("url"))\n        time.sleep(0.2)\n    except Exception as e:\n        print("err", ra, dec, e)\n\n# Append to manifest\nprev = []\nif manifest_path.exists():\n    try:\n        prev = json.loads(manifest_path.read_text())\n    except Exception:\n        prev = []\nmanifest_path.write_text(json.dumps(prev + added, indent=2))\nprint(f"Appended {len(added)} entries to manifest =>", manifest_path)


In [None]:
# Batch builder: create multiple SkyView cutouts and append manifest\nimport asyncio\nimport json\nfrom typing import Sequence, Tuple\n\nasync def build_skyview_manifest(\n    positions: Sequence[Tuple[float, float]],\n    survey: str = "DSS2 Red",\n    fov_deg: float = 0.05,\n    size_pixels: int = 512,\n) -> list[dict]:\n    client = SkyViewClient(timeout=60)\n    proc = FITSProcessor()\n    r2c = R2StorageClient()\n\n    entries: list[dict] = []\n    for ra, dec in positions:\n        img, info = client.fetch_reference_image(\n            ra, dec, size_pixels=size_pixels, fov_deg=fov_deg, survey=survey\n        )\n        if img is None:\n            print("skip (no image)", ra, dec, info.get("error"))\n            continue\n        tmp = Path(tempfile.gettempdir()) / f"skyview_{uuid.uuid4().hex[:8]}.fits"\n        import numpy as np\n        proc.save_fits(np.asarray(img), str(tmp))\n        key = f"references/skyview/{survey.replace(' ', '_')}/{ra:.4f}_{dec:.4f}_{fov_deg:.3f}deg.fits"\n        up = await r2c.upload_file(local_path=str(tmp), object_key=key, bucket=bucket, metadata={\n            "ra": str(ra), "dec": str(dec), "survey": survey, "fov_deg": str(fov_deg)\n        })\n        entries.append({"bucket": bucket, "key": key, "url": up.get("url"), "ra": ra, "dec": dec, "survey": survey})\n        print("uploaded", key)\n\n    # Append to manifest\n    prev: list[dict] = []\n    if manifest_path.exists():\n        try:\n            prev = json.loads(manifest_path.read_text())\n        except Exception:\n            prev = []\n    prev.extend(entries)\n    manifest_path.write_text(json.dumps(prev, indent=2))\n    print(f"Manifest now has {len(prev)} entries -> {manifest_path}")\n    return entries\n\n# Example small grid (RA/Dec near 180,45)\npos_list = [(180.0, 45.0), (180.05, 45.0), (179.95, 45.02)]\nawait build_skyview_manifest(pos_list)


In [None]:
# Inspect dataset via API (no direct DB access needed in notebook)
# Resolve dataset_id from prior response or fallback to latest dataset via API
if isinstance(resp, dict) and "data" in resp and resp["data"].get("dataset_id"):
    dataset_id = resp["data"]["dataset_id"]
else:
    r = requests.get(f"{API_BASE}/training/datasets", headers=AUTH_HEADERS)
    r.raise_for_status()
    datasets = r.json().get("data", [])
    if not datasets:
        raise RuntimeError("No training datasets available")
    # pick the most recent
    datasets_sorted = sorted(datasets, key=lambda d: d.get("created_at", ""), reverse=True)
    dataset_id = str(datasets_sorted[0]["id"])  # ensure string

print("Dataset ID:", dataset_id)

# Get dataset details
detail = requests.get(f"{API_BASE}/training/datasets/{dataset_id}", headers=AUTH_HEADERS)
detail.raise_for_status()
detail.json()


In [None]:
# Optional: preview a few sample image paths from DB via API
# (Future enhancement: API could return sample listings)
print("For now, verify counts above. Integration with training notebook will consume by dataset_id.")
