# ADAPT Data API Demo

This notebook demonstrates the **DataAPI** — the subscription-based read interface
for downstream consumers of ADAPT pipeline outputs.

## Architecture Recap

```
Thread 1 (Acquisition)  →  DataStore  ←  Thread 2 (Processing)
                              ↑
                     Thread 3 (External)
                     This notebook / scripts
```

- The **DataStore** (SQLite) is the sole communication channel.
- The **DataAPI** wraps the store with a read-only, subscription-based interface.
- External scripts **never** import pipeline internals — they only use `DataAPI`.

## Two Data Kinds

| Kind   | Format  | Reader           | Examples                            |
|--------|---------|------------------|-------------------------------------|
| Grid   | NetCDF  | `xr.Dataset`     | Reflectivity, segmentation masks, flow vectors |
| Series | SQLite  | `pd.DataFrame`   | Cell statistics timeseries          |

## Setup

We create a temporary DataStore, register some fake products, and then use the
DataAPI to read them — exactly as a real visualization or analysis script would.

In [None]:
import tempfile
import sqlite3
from pathlib import Path
from datetime import datetime, timedelta

import numpy as np
import xarray as xr
import pandas as pd

from adapt.core.store import DataStore
from adapt.data_access import DataAPI, Subscription

In [None]:
# Create a temporary workspace
tmpdir = Path(tempfile.mkdtemp(prefix="adapt_demo_"))
db_path = tmpdir / "adapt.db"
print(f"Working directory: {tmpdir}")

## 1. Simulating Pipeline Output

In production, the pipeline writes NetCDF files and registers them in the DataStore.
Here we simulate that by creating fake grid and series products.

In [None]:
def create_fake_gridded_netcdf(path, scan_time, shape=(100, 100)):
    """Create a fake gridded reflectivity NetCDF."""
    ny, nx = shape
    y = np.linspace(-150, 150, ny)  # km
    x = np.linspace(-150, 150, nx)  # km
    
    # Simulate reflectivity with a few storm cells
    Y, X = np.meshgrid(y, x, indexing='ij')
    refl = np.full(shape, -10.0, dtype='float32')
    
    # Cell 1: centered at (30, -50)
    r1 = np.sqrt((Y - 30)**2 + (X + 50)**2)
    refl += 55 * np.exp(-r1**2 / (2 * 20**2))
    
    # Cell 2: centered at (-20, 60)
    r2 = np.sqrt((Y + 20)**2 + (X - 60)**2)
    refl += 45 * np.exp(-r2**2 / (2 * 15**2))
    
    ds = xr.Dataset(
        {'reflectivity': (('y', 'x'), refl)},
        coords={'y': y, 'x': x},
        attrs={
            'scan_time': scan_time.isoformat(),
            'radar_id': 'KLOT',
            'z_level_m': 2000,
        },
    )
    path.parent.mkdir(parents=True, exist_ok=True)
    ds.to_netcdf(path)
    ds.close()
    return path


def create_fake_segmented_netcdf(path, scan_time, shape=(100, 100)):
    """Create a fake segmented NetCDF with reflectivity, cell_labels, and heading vectors."""
    ny, nx = shape
    y = np.linspace(-150, 150, ny)
    x = np.linspace(-150, 150, nx)
    Y, X = np.meshgrid(y, x, indexing='ij')
    
    # Reflectivity
    refl = np.full(shape, -10.0, dtype='float32')
    r1 = np.sqrt((Y - 30)**2 + (X + 50)**2)
    refl += 55 * np.exp(-r1**2 / (2 * 20**2))
    r2 = np.sqrt((Y + 20)**2 + (X - 60)**2)
    refl += 45 * np.exp(-r2**2 / (2 * 15**2))
    
    # Cell labels (0 = background, 1 = cell 1, 2 = cell 2)
    labels = np.zeros(shape, dtype='int32')
    labels[r1 < 30] = 1
    labels[r2 < 25] = 2
    
    # Heading vectors (simulated storm motion)
    heading_x = np.where(labels > 0, 5.0, 0.0).astype('float32')
    heading_y = np.where(labels > 0, 2.0, 0.0).astype('float32')
    
    ds = xr.Dataset(
        {
            'reflectivity': (('y', 'x'), refl),
            'cell_labels': (('y', 'x'), labels),
            'heading_x': (('y', 'x'), heading_x),
            'heading_y': (('y', 'x'), heading_y),
        },
        coords={'y': y, 'x': x},
        attrs={
            'scan_time': scan_time.isoformat(),
            'radar_id': 'KLOT',
            'z_level_m': 2000,
        },
    )
    path.parent.mkdir(parents=True, exist_ok=True)
    ds.to_netcdf(path)
    ds.close()
    return path


