# NADF Training Pipeline

This notebook provides a complete pipeline for training Neural Adversarial Distance Functions (NADF).

## Pipeline Steps:
1. **Setup**: Configure paths and parameters
2. **Generate/Load Adversarial Data**: Create or load PGD attacks
3. **Train Probe**: Train the NADF probe model with optional augmentation and upweighting
4. **Evaluate**: Assess probe performance

Run cells in order for the complete workflow.


## 1. Setup and Configuration

Import libraries and configure training parameters.


In [1]:
# Imports
import os
import torch
from types import SimpleNamespace
from dotenv import load_dotenv

from nadf.data.adversarial import load_or_create_dataset
# from nadf.data.datasets import create_training_datasets
# from nadf.training.pipeline import apply_augmentation, save_augmented_dataset, train_probe_model

# Load environment variables
load_dotenv()

print("✓ Environment setup complete!")
print(f"  CUDA available: {torch.cuda.is_available()}")
print(f"  Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")


✓ Environment setup complete!
  CUDA available: False
  Device: cpu


In [2]:
# Configuration - Edit these parameters as needed
args = SimpleNamespace(
    # Data paths
    data_folder=os.getenv("RESNET_MODEL_FOLDER"),
    cache_folder="/rds/general/user/nk1924/home/nadf/data",
    target_class=-1,
    recreate_data=True,
    
    # Model architecture
    model_type="mlp",
    depth=5,
    width=256,
    activation="relu",
    
    # Training hyperparameters
    lr=1e-4,
    weight_decay=1e-4,
    batch_size=512,
    epochs=100,
    
    # Loss function
    loss="mse",
    huber_delta=0.05,
    
    # Upweighting and augmentation
    upweight=1.0,
    augmentation="none",
    num_augmentations=1,
    
    # Output
    save_dir="trained_models",
    checkpoint_name=None,
    verbose=False
)

print("="*60)
print("NADF Probe Training Configuration")
print("="*60)
print(f"Model: {args.model_type}, depth={args.depth}, width={args.width}")
print(f"Upweighting: {args.upweight}x, Loss: {args.loss}")
print(f"Augmentation: {args.augmentation}")
print("="*60)


NADF Probe Training Configuration
Model: mlp, depth=5, width=256
Upweighting: 1.0x, Loss: mse
Augmentation: none


## Debug: Check Model Configuration


In [4]:
# Check what path the model is configured to use
print(f"RESNET_MODEL_FOLDER from .env: {os.getenv('RESNET_MODEL_FOLDER')}")
print(f"Data folder being used: {args.data_folder}")

# Load the model's saved args to see what data path it expects
model_args = torch.load(os.path.join(args.data_folder, "args.info"), map_location="cpu", weights_only=False)
print(f"\nModel's saved configuration:")
print(f"  Dataset: {model_args.dataset}")
print(f"  Data path: {model_args.path}")
print(f"  Path exists: {os.path.exists(model_args.path)}")

# Check if CIFAR10 data exists at that path
cifar10_path = os.path.join(model_args.path, "cifar-10-batches-py")
print(f"\nCIFAR10 data path: {cifar10_path}")
print(f"CIFAR10 data exists: {os.path.exists(cifar10_path)}")


RESNET_MODEL_FOLDER from .env: /rds/general/user/nk1924/home/shared_models/cifar10/adam_augbasic_cosine-60-0.0_wd0.01_resnet18repr-128_cifar10_lr0.001_s0
Data folder being used: /rds/general/user/nk1924/home/shared_models/cifar10/adam_augbasic_cosine-60-0.0_wd0.01_resnet18repr-128_cifar10_lr0.001_s0

Model's saved configuration:
  Dataset: cifar10
  Data path: ./data
  Path exists: True

CIFAR10 data path: ./data/cifar-10-batches-py
CIFAR10 data exists: False


## Fix: Update Model's Data Path


In [5]:
# Update the model's data path to point to where CIFAR10 actually exists
model_args.path = "/rds/general/user/nk1924/home/shared_models"

# Save the updated args back to the model folder
torch.save(model_args, os.path.join(args.data_folder, "args.info"))

print("✓ Updated model's data path to:", model_args.path)
print("✓ Verifying CIFAR10 data exists...")

cifar10_check = os.path.join(model_args.path, "cifar10", "cifar-10-batches-py")
print(f"  Checking: {cifar10_check}")
print(f"  Exists: {os.path.exists(cifar10_check)}")


✓ Updated model's data path to: /rds/general/user/nk1924/home/shared_models
✓ Verifying CIFAR10 data exists...
  Checking: /rds/general/user/nk1924/home/shared_models/cifar10/cifar-10-batches-py
  Exists: False


## Debug: Check what's in the cifar10 directory


In [6]:
# Check what's actually in the cifar10 directory
import subprocess

