# disRNN Result Access — Example Usage

This notebook demonstrates how to use `aind-disrnn-result-access` to query
W&B run metadata and download training artifacts.

**Prerequisites:**
- Install the package: `uv sync` or `pip install -e .`
- Set WANDB_API_KEY environment variable (or run `wandb login`)

In [None]:
# Autoload notebook extension for improved usability
%load_ext autoreload
%autoreload 2

## 1. Initialize the Client

In [None]:
from aind_disrnn_result_access import WandbClient

client = WandbClient()  # defaults to entity="AIND-disRNN"

## 2. List Available Projects

In [None]:
projects = client.get_projects()
print("Available projects:")
for p in projects:
    print(f"  - {p}")

## 3. Browse Runs

List runs in a project. You can filter and sort them.

In [None]:
# Pick a project (change this to whichever project you want to explore)
PROJECT = "han_cpu_gpu_test"

In [None]:
runs = client.get_runs(project=PROJECT)
print(f"Found {len(runs)} runs in '{PROJECT}'")
for run in runs[:5]:  # show first 5
    print(f"  [{run.id}] {run.name} — state={run.state}")

### Filter runs

Use [MongoDB-style queries](https://docs.wandb.ai/guides/runs/filter-runs) to filter runs.

In [None]:
finished_runs = client.get_runs(
    project=PROJECT,
    filters={"state": "finished"},
)
print(f"Found {len(finished_runs)} finished runs")

### Get runs as DataFrame

For easier analysis, you can get all runs as a pandas DataFrame (similar to W&B web UI table):

In [None]:
# Get runs as DataFrame with flattened config and summary
df = client.get_runs_dataframe(project=PROJECT)

print(f"DataFrame shape: {df.shape}")
print(f"\nColumns ({len(df.columns)} total):")
print(f"  Basic: {[c for c in df.columns if not c.startswith(('config.', 'summary.'))]}")
print(f"  Config: {[c for c in df.columns if c.startswith('config.')][:5]}...")
print(f"  Summary: {[c for c in df.columns if c.startswith('summary.')][:5]}...")

# Display first few rows with selected columns
display_cols = ['id', 'name', 'state', 'summary.likelihood', 'summary.final.val_loss']
available_cols = [c for c in display_cols if c in df.columns]
print(f"\nFirst 5 runs:")
df[available_cols].head()

### DataFrame analysis examples

The DataFrame makes it easy to filter and analyze runs:

In [None]:
# Example: Filter runs by performance
if 'summary.likelihood' in df.columns:
    high_performers = df[df['summary.likelihood'] > 0.8]
    print(f"High-performing runs (likelihood > 0.8): {len(high_performers)}")
    print(high_performers[['id', 'name', 'summary.likelihood']].head())

# Example: Sort by validation loss
if 'summary.final.val_loss' in df.columns:
    best_runs = df.sort_values('summary.final.val_loss').head(3)
    print(f"\nTop 3 runs by validation loss:")
    print(best_runs[['id', 'name', 'summary.final.val_loss']])

# Example: Group by config parameter
if 'config.data.batch_size' in df.columns:
    print(f"\nRuns by batch size:")
    print(df.groupby('config.data.batch_size').size())

## 4. Inspect a Single Run

Get detailed metadata for a specific run.

In [None]:
# Use the first run from our list (or replace with a known run ID)
if runs:
    run = client.get_run(runs[0].id, project=PROJECT)
    print(f"Run: {run.name} ({run.id})")
    print(f"State: {run.state}")
    print(f"Tags: {run.tags}")
    print(f"Created: {run.created_at}")
    print(f"URL: {run.url}")
else:
    print("No runs found — adjust PROJECT above.")

### Run config (training hyperparameters)

In [None]:
if runs:
    print("Config:")
    for key, value in run.config.items():
        print(f"  {key}: {value}")

### Run summary (final metrics)

In [None]:
if runs:
    print("Summary metrics:")
    for key, value in run.summary.items():
        print(f"  {key}: {value}")

## 5. Download Artifacts

Download training output artifacts (model parameters, plots, CSVs) for a run.

**Default behavior:** Downloads to `/root/capsule/results/downloaded_artifacts/<artifact_name>/`

Note: Artifact names typically contain the run_id (e.g., `disrnn-output-0q45cmry`).

In [None]:
if runs:
    artifacts = client.download_artifact(
        runs[0].id,
        project=PROJECT,
    )
    for art in artifacts:
        print(f"Artifact: {art.name} (type={art.type}, version={art.version})")
        print(f"  Downloaded to: {art.download_path}")
        print(f"  Files: {art.files}")

### Download specific files only

You can download only specific files instead of the entire artifact:

In [None]:
if runs:
    # Download only params.json
    artifacts = client.download_artifact(
        runs[0].id,
        project=PROJECT,
        files=["params.json"]
    )
    print(f"Downloaded {len(artifacts[0].files)} file(s): {artifacts[0].files}")

## 6. Explore Downloaded Files

After downloading, artifacts are available as local files in `/root/capsule/results/downloaded_artifacts/`.

In [None]:
import json
from pathlib import Path

if runs:
    # Download all files using default settings
    artifacts = client.download_artifact(runs[0].id, project=PROJECT)
    artifact_dir = artifacts[0].download_path
    
    print(f"Contents of {artifact_dir}:")
    for f in sorted(Path(artifact_dir).rglob("*")):
        if f.is_file():
            print(f"  {f.name}")

    # Example: load params.json if it exists
    params_file = artifact_dir / "params.json"
    if params_file.exists():
        with open(params_file) as fh:
            params = json.load(fh)
        print(f"\nLoaded params.json with {len(params)} keys")
        for k, v in list(params.items())[:5]:
            print(f"  {k}: {v}")