def create_fake_analysis_sqlite(path, scan_times):
    """Create a fake analysis SQLite with cell statistics."""
    path.parent.mkdir(parents=True, exist_ok=True)
    conn = sqlite3.connect(str(path))
    conn.execute("""
        CREATE TABLE cells (
            scan_time TEXT,
            cell_label INTEGER,
            reflectivity_max REAL,
            reflectivity_mean REAL,
            area_km2 REAL,
            centroid_x REAL,
            centroid_y REAL
        )
    """)
    
    rows = []
    for t in scan_times:
        for cell_id in [1, 2]:
            rows.append((
                t.isoformat(),
                cell_id,
                round(40 + np.random.rand() * 20, 1),
                round(25 + np.random.rand() * 15, 1),
                round(50 + np.random.rand() * 100, 1),
                round(-50 + cell_id * 80 + np.random.randn() * 5, 1),
                round(30 - cell_id * 40 + np.random.randn() * 5, 1),
            ))
    
    conn.executemany(
        "INSERT INTO cells VALUES (?, ?, ?, ?, ?, ?, ?)",
        rows,
    )
    conn.commit()
    conn.close()
    return path


print("Helper functions defined.")

In [None]:
# Simulate a sequence of 5 scans (5 minutes apart)
base_time = datetime(2025, 6, 15, 18, 0, 0)
scan_times = [base_time + timedelta(minutes=5 * i) for i in range(5)]

store = DataStore(db_path)

# Register gridded + segmented NetCDFs
for t in scan_times:
    ts = t.strftime('%Y%m%d_%H%M%S')
    
    # Gridded
    grid_path = create_fake_gridded_netcdf(
        tmpdir / f"gridded/KLOT_{ts}_gridded.nc", t
    )
    grid_id = store.register_product(
        product_type="gridded_netcdf",
        file_path=grid_path,
        radar_id="KLOT",
        scan_time=t,
        producer_module="acquisition",
    )
    
    # Segmented
    seg_path = create_fake_segmented_netcdf(
        tmpdir / f"segmented/KLOT_{ts}_segmented.nc", t
    )
    store.register_product(
        product_type="segmented_netcdf",
        file_path=seg_path,
        radar_id="KLOT",
        scan_time=t,
        producer_module="detection",
        parent_ids=[grid_id],
    )

# Register analysis SQLite
analysis_path = create_fake_analysis_sqlite(
    tmpdir / "analysis/KLOT_cells.db", scan_times
)
store.register_product(
    product_type="analysis_sqlite",
    file_path=analysis_path,
    radar_id="KLOT",
    scan_time=scan_times[-1],
    producer_module="analysis",
)

stats = store.get_statistics(radar_id="KLOT")
print(f"Registered {stats['total_products']} products:")
for ptype, count in stats['products_by_type'].items():
    print(f"  {ptype}: {count}")

## 2. Using the DataAPI

From this point, we use **only** the `DataAPI`. This is the same interface a real
visualization script or Jupyter analysis session would use.

In [None]:
api = DataAPI.from_path(db_path)
print(f"DataAPI connected to: {db_path}")

### 2a. List Available Grid Products

In [None]:
# List all segmented grids
seg_products = api.list_grids(
    product_type="segmented_netcdf",
    radar_id="KLOT",
)

print(f"Found {len(seg_products)} segmented grids:")
for p in seg_products:
    print(f"  {p.product_id[:8]}...  scan={p.scan_time}  file={Path(p.file_path).name}")

### 2b. Read a Grid Product (NetCDF → xr.Dataset)

In [None]:
# Read the most recent segmented grid
ds = api.get_latest_grid("segmented_netcdf", "KLOT")

print(f"Scan time: {ds.attrs['scan_time']}")
print(f"Variables: {list(ds.data_vars)}")
print(f"Grid shape: {ds['reflectivity'].shape}")
print(f"Cells detected: {len(np.unique(ds['cell_labels'].values)) - 1}")
ds

### 2c. Visualize Grid Data

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Reflectivity
ds['reflectivity'].plot(ax=axes[0], cmap='turbo', vmin=-10, vmax=70)
axes[0].set_title('Reflectivity (dBZ)')

# Cell labels
labels = ds['cell_labels'].values
masked = np.ma.masked_where(labels == 0, labels)
axes[1].imshow(masked, origin='lower', cmap='tab20', alpha=0.8,
               extent=[ds.x.min(), ds.x.max(), ds.y.min(), ds.y.max()])
axes[1].set_title(f'Cell Labels ({len(np.unique(labels)) - 1} cells)')
axes[1].set_xlabel('x (km)')
axes[1].set_ylabel('y (km)')

# Heading vectors overlaid on reflectivity
ds['reflectivity'].plot(ax=axes[2], cmap='gray', alpha=0.4, add_colorbar=False)
step = 10
Y, X = np.meshgrid(ds.y.values[::step], ds.x.values[::step], indexing='ij')
hx = ds['heading_x'].values[::step, ::step]
hy = ds['heading_y'].values[::step, ::step]
mask = ds['cell_labels'].values[::step, ::step] > 0
axes[2].quiver(X[mask], Y[mask], hx[mask], hy[mask], color='red', scale=50)
axes[2].set_title('Heading Vectors')

