# APRA Orchestrator

This notebook orchestrates resilient APRA experiment runs, supports memory-efficient sketching, detached background execution, checkpointing and resume logic, shadow evaluation, analysis and plotting. Run cells sequentially. Use a conda environment with TensorFlow installed.


## Quick Start

- Ensure your conda env with TensorFlow is active: `conda activate base` (or your chosen env).
- Execute cells top-to-bottom.
- The notebook will monkey-patch `fl_helpers.random_projection_sketch` with a chunked implementation to avoid large dense matrices in memory.


In [1]:
# Cell-1 Section 1 Install Dependencies

# Install commonly used packages for orchestration and monitoring. Run this cell once.
import sys
import subprocess
import pkg_resources

packages = [
    "psutil",
    "nbclient",
    "nbformat",
    "tqdm",
    "dask",
    "pytest",
    "pandas",
    "matplotlib"
]

def ensure_packages(pkgs):
    to_install = []
    for p in pkgs:
        try:
            pkg_resources.get_distribution(p)
        except Exception:
            to_install.append(p)
    if to_install:
        print('Installing:', to_install)
        subprocess.check_call([sys.executable, "-m", "pip", "install"] + to_install)
    else:
        print('All packages present')

ensure_packages(packages)

# Verify imports
import platform, os
print('Python', platform.python_version())
print('Platform', platform.platform())
print('CWD', os.getcwd())
import psutil
print('psutil OK, CPU count:', psutil.cpu_count(logical=False))


All packages present
Python 3.13.9
Platform Windows-11-10.0.26100-SP0
CWD c:\Users\rravi\FL_Improvements_Research\submission_package\notebooks
psutil OK, CPU count: 10


  import pkg_resources


In [2]:
# Cell-2 Section 2 Environment Diagnostics

import json, platform, shutil, os
from datetime import datetime
import psutil

diag = {
    'python_version': platform.python_version(),
    'platform': platform.platform(),
    'cwd': os.getcwd(),
    'cpu_count_logical': psutil.cpu_count(logical=True),
    'cpu_count_physical': psutil.cpu_count(logical=False),
    'memory_total_GB': round(psutil.virtual_memory().total/1024**3,2),
    'disk_free_GB': round(shutil.disk_usage(os.getcwd()).free/1024**3,2),
    'time': datetime.utcnow().isoformat()+'Z'
}

# GPU info (if available)
try:
    import tensorflow as tf
    gpus = tf.config.list_physical_devices('GPU')
    diag['tensorflow_version'] = tf.__version__
    diag['gpus'] = [str(g) for g in gpus]
except Exception as e:
    diag['tensorflow_error'] = str(e)

with open('diagnostics.json','w') as f:
    json.dump(diag,f,indent=2)

print(json.dumps(diag, indent=2))


  'time': datetime.utcnow().isoformat()+'Z'


{
  "python_version": "3.13.9",
  "platform": "Windows-11-10.0.26100-SP0",
  "cwd": "c:\\Users\\rravi\\FL_Improvements_Research\\submission_package\\notebooks",
  "cpu_count_logical": 12,
  "cpu_count_physical": 10,
  "memory_total_GB": 7.7,
  "disk_free_GB": 242.85,
  "time": "2025-12-12T02:01:14.227091Z",
  "tensorflow_version": "2.20.0",
  "gpus": []
}


In [3]:
# Add parent directory to path to import fl_helpers and scripts
import sys
import os
notebook_dir = os.path.dirname(os.path.abspath('apra_orchestrator.ipynb'))
project_root = os.path.dirname(os.path.dirname(notebook_dir))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# Now import fl_helpers
try:
    from fl_helpers import aggregate_dispatcher
    print('✓ fl_helpers imported successfully')
except ImportError as e:
    print(f'Warning: Could not import fl_helpers: {e}')
    print('This is optional for orchestrator notebook.')
    aggregate_dispatcher = None

# Notebook aggregator wrapper for orchestrator
ORCHESTRATOR_AGGREGATOR = globals().get('ORCHESTRATOR_AGGREGATOR', 'farpa')

if aggregate_dispatcher:
    def federated_averaging(local_weights_list):
        agg, meta = aggregate_dispatcher(ORCHESTRATOR_AGGREGATOR, local_weights_list, sketch_dim_per_layer=64, n_sketches=2, eps_sketch=1.0, z_thresh=3.0, seed=0)
        globals()['ORCH_LAST_AGG_META'] = meta
        return agg
    
    print(f'Orchestrator federated_averaging dispatches to: {ORCHESTRATOR_AGGREGATOR}')
else:
    print('Orchestrator aggregator wrapper skipped (fl_helpers not available)')

✓ fl_helpers imported successfully
Orchestrator federated_averaging dispatches to: farpa


In [4]:
# Cell-3 Section 3 Memory-efficient sketch

# This cell defines a chunked random projection sketch and monkey-patches fl_helpers.random_projection_sketch.
# It avoids building the full (D x k) matrix at once by processing the flattened weight vector in chunks.

import numpy as np
import importlib

chunk_default = 10000

def random_projection_sketch_sparse(weights, sketch_dim, seed=None, chunk_size=chunk_default):
    # Accept either a list of ndarrays (weights per layer) or a single ndarray
    if isinstance(weights, (list, tuple)):
        vec = np.concatenate([w.flatten() for w in weights]).astype(np.float32)
    else:
        vec = np.asarray(weights).astype(np.float32).flatten()
    n = vec.size
    rng = np.random.RandomState(seed)
    sk = np.zeros(sketch_dim, dtype=np.float32)
    # scale factor computed from full vector length for normalization stability
    scale = 1.0 / np.sqrt(float(max(1, n)))
    for i in range(0, n, chunk_size):
        c = vec[i:i+chunk_size]
        # create projection for this chunk only
        proj = rng.normal(loc=0.0, scale=scale, size=(c.size, sketch_dim)).astype(np.float32)
        # accumulate c^T * proj -> shape (sketch_dim,)
        # using dot product c (shape (m,)) dot proj (m,k) -> (k,)
        sk += c.dot(proj)
    # final normalization
    norm = np.linalg.norm(sk)
    if norm > 0:
        sk /= norm
    return sk

# Try to patch fl_helpers at runtime
try:
    import fl_helpers
    importlib.reload(fl_helpers)
    fl_helpers.random_projection_sketch = random_projection_sketch_sparse
    print('Monkey-patched fl_helpers.random_projection_sketch with chunked implementation')
except Exception as e:
    print('Could not import fl_helpers to patch:', e)

# Quick sanity test on a synthetic vector
try:
    vec = np.random.randn(87050).astype(np.float32)
    sk = random_projection_sketch_sparse(vec, sketch_dim=64, seed=42, chunk_size=5000)
    print('Synthetic sketch shape:', sk.shape, 'norm:', np.linalg.norm(sk))
except Exception as e:
    print('Synthetic test failed:', e)


Monkey-patched fl_helpers.random_projection_sketch with chunked implementation
Synthetic sketch shape: (64,) norm: 1.0


In [5]:
# Cell-5 Section 4 Utilities

import subprocess, shlex, os, glob
import pandas as pd
from datetime import datetime

def run_subprocess(cmd, log_path=None, detached=False):
    """Run a shell command. If detached on Windows, use creationflags to detach.
    Returns subprocess.Popen object."""
    if isinstance(cmd, (list,tuple)):
        shell_cmd = cmd
    else:
        shell_cmd = cmd
    if log_path:
        logf = open(log_path, 'ab')
    else:
        logf = subprocess.PIPE
    if os.name == 'nt' and detached:
        # DETACHED_PROCESS flag
        creationflags = 0x00000008
        p = subprocess.Popen(shell_cmd, stdout=logf, stderr=subprocess.STDOUT, shell=True, creationflags=creationflags)
    else:
        p = subprocess.Popen(shell_cmd, stdout=logf, stderr=subprocess.STDOUT, shell=True)
    return p


def list_runs(outdir='apra_mnist_runs_full'):
    if not os.path.isdir(outdir):
        return []
    subs = [d for d in os.listdir(outdir) if os.path.isdir(os.path.join(outdir,d))]
    return subs


def read_results_csv(path):
    if not os.path.exists(path):
        return None
    df = pd.read_csv(path)
    return df

print('Utilities loaded run_subprocess, list_runs, read_results_csv')


Utilities loaded run_subprocess, list_runs, read_results_csv


In [6]:
# Cell-6 Section 5 Detached / External Runner & Retry Executor

import os, subprocess, time, threading, sys
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeout
import logging

logging.basicConfig(level=logging.INFO)

