Skip to content

akankshadhar2/SpotShieldProject

Repository files navigation

SpotShield

Training deep learning models on GCP Spot VMs with resilient checkpointing. A term project for COMS 4995 Applied ML in the Cloud (Spring 2026).

What This Is

SpotShield measures the cost-reliability tradeoff of training on preemptible cloud GPUs. We train ResNet-50 on CIFAR-100 under three checkpointing strategies (naive fixed-interval, adaptive EMA-based, milestone on validation improvement) and three interrupt patterns (regular, Poisson, bursty), then compare cost, wasted compute, and accuracy against an on-demand baseline.

The core question: at what preemption rate does Spot training stop being cheaper than on-demand? For our workload, the answer is "never" — even at 4+ interrupts/hour, Spot stays under half the on-demand cost with no accuracy loss.

Key Findings

  • Spot VMs provide ~57% cost savings under simulated interruptions, ~61% when uninterrupted
  • All three checkpointing strategies preserve accuracy within ±1% of the on-demand baseline (63.6%)
  • Strategy choice doesn't matter for this workload: interrupt intervals (~20 min) are long relative to epoch duration (~38 sec), so checkpoint gaps are always small
  • Zero wasted compute across all conditions — emergency SIGTERM checkpoints work reliably

Full findings with numbers and caveats are in planning/findings.md. Plots and CSV data are in analysis/plots/.

Repository Structure

src/spotshield/          # Core Python package (pip install -e .)
├── config.py            # YAML config loading and merging
├── models.py            # Model factory (ResNet-50)
├── data.py              # CIFAR-100 data loading
├── train.py             # Training loop with checkpoint/preemption hooks
├── verification.py      # Post-run result verification
├── checkpoint/          # Checkpointing strategies
│   ├── base.py          #   Abstract base class
│   ├── naive.py         #   Fixed-interval saves
│   ├── adaptive.py      #   EMA-based interval adjustment
│   └── milestone.py     #   Save on validation improvement
├── preemption/          # Preemption handling
│   ├── handler.py       #   SIGTERM signal handler
│   ├── resume.py        #   Checkpoint discovery and loading
│   └── simulate.py      #   Interrupt simulator (subprocess + SIGTERM)
└── logging/
    └── events.py        # Structured JSONL event logging

configs/                 # Experiment configuration (3-layer YAML)
├── models/              #   Model hyperparameters
├── strategies/          #   Checkpoint strategy parameters
└── experiments/         #   Full experiment definitions (14 configs)

scripts/                 # Experiment orchestration
├── run_experiment.py    #   Run a single experiment config
├── run_controlled.py    #   Run with simulated interrupts
├── run_batch.py         #   Batch orchestrator (sequential configs)
├── collect_results.py   #   Gather event logs into results/
├── verify_results.py    #   Validate result integrity
├── generate_configs.py  #   Generate experiment YAML from templates
├── monitor.sh           #   Poll GCP VM for progress
└── spot_watcher.sh      #   Auto-restart Spot VM after preemption

infra/                   # GCP infrastructure scripts
├── spawn_vm.sh          #   Find GPU capacity and create VM
├── setup_vm.sh          #   Bootstrap VM (drivers, venv, package)
├── setup_nat.sh         #   Create Cloud NAT for outbound internet
├── setup_gcs.sh         #   Create GCS bucket
├── check_quotas.sh      #   Check GPU quota
├── gpu_checker.py       #   Scan zones for GPU availability
└── teardown.sh          #   Delete VM

analysis/                # Post-experiment analysis (Phase 6)
├── metrics.py           #   Per-run metric computation and aggregation
├── plots.py             #   7 plot types (PDF + PNG)
└── analyze.py           #   CLI driver: metrics → plots → CSV

results/                 # Collected experiment event logs (42 runs)
experiment/RUNBOOK.md    # Step-by-step operator instructions for GCP
planning/                # Project planning docs and phase breakdowns
tests/                   # 238 fast tests (pytest)

Getting Started

Setup

python -m venv .venv
source .venv/bin/activate
pip install -e ".[dev,analysis]"

Run tests

pytest tests/ \
    --ignore=tests/test_preemption.py \
    --ignore=tests/test_interrupt_resume.py \
    --ignore=tests/test_train.py \
    --ignore=tests/test_resume.py \
    --ignore=tests/test_data.py \
    --ignore=tests/test_run_experiment.py \
    --ignore=tests/test_run_controlled.py

The ignored tests train real models on CPU and take several minutes each. They work, they're just slow.

Run analysis on existing results

python -m analysis.analyze --results-dir ./results --output-dir ./analysis/plots --force

This reads the 42 event logs in results/, computes per-run metrics, generates 7 plot types to analysis/plots/, and writes metrics_summary.csv. The --force flag skips the verification check.

Run a local smoke test

python -m scripts.run_controlled \
    --experiment configs/experiments/controlled_regular_naive_resnet.yaml \
    --run-index 0 \
    --data-root ./data \
    --output-dir ./output \
    --max-samples 256

This trains ResNet-50 on 256 CIFAR-100 samples with simulated interrupts. Takes a few minutes on CPU.

Run on GCP

See experiment/RUNBOOK.md for full instructions covering VM provisioning, deployment, experiment execution, and result collection. The infra scripts in infra/ handle most of the GCP setup. The spot_watcher.sh script automates restart-after-preemption for Spot VMs.

About

SpotShield is a cloud-native deep learning training framework for reliable and cost-efficient AI workloads on GCP Spot VMs. It features automated checkpoint recovery, interruption handling, experiment orchestration, and resilient PyTorch training pipelines for scalable GPU-based model training.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors