# disRNN Result Access — Example Usage

This notebook demonstrates how to use `aind-disrnn-result-access` to query
W&B run metadata and download training artifacts.

**Prerequisites:**
- Install the package: `uv sync` or `pip install -e .`
- Set WANDB_API_KEY environment variable (or run `wandb login`)

In [1]:
# Autoload notebook extension for improved usability
%load_ext autoreload
%autoreload 2

## 1. Initialize the Client

In [2]:
from aind_disrnn_result_access import WandbClient

client = WandbClient()  # defaults to entity="AIND-disRNN"

[34m[1mwandb[0m: [wandb.Api()] Loaded credentials for https://api.wandb.ai from WANDB_API_KEY.


## 2. List Available Projects

In [3]:
projects = client.get_projects()
print("Available projects:")
for p in projects:
    print(f"  - {p}")

Available projects:
  - alex_test
  - pochen_mice_multisubject
  - rachel_mice_grurnn_parascan
  - debug-synthetic_task_trained_rnn
  - synthetic_task_trained_rnn
  - han_synthetic_rl_disrnn
  - han_mice_disrnn
  - han_cpu_gpu_test
  - han_mice_disrnn_bottleneck_overtime
  - han_mice_disrnn_parascan
  - test
  - han-test


## 3. Browse Runs

List runs in a project. You can filter and sort them.

In [4]:
# Pick a project (change this to whichever project you want to explore)
PROJECT = "han_cpu_gpu_test"

In [5]:
runs = client.get_runs(project=PROJECT)
print(f"Found {len(runs)} runs in '{PROJECT}'")
for run in runs[:5]:  # show first 5
    print(f"  [{run.id}] {run.name} — state={run.state}")

Found 73 runs in 'han_cpu_gpu_test'
  [0q45cmry] 1gpu4cpu64G_bs2048_wu500_5_3_4_0_beta0.001_lr0.005_grad_clip1 — state=finished
  [lzij3ybo] 1gpu4cpu64G_bs256_wu500_5_3_4_0_beta0.001_lr0.005_grad_clip1 — state=finished
  [0ec4vqq6] 1gpu4cpu64G_bs1024_wu500_5_3_4_0_beta0.001_lr0.005_grad_clip1 — state=finished
  [cju5e411] 1gpu4cpu64G_bs2048_wu500_5_3_4_0_beta0.001_lr0.005_grad_clip1 — state=finished
  [1mqgpxrq] 1gpu4cpu64G_bs128_wu500_5_3_4_0_beta0.001_lr0.005_grad_clip1 — state=finished


### Filter runs

Use [MongoDB-style queries](https://docs.wandb.ai/guides/runs/filter-runs) to filter runs.

In [6]:
finished_runs = client.get_runs(
    project=PROJECT,
    filters={"state": "finished"},
)
print(f"Found {len(finished_runs)} finished runs")

Found 45 finished runs


## 4. Inspect a Single Run

Get detailed metadata for a specific run.

In [7]:
# Use the first run from our list (or replace with a known run ID)
if runs:
    run = client.get_run(runs[0].id, project=PROJECT)
    print(f"Run: {run.name} ({run.id})")
    print(f"State: {run.state}")
    print(f"Tags: {run.tags}")
    print(f"Created: {run.created_at}")
    print(f"URL: {run.url}")
else:
    print("No runs found — adjust PROJECT above.")

Run: 1gpu4cpu64G_bs2048_wu500_5_3_4_0_beta0.001_lr0.005_grad_clip1 (0q45cmry)
State: finished
Tags: ['batch_size', 'cpu', 'disrnn', 'synthetic']
Created: 2026-02-06T22:20:05Z
URL: https://wandb.ai/AIND-disRNN/han_cpu_gpu_test/runs/0q45cmry


### Run config (training hyperparameters)

In [8]:
if runs:
    print("Config:")
    for key, value in run.config.items():
        print(f"  {key}: {value}")

Config:
  data: {'seed': 0, 'task': {'mean': 0, 'seed': 0, 'type': 'random_walk', 'p_max': 1, 'p_min': 0, 'sigma': 0.15, 'num_trials': 500, 'reward_baiting': True}, 'type': 'synthetic', 'agent': {'seed': 0, 'type': 'ForagerQLearning', 'agent_class': 'ForagerQLearning', 'agent_kwargs': {'choice_kernel': 'none', 'action_selection': 'softmax', 'number_of_forget_rate': 0, 'number_of_learning_rate': 1}, 'agent_params': {'learn_rate': 0.5, 'softmax_inverse_temperature': 10}, 'loader_target': 'data_loaders.synthetic.SyntheticCognitiveAgents', 'agent_params_session_var': {'biasL': {'max': 3, 'min': -3, 'type': 'uniform'}}}, '_target_': 'data_loaders.synthetic.SyntheticCognitiveAgents', 'batch_mode': 'random', 'batch_size': 2048, 'num_trials': 500, 'eval_every_n': 2, 'num_sessions': 1000, 'run_name_component': 'synthetic_ForagerQLearning_random_walk'}
  model: {'seed': 0, 'type': 'disrnn', '_target_': 'model_trainers.disrnn_trainer.DisrnnTrainer', 'training': {'lr': 0.005, 'loss': 'penalized_ca

### Run summary (final metrics)

In [9]:
if runs:
    print("Summary metrics:")
    for key, value in run.summary.items():
        print(f"  {key}: {value}")

Summary metrics:
  _runtime: 4008
  _step: 5504
  _timestamp: 1770420397.23481
  _wandb: {'runtime': 4008}
  elapsed_seconds: 3581.5833921432495
  fig/bottlenecks: {'_type': 'image-file', 'format': 'png', 'height': 500, 'path': 'media/images/fig/bottlenecks_5500_c0f37d17c21160d7cd00.png', 'sha256': 'c0f37d17c21160d7cd002c18c7e20e0731515207728393dfcabb0d13f0b6ddf7', 'size': 60570, 'width': 1500}
  fig/choice_rule: {'_type': 'image-file', 'format': 'png', 'height': 480, 'path': 'media/images/fig/choice_rule_5501_207bb5ad61bd2fec6d51.png', 'sha256': '207bb5ad61bd2fec6d51f34a354c4984492f6d79de1c6936e536f3c8e6aaa350', 'size': 25705, 'width': 640}
  fig/update_rule_0: {'_type': 'image-file', 'format': 'png', 'height': 550, 'path': 'media/images/fig/update_rule_0_5502_f8aee37072a6f5891b2a.png', 'sha256': 'f8aee37072a6f5891b2a12fab0d017bb7e1c91657bd6ab8583de429bac8673d7', 'size': 127483, 'width': 1800}
  fig/update_rule_1: {'_type': 'image-file', 'format': 'png', 'height': 550, 'path': 'media/

## 5. Download Artifacts

Download training output artifacts (model parameters, plots, CSVs) for a run.

**Default behavior:** Downloads to `/root/capsule/results/artifacts/<run_id>/`

In [18]:
if runs:
    artifacts = client.download_artifact(
        runs[0].id,
        project=PROJECT,
    )
    for art in artifacts:
        print(f"Artifact: {art.name} (type={art.type}, version={art.version})")
        print(f"  Downloaded to: {art.download_path}")
        print(f"  Files: {art.files}")

[34m[1mwandb[0m: Downloading large artifact 'disrnn-output-0q45cmry:v0', 83.80MB. 10 files...
[34m[1mwandb[0m:   10 of 10 files downloaded.  
Done. 00:00:00.2 (510.1MB/s)


Artifact: disrnn-output-0q45cmry:v0 (type=training-output, version=v0)
  Downloaded to: /root/capsule/results/artifacts/0q45cmry/disrnn-output-0q45cmry:v0
  Files: ['bottlenecks.png', 'choice_rule.png', 'output_df.csv', 'output_summary.json', 'params.json', 'update_rule_0.png', 'update_rule_1.png', 'update_rule_2.png', 'validation.png', 'warmup_validation.png']


### Download specific files only

You can download only specific files instead of the entire artifact:

In [None]:
if runs:
    # Download only params.json
    artifacts = client.download_artifact(
        runs[0].id,
        project=PROJECT,
        files=["params.json"]
    )
    print(f"Downloaded {len(artifacts[0].files)} file(s): {artifacts[0].files}")

### Custom output directory

You can specify a custom base directory. The run_id is automatically appended:

In [None]:
if runs:
    artifacts = client.download_artifact(
        runs[0].id,
        project=PROJECT,
        output_dir="/root/capsule/data",  # Will download to /root/capsule/data/<run_id>/
    )
    print(f"Downloaded to: {artifacts[0].download_path}")

In [None]:
import json
from pathlib import Path

if runs:
    # Download all files for demonstration
    artifacts = client.download_artifact(runs[0].id, project=PROJECT)
    artifact_dir = artifacts[0].download_path
    
    print(f"Contents of {artifact_dir}:")
    for f in sorted(Path(artifact_dir).rglob("*")):
        if f.is_file():
            print(f"  {f.name}")

    # Example: load params.json if it exists
    params_file = artifact_dir / "params.json"
    if params_file.exists():
        with open(params_file) as fh:
            params = json.load(fh)
        print(f"\nLoaded params.json with {len(params)} keys")
        for k, v in list(params.items())[:5]:
            print(f"  {k}: {v}")

## 6. Explore Downloaded Files

After downloading, artifacts are available as local files in `/root/capsule/results/artifacts/`.