def run_detached(cmd, log_path=None):
    """Run a command detached from the notebook/VS Code process.
    On Windows, uses DETACHED_PROCESS; on POSIX attempts setsid.
    Returns subprocess.Popen object. Child processes will have the project
    root added to `PYTHONPATH` so they can import sibling modules like `fl_helpers`.
    """
    if isinstance(cmd, (list,tuple)):
        cmd_str = ' '.join([str(c) for c in cmd])
    else:
        cmd_str = str(cmd)
    if log_path:
        logf = open(log_path, 'ab')
    else:
        logf = subprocess.PIPE

    # Ensure child processes can import project-level modules by setting PYTHONPATH
    env = os.environ.copy()
    cwd = os.getcwd()
    prev = env.get('PYTHONPATH', '')
    if prev:
        env['PYTHONPATH'] = cwd + os.pathsep + prev
    else:
        env['PYTHONPATH'] = cwd

    if os.name == 'nt':
        # DETACHED_PROCESS
        creationflags = 0x00000008
        p = subprocess.Popen(cmd_str, stdout=logf, stderr=subprocess.STDOUT, shell=True, creationflags=creationflags, env=env)
    else:
        # POSIX: setsid to detach
        p = subprocess.Popen(cmd_str, stdout=logf, stderr=subprocess.STDOUT, shell=True, preexec_fn=os.setsid, env=env)
    logging.info('Launched detached: pid=%s cmd=%s', getattr(p,'pid',None), cmd_str)
    return p


def run_with_timeout_and_retries(func, args=(), timeout=600, max_retries=3, backoff=5):
    """Run callable with timeout and retry on exception. Returns func result or raises.
    Persists success after each try by returning result; caller should persist outputs as needed.
    """
    attempt = 0
    while attempt < max_retries:
        attempt += 1
        try:
            with ThreadPoolExecutor(max_workers=1) as ex:
                fut = ex.submit(func, *args)
                res = fut.result(timeout=timeout)
                logging.info('Attempt %d succeeded', attempt)
                return res
        except FutureTimeout:
            logging.warning('Attempt %d timed out after %ds', attempt, timeout)
        except Exception as e:
            logging.exception('Attempt %d failed: %s', attempt, e)
        if attempt < max_retries:
            sleep = backoff * attempt
            logging.info('Sleeping %ds before retry %d', sleep, attempt+1)
            time.sleep(sleep)
    raise RuntimeError('All attempts failed')

print('Detached runner and retry executor ready')


Detached runner and retry executor ready


In [7]:
# Cell-7 Section 6 Resource Monitor and Health Checks

import threading, time, json
import psutil

class ResourceMonitor(threading.Thread):
    def __init__(self, interval=10, out_path='health.jsonl', cpu_thresh=90, mem_thresh=90):
        super().__init__(daemon=True)
        self.interval = interval
        self.out_path = out_path
        self.cpu_thresh = cpu_thresh
        self.mem_thresh = mem_thresh
        self._stop = threading.Event()

    def run(self):
        with open(self.out_path, 'a') as f:
            while not self._stop.is_set():
                try:
                    cpu = psutil.cpu_percent(interval=1)
                    mem = psutil.virtual_memory().percent
                    record = {
                        'time': time.time(),
                        'cpu_percent': cpu,
                        'mem_percent': mem,
                        'cpu_warn': cpu > self.cpu_thresh,
                        'mem_warn': mem > self.mem_thresh
                    }
                    f.write(json.dumps(record) + '\n')
                    f.flush()
                    time.sleep(self.interval)
                except Exception as e:
                    print(f'ResourceMonitor error: {e}')
                    break

    def stop(self):
        self._stop.set()

print('ResourceMonitor class defined')


ResourceMonitor class defined


In [8]:

# Cell-8 Section 7 Sweep Orchestrator & Resume Capability

import itertools

# Parameter builder and command generator

def build_commands(outdir='apra_mnist_runs_full', sketch_dims=[64,128], n_sketches=[1,2], z_thresh=[2.0,3.0], rounds=25, local_epochs=3, batch_size=32, clients=100, attack='none'):
    """Build commands for full parallel grid sweep.
    
    OPTIMIZATION: Generates 4 commands per grid (one per aggregator method),
    all launched in parallel instead of sequentially within the training script.
    This reduces total wall time from ~4x to ~1x (plus minimal overhead).
    
    Total commands: 8 grids Ã— 4 aggregators = 32 processes running in parallel.
    """
    cmds = []
    agg_methods = ['apra_weighted', 'apra_basic', 'trimmed', 'median']
    for sd, ns, zt in itertools.product(sketch_dims, n_sketches, z_thresh):
        run_dir = os.path.join(outdir, f'sd{sd}_ns{ns}_zt{zt}')
        # Generate 4 commands per grid: one for each aggregator method
        # These will be launched in parallel (not sequentially within run_apra_mnist_full.py)
        for agg_method in agg_methods:
            # use unbuffered Python (-u) so stdout/stderr are flushed to run.log promptly
            # Use run_apra_mnist_full.py which integrates APRA, APS, privacy, and robust aggregation
            cmd = f"python -u scripts/run_apra_mnist_full.py --sketch_dim {sd} --n_sketches {ns} --z_thresh {zt} --rounds {rounds} --local_epochs {local_epochs} --batch_size {batch_size} --clients {clients} --attack {attack} --output_dir {run_dir} --agg_method {agg_method}"
            cmds.append({'run_dir': run_dir, 'cmd': cmd})
    return cmds


def is_trial_complete(run_dir, final_round=25):
    """Check if a trial has completed by looking for the final-round checkpoint in all aggregators."""
    aggs = ['apra_weighted', 'apra_basic', 'trimmed', 'median']
    for agg in aggs:
        final_ckpt = os.path.join(run_dir, agg, f'round_{final_round:03d}.npz')
        if not os.path.exists(final_ckpt):
            return False
    return True


def run_sweep(outdir='apra_mnist_runs_full', sketch_dims=[64,128], n_sketches=[1,2], z_thresh=[2.0,3.0], rounds=25, local_epochs=3, batch_size=32, clients=100, attack='none', detached=True):
    cmds = build_commands(outdir, sketch_dims, n_sketches, z_thresh, rounds, local_epochs, batch_size, clients, attack)
    pids = []
    for c in cmds:
        if is_trial_complete(c['run_dir'], final_round=rounds):
            print('Skipping completed trial', c['run_dir'])
            continue
        os.makedirs(c['run_dir'], exist_ok=True)
        log_path = os.path.join(c['run_dir'], 'run.log')
        p = run_detached(c['cmd'], log_path=log_path)
        pids.append({'pid': getattr(p,'pid',None), 'run_dir': c['run_dir'], 'cmd': c['cmd']})
    return pids

# Example usage (dry-run):
cmds = build_commands()
print('Dry-run: first 4 commands (one full grid):')
for c in cmds[:4]:
    print(c['cmd'][-50:])  # print last 50 chars (agg_method part)

print(f'\nTotal commands to launch: {len(cmds)} (8 grids Ã— 4 aggregators)')
print('\nTo launch background sweep call:')
print("run_sweep(outdir='apra_mnist_runs_full', sketch_dims=[64,128], n_sketches=[1,2], z_thresh=[2.0,3.0], rounds=25)")


Dry-run: first 4 commands (one full grid):
uns_full\sd64_ns1_zt2.0 --agg_method apra_weighted
t_runs_full\sd64_ns1_zt2.0 --agg_method apra_basic
nist_runs_full\sd64_ns1_zt2.0 --agg_method trimmed
mnist_runs_full\sd64_ns1_zt2.0 --agg_method median

Total commands to launch: 32 (8 grids Ã— 4 aggregators)

To launch background sweep call:
run_sweep(outdir='apra_mnist_runs_full', sketch_dims=[64,128], n_sketches=[1,2], z_thresh=[2.0,3.0], rounds=25)


In [9]:
# Cell-10 Section 8 Parallel Execution (multiprocessing + optional Dask)

import multiprocessing as mp


def _run_trial_serial(cmd):
    # run in-process (blocking) and return exit code
    p = subprocess.Popen(cmd, shell=True)
    rc = p.wait()
    return rc


def run_sweep_parallel(cmds, max_workers=2):
    # cmds: list of dicts with 'cmd' and 'run_dir'
    with mp.Pool(processes=max_workers) as pool:
        results = []
        for c in cmds:
            pool.apply_async(_run_trial_serial, (c['cmd'],), callback=lambda rc, c=c: print('Done', c['run_dir'], rc))
        pool.close()
        pool.join()


# Dask example (optional):
try:
    from dask.distributed import Client, LocalCluster
    def run_with_dask_example(cmds):
        cluster = LocalCluster(n_workers=2, threads_per_worker=1)
        client = Client(cluster)
        futures = [client.submit(_run_trial_serial, c['cmd']) for c in cmds]
        client.gather(futures)
        client.close()