print("Contents of /rds/general/user/nk1924/home/shared_models/:")
result = subprocess.run(['ls', '-la', '/rds/general/user/nk1924/home/shared_models/'], 
                       capture_output=True, text=True)
print(result.stdout)

print("\nContents of /rds/general/user/nk1924/home/shared_models/cifar10/:")
result = subprocess.run(['ls', '-la', '/rds/general/user/nk1924/home/shared_models/cifar10/'], 
                       capture_output=True, text=True)
print(result.stdout)

# Check if maybe it's directly in shared_models without the extra cifar10 subdirectory
cifar_batch_path = "/rds/general/user/nk1924/home/shared_models/cifar-10-batches-py"
print(f"\nDoes {cifar_batch_path} exist? {os.path.exists(cifar_batch_path)}")

# Or maybe the pytorch CIFAR10 dataset structure
pytorch_cifar_path = "/rds/general/user/nk1924/home/shared_models/CIFAR10"
print(f"Does {pytorch_cifar_path} exist? {os.path.exists(pytorch_cifar_path)}")


Contents of /rds/general/user/nk1924/home/shared_models/:
total 3
drwxr-sr-x.  3 nk1924 hpc-tbirdal 4096 Aug  1 13:22 .
drwx--s---. 26 nk1924 hpc-tbirdal 4096 Oct 28 13:09 ..
drwxr-xr-x. 18 nk1924 hpc-tbirdal 4096 Jul 31 16:48 cifar10


Contents of /rds/general/user/nk1924/home/shared_models/cifar10/:
total 946
drwxr-xr-x. 18 nk1924 hpc-tbirdal  4096 Jul 31 16:48 .
drwxr-sr-x.  3 nk1924 hpc-tbirdal  4096 Aug  1 13:22 ..
drwxr-xr-x.  4 nk1924 hpc-tbirdal  4096 Jul 31 16:46 adam_augbasic_cosine-60-0.0_wd0.001_resnet18_cifar10_lr0.0001_s0
-rw-r--r--.  1 nk1924 hpc-tbirdal 31823 Jul 31 16:46 adam_augbasic_cosine-60-0.0_wd0.001_resnet18_cifar10_lr0.0001_s0_traj.log
drwxr-xr-x.  4 nk1924 hpc-tbirdal  4096 Jul 31 16:46 adam_augbasic_cosine-60-0.0_wd0.001_resnet18_cifar10_lr0.001_s0
-rw-r--r--.  1 nk1924 hpc-tbirdal 43249 Jul 31 16:46 adam_augbasic_cosine-60-0.0_wd0.001_resnet18_cifar10_lr0.001_s0_traj.log
drwxr-xr-x.  4 nk1924 hpc-tbirdal  4096 Jul 31 16:48 adam_augbasic_cosine-60-0.0_wd0.001

## Download CIFAR10 Dataset


In [None]:
# Download CIFAR10 dataset to the shared_models directory
# This will bypass SSL verification since we're on a cluster
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

from torchvision import datasets

# Download CIFAR10 to the shared_models directory
print("Downloading CIFAR10 dataset (this may take a few minutes)...")
cifar_download_path = "/rds/general/user/nk1924/home/shared_models"

# Download train and test sets
train_dataset = datasets.CIFAR10(root=cifar_download_path, train=True, download=True)
test_dataset = datasets.CIFAR10(root=cifar_download_path, train=False, download=True)

print(f"✓ CIFAR10 downloaded successfully to {cifar_download_path}")
print(f"✓ Train set: {len(train_dataset)} samples")
print(f"✓ Test set: {len(test_dataset)} samples")

# Verify it exists
cifar_batch_path = os.path.join(cifar_download_path, "cifar-10-batches-py")
print(f"\n✓ Dataset files exist at: {cifar_batch_path}")
print(f"  Verified: {os.path.exists(cifar_batch_path)}")


## 2. Generate or Load Adversarial Examples


In [3]:
num_attacks_eps_coef = [(4, 0.25), (2, 0.5), (3, 1), (1, 2)]

dataset = load_or_create_dataset(
    folder=args.data_folder,
    target_class=args.target_class,
    num_attacks_eps_coef=num_attacks_eps_coef,
    splits=["train", "val", "test"],
    recreate=args.recreate_data,
)

for split in ["train", "val", "test"]:
    print(f"{split}: {len(dataset['z_clean'][split])} clean, {len(dataset['z_adv'][split])} adversarial")


  Loading train data...


RuntimeError: Dataset not found or corrupted. You can use download=True to download it

## 3. Apply Augmentation


In [None]:
augmentation_stats = apply_augmentation(dataset, args)


## 4. Create Training Datasets


In [None]:
datasets = create_training_datasets(dataset, args)


## 5. Save Augmented Dataset (if applicable)


In [None]:
if args.augmentation != "none" and augmentation_stats:
    save_augmented_dataset(dataset, args)


## 6. Train Probe


In [None]:
train_probe_model(datasets, args)
print("✓ Training complete!")