fig.suptitle(f"KLOT — {ds.attrs['scan_time']}", fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

### 2d. Read Series Data (SQLite → pd.DataFrame)

In [None]:
# List series products
series_products = api.list_series(
    product_type="analysis_sqlite",
    radar_id="KLOT",
)
print(f"Found {len(series_products)} analysis product(s)")

# Read cell statistics
df = api.read_series(
    series_products[0],
    table="cells",
    columns=["scan_time", "cell_label", "reflectivity_max", "area_km2"],
)
print(f"\nCell statistics: {len(df)} rows")
df.head(10)

In [None]:
# Read with filters: only cell 1
df_cell1 = api.read_series(
    series_products[0],
    table="cells",
    where={"cell_label": "1"},
)
print(f"Cell 1 records: {len(df_cell1)}")
df_cell1

In [None]:
# Read with time range filter
df_recent = api.read_series(
    series_products[0],
    table="cells",
    after="2025-06-15T18:10:00",
    before="2025-06-15T18:20:00",
)
print(f"Records in time window: {len(df_recent)}")
df_recent

### 2e. Visualize Timeseries

In [None]:
# Full timeseries for both cells
df_all = api.read_series(
    series_products[0],
    table="cells",
)
df_all['scan_time'] = pd.to_datetime(df_all['scan_time'])

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

for cell_id, grp in df_all.groupby('cell_label'):
    axes[0].plot(grp['scan_time'], grp['reflectivity_max'], 'o-', label=f'Cell {cell_id}')
    axes[1].plot(grp['scan_time'], grp['area_km2'], 's-', label=f'Cell {cell_id}')

axes[0].set_ylabel('Max Reflectivity (dBZ)')
axes[0].set_title('Peak Reflectivity Over Time')
axes[0].legend()
axes[0].tick_params(axis='x', rotation=30)

axes[1].set_ylabel('Area (km²)')
axes[1].set_title('Cell Area Over Time')
axes[1].legend()
axes[1].tick_params(axis='x', rotation=30)

fig.suptitle('Cell Statistics Timeseries — KLOT', fontsize=14)
plt.tight_layout()
plt.show()

## 3. Subscription Model

The `Subscription` class enables poll-based notification: a downstream consumer
subscribes to product types and polls for new data. Each `poll()` returns only
products added since the last poll.

In [None]:
# Create a subscription
sub = api.subscribe(
    product_types=["segmented_netcdf"],
    radar_id="KLOT",
)

# First poll: gets all existing products
batch1 = sub.poll()
print(f"First poll: {len(batch1)} products")
for p in batch1:
    print(f"  {p.scan_time} — {Path(p.file_path).name}")

In [None]:
# Second poll: nothing new
batch2 = sub.poll()
print(f"Second poll: {len(batch2)} products (nothing new)")

# has_new() check
print(f"Has new: {sub.has_new()}")

In [None]:
# Simulate pipeline producing a new scan
new_time = base_time + timedelta(minutes=25)
new_ts = new_time.strftime('%Y%m%d_%H%M%S')
new_path = create_fake_segmented_netcdf(
    tmpdir / f"segmented/KLOT_{new_ts}_segmented.nc", new_time
)
store.register_product(
    product_type="segmented_netcdf",
    file_path=new_path,
    radar_id="KLOT",
    scan_time=new_time,
    producer_module="detection",
)

# Now poll picks up only the new one
print(f"Has new: {sub.has_new()}")
batch3 = sub.poll()
print(f"Third poll: {len(batch3)} new product(s)")
for p in batch3:
    print(f"  {p.scan_time} — {Path(p.file_path).name}")

### 3a. get_latest() — Peek Without Advancing the Cursor

In [None]:
latest = sub.get_latest()
print(f"Latest product: {latest.scan_time} — {Path(latest.file_path).name}")

## 4. Multi-Product Subscription

Subscribe to multiple product types simultaneously. Useful for a script that
needs both segmentation masks and analysis data.

In [None]:
multi_sub = api.subscribe(
    product_types=["segmented_netcdf", "analysis_sqlite"],
    radar_id="KLOT",
)

products = multi_sub.poll()
print(f"Multi-subscription poll: {len(products)} products")
for p in products:
    print(f"  [{p.product_type}] {p.scan_time} — {Path(p.file_path).name}")

## 5. Real-World Usage Pattern

A typical visualization script would look like this:

```python
from adapt.data_access import DataAPI
import time

api = DataAPI.from_path("output/adapt.db")

sub = api.subscribe(
    product_types=["segmented_netcdf"],
    radar_id="KLOT",
)

while True:
    for product in sub.poll():
        ds = api.read_grid(product)
        plot(ds)  # your plotting function
    time.sleep(10)
```

A timeseries analysis script:

```python
api = DataAPI.from_path("output/adapt.db")

products = api.list_series(
    product_type="analysis_sqlite",
    radar_id="KLOT",
)

df = api.read_series(
    products[0],
    table="cells",
    columns=["scan_time", "cell_label", "reflectivity_max"],
)
```

## Cleanup

In [None]:
store.close()

import shutil
shutil.rmtree(tmpdir, ignore_errors=True)
print(f"Cleaned up {tmpdir}")