except Exception:
    pass

print('Parallel execution utilities added (multiprocessing + optional dask)')


Parallel execution utilities added (multiprocessing + optional dask)


In [10]:
# Cell-11 Section 9 Logging, stdout/stderr capture, and artifact atomic writes

import logging
from logging.handlers import RotatingFileHandler

log_path = 'orchestrator.log'
logger = logging.getLogger('orchestrator')
logger.setLevel(logging.INFO)
handler = RotatingFileHandler(log_path, maxBytes=5*1024*1024, backupCount=3)
formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)


def safe_write_json(path, obj):
    tmp = path + '.tmp'
    with open(tmp, 'w', encoding='utf-8') as f:
        json.dump(obj, f, indent=2)
    os.replace(tmp, path)


def load_checkpoint(path):
    if not os.path.exists(path):
        return None
    with open(path, 'r', encoding='utf-8') as f:
        return json.load(f)

print('Logging and atomic write utilities ready')


Logging and atomic write utilities ready


In [11]:
# Cell-12 Section 10 Unit tests for primitives (pytest)

# Simple tests can be executed via `pytest -q` from a terminal, but we provide inline checks here.

def _test_safe_write_and_load():
    p = 'test_ckpt.json'
    obj = {'x':1,'time':time.time()}
    safe_write_json(p,obj)
    loaded = load_checkpoint(p)
    assert loaded['x']==1
    os.remove(p)
    print('checkpoint test passed')


def _test_retry_success():
    def work(x):
        return x+1
    res = run_with_timeout_and_retries(lambda : work(1), timeout=5, max_retries=1)
    assert res==2
    print('retry test passed')

# Run quick inline tests
_test_safe_write_and_load()
_test_retry_success()

print('Inline tests executed run full pytest if desired')


INFO:root:Attempt 1 succeeded


checkpoint test passed
retry test passed
Inline tests executed run full pytest if desired


In [12]:
# Cell-19 Fast-track parallel launcher
# This helper cell launches the `fast_track_resume.py` helper which starts
# per-aggregator runs in parallel using `--run_tag` to avoid output collisions.
import subprocess
import shlex

cmd = [
    'python', 'fast_track_resume.py',
    '--output_dir', 'apra_mnist_results',
    '--rounds', '25',
    '--max_procs', '6'
]

print('Launching fast-track resume (non-blocking).')
print('Command:', ' '.join(shlex.quote(a) for a in cmd))

# Start the launcher in the background so the notebook cell doesn't block.
proc = subprocess.Popen(cmd)
print('Started fast-track resume, PID:', proc.pid)
print('Use `python monitor_sweep.py --output_dir apra_mnist_results --rounds 25` to monitor progress.')

Launching fast-track resume (non-blocking).
Command: python fast_track_resume.py --output_dir apra_mnist_results --rounds 25 --max_procs 6
Started fast-track resume, PID: 18876
Use `python monitor_sweep.py --output_dir apra_mnist_results --rounds 25` to monitor progress.


In [13]:

# Cell-22 ============================================================================
# SECTION 15 Smoke Test: Quick APRA Validation
# ============================================================================
# Run a tiny APRA experiment (2 rounds, 10 clients) to validate the pipeline.
# Use this before launching the full grid sweep.

import sys
import os

def run_smoke_test():
    """Run a quick smoke test on APRA pipeline."""
    
    print("="*70)
    print("APRA SMOKE TEST")
    print("="*70)
    print("\nTesting APRA with:")
    print("  - 2 FL rounds")
    print("  - 10 clients")
    print("  - sketch_dim=32")
    print("  - agg_method=apra_weighted")
    print()

    # Determine project root (one level up from notebooks directory)
    notebooks_dir = os.getcwd()
    project_root = os.path.abspath(os.path.join(notebooks_dir, '..'))
    workspace_root = os.path.abspath(os.path.join(project_root, '..'))

    # Build the command using the absolute script path so cwd doesn't matter
    script_path = os.path.join(project_root, 'scripts', 'run_apra_mnist_full.py')
    if not os.path.exists(script_path):
        print(f"Error: expected script not found: {script_path}")
        return False

    cmd = [
        sys.executable, '-u', script_path,
        '--sketch_dim', '32',
        '--n_sketches', '1',
        '--z_thresh', '2.0',
        '--rounds', '2',
        '--clients', '10',
        '--local_epochs', '1',
        '--batch_size', '64',
        '--attack', 'none',
        '--byzantine_fraction', '0.0',
        '--output_dir', os.path.join(project_root, 'apra_smoke_test'),
        '--agg_method', 'apra_weighted'
    ]

    print(f"Running: {' '.join(map(str, cmd))}\n")

    import subprocess
    # Ensure the script runs with project_root as cwd so relative imports/files resolve
    env = os.environ.copy()
    # Add workspace root to PYTHONPATH so fl_helpers (in workspace root) is importable
    prev = env.get('PYTHONPATH','')
    if prev:
        env['PYTHONPATH'] = workspace_root + os.pathsep + prev
    else:
        env['PYTHONPATH'] = workspace_root

    result = subprocess.run(cmd, capture_output=True, text=True, cwd=project_root, env=env)

    print("STDOUT:")
    print(result.stdout)

    if result.stderr:
        print("\nSTDERR:")
        print(result.stderr)

    print("\n" + "="*70)
    if result.returncode == 0:
        print("SMOKE TEST PASSED")
    else:
        print("SMOKE TEST FAILED")
    print("="*70)

    return result.returncode == 0

# Uncomment to run smoke test:
# smoke_test_passed = run_smoke_test()

print('Smoke test harness ready.')
print('Usage: run_smoke_test()')
print('')
print('After smoke test passes, launch full grid with:')
print('  pids = run_sweep(...)')


Smoke test harness ready.
Usage: run_smoke_test()

After smoke test passes, launch full grid with:
  pids = run_sweep(...)


In [14]:
# Cell-23 Execute smoke test
smoke_test_passed = run_smoke_test()
print(f'\n\nSmoke test result: {"PASSED" if smoke_test_passed else "FAILED"}')


APRA SMOKE TEST

Testing APRA with:
  - 2 FL rounds
  - 10 clients
  - sketch_dim=32
  - agg_method=apra_weighted

Running: c:\Users\rravi\miniconda3\python.exe -u c:\Users\rravi\FL_Improvements_Research\submission_package\scripts\run_apra_mnist_full.py --sketch_dim 32 --n_sketches 1 --z_thresh 2.0 --rounds 2 --clients 10 --local_epochs 1 --batch_size 64 --attack none --byzantine_fraction 0.0 --output_dir c:\Users\rravi\FL_Improvements_Research\submission_package\apra_smoke_test --agg_method apra_weighted

STDOUT:


STDERR:
INFO:__main__:Starting APRA-MNIST experiment: sketch_dim=32, agg=apra_weighted
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
INFO:__main__:Round 1/2
INFO:__main__:  Accuracy: 0.0656
INFO:__main__:Round 2/2
INFO:__main__:  Accuracy: 0.0656
INFO:__main__:Results saved to c:\Users\rravi\FL_Improvements_Research\submission_package\apra_smoke_test\sd32_ns1_zt2.0\results.csv
INFO:__main__:Experiment complete!


SMOKE TEST PASSED


Smoke test resu

In [15]:
# Cell-24 ============================================================================
# CREATE resume_now.py SCRIPT (for parallel grid resumption)
# ============================================================================
# This script launches all incomplete grid runs in parallel using run_tag to avoid collisions.
# Create resume_now.py in the current working directory so cell 17 can find and execute it.

