# StealthGAN-IDS Training on Google Colab

This notebook provides a complete pipeline for training and evaluating StealthGAN-IDS on Google Colab.

**Features:**
- Automatic GPU detection and setup
- Dataset download and preprocessing
- Training with checkpointing
- Evaluation and visualization
- Easy result download

**Supported Datasets:**
- NSL-KDD (legacy)
- CIC-IDS2017
- CIC-IDS2018
- UNSW-NB15

## 1. Setup and Installation

In [1]:
# Check GPU availability
import torch
import os
from pathlib import Path

print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("‚ö†Ô∏è  No GPU detected! Training will be very slow on CPU.")
    print("Go to Runtime > Change runtime type > Hardware accelerator > GPU")

CUDA available: True
GPU: NVIDIA GeForce RTX 5090
CUDA version: 12.8
GPU Memory: 33.66 GB


In [3]:
# Clone the repository
# Auto-detect environment (Colab vs QuickPod vs local)
import os
from pathlib import Path

# Detect base directory
if Path("/workspace").exists():
    # QuickPod environment
    base_dir = Path("/workspace")
elif Path("/content").exists():
    # Google Colab environment
    base_dir = Path("/content")
else:
    # Local or other environment
    base_dir = Path.cwd()

repo_dir = base_dir / "SGAN-IDS"
repo_url = "https://github.com/yourusername/SGAN-IDS.git"  # ‚ö†Ô∏è UPDATE THIS if cloning

print(f"Detected environment: {base_dir}")
print(f"Repository directory: {repo_dir}")

# Check if repo exists in current directory or base directory
if repo_dir.exists():
    print(f"‚úÖ Repository found at {repo_dir}")
    os.chdir(repo_dir)
    if (repo_dir / ".git").exists():
        !git pull
elif Path("SGAN-IDS").exists():
    # Check if repo is in current working directory
    repo_dir = Path("SGAN-IDS").resolve()
    print(f"‚úÖ Repository found at {repo_dir}")
    os.chdir(repo_dir)
elif Path.cwd().name == "SGAN-IDS":
    # Already in the repo directory
    repo_dir = Path.cwd()
    print(f"‚úÖ Already in repository directory: {repo_dir}")
else:
    # Try to clone or use current directory
    if repo_url != "https://github.com/yourusername/SGAN-IDS.git":
        print(f"Cloning repository to {repo_dir}...")
        !git clone {repo_url} {repo_dir}
        os.chdir(repo_dir)
    else:
        # Use current directory as repo (for QuickPod where files are already there)
        repo_dir = Path.cwd()
        print(f"‚ö†Ô∏è  Using current directory as repository: {repo_dir}")
        print("If this is wrong, update repo_url above or ensure SGAN-IDS folder exists")

print(f"Working directory: {os.getcwd()}")
print(f"Repository root: {repo_dir}")

Detected environment: /workspace
Repository directory: /workspace/SGAN-IDS
‚úÖ Repository found at /workspace/SGAN-IDS
Already up to date.
Working directory: /workspace/SGAN-IDS
Repository root: /workspace/SGAN-IDS


In [4]:
# Install dependencies
# Check if torch is already installed with CUDA support (Colab has it pre-installed)
if not torch.cuda.is_available():
    print("Installing PyTorch with CUDA support...")
    !pip install -q torch torchvision --index-url https://download.pytorch.org/whl/cu118
else:
    print(f"PyTorch {torch.__version__} already installed with CUDA support ‚úì")

!pip install -q -r requirements.txt

# Optional: Install additional evaluation dependencies
!pip install -q xgboost lightgbm pyyaml

print("‚úÖ Dependencies installed!")

PyTorch 2.9.1+cu128 already installed with CUDA support ‚úì
[0m‚úÖ Dependencies installed!


## 2. Dataset Setup

Choose your dataset and download/prepare it.

In [5]:
# Configuration
from pathlib import Path

# Auto-detect data root based on environment
if Path("/workspace").exists():
    DATA_ROOT = Path("/workspace/data")  # QuickPod
elif Path("/content").exists():
    DATA_ROOT = Path("/content/data")  # Colab
else:
    DATA_ROOT = Path("./data")  # Local

DATA_ROOT.mkdir(exist_ok=True, parents=True)

DATASET = "cic_ids2018"  # Options: nsl_kdd, cic_ids2017, cic_ids2018, unsw_nb15, unified

print(f"Selected dataset: {DATASET}")
print(f"Data root: {DATA_ROOT}")

Selected dataset: cic_ids2018
Data root: /workspace/data


In [None]:
# Download CIC-IDS2018 from Kaggle (requires Kaggle API)
# Option 1: Using Kaggle API (recommended)
if DATASET == "cic_ids2018":
    print("To download CIC-IDS2018:")
    print("1. Install Kaggle: !pip install kaggle")
    print("2. Upload kaggle.json (from Kaggle account settings)")
    print("3. Run: !mkdir -p ~/.kaggle && cp kaggle.json ~/.kaggle/ && chmod 600 ~/.kaggle/kaggle.json")
    print("4. Then run:")
    print(f"   !kaggle datasets download -d solarmainframe/ids-intrusion-csv -p {DATA_ROOT}")
    print(f"   !unzip -q {DATA_ROOT}/ids-intrusion-csv.zip -d {DATA_ROOT}/CIC-IDS2018")
    print("\n‚ö†Ô∏è  Or manually upload dataset files via Colab file browser")

# Option 2: Manual upload via Colab file browser
# Files > Upload to session storage > Extract to DATA_ROOT

In [6]:
# Quick fix for QuickPod - set paths manually
import os
from pathlib import Path

# QuickPod paths
repo_dir = Path("/workspace/SGAN-IDS")
os.chdir(repo_dir)
print(f"Working directory: {os.getcwd()}")
print(f"Repository: {repo_dir}")

Working directory: /workspace/SGAN-IDS
Repository: /workspace/SGAN-IDS


In [7]:
# Preprocess the dataset (optional - training will preprocess if needed)
# Note: preprocess_data.py only supports: nsl_kdd, cic_ids2017, unified
# For cic_ids2018 and unsw_nb15, preprocessing happens automatically during training

import sys
import subprocess

# Only preprocess if dataset is supported by preprocess script
if DATASET in ["nsl_kdd", "cic_ids2017", "unified"]:
    print(f"Preprocessing {DATASET}...")
    
    cmd = [
        "python", "scripts/preprocess_data.py",
        "--data-root", str(DATA_ROOT),
        "--dataset", DATASET
    ]
    
    result = subprocess.run(cmd, cwd=repo_dir)
    
    if result.returncode == 0:
        print("‚úÖ Preprocessing complete!")
    else:
        print(f"‚ùå Preprocessing failed with exit code {result.returncode}")
else:
    print(f"‚ö†Ô∏è  Dataset {DATASET} will be preprocessed automatically during training")
    print("Skipping standalone preprocessing step...")

‚ö†Ô∏è  Dataset cic_ids2018 will be preprocessed automatically during training
Skipping standalone preprocessing step...


## 3. Hyperparameter Tuning (Optional)

Use Optuna to automatically find optimal hyperparameters before full training.
This step is optional but recommended for best results.

In [8]:
# Tuning configuration
# Use fewer samples and epochs for faster tuning
TUNE_TRIALS = 30  # Number of Optuna trials (more = better results, slower)
TUNE_EPOCHS = 15  # Epochs per trial (fewer = faster, less accurate)
TUNE_SAMPLES = 100000  # Samples for tuning (smaller = faster)
SEED = 42

print("Hyperparameter Tuning Configuration:")
print(f"  Trials: {TUNE_TRIALS}")
print(f"  Epochs per trial: {TUNE_EPOCHS}")
print(f"  Samples: {TUNE_SAMPLES}")
print("\n‚ö†Ô∏è  Tuning can take 1-3 hours depending on settings.")

Hyperparameter Tuning Configuration:
  Trials: 30
  Epochs per trial: 15
  Samples: 100000

‚ö†Ô∏è  Tuning can take 1-3 hours depending on settings.


In [9]:
# Run hyperparameter tuning with Optuna
# Skip this cell if you want to use default hyperparameters
import subprocess
import gc
import torch

# Clear memory
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

cmd = [
    "python", "scripts/tune_optuna.py",
    "--data-root", str(DATA_ROOT),
    "--dataset", DATASET,
    "--n-trials", str(TUNE_TRIALS),
    "--tune-epochs", str(TUNE_EPOCHS),
    "--max-samples", str(TUNE_SAMPLES),
    "--seed", str(SEED),
    "--output", "best_hyperparams.json",
]

print("Starting hyperparameter tuning...")
print(f"Command: {' '.join(cmd)}")
print("\n‚è≥ This will run multiple trials to find optimal hyperparameters.")
print("üí° Best parameters will be saved to best_hyperparams.json")

result = subprocess.run(cmd, cwd=repo_dir)

if result.returncode == 0:
    print("\n‚úÖ Tuning completed! Check best_hyperparams.json for optimal values.")
elif result.returncode == -9:
    print("\n‚ùå Tuning killed (OOM) - Try reducing TUNE_SAMPLES to 50000")
else:
    print(f"\n‚ùå Tuning failed with exit code {result.returncode}")

Starting hyperparameter tuning...
Command: python scripts/tune_optuna.py --data-root /workspace/data --dataset cic_ids2018 --n-trials 30 --tune-epochs 15 --max-samples 100000 --seed 42 --output best_hyperparams.json

‚è≥ This will run multiple trials to find optimal hyperparameters.
üí° Best parameters will be saved to best_hyperparams.json


[32m[I 2026-01-19 07:18:30,671][0m A new study created in memory with name: stealthgan_tuning[0m
  0%|          | 0/30 [00:00<?, ?it/s]

StealthGAN-IDS Hyperparameter Tuning with Optuna
[tune] Using device: cuda

[1/3] Loading data...
[tune] Limiting dataset to 100000 samples
[CIC-IDS2018] Early stop after 1 files, ~150000 rows
[CIC-IDS2018] Sampling 100000 from 150000 rows
[CIC-IDS2018] Loaded shape: (100000, 80)
[DataForge] Loaded datasets: ['cic_ids2018']
[DataForge] Dropped 3 rows with NaN/inf values
[DataForge] Data shape after cleaning: (24904, 80)
[DataForge] Number of classes: 2
[DataForge] Classes: ['Benign', 'FTP-BruteForce']
[DataForge] Split sizes - Train: 17432, Val: 3736, Test: 3736
[DataForge] Fitting transformers on training data...
[DataForge] Transforming validation data...
[DataForge] Transforming test data...
[DataForge] Converting sparse to dense...
[DataForge] Number of features after encoding: 4602
[DataForge] Reducing dimensions from 4602 to 256 using TruncatedSVD...
[DataForge] Explained variance: 98.59%
[DataForge] Final number of features: 256
[tune] Data: 256 features, 2 classes
[tune] Train:

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Best trial: 0. Best value: 0:  33%|‚ñà‚ñà‚ñà‚ñé      | 10/30 [09:41<18:54, 56.75s/it]


[Trial 0] lr_g=4.33e-05, lr_d=4.12e-04, critic=4, gp=12.4, fm=0.86, batch=256
  Epoch 2: D=-730.42 G=1264.91 | F1 baseline=1.000 aug=1.000 (+0.0000) std=0.640
  Epoch 5: D=-758.59 G=1257.36 | F1 baseline=1.000 aug=1.000 (+0.0000) std=0.619
  Epoch 8: D=-756.62 G=1260.25 | F1 baseline=1.000 aug=1.000 (+0.0000) std=0.604
  Epoch 11: D=-762.65 G=1252.38 | F1 baseline=1.000 aug=1.000 (+0.0000) std=0.601
  Epoch 14: D=-771.99 G=1226.13 | F1 baseline=1.000 aug=1.000 (+0.0000) std=0.607
[32m[I 2026-01-19 07:19:33,145][0m Trial 0 finished with value: 0.0 and parameters: {'lr_g': 4.3284502212938785e-05, 'lr_d': 0.0004123206532618727, 'critic_updates': 4, 'gp_lambda': 12.374511199743695, 'feature_matching_weight': 0.864491338167939, 'batch_size': 256, 'latent_dim': 100, 'ema_decay': 0.9996021075364038}. Best is trial 0 with value: 0.0.[0m

[Trial 1] lr_g=2.60e-04, lr_d=2.29e-05, critic=1, gp=4.5, fm=1.59, batch=64
  Epoch 2: D=25.82 G=-5.39 | F1 baseline=1.000 aug=1.000 (+0.0000) std=1.237
 

Best trial: 0. Best value: 0:  63%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé   | 19/30 [22:35<14:53, 81.26s/it] 

[32m[I 2026-01-19 07:28:11,615][0m Trial 9 finished with value: 0.0 and parameters: {'lr_g': 2.4474916579073785e-05, 'lr_d': 1.3514082247401414e-05, 'critic_updates': 2, 'gp_lambda': 4.0632044578260835, 'feature_matching_weight': 4.655518496478608, 'batch_size': 256, 'latent_dim': 128, 'ema_decay': 0.995339488194965}. Best is trial 0 with value: 0.0.[0m

[Trial 10] lr_g=2.09e-04, lr_d=3.85e-04, critic=3, gp=18.7, fm=2.05, batch=128
  Epoch 2: D=-357.78 G=634.41 | F1 baseline=1.000 aug=1.000 (+0.0000) std=0.631
  Epoch 5: D=-451.85 G=879.55 | F1 baseline=1.000 aug=1.000 (+0.0000) std=0.711
  Epoch 8: D=-450.21 G=887.17 | F1 baseline=1.000 aug=1.000 (+0.0000) std=1.119
  Epoch 11: D=-444.22 G=887.39 | F1 baseline=1.000 aug=1.000 (+0.0000) std=1.705
  Epoch 14: D=-444.44 G=887.72 | F1 baseline=1.000 aug=1.000 (+0.0000) std=2.628
[32m[I 2026-01-19 07:29:10,895][0m Trial 10 finished with value: 0.0 and parameters: {'lr_g': 0.00020917152515896608, 'lr_d': 0.0003845072704730259, 'critic_

Best trial: 0. Best value: 0:  97%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã| 29/30 [31:55<00:50, 50.30s/it]

  Epoch 8: D=-370.25 G=612.48 | F1 baseline=1.000 aug=1.000 (+0.0000) std=1.734
  Epoch 11: D=-375.10 G=597.51 | F1 baseline=1.000 aug=1.000 (+0.0000) std=3.304
  Epoch 14: D=-382.97 G=559.75 | F1 baseline=1.000 aug=1.000 (+0.0000) std=4.122
[32m[I 2026-01-19 07:43:07,663][0m Trial 19 finished with value: 0.0 and parameters: {'lr_g': 5.300907148719937e-05, 'lr_d': 5.5738676559539255e-05, 'critic_updates': 5, 'gp_lambda': 13.533989746918186, 'feature_matching_weight': 1.941924755842845, 'batch_size': 128, 'latent_dim': 128, 'ema_decay': 0.9959168124080359}. Best is trial 0 with value: 0.0.[0m

[Trial 20] lr_g=1.12e-04, lr_d=2.26e-05, critic=4, gp=2.2, fm=2.58, batch=256
  Epoch 2: D=118.73 G=-4.93 | F1 baseline=1.000 aug=1.000 (+0.0000) std=0.621
  Epoch 5: D=-3.50 G=5.96 | F1 baseline=1.000 aug=1.000 (+0.0000) std=0.626
  Epoch 8: D=-74.01 G=44.11 | F1 baseline=1.000 aug=1.000 (+0.0000) std=0.872
  Epoch 11: D=-233.09 G=251.02 | F1 baseline=1.000 aug=1.000 (+0.0000) std=1.088
  Epoc

Best trial: 0. Best value: 0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 30/30 [34:48<00:00, 69.61s/it]



[Trial 29] lr_g=2.68e-05, lr_d=1.26e-04, critic=4, gp=17.4, fm=0.55, batch=64
  Epoch 2: D=-200.51 G=309.52 | F1 baseline=1.000 aug=1.000 (+0.0000) std=0.758
  Epoch 5: D=-198.52 G=303.55 | F1 baseline=1.000 aug=1.000 (+0.0000) std=1.205
  Epoch 8: D=-192.48 G=311.40 | F1 baseline=1.000 aug=1.000 (+0.0000) std=3.966
  Epoch 11: D=-200.87 G=284.85 | F1 baseline=1.000 aug=1.000 (+0.0000) std=2.285
  Epoch 14: D=-201.72 G=288.51 | F1 baseline=1.000 aug=1.000 (+0.0000) std=1.915
[32m[I 2026-01-19 07:53:18,816][0m Trial 29 finished with value: 0.0 and parameters: {'lr_g': 2.676822676846602e-05, 'lr_d': 0.00012553619932692753, 'critic_updates': 4, 'gp_lambda': 17.426457016105374, 'feature_matching_weight': 0.5465090978585505, 'batch_size': 64, 'latent_dim': 100, 'ema_decay': 0.997673899403075}. Best is trial 0 with value: 0.0.[0m

TUNING COMPLETE

Study statistics:
  Number of finished trials: 30
  Number of pruned trials: 0
  Number of complete trials: 30

Best trial:
  Value (F1 improv

In [16]:
import subprocess
import gc
import torch

# Clear memory
gc.collect()

0

In [15]:
!python scripts/tune_optuna.py --data-root /workspace/data --dataset cic_ids2018 --n-trials 50 --max-samples 400000

StealthGAN-IDS Hyperparameter Tuning with Optuna
[tune] Using device: cuda

[1/3] Loading data...
[tune] Limiting dataset to 400000 samples
[CIC-IDS2018] Early stop after 1 files, ~600000 rows
[CIC-IDS2018] Sampling 400000 from 600000 rows
[CIC-IDS2018] Loaded shape: (400000, 80)
[DataForge] Loaded datasets: ['cic_ids2018']
[DataForge] Dropped 779 rows with NaN/inf values
[DataForge] Data shape after cleaning: (260200, 80)
[DataForge] Number of classes: 3
[DataForge] Classes: ['Benign', 'FTP-BruteForce', 'SSH-Bruteforce']
[DataForge] Split sizes - Train: 182140, Val: 39030, Test: 39030
[DataForge] Fitting transformers on training data...
[DataForge] Transforming validation data...
[DataForge] Transforming test data...
[DataForge] Converting sparse to dense...
[DataForge] Number of features after encoding: 26528
[DataForge] Reducing dimensions from 26528 to 256 using TruncatedSVD...
[DataForge] Explained variance: 98.63%
[DataForge] Final number of features: 256
[tune] Data: 256 feature

In [None]:
# View tuning results
import json
from pathlib import Path

hp_file = repo_dir / "best_hyperparams.json"
if hp_file.exists():
    with open(hp_file) as f:
        best_hp = json.load(f)
    
    print("=" * 50)
    print("Best Hyperparameters Found")
    print("=" * 50)
    print(f"\nF1 Improvement: {best_hp['best_value']:.4f}")
    print(f"Total trials: {best_hp['n_trials']}")
    print("\nOptimal parameters:")
    for key, value in best_hp['params'].items():
        if isinstance(value, float):
            print(f"  {key}: {value:.6f}")
        else:
            print(f"  {key}: {value}")
    
    print("\nüí° These values will be used for full training below.")
else:
    print("‚ö†Ô∏è  No tuning results found. Run tuning first or skip to use defaults.")

## 4. Full Training

Train the GAN with optimal hyperparameters (from tuning or defaults).

In [None]:
# Training configuration
EPOCHS = 100
BATCH_SIZE = 128  # Will be overridden by tuned value if available
MAX_SAMPLES = 300000  # Limit dataset size for memory
CHECKPOINT_INTERVAL = 10
USE_AMP = True  # Mixed precision
NUM_WORKERS = 2

# Load tuned hyperparameters if available
import json
from pathlib import Path

hp_file = repo_dir / "best_hyperparams.json"
tuned_params = {}
if hp_file.exists():
    with open(hp_file) as f:
        best_hp = json.load(f)
    tuned_params = best_hp.get('params', {})
    if 'batch_size' in tuned_params:
        BATCH_SIZE = tuned_params['batch_size']
    print("‚úÖ Using tuned hyperparameters")
else:
    print("‚ö†Ô∏è  No tuned parameters found, using defaults")

# Clear GPU memory
import gc
import torch
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print(f"GPU memory available: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

print(f"\nTraining configuration:")
print(f"  Dataset: {DATASET}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Max samples: {MAX_SAMPLES}")
print(f"  Epochs: {EPOCHS}")
print(f"  Mixed precision: {USE_AMP}")

In [None]:
# Start full training
import subprocess

cmd = [
    "python", "scripts/train_gan.py",
    "--data-root", str(DATA_ROOT),
    "--dataset", DATASET,
    "--epochs", str(EPOCHS),
    "--batch-size", str(BATCH_SIZE),
    "--checkpoint-interval", str(CHECKPOINT_INTERVAL),
    "--seed", str(SEED),
    "--num-workers", str(NUM_WORKERS),
]

if MAX_SAMPLES:
    cmd.extend(["--max-samples", str(MAX_SAMPLES)])

if USE_AMP:
    cmd.append("--amp")

print("Starting full training...")
print(f"Command: {' '.join(cmd)}")
print("\n‚ö†Ô∏è  This may take several hours. Checkpoints will be saved periodically.")
print("üí° If you get OOM errors, reduce MAX_SAMPLES (try 200000) or BATCH_SIZE (try 64)")

result = subprocess.run(cmd, cwd=repo_dir)

if result.returncode == 0:
    print("\n‚úÖ Training completed successfully!")
elif result.returncode == -9:
    print("\n‚ùå Training killed (exit code -9) - Out of Memory!")
    print("   Try reducing MAX_SAMPLES to 200000 or BATCH_SIZE to 64")
else:
    print(f"\n‚ùå Training failed with exit code {result.returncode}")

In [None]:
# Check training outputs
output_files = [
    "training_stats.csv",
    "generator_best.pth",
    "generator_ema_best.pth",
    "generator.pth",
    "discriminator.pth",
]

print("Training outputs:")
for fname in output_files:
    path = repo_dir / fname
    if path.exists():
        size_mb = path.stat().st_size / 1e6
        print(f"  ‚úÖ {fname} ({size_mb:.2f} MB)")
    else:
        print(f"  ‚ùå {fname} (not found)")

# List checkpoints
checkpoints = list(repo_dir.glob("checkpoint_epoch_*.pth"))
if checkpoints:
    print(f"\nCheckpoints found: {len(checkpoints)}")
    for cp in sorted(checkpoints)[-5:]:  # Show last 5
        size_mb = cp.stat().st_size / 1e6
        print(f"  {cp.name} ({size_mb:.2f} MB)")

## 5. Evaluation

Evaluate the trained generator with comprehensive metrics.

In [None]:
# Evaluation configuration
GENERATOR_PATH = "generator_ema_best.pth"  # Use EMA version (better quality)
N_PER_CLASS = 2000  # Synthetic samples per class
TARGET_MINORITY = True  # Focus on minority classes
CV_FOLDS = 5  # Cross-validation folds
OUTPUT_DIR = "eval_outputs"

print(f"Evaluation configuration:")
print(f"  Generator: {GENERATOR_PATH}")
print(f"  Samples per class: {N_PER_CLASS}")
print(f"  Target minority: {TARGET_MINORITY}")
print(f"  CV folds: {CV_FOLDS}")

In [None]:
# Run evaluation
import subprocess
import gc
import torch

# Clear memory before evaluation
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

cmd = [
    "python", "scripts/eval_gan.py",
    "--data-root", str(DATA_ROOT),
    "--dataset", DATASET,
    "--generator-path", GENERATOR_PATH,
    "--n-per-class", str(N_PER_CLASS),
    "--cv-folds", str(CV_FOLDS),
    "--output-dir", OUTPUT_DIR,
]

# Pass max_samples to prevent OOM during data loading
if MAX_SAMPLES:
    cmd.extend(["--max-samples", str(MAX_SAMPLES)])

if TARGET_MINORITY:
    cmd.append("--target-minority")

print("Starting evaluation...")
print(f"Command: {' '.join(cmd)}")

result = subprocess.run(cmd, cwd=repo_dir)

if result.returncode == 0:
    print("\n‚úÖ Evaluation completed successfully!")
elif result.returncode == -9:
    print("\n‚ùå Evaluation killed (exit code -9) - Out of Memory!")
    print("   Try reducing MAX_SAMPLES to 200000 or N_PER_CLASS to 1000")
else:
    print(f"\n‚ùå Evaluation failed with exit code {result.returncode}")

In [None]:
# View evaluation results
import json

results_file = repo_dir / OUTPUT_DIR / "evaluation_results.json"
if results_file.exists():
    with open(results_file) as f:
        results = json.load(f)
    
    print("Evaluation Results Summary:")
    print("=" * 50)
    
    if "classifier_results" in results:
        print("\nClassifier Performance:")
        for classifier, metrics in results["classifier_results"].items():
            if "baseline" in metrics and "augmented" in metrics:
                baseline = metrics["baseline"]["mean"]
                augmented = metrics["augmented"]["mean"]
                improvement = augmented - baseline
                print(f"  {classifier}:")
                print(f"    Baseline F1: {baseline:.4f}")
                print(f"    Augmented F1: {augmented:.4f}")
                print(f"    Improvement: {improvement:+.4f}")
    
    if "distribution_metrics" in results:
        print("\nDistribution Quality:")
        for metric, value in results["distribution_metrics"].items():
            print(f"  {metric}: {value:.4f}")
else:
    print("‚ö†Ô∏è  Results file not found. Run evaluation first.")

## 5. Visualizations

View generated plots and visualizations.

In [None]:
# Display evaluation plots
from IPython.display import Image, display

plots_dir = repo_dir / OUTPUT_DIR / "plots"
if plots_dir.exists():
    plot_files = list(plots_dir.glob("*.png"))
    if plot_files:
        print(f"Found {len(plot_files)} plot(s):")
        for plot_file in plot_files:
            print(f"\n{plot_file.name}:")
            display(Image(str(plot_file)))
    else:
        print("No plots found in plots directory.")
else:
    print("Plots directory not found. Run evaluation first.")

## 6. Download Results

Download your trained models and results.

In [None]:
# Create download package
import shutil
from pathlib import Path

download_dir = Path("/content/downloads")
download_dir.mkdir(exist_ok=True)

# Copy important files
files_to_download = [
    "generator_ema_best.pth",
    "generator_best.pth",
    "discriminator.pth",
    "training_stats.csv",
]

# Copy checkpoints
checkpoints = list(repo_dir.glob("checkpoint_epoch_*.pth"))
if checkpoints:
    checkpoint_dir = download_dir / "checkpoints"
    checkpoint_dir.mkdir(exist_ok=True)
    for cp in checkpoints:
        shutil.copy2(cp, checkpoint_dir / cp.name)
    print(f"Copied {len(checkpoints)} checkpoints")

# Copy evaluation outputs
eval_dir = repo_dir / OUTPUT_DIR
if eval_dir.exists():
    shutil.copytree(eval_dir, download_dir / OUTPUT_DIR, dirs_exist_ok=True)
    print("Copied evaluation outputs")

# Copy files
for fname in files_to_download:
    src = repo_dir / fname
    if src.exists():
        shutil.copy2(src, download_dir / fname)

print(f"\n‚úÖ Files prepared for download in {download_dir}")
print("\nTo download:")
print("1. Use Colab file browser (left sidebar)")
print("2. Navigate to /content/downloads")
print("3. Right-click files and select 'Download'")

In [None]:
# Alternative: Create a zip file for easy download
import shutil
from pathlib import Path
from google.colab import files

zip_path = "/content/stealthgan_results.zip"
download_dir = Path("/content/downloads")

if download_dir.exists():
    shutil.make_archive(
        zip_path.replace(".zip", ""),
        "zip",
        download_dir
    )
    
    size_mb = Path(zip_path).stat().st_size / 1e6
    print(f"‚úÖ Created zip file: {zip_path} ({size_mb:.2f} MB)")
    print("\nDownloading zip file...")
    files.download(zip_path)
else:
    print("‚ö†Ô∏è  Download directory not found. Run the previous cell first.")

## 7. Resume Training (Optional)

Resume training from a checkpoint if your session disconnects.

In [None]:
# Resume training from checkpoint
CHECKPOINT_PATH = "checkpoint_epoch_50.pth"  # ‚ö†Ô∏è Update with your checkpoint name
RESUME_EPOCHS = 100  # Total epochs (will continue from checkpoint)

cmd = [
    "python", "scripts/train_gan.py",
    "--data-root", str(DATA_ROOT),
    "--dataset", DATASET,
    "--epochs", str(RESUME_EPOCHS),
    "--batch-size", str(BATCH_SIZE),
    "--resume", CHECKPOINT_PATH,
    "--seed", str(SEED),
    "--num-workers", str(NUM_WORKERS),
]

if MAX_SAMPLES:
    cmd.extend(["--max-samples", str(MAX_SAMPLES)])

if USE_AMP:
    cmd.append("--amp")

print("To resume training, update CHECKPOINT_PATH above and uncomment the last line:")
print(f"{' '.join(cmd)}")

# Uncomment to run:
# result = subprocess.run(cmd, cwd=repo_dir)

## Troubleshooting

**Common Issues:**

1. **Out of Memory (exit code -9)**:
   - Reduce `MAX_SAMPLES` (try 200000 or 100000)
   - Reduce `BATCH_SIZE` (try 32 or 16)
   - Use a smaller dataset (`nsl_kdd` instead of `cic_ids2018`)
2. **Session Timeout**: Colab free tier has 12hr limit. Use checkpoints to resume.
3. **Dataset Not Found**: Ensure dataset is downloaded and extracted correctly.
4. **Slow Training**: Enable GPU (Runtime > Change runtime type > GPU)

**Memory Guide for Colab Free Tier (12GB RAM):**
| Dataset | Recommended MAX_SAMPLES |
|---------|------------------------|
| NSL-KDD | None (all ~125K) |
| CIC-IDS2017 | 500000 |
| CIC-IDS2018 | 500000 |
| UNSW-NB15 | None (all ~175K) |

**Tips:**
- Save checkpoints frequently
- Use Colab Pro for longer sessions (24hr limit) and more RAM
- Download results before session expires
- Monitor GPU usage: `!nvidia-smi`

In [None]:
# Monitor GPU usage
!nvidia-smi