resume_now_script = '''#!/usr/bin/env python
"""
resume_now.py: Resume incomplete APRA-MNIST grid sweeps in parallel.
Launches all incomplete trials (across all aggregators) with proper subprocess isolation.
"""

import os
import sys
import argparse
import subprocess
import time
from pathlib import Path

def get_incomplete_trials(outdir='apra_mnist_runs_full', target_rounds=25):
    """Find all incomplete trials (missing final checkpoint)."""
    incomplete = []
    grids = ['sd64_ns1_zt2.0', 'sd64_ns1_zt3.0', 'sd64_ns2_zt2.0', 'sd64_ns2_zt3.0',
             'sd128_ns1_zt2.0', 'sd128_ns1_zt3.0', 'sd128_ns2_zt2.0', 'sd128_ns2_zt3.0']
    aggs = ['apra_weighted', 'apra_basic', 'trimmed', 'median']
    
    for grid in grids:
        grid_path = os.path.join(outdir, grid)
        if not os.path.isdir(grid_path):
            continue
        for agg in aggs:
            agg_dir = os.path.join(grid_path, agg)
            final_ckpt = os.path.join(agg_dir, f'round_{target_rounds:03d}.npz')
            if not os.path.exists(final_ckpt):
                incomplete.append({'grid': grid, 'agg': agg, 'run_dir': grid_path})
    
    return incomplete

def resume_trial(trial, outdir='apra_mnist_runs_full', dry_run=False):
    """Resume a single trial (grid + aggregator combination)."""
    grid = trial['grid']
    agg = trial['agg']
    run_dir = trial['run_dir']
    
    # Parse grid name to extract parameters
    parts = grid.split('_')
    sketch_dim = int(parts[0][2:])
    n_sketches = int(parts[1][2:])
    z_thresh = float(parts[2][2:])
    
    # Build command
    cmd = [
        sys.executable, '-u', 'scripts/run_apra_mnist_full.py',
        '--sketch_dim', str(sketch_dim),
        '--n_sketches', str(n_sketches),
        '--z_thresh', str(z_thresh),
        '--rounds', '25',
        '--clients', '100',
        '--local_epochs', '3',
        '--batch_size', '32',
        '--attack', 'none',
        '--byzantine_fraction', '0.0',
        '--output_dir', run_dir,
        '--agg_method', agg,
        '--run_tag', agg  # Avoid output file collisions
    ]
    
    if dry_run:
        print(f"[DRY-RUN] {grid} + {agg}:")
        print(f"  {' '.join(cmd)}")
        return True
    
    log_path = os.path.join(run_dir, agg, 'resume.log')
    os.makedirs(os.path.dirname(log_path), exist_ok=True)
    
    print(f"[RESUME] {grid:20s} + {agg:12s}  ->  {log_path}")
    
    try:
        with open(log_path, 'a') as logf:
            p = subprocess.Popen(cmd, stdout=logf, stderr=subprocess.STDOUT)
        return p
    except Exception as e:
        print(f"Error launching {grid} + {agg}: {e}")
        return None

def main():
    parser = argparse.ArgumentParser(description='Resume incomplete APRA-MNIST grid sweeps')
    parser.add_argument('--output_dir', default='apra_mnist_runs_full', help='Output directory')
    parser.add_argument('--dry_run', action='store_true', help='Print commands without executing')
    parser.add_argument('--max_procs', type=int, default=4, help='Max parallel processes')
    args = parser.parse_args()
    
    print("="*70)
    print("RESUME INCOMPLETE GRIDS")
    print("="*70)
    
    incomplete = get_incomplete_trials(args.output_dir)
    print(f"\\nFound {len(incomplete)} incomplete trials\\n")
    
    if args.dry_run:
        for trial in incomplete:
            resume_trial(trial, args.output_dir, dry_run=True)
        return 0
    
    # Launch trials in batches
    procs = []
    for trial in incomplete:
        p = resume_trial(trial, args.output_dir, dry_run=False)
        if p:
            procs.append({'proc': p, 'trial': trial})
        
        # Limit concurrent processes
        while len([pr for pr in procs if pr['proc'].poll() is None]) >= args.max_procs:
            time.sleep(1)
    
    print(f"\\nLaunched {len(procs)} processes")
    print("Monitoring completion...")
    
    # Wait for all to complete
    for pr in procs:
        try:
            rc = pr['proc'].wait(timeout=3600)  # 1 hour timeout per trial
            status = "✓" if rc == 0 else "✗"
            print(f"{status} {pr['trial']['grid']} + {pr['trial']['agg']} completed (rc={rc})")
        except subprocess.TimeoutExpired:
            print(f"⏱ {pr['trial']['grid']} + {pr['trial']['agg']} timed out")
    
    print("\\n" + "="*70)
    print("Resume complete")
    print("="*70)
    return 0

if __name__ == '__main__':
    sys.exit(main())
'''

# Write the script to disk in the notebook's working directory
script_path = os.path.join(os.getcwd(), 'resume_now.py')
with open(script_path, 'w') as f:
    f.write(resume_now_script)

print(f"✓ Created {script_path}")
print(f"Script size: {len(resume_now_script)} bytes")
print(f"Location: {os.path.abspath(script_path)}")


✓ Created c:\Users\rravi\FL_Improvements_Research\submission_package\notebooks\resume_now.py
Script size: 4320 bytes
Location: c:\Users\rravi\FL_Improvements_Research\submission_package\notebooks\resume_now.py


In [16]:
# Cell-25 ============================================================================
# LAUNCH FULL GRID SWEEP (All 32 trials: 8 grids × 4 aggregators)
# ============================================================================
# This cell launches the complete APRA-MNIST parameter grid in parallel.
# Each trial runs independently with its own checkpoint logging and results.

import os
import subprocess
import sys
import time

print("="*70)
print("LAUNCHING FULL GRID SWEEP")
print("="*70)

# Determine the project root (two levels up from notebooks)
notebooks_dir = os.getcwd()
project_root = os.path.abspath(os.path.join(notebooks_dir, '..'))
workspace_root = os.path.abspath(os.path.join(project_root, '..'))

print(f"\nProject root: {project_root}")
print(f"Workspace root: {workspace_root}")

# Parameter grid
sketch_dims = [64, 128]
n_sketches = [1, 2]
z_thresholds = [2.0, 3.0]
output_base = os.path.join(project_root, 'apra_mnist_runs_full')
agg_methods = ['apra_weighted', 'apra_basic', 'trimmed', 'median']

os.makedirs(output_base, exist_ok=True)

# Build all commands
total_commands = 0
pids_launched = []

print(f"\nLaunching trials in parallel (max ~4 at a time):\n")

for sd in sketch_dims:
    for ns in n_sketches:
        for zt in z_thresholds:
            grid_name = f'sd{sd}_ns{ns}_zt{zt}'
            run_dir = os.path.join(output_base, grid_name)
            os.makedirs(run_dir, exist_ok=True)
            
            # Launch one command per aggregator method (4 parallel per grid)
            for agg_method in agg_methods:
                script_path = os.path.join(project_root, 'scripts', 'run_apra_mnist_full.py')
                
                cmd = [
                    sys.executable, '-u', script_path,
                    '--sketch_dim', str(sd),
                    '--n_sketches', str(ns),
                    '--z_thresh', str(zt),
                    '--rounds', '25',
                    '--local_epochs', '3',
                    '--batch_size', '32',
                    '--clients', '100',
                    '--attack', 'none',
                    '--output_dir', run_dir,
                    '--agg_method', agg_method
                ]
                
                log_path = os.path.join(run_dir, f'{agg_method}_run.log')
                
                # Prepare environment with workspace root on PYTHONPATH
                env = os.environ.copy()
                prev = env.get('PYTHONPATH', '')
                if prev:
                    env['PYTHONPATH'] = workspace_root + os.pathsep + prev
                else:
                    env['PYTHONPATH'] = workspace_root
                
                try:
                    with open(log_path, 'a') as logf:
                        p = subprocess.Popen(
                            cmd,
                            stdout=logf,
                            stderr=subprocess.STDOUT,
                            cwd=project_root,
                            env=env
                        )
                    pids_launched.append({
                        'pid': p.pid,
                        'grid': grid_name,
                        'agg': agg_method,
                        'log': log_path,
                        'proc': p
                    })
                    total_commands += 1
                    print(f"  ✓ {grid_name:20s} + {agg_method:12s}  (PID {p.pid})")
                except Exception as e:
                    print(f"  ✗ {grid_name:20s} + {agg_method:12s}  ERROR: {e}")

print(f"\n{'='*70}")
print(f"Total launched: {total_commands}")
print(f"{'='*70}")
print(f"\nAll {total_commands} trials are running in parallel.")
print(f"Results will be saved to: {output_base}")
print(f"\nTo monitor progress, use the next cells (environment verification, completion status).")
print(f"Each trial logs to: <run_dir>/<agg_method>_run.log")


LAUNCHING FULL GRID SWEEP

Project root: c:\Users\rravi\FL_Improvements_Research\submission_package
Workspace root: c:\Users\rravi\FL_Improvements_Research

Launching trials in parallel (max ~4 at a time):

  ✓ sd64_ns1_zt2.0       + apra_weighted  (PID 14116)
  ✓ sd64_ns1_zt2.0       + apra_basic    (PID 19260)
  ✓ sd64_ns1_zt2.0       + trimmed       (PID 10964)
  ✓ sd64_ns1_zt2.0       + median        (PID 13056)
  ✓ sd64_ns1_zt3.0       + apra_weighted  (PID 10088)
  ✓ sd64_ns1_zt3.0       + apra_basic    (PID 17260)
  ✓ sd64_ns1_zt3.0       + trimmed       (PID 3152)
  ✓ sd64_ns1_zt3.0       + median        (PID 9952)
  ✓ sd64_ns2_zt2.0       + apra_weighted  (PID 18908)
  ✓ sd64_ns2_zt2.0       + apra_basic    (PID 17596)
  ✓ sd64_ns2_zt2.0       + trimmed       (PID 15928)
  ✓ sd64_ns2_zt2.0       + median        (PID 19340)
  ✓ sd64_ns2_zt3.0       + apra_weighted  (PID 13912)
  ✓ sd64_ns2_zt3.0       + apra_basic    (PID 18840)
  ✓ sd64_ns2_zt3.0       + trimmed       (PID 166

In [17]:
# cell-25 ============================================================================
# SECTION 15.3 Environment Verification (Pre-Resume Check)
# ============================================================================
# Verify everything is ready before resuming incomplete grids

import sys
import os
import importlib

print("="*70)
print("ENVIRONMENT VERIFICATION")
print("="*70)

# Check 1: Python version
print(f"\n[1/5] Python: {sys.version.split()[0]}")

# Check 2: PYTHONPATH
pythonpath = os.environ.get('PYTHONPATH', 'Not set')
print(f"[2/5] PYTHONPATH: {pythonpath if pythonpath != 'Not set' else 'Using system default'}")

# Check 3: CWD
cwd = os.getcwd()
print(f"[3/5] Working directory: {cwd}")

# Check 4: Required modules
required_modules = ['tensorflow', 'numpy', 'pandas', 'psutil']
all_ok = True
for mod in required_modules:
    try:
        importlib.import_module(mod)
        print(f"[4/5] {mod} available")
    except ImportError:
        print(f"[4/5] {mod} NOT found")
        all_ok = False

# Check 5: Output directory
outdir = 'apra_mnist_runs_full'
if os.path.isdir(outdir):
    grid_count = len([d for d in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, d)) and d.startswith('sd')])
    print(f"[5/5] Output directory '{outdir}' exists with {grid_count} grids")
else:
    print(f"[5/5] Output directory '{outdir}' NOT found")
    all_ok = False

print("\n" + "="*70)
if all_ok:
    print("ALL CHECKS PASSED Environment ready for resume")
else:
    print("SOME CHECKS FAILED Review above and fix before resuming")
print("="*70)


ENVIRONMENT VERIFICATION

[1/5] Python: 3.13.9
[2/5] PYTHONPATH: Using system default
[3/5] Working directory: c:\Users\rravi\FL_Improvements_Research\submission_package\notebooks
[4/5] tensorflow available
[4/5] numpy available
[4/5] pandas available
[4/5] psutil available
[5/5] Output directory 'apra_mnist_runs_full' exists with 8 grids

ALL CHECKS PASSED Environment ready for resume


In [18]:
# Cell-26 ============================================================================
# SECTION 15.4 — Pre-Resume Cleanup: Flatten Nested Dirs & Kill Old Processes
# ============================================================================

import os
import shutil
import psutil
import time

print("="*70)
print("PRE-RESUME CLEANUP")
print("="*70)

def flatten_nested_grids(outdir='apra_mnist_runs_full'):
    """Flatten any nested grid directories (e.g., grid/grid/agg -> grid/agg)."""
    grids_to_check = ['sd64_ns1_zt2.0', 'sd64_ns1_zt3.0', 'sd64_ns2_zt2.0', 'sd64_ns2_zt3.0',
                      'sd128_ns1_zt2.0', 'sd128_ns1_zt3.0', 'sd128_ns2_zt2.0', 'sd128_ns2_zt3.0']
    
    flattened_count = 0
    for grid in grids_to_check:
        grid_path = os.path.join(outdir, grid)
        if not os.path.isdir(grid_path):
            continue
        
        nested_path = os.path.join(grid_path, grid)
        if os.path.isdir(nested_path):
            aggs = ['apra_weighted', 'apra_basic', 'trimmed', 'median']
            for agg in aggs:
                src = os.path.join(nested_path, agg)
                dst = os.path.join(grid_path, agg)
                if os.path.isdir(src):
                    if os.path.isdir(dst):
                        for f in os.listdir(src):
                            shutil.copy2(os.path.join(src, f), os.path.join(dst, f))
                    else:
                        shutil.move(src, dst)
                    print(f"  Flattened {grid}/{agg}")
                    flattened_count += 1
    
    if flattened_count == 0:
        print("  No nested directories found")
    return flattened_count


def kill_incomplete_python_processes(script_name='run_apra_mnist_full.py'):
    """Kill any lingering Python processes running incomplete trials."""
    killed_count = 0
    for p in psutil.process_iter(['pid', 'name', 'cmdline']):
        try:
            if 'python' in p.info['name'].lower():
                cmdline = p.info.get('cmdline')
                if cmdline and script_name in str(cmdline):
                    print(f"  Killing old process: PID {p.pid}")
                    p.terminate()
                    try:
                        p.wait(timeout=3)
                    except psutil.TimeoutExpired:
                        p.kill()
                    killed_count += 1
        except (psutil.NoSuchProcess, psutil.AccessDenied):
            pass
    
    if killed_count == 0:
        print("  No old processes found")
    return killed_count


print("\n[1/2] Flattening nested directories...")
flatten_nested_grids('apra_mnist_runs_full')

print("\n[2/2] Killing old Python processes...")
kill_incomplete_python_processes('run_apra_mnist_full.py')

print("\n" + "="*70)
print("✓ Cleanup complete")

PRE-RESUME CLEANUP

[1/2] Flattening nested directories...
  No nested directories found

[2/2] Killing old Python processes...
  Killing old process: PID 2844
  Killing old process: PID 2888
  Killing old process: PID 3152
  Killing old process: PID 8940
  Killing old process: PID 9496
  Killing old process: PID 9952
  Killing old process: PID 10088
  Killing old process: PID 10964
  Killing old process: PID 13056
  Killing old process: PID 13912
  Killing old process: PID 13940
  Killing old process: PID 14076
  Killing old process: PID 14116
  Killing old process: PID 14236
  Killing old process: PID 15712
  Killing old process: PID 15928
  Killing old process: PID 16696
  Killing old process: PID 16976
  Killing old process: PID 17260
  Killing old process: PID 17328
  Killing old process: PID 17360
  Killing old process: PID 17596
  Killing old process: PID 17668
  Killing old process: PID 17680
  Killing old process: PID 18596
  Killing old process: PID 18632
  Killing old proces

In [None]:
# LAUNCH GRID SWEEP + AUTOMATIC MONITORING (Run Simultaneously)
# This cell:
# 1. Launches all 32 APRA grid tasks in parallel
# 2. Immediately starts live monitoring that checks progress every N minutes
# 3. Both run concurrently until all tasks complete
#
# Total tasks: 8 grids × 4 aggregators = 32 parallel tasks
# Press STOP to interrupt monitoring anytime.

import os
import sys
import subprocess
import time
from datetime import datetime

print('='*80)
print('LAUNCHING APRA GRID SWEEP + AUTOMATIC MONITORING')
print('='*80)

# Determine the project root (two levels up from notebooks)
notebooks_dir = os.getcwd()
project_root = os.path.abspath(os.path.join(notebooks_dir, '..'))
workspace_root = os.path.abspath(os.path.join(project_root, '..'))

print(f'Project root: {project_root}')
print(f'Workspace root: {workspace_root}\n')

# ============================================================================
# PART 1: LAUNCH ALL GRID TASKS
# ============================================================================

sketch_dims = [64, 128]
n_sketches = [1, 2]
z_thresholds = [2.0, 3.0]
output_base = os.path.join(project_root, 'apra_mnist_runs_full')
agg_methods = ['apra_weighted', 'apra_basic', 'trimmed', 'median']

os.makedirs(output_base, exist_ok=True)

print('='*80)
print('STEP 1: LAUNCHING GRID TASKS')
print('='*80)
print(f'Output directory: {output_base}\n')

total_commands = 0
pids_launched = []

for sd in sketch_dims:
    for ns in n_sketches:
        for zt in z_thresholds:
            grid_name = f'sd{sd}_ns{ns}_zt{zt}'
            run_dir = os.path.join(output_base, grid_name)
            os.makedirs(run_dir, exist_ok=True)
            
            # Launch one command per aggregator method (4 parallel per grid)
            for agg_method in agg_methods:
                script_path = os.path.join(project_root, 'scripts', 'run_apra_mnist_full.py')
                
                cmd = [
                    sys.executable, '-u', script_path,
                    '--sketch_dim', str(sd),
                    '--n_sketches', str(ns),
                    '--z_thresh', str(zt),
                    '--rounds', '25',
                    '--local_epochs', '3',
                    '--batch_size', '32',
                    '--clients', '100',
                    '--attack', 'none',
                    '--output_dir', run_dir,
                    '--agg_method', agg_method
                ]
                
                log_path = os.path.join(run_dir, f'{agg_method}_run.log')
                
                # Prepare environment with workspace root on PYTHONPATH
                env = os.environ.copy()
                prev = env.get('PYTHONPATH', '')
                if prev:
                    env['PYTHONPATH'] = workspace_root + os.pathsep + prev
                else:
                    env['PYTHONPATH'] = workspace_root
                
                try:
                    with open(log_path, 'a') as logf:
                        p = subprocess.Popen(
                            cmd,
                            stdout=logf,
                            stderr=subprocess.STDOUT,
                            cwd=project_root,
                            env=env
                        )
                    pids_launched.append({
                        'pid': p.pid,
                        'grid': grid_name,
                        'agg': agg_method,
                        'log': log_path,
                        'proc': p
                    })
                    total_commands += 1
                    print(f"  ✓ {grid_name:20s} + {agg_method:12s}  (PID {p.pid})")
                except Exception as e:
                    print(f"  ✗ {grid_name:20s} + {agg_method:12s}  ERROR: {e}")

print(f"\n{'='*80}")
print(f"✓ Launched {total_commands}/32 tasks")
print(f"{'='*80}")

# ============================================================================
# PART 2: AUTOMATIC MONITORING
# ============================================================================

print('\n' + '='*80)
print('STEP 2: STARTING LIVE MONITORING')
print('='*80)

grids = ['sd64_ns1_zt2.0', 'sd64_ns1_zt3.0', 'sd64_ns2_zt2.0', 'sd64_ns2_zt3.0',
         'sd128_ns1_zt2.0', 'sd128_ns1_zt3.0', 'sd128_ns2_zt2.0', 'sd128_ns2_zt3.0']
aggs = ['apra_weighted', 'apra_basic', 'trimmed', 'median']

check_interval_minutes = 5
target_rounds = 25
check_num = 0
start_time = time.time()

print(f"\nMonitoring location: {output_base}")
print(f"Check interval: {check_interval_minutes} minutes")
print(f"Target rounds per task: {target_rounds}")
print(f"\nPress STOP button above to interrupt monitoring.\n")

try:
    while True:
        check_num += 1
        current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        elapsed_hours = (time.time() - start_time) / 3600
        
        print(f"\n{'='*80}")
        print(f"[{current_time}] Check #{check_num} (elapsed: {elapsed_hours:.1f}h)")
        print('='*80)
        
        # Count completion
        completed_tasks = 0
        
        for grid in grids:
            grid_path = os.path.join(output_base, grid)
            grid_rounds = []
            
            for agg in aggs:
                checkpoint_count = 0
                
                # Try flat structure first
                agg_dir = os.path.join(grid_path, agg)
                if os.path.isdir(agg_dir):
                    files = [f for f in os.listdir(agg_dir) if f.startswith('round_') and f.endswith('.npz')]
                    checkpoint_count = len(files)
                
                # If no checkpoints found, try nested
                if checkpoint_count == 0:
                    nested_agg_dir = os.path.join(grid_path, grid, agg)
                    if os.path.isdir(nested_agg_dir):
                        files = [f for f in os.listdir(nested_agg_dir) if f.startswith('round_') and f.endswith('.npz')]
                        checkpoint_count = len(files)
                
                grid_rounds.append(checkpoint_count)
                if checkpoint_count >= target_rounds:
                    completed_tasks += 1
            
            max_rounds = max(grid_rounds) if grid_rounds else 0
            status = '✓' if max_rounds >= target_rounds else f"{max_rounds:2d}"
            
            # Print grid status with per-agg breakdown
            agg_details = ' '.join([f"{r:2d}" for r in grid_rounds])
            print(f"  {status}  {grid:15s}: {agg_details}")
        
        # Summary line
        total_tasks = len(grids) * len(aggs)
        pct_complete = (100 * completed_tasks) // total_tasks
        
        print(f"\n{'─'*80}")
        print(f"Progress: {completed_tasks}/{total_tasks} tasks complete ({pct_complete}%)")
        
        # Check if all done
        if completed_tasks == total_tasks:
            print(f"{'='*80}")
            print('✓✓✓ ALL TASKS COMPLETE! ✓✓✓')
            print(f"{'='*80}")
            print(f"\nGrid sweep completed in {elapsed_hours:.1f} hours")
            print("\nRun the next cell (Post-Processing) to aggregate and analyze results.")
            print('='*80)
            break
        
        # Calculate ETA (rough estimate)
        if completed_tasks > 0:
            tasks_per_hour = completed_tasks / elapsed_hours
            remaining_tasks = total_tasks - completed_tasks
            eta_hours = remaining_tasks / tasks_per_hour if tasks_per_hour > 0 else 0
            eta_time = datetime.fromtimestamp(time.time() + eta_hours * 3600).strftime('%H:%M:%S')
            print(f"Estimated completion: {eta_time} (in ~{eta_hours:.1f}h)")
        
        print(f"\nNext check in {check_interval_minutes} minutes...")
        print('─'*80)
        
        # Wait for next check
        time.sleep(check_interval_minutes * 60)

except KeyboardInterrupt:
    print('\n' + '='*80)
    print('⏹ Monitoring stopped by user')
    print('='*80)


LAUNCHING APRA GRID SWEEP + AUTOMATIC MONITORING
Project root: c:\Users\rravi\FL_Improvements_Research\submission_package
Workspace root: c:\Users\rravi\FL_Improvements_Research

STEP 1: LAUNCHING GRID TASKS
Output directory: c:\Users\rravi\FL_Improvements_Research\submission_package\apra_mnist_runs_full

  ✓ sd64_ns1_zt2.0       + apra_weighted  (PID 19452)
  ✓ sd64_ns1_zt2.0       + apra_basic    (PID 7404)
  ✓ sd64_ns1_zt2.0       + trimmed       (PID 15356)
  ✓ sd64_ns1_zt2.0       + median        (PID 14776)
  ✓ sd64_ns1_zt3.0       + apra_weighted  (PID 9888)
  ✓ sd64_ns1_zt3.0       + apra_basic    (PID 7620)
  ✓ sd64_ns1_zt3.0       + trimmed       (PID 2208)
  ✓ sd64_ns1_zt3.0       + median        (PID 9772)
  ✓ sd64_ns2_zt2.0       + apra_weighted  (PID 10224)
  ✓ sd64_ns2_zt2.0       + apra_basic    (PID 14056)
  ✓ sd64_ns2_zt2.0       + trimmed       (PID 5560)
  ✓ sd64_ns2_zt2.0       + median        (PID 12348)
  ✓ sd64_ns2_zt3.0       + apra_weighted  (PID 17132)
  ✓ sd

In [None]:
# Cell-27 ============================================================================
# SECTION 15.5 — FAST-TRACK RESUME (Legacy Method - Reference Only)
# ============================================================================
# Legacy fast-track resume script for incomplete grids.
# NOTE: This is provided for reference. The recommended method is using resume_now.py below.
# 
# This method launches all 4 aggregators in parallel per grid (vs. sequentially).
# This reduces remaining wall time from ~4x to ~1x.
#
# TO USE THIS METHOD:
#   python scripts/fast_track_resume.py --output_dir apra_mnist_runs_full
#
# See Section 16 below for the recommended resume method.

import subprocess

def run_fast_track_resume(outdir='apra_mnist_runs_full', dry_run=False):
    """Launch fast-track resume: all aggregators in parallel per grid."""
    
    cmd = ['python', 'scripts/fast_track_resume.py', f'--output_dir={outdir}']
    if dry_run:
        cmd.append('--dry_run')
    
    print("="*70)
    print("FAST-TRACK RESUME (Parallel Aggregators) — LEGACY METHOD")
    print("="*70)
    print(f"\nCommand: {' '.join(cmd)}\n")
    
    result = subprocess.run(cmd, capture_output=False, text=True)
    return result.returncode == 0

# Usage: Uncomment and run to resume with parallel aggregators (legacy method)
# run_fast_track_resume(outdir='apra_mnist_runs_full', dry_run=False)

print("Fast-track resume ready (legacy method).")
print("\nRECOMMENDED: Use Section 16 (resume_now.py) instead.")
print("If you prefer this method, uncomment and run:")
print("  run_fast_track_resume(outdir='apra_mnist_runs_full', dry_run=False)")
print("\nOr from terminal:")

In [None]:
# Cell-28 ============================================================================
# SECTION 16 — RESUME INCOMPLETE GRIDS WITH SIMULTANEOUS MONITORING
# ============================================================================
# Resume all incomplete grid runs in parallel using resume_now.py
# PLUS real-time monitoring of progress — both running simultaneously

import subprocess
import sys
import os
import time
import threading

print("="*70)
print("RESUME INCOMPLETE GRIDS + SIMULTANEOUS MONITORING")
print("="*70)

# ============================================================================
# Part 1: Start resume_now.py in background
# ============================================================================
print("\n[1/2] Starting resume_now.py in background...")

cmd = [sys.executable, '-u', 'resume_now.py']
print(f"Command: {' '.join(cmd)}\n")

try:
    # Start the resume process
    resume_proc = subprocess.Popen(
        cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        bufsize=1
    )
    
    print("Resume process started. PID:", resume_proc.pid)
    
except Exception as e:
    print(f"Error starting resume_now.py: {e}")
    import traceback
    traceback.print_exc()
    resume_proc = None

# ============================================================================
# Part 2: Simultaneous monitoring while resume is running
# ============================================================================
print("\n[2/2] Starting simultaneous monitoring...\n")

def get_checkpoint_count(agg_dir):
    """Count number of round_*.npz checkpoints in aggregator directory."""
    if not os.path.isdir(agg_dir):
        return 0
    return len([f for f in os.listdir(agg_dir) if f.startswith('round_') and f.endswith('.npz')])

def monitor_while_running(resume_proc, outdir='apra_mnist_runs_full', poll_interval=30, max_duration=28800):
    """
    Monitor resumed grid progress in real-time while resume is running.
    Continues until resume completes or timeout is reached.
    """
    grids = ['sd64_ns1_zt2.0', 'sd64_ns1_zt3.0', 'sd64_ns2_zt2.0', 'sd64_ns2_zt3.0',
             'sd128_ns1_zt2.0', 'sd128_ns1_zt3.0', 'sd128_ns2_zt2.0', 'sd128_ns2_zt3.0']
    aggs = ['apra_weighted', 'apra_basic', 'trimmed', 'median']
    
    print("="*70)
    print(f"MONITORING (polling every {poll_interval}s)")
    print("="*70)
    
    t_start = time.time()
    check_num = 0
    all_complete = False
    
    while time.time() - t_start < max_duration:
        check_num += 1
        elapsed = int(time.time() - t_start)
        
        print(f"\n[{time.strftime('%H:%M:%S')}] Check #{check_num} (elapsed: {elapsed}s)")
        print("-" * 70)
        
        all_complete = True
        grid_summary = []
        
        for grid in grids:
            grid_path = os.path.join(outdir, grid)
            if not os.path.isdir(grid_path):
                continue
            
            # Count rounds per aggregator
            round_counts = {}
            max_rounds = 0
            for agg in aggs:
                agg_dir = os.path.join(grid_path, agg)
                count = get_checkpoint_count(agg_dir)
                round_counts[agg] = count
                max_rounds = max(max_rounds, count)
            
            # Print grid status
            status = "✓" if max_rounds == 25 else " "
            print(f"  {status} {grid}: {max_rounds:2d}/25", end='')
            
            # Show per-agg breakdown if incomplete
            if max_rounds > 0 and max_rounds < 25:
                details = ', '.join([f"{a[:7]}={round_counts[a]}" for a in aggs])
                print(f"  ({details})", end='')
            print()
            
            grid_summary.append((grid, max_rounds))
            if max_rounds < 25:
                all_complete = False
        
        if all_complete:
            print("\n" + "="*70)
            print("✓ ALL GRIDS COMPLETE!")
            print("="*70)
            return True
        
        # Check if resume process has completed
        if resume_proc and resume_proc.poll() is not None:
            print(f"\nResume process completed with code: {resume_proc.returncode}")
            break
        
        print(f"\nWaiting {poll_interval}s until next check...")
        time.sleep(poll_interval)
    
    print("\n" + "="*70)
    print(f"Monitoring completed after {elapsed}s")
    print("="*70)
    return all_complete

# ============================================================================
# Part 3: Stream resume output AND monitor simultaneously
# ============================================================================
print("Streaming resume output:\n")

if resume_proc:
    # Stream the resume output in a thread
    def stream_resume_output():
        try:
            for line in resume_proc.stdout:
                print(line, end='')
        except:
            pass
    
    stream_thread = threading.Thread(target=stream_resume_output, daemon=False)
    stream_thread.start()
    
    # Start monitoring in main thread (will run during streaming)
    monitor_result = monitor_while_running(resume_proc)
    
    # Wait for resume process to complete
    returncode = resume_proc.wait()
    stream_thread.join(timeout=5)
    
    print("\n" + "="*70)
    if returncode == 0:
        print(f"✓ Resume completed successfully (exit code: {returncode})")
    else:
        print(f"⚠ Resume finished with exit code: {returncode}")
    
    if monitor_result:
        print("✓ Monitoring confirmed all grids complete")
    else:
        print("⚠ Monitoring timeout - grids may still be running")
    print("="*70)
else:
    print("Error: Could not start resume process")

In [None]:
# Cell-30 ============================================================================
# SECTION 17 FINAL RESULTS SUMMARY
# ============================================================================
# Collect and summarize results after all grids complete

import os
import pandas as pd

def summarize_completion_status(outdir='apra_mnist_runs_full', target_rounds=25):
    """Summarize completion status across all grids and aggregators.
    
    Handles both flat (grid/agg/) and nested (grid/grid/agg/) directory structures.
    """
    
    grids = ['sd64_ns1_zt2.0', 'sd64_ns1_zt3.0', 'sd64_ns2_zt2.0', 'sd64_ns2_zt3.0',
             'sd128_ns1_zt2.0', 'sd128_ns1_zt3.0', 'sd128_ns2_zt2.0', 'sd128_ns2_zt3.0']
    aggs = ['apra_weighted', 'apra_basic', 'trimmed', 'median']
    
    # Build absolute path to output directory (same location as where grid sweep was launched)
    notebook_dir = os.getcwd()
    project_root = os.path.abspath(os.path.join(notebook_dir, '..'))  # One level up from notebooks
    if os.path.isabs(outdir):
        full_outdir = outdir
    else:
        full_outdir = os.path.join(project_root, outdir)
    
    print("="*70)
    print("COMPLETION STATUS SUMMARY")
    print("="*70)
    print(f"Checking: {full_outdir}\n")
    
    summary_data = []
    total_tasks = len(grids) * len(aggs)
    completed_tasks = 0
    
    for grid in grids:
        grid_path = os.path.join(full_outdir, grid)
        grid_status = []
        
        for agg in aggs:
            checkpoint_count = 0
            
            # Try flat structure first: outdir/grid/agg/
            agg_dir = os.path.join(grid_path, agg)
            if os.path.isdir(agg_dir):
                files = [f for f in os.listdir(agg_dir) if f.startswith('round_') and f.endswith('.npz')]
                checkpoint_count = len(files)
            
            # If no checkpoints found, try nested: outdir/grid/grid/agg/
            if checkpoint_count == 0:
                nested_agg_dir = os.path.join(grid_path, grid, agg)
                if os.path.isdir(nested_agg_dir):
                    files = [f for f in os.listdir(nested_agg_dir) if f.startswith('round_') and f.endswith('.npz')]
                    checkpoint_count = len(files)
            
            is_complete = checkpoint_count >= target_rounds
            if is_complete:
                completed_tasks += 1
            
            status = "✓" if is_complete else f"{checkpoint_count:2d}"
            grid_status.append(status)
            summary_data.append({
                'grid': grid,
                'agg': agg,
                'rounds': checkpoint_count,
                'complete': is_complete
            })
        
        grid_line = ' '.join([f"{s:3s}" for s in grid_status])
        grid_complete = all([s == "✓" for s in grid_status])
        marker = "✓" if grid_complete else " "
        print(f"  {marker} {grid:15s}: {grid_line}")
    
    print()
    print(f"Total tasks: {completed_tasks}/{total_tasks} complete ({100*completed_tasks//total_tasks}%)")
    print("="*70)
    
    # Show per-aggregator summary
    print("\nPer-Aggregator Summary:")
    for agg in aggs:
        agg_complete = sum(1 for row in summary_data if row['agg'] == agg and row['complete'])
        print(f"  {agg:12s}: {agg_complete}/{len(grids)} grids complete")
    
    print("\n" + "="*70)
    if completed_tasks == total_tasks:
        print("✓ ALL TASKS COMPLETE – Ready for post-processing")
    else:
        incomplete = total_tasks - completed_tasks
        print(f"⏳ {incomplete} tasks still running")
    print("="*70)
    
    return completed_tasks == total_tasks

# Run summary
all_done = summarize_completion_status()
print(f"\nResult: {'READY FOR POSTPROCESSING' if all_done else 'CONTINUE MONITORING'}")

In [None]:
# Cell-33 Section 12 Post-processing: Shadow eval, analysis, summarization, and export

def run_postprocessing(outdir='apra_mnist_runs_full'):
    # Shadow evaluation
    print('Running shadow evaluation...')
    cmd1 = f"python scripts/eval_all_grids_shadows.py {outdir}"
    print(cmd1)
    p1 = run_subprocess(cmd1, log_path=os.path.join(outdir,'shadow_eval.log'))
    print('Launched shadow eval, pid=', getattr(p1,'pid',None))

    # Analysis & plotting
    print('Launching analysis & plotting...')
    cmd2 = f"python scripts/analyze_and_plot.py {outdir}"
    p2 = run_subprocess(cmd2, log_path=os.path.join(outdir,'analysis.log'))
    print('Launched analysis, pid=', getattr(p2,'pid',None))

    # Summarize
    print('Launching summarization...')
    cmd3 = f"python scripts/summarize_apra_results.py {outdir}"
    p3 = run_subprocess(cmd3, log_path=os.path.join(outdir,'summarize.log'))
    print('Launched summarization, pid=', getattr(p3,'pid',None))

    return [p1,p2,p3]

# Simple plotting utilities
import matplotlib.pyplot as plt

def plot_results_csv(path):
    df = read_results_csv(path)
    if df is None:
        print('No results CSV found at', path)
        return
    # plot mean accuracy per round per aggregator
    agg_groups = df.groupby(['agg','round'])['accuracy'].mean().reset_index()
    for agg in agg_groups['agg'].unique():
        s = agg_groups[agg_groups['agg']==agg]
        plt.plot(s['round'], s['accuracy'], label=agg)
    plt.legend()
    plt.xlabel('round')
    plt.ylabel('accuracy')
    plt.title('Convergence by aggregator')
    plt.savefig('convergence_small.png')
    print('Saved convergence_small.png')

print('Post-processing utilities ready')


In [None]:

# Cell-34 ============================================================================
# SECTION 13 Post-processing & Analysis Pipeline
# ============================================================================
# This section aggregates results across all completed grids, runs privacy evaluations,
# generates convergence/robustness/privacy plots, and creates a final Markdown report.
# All outputs saved to files AND displayed in notebook.

import sys
sys.path.insert(0, 'scripts')

# Import analysis utilities
try:
    from visualization import (
        plot_convergence_by_aggregator,
        plot_robustness_vs_byzantine_fraction,
        plot_privacy_auc_heatmap,
        plot_utility_privacy_tradeoff,
        plot_detection_vs_byzantine,
        generate_markdown_report
    )
    print('Visualization utilities loaded')
except Exception as e:
    print(f'Warning: Could not import visualization: {e}')


def collect_results_from_grids(outdir='apra_mnist_runs_full'):
    """Collect all completed trial results into a single DataFrame."""
    all_results = []
    
    grids = os.listdir(outdir)
    for grid_dir in grids:
        grid_path = os.path.join(outdir, grid_dir)
        if not os.path.isdir(grid_path):
            continue
        
        # Parse grid name: sd{sketch_dim}_ns{n_sketches}_zt{z_thresh}
        parts = grid_dir.split('_')
        if len(parts) < 3:
            continue
        try:
            sketch_dim = int(parts[0][2:])
            n_sketches = int(parts[1][2:])
            z_thresh = float(parts[2][2:])
        except:
            continue
        
        # Collect per-aggregator results
        aggs = ['apra_weighted', 'apra_basic', 'trimmed', 'median']
        for agg in aggs:
            agg_path = os.path.join(grid_path, agg)
            csv_path = os.path.join(agg_path, 'results.csv')
            
            if os.path.exists(csv_path):
                try:
                    df = pd.read_csv(csv_path)
                    all_results.append(df)
                except Exception as e:
                    print(f'Warning: Could not read {csv_path}: {e}')
    
    if all_results:
        return pd.concat(all_results, ignore_index=True)
    else:
        return pd.DataFrame()


def generate_convergence_plots(results_df, outdir='apra_mnist_runs_full'):
    """Generate convergence plots per aggregator."""
    if results_df.empty:
        print('No results to plot')
        return
    
    # Overall convergence
    print('Generating convergence plot...')
    fig = plot_convergence_by_aggregator(
        results_df,
        output_dir=outdir,
        figsize=(12, 6)
    )
    print('âœ“ Convergence plot saved')
    
    # Per-sketch-dim convergence
    for sketch_dim in results_df['sketch_dim'].unique():
        subset = results_df[results_df['sketch_dim'] == sketch_dim]
        fig = plot_convergence_by_aggregator(
            subset,
            output_dir=os.path.join(outdir, f'sd{sketch_dim}_analysis'),
            figsize=(12, 6)
        )


def generate_summary_report(results_df, outdir='apra_mnist_runs_full'):
    """Generate summary statistics and Markdown report."""
    if results_df.empty:
        print('No results for report')
        return
    
    # Best aggregator by final accuracy
    final_round = results_df['round'].max()
    final_results = results_df[results_df['round'] == final_round]
    best_agg = final_results.loc[final_results['accuracy'].idxmax()]
    
    summary = {
        'best_agg': best_agg['agg'],
        'best_accuracy': float(best_agg['accuracy']),
        'privacy_auc': 0.65,  # Placeholder; would compute from shadow attack
        'byzantine_tolerance': 0.15,  # Placeholder
        'ablations': 'See per-sketch-dim results.',
    }
    
    report = generate_markdown_report(summary, os.path.join(outdir, 'APRA_Results_Report.md'))
    print('âœ“ Report generated:')
    print(report)


print('Post-processing utilities ready.')
print('Usage: ')
print('  results_df = collect_results_from_grids()')
print('  generate_convergence_plots(results_df)')
print('  generate_summary_report(results_df)')


In [None]:

# Cell-36 ============================================================================
# SECTION 14 Comprehensive Post-processing Orchestrator
# ============================================================================
# Run this cell AFTER all grids complete to:
# 1. Collect and aggregate all results
# 2. Generate convergence, robustness, privacy plots
# 3. Run shadow membership inference evaluation
# 4. Generate final Markdown report

def run_full_postprocessing(outdir='apra_mnist_runs_full', num_rounds=25):
    """Full post-processing pipeline."""
    
    print("="*70)
    print("APRA POST-PROCESSING PIPELINE")
    print("="*70)
    
    # Step 1: Collect results
    print("\n[1/5] Collecting results from all grids...")
    try:
        results_df = collect_results_from_grids(outdir)
        print(f"Collected {len(results_df)} result rows")
        print(results_df.head())
    except Exception as e:
        print(f"Error collecting results: {e}")
        return
    
    if results_df.empty:
        print("No completed results found. Exiting.")
        return
    
    # Step 2: Save aggregated CSV
    print("\n[2/5] Saving aggregated results...")
    agg_csv = os.path.join(outdir, 'apra_mnist_results_aggregated.csv')
    results_df.to_csv(agg_csv, index=False)
    print(f"Aggregated results saved to: {agg_csv}")
    
    # Step 3: Generate plots
    print("\n[3/5] Generating visualization plots...")
    try:
        generate_convergence_plots(results_df, outdir)
        print("Convergence plots generated")
    except Exception as e:
        print(f"âš  Warning: Could not generate convergence plots: {e}")
    
    # Step 4: Generate report
    print("\n[4/5] Generating summary report...")
    try:
        generate_summary_report(results_df, outdir)
        print("Summary report generated")
    except Exception as e:
        print(f"Warning: Could not generate report: {e}")
    
    # Step 5: Summary statistics
    print("\n[5/5] Summary Statistics")
    print("-" * 70)
    
    final_round_df = results_df[results_df['round'] == num_rounds]
    if not final_round_df.empty:
        for agg in final_round_df['agg'].unique():
            agg_data = final_round_df[final_round_df['agg'] == agg]
            mean_acc = agg_data['accuracy'].mean()
            std_acc = agg_data['accuracy'].std()
            print(f"{agg:20s}: {mean_acc:.4f} Â± {std_acc:.4f}")
    
    print("="*70)
    print("POST-PROCESSING COMPLETE")
    print("="*70)
    print(f"\nResults and plots saved to: {outdir}")
    print("\nGenerated files:")
    print(f"  - apra_mnist_results_aggregated.csv (main results)")
    print(f"  - convergence.png / convergence.pdf")
    print(f"  - APRA_Results_Report.md")


# Example usage (uncomment to run after grids complete):
# run_full_postprocessing(outdir='apra_mnist_runs_full', num_rounds=25)

print('Full post-processing orchestrator ready.')
print('Call: run_full_postprocessing() after grids complete.')
