# üöÄ PyReflect GPU Worker

This notebook connects to your VPS to process training jobs on Colab's free GPU.

**How it works:**
1. Connects to your Redis queue on the VPS
2. Polls for queued training jobs
3. Runs training on Colab GPU
4. Sends results back through Redis

**Setup:**
1. Update the `REDIS_URL` below with your VPS credentials
2. Enable GPU runtime: `Runtime > Change runtime type > T4 GPU`
3. Run all cells

## 1. Configuration

Update these settings to match your VPS:

In [None]:
# ===== CONFIGURATION =====
# Update this to your VPS Redis URL
# Format: redis://:PASSWORD@YOUR_VPS_IP:6379
REDIS_URL = "redis://:your_redis_password@YOUR_VPS_IP:6379"

# Queue name (should match your backend)
QUEUE_NAME = "training"

# Worker settings
POLL_INTERVAL = 5  # seconds between polls when idle
WORKER_NAME = "colab-gpu-worker"

## 2. Install Dependencies

In [None]:
%pip install redis rq torch numpy pymongo huggingface_hub pyreflect-ml -q
print("‚úÖ Dependencies installed!")

## 3. GPU Check

In [None]:
import torch

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    print(f"‚úÖ GPU available: {gpu_name}")
    DEVICE = torch.device("cuda")
else:
    print("‚ö†Ô∏è No GPU detected! Go to Runtime > Change runtime type > T4 GPU")
    DEVICE = torch.device("cpu")

## 4. Connect to Redis

In [None]:
from redis import Redis
from rq import Queue, Worker

try:
    redis_conn = Redis.from_url(REDIS_URL)
    redis_conn.ping()
    print(f"‚úÖ Connected to Redis!")
    
    queue = Queue(QUEUE_NAME, connection=redis_conn)
    print(f"üìã Queue '{QUEUE_NAME}' has {len(queue)} jobs waiting")
except Exception as e:
    print(f"‚ùå Failed to connect: {e}")
    print("\nTroubleshooting:")
    print("1. Check your REDIS_URL is correct")
    print("2. Ensure Redis is running on your VPS")
    print("3. Verify firewall allows port 6379")

## 5. Training Job Code

This is a copy of your backend's `run_training_job` function, adapted for Colab:

In [None]:
import time
import uuid
from datetime import datetime, timezone
from typing import Any
import numpy as np

# Import pyreflect components
from pyreflect import ReflectivityDataGenerator, DataProcessor, CNN
try:
    from pyreflect import compute_nr_from_sld
    COMPUTE_NR_AVAILABLE = True
except ImportError:
    COMPUTE_NR_AVAILABLE = False

# Training constants (matching your backend defaults)
LEARNING_RATE = 0.001
WEIGHT_DECAY = 1e-5
SPLIT_RATIO = 0.8


def _compute_norm_stats(curves: np.ndarray) -> dict:
    """Compute normalization statistics for curves."""
    x_points = curves[:, 0, :]
    y_points = curves[:, 1, :]
    return {
        "x": {"min": float(np.min(x_points)), "max": float(np.max(x_points))},
        "y": {"min": float(np.min(y_points)), "max": float(np.max(y_points))},
    }


def run_training_job(
    job_params: dict[str, Any],
    *,
    user_id: str | None = None,
    name: str | None = None,
    hf_config: dict | None = None,
    mongo_uri: str | None = None,
) -> dict[str, Any]:
    """
    Run a training job on Colab GPU.
    
    This is adapted from your backend's run_training_job function.
    """
    from rq import get_current_job
    
    job = get_current_job()
    logs = []
    
    def log(message: str):
        """Log message to console and job meta."""
        print(message)
        logs.append(message)
        if job:
            job.meta["logs"] = logs
            job.meta["updated_at"] = datetime.now(timezone.utc).isoformat()
            job.save_meta()
    
    def update_progress(epoch: int, total: int, train_loss: float, val_loss: float):
        if job:
            job.meta["progress"] = {
                "epoch": epoch,
                "total": total,
                "trainLoss": train_loss,
                "valLoss": val_loss,
            }
            job.save_meta()
    
    # Initialize job meta
    if job:
        job.meta["status"] = "initializing"
        job.meta["logs"] = logs
        if user_id:
            job.meta["user_id"] = user_id
        if name:
            job.meta["name"] = name
        job.meta["started_at"] = datetime.now(timezone.utc).isoformat()
        job.save_meta()
    
    # Extract parameters
    gen_params = job_params.get("generator", {})
    train_params = job_params.get("training", {})
    
    num_curves = gen_params.get("numCurves", 1000)
    num_film_layers = gen_params.get("numFilmLayers", 3)
    epochs = train_params.get("epochs", 50)
    batch_size = train_params.get("batchSize", 32)
    layers = train_params.get("layers", [512, 256, 128])
    dropout = train_params.get("dropout", 0.2)
    
    total_start = time.perf_counter()
    
    # =====================
    # Data Generation
    # =====================
    log(f"üîÑ Generating {num_curves} synthetic curves with {num_film_layers} film layers...")
    if job:
        job.meta["status"] = "generating"
        job.save_meta()
    
    gen_start = time.perf_counter()
    data_generator = ReflectivityDataGenerator(num_layers=num_film_layers)
    nr_curves, sld_curves = data_generator.generate(num_curves)
    gen_time = time.perf_counter() - gen_start
    
    log(f"   Generated NR shape: {nr_curves.shape}, SLD shape: {sld_curves.shape}")
    log(f"   Generation took {gen_time:.2f}s")
    
    # =====================
    # Preprocessing
    # =====================
    log("üìä Preprocessing data...")
    if job:
        job.meta["status"] = "preprocessing"
        job.save_meta()
    
    nr_log = np.array(nr_curves, copy=True)
    nr_log[:, 1, :] = np.log10(np.clip(nr_log[:, 1, :], 1e-8, None))
    nr_stats = _compute_norm_stats(nr_log)
    normalized_nr = DataProcessor.normalize_xy_curves(nr_curves, apply_log=True, min_max_stats=nr_stats)
    
    sld_stats = _compute_norm_stats(sld_curves)
    normalized_sld = DataProcessor.normalize_xy_curves(sld_curves, apply_log=False, min_max_stats=sld_stats)
    
    reshaped_nr = normalized_nr[:, 1:2, :]
    
    # =====================
    # Training
    # =====================
    log(f"üèãÔ∏è Training CNN model ({epochs} epochs, batch size {batch_size})...")
    if job:
        job.meta["status"] = "training"
        job.save_meta()
    
    model = CNN(layers=layers, dropout_prob=dropout).to(DEVICE)
    model.train()
    
    list_arrays = DataProcessor.split_arrays(reshaped_nr, normalized_sld, size_split=SPLIT_RATIO)
    tensor_arrays = DataProcessor.convert_tensors(list_arrays)
    _, _, _, train_loader, valid_loader, _ = DataProcessor.get_dataloaders(*tensor_arrays, batch_size=batch_size)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    loss_fn = torch.nn.MSELoss()
    
    epoch_list = []
    train_losses = []
    val_losses = []
    
    training_start = time.perf_counter()
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = loss_fn(outputs, y_batch)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        train_loss = running_loss / len(train_loader)
        
        model.eval()
        val_running_loss = 0.0
        with torch.no_grad():
            for X_batch, y_batch in valid_loader:
                X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
                outputs = model(X_batch)
                val_running_loss += loss_fn(outputs, y_batch).item()
        val_loss = val_running_loss / len(valid_loader)
        
        epoch_list.append(epoch + 1)
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        
        update_progress(epoch + 1, epochs, train_loss, val_loss)
        if (epoch + 1) % 5 == 0 or epoch == 0:
            log(f"   Epoch {epoch + 1}/{epochs} - Train: {train_loss:.6f}, Val: {val_loss:.6f}")
    
    training_time = time.perf_counter() - training_start
    log(f"   Training completed in {training_time:.2f}s")
    
    # =====================
    # Inference
    # =====================
    log("üîç Running inference on test sample...")
    if job:
        job.meta["status"] = "inference"
        job.save_meta()
    
    split_idx = int(len(nr_curves) * SPLIT_RATIO)
    test_idx = split_idx
    
    gt_nr = nr_curves[test_idx]
    gt_sld = sld_curves[test_idx]
    
    inference_start = time.perf_counter()
    model.eval()
    with torch.no_grad():
        test_nr_normalized = normalized_nr[test_idx : test_idx + 1, 1:2, :]
        test_input = torch.tensor(test_nr_normalized, dtype=torch.float32).to(DEVICE)
        pred_sld_normalized = model(test_input).cpu().numpy()
    
    pred_sld_denorm = DataProcessor.denormalize_xy_curves(pred_sld_normalized, stats=sld_stats, apply_exp=False)
    pred_sld_y = pred_sld_denorm[0, 1, :]
    pred_sld_z = pred_sld_denorm[0, 0, :]
    
    sld_z = np.linspace(0, 450, len(gt_sld[1]))
    
    # Compute NR from predicted SLD
    computed_nr = gt_nr[1].tolist()
    if COMPUTE_NR_AVAILABLE:
        log("   Computing NR from predicted SLD...")
        try:
            pred_sld_profile = (pred_sld_z, pred_sld_y)
            _, computed_r = compute_nr_from_sld(pred_sld_profile, Q=gt_nr[0], order="substrate_to_air")
            computed_nr = computed_r.tolist()
        except Exception as exc:
            log(f"   Warning: Could not compute NR: {exc}")
    
    # Calculate metrics
    sample_indices = np.linspace(0, len(pred_sld_y) - 1, 50, dtype=int)
    chi = [
        {"x": int(i), "predicted": float(pred_sld_y[idx]), "actual": float(gt_sld[1][idx])}
        for i, idx in enumerate(sample_indices)
    ]
    
    final_mse = val_losses[-1] if val_losses else 0.0
    r2 = 1 - (final_mse / np.var(normalized_sld[:, 1, :]))
    mae = float(np.mean(np.abs(pred_sld_y - gt_sld[1])))
    inference_time = time.perf_counter() - inference_start
    total_time = time.perf_counter() - total_start
    
    model_id = str(uuid.uuid4())
    
    log(f"‚úÖ Complete! Total time: {total_time:.2f}s")
    log(f"   MSE: {final_mse:.6f}, R¬≤: {r2:.4f}, MAE: {mae:.4f}")
    
    # =====================
    # Build Result
    # =====================
    result = {
        "nr": {"q": gt_nr[0].tolist(), "groundTruth": gt_nr[1].tolist(), "computed": computed_nr},
        "sld": {"z": sld_z.tolist(), "groundTruth": gt_sld[1].tolist(), "predicted": pred_sld_y.tolist()},
        "training": {"epochs": epoch_list, "trainingLoss": train_losses, "validationLoss": val_losses},
        "chi": chi,
        "metrics": {"mse": float(final_mse), "r2": float(np.clip(r2, 0, 1)), "mae": mae},
        "name": name,
        "model_id": model_id,
        "timing": {
            "generation": gen_time,
            "training": training_time,
            "inference": inference_time,
            "total": total_time,
        },
    }
    
    # Save to MongoDB if configured
    runtime_user_id = None
    runtime_name = name
    if job:
        runtime_user_id = (job.meta or {}).get("user_id") or user_id
        runtime_name = (job.meta or {}).get("name") or name
    else:
        runtime_user_id = user_id
    
    if mongo_uri and runtime_user_id:
        if job:
            job.meta["status"] = "saving_to_history"
            job.save_meta()
        log("üíæ Saving to database...")
        try:
            from pymongo import MongoClient
            client = MongoClient(mongo_uri)
            db = client.get_default_database()
            doc = {
                "user_id": runtime_user_id,
                "name": runtime_name,
                "created_at": datetime.now(timezone.utc),
                "params": job_params,
                "result": result,
            }
            db.generations.insert_one(doc)
            log("   ‚úÖ Saved to database!")
        except Exception as exc:
            log(f"   ‚ö†Ô∏è Could not save to database: {exc}")
    
    # Finalize job meta
    if job:
        job.meta["status"] = "completed"
        job.meta["completed_at"] = datetime.now(timezone.utc).isoformat()
        job.meta["logs"] = logs
        job.save_meta()
    
    return result


print("‚úÖ Training job function loaded!")

## 6. Start Worker

Run this cell to start processing jobs. It will:
- Listen to the Redis queue
- Pick up training jobs
- Process them on GPU
- Send results back through Redis

**Keep this running!** Your frontend will see jobs progress.

In [None]:
from rq import Worker, Queue

print(f"üöÄ Starting worker '{WORKER_NAME}'...")
print(f"üìã Listening to queue: {QUEUE_NAME}")
print(f"‚è±Ô∏è  Poll interval: {POLL_INTERVAL}s")
print("="*50)
print("Worker is now running! Submit jobs from your UI.")
print("Press the ‚¨õ stop button to terminate.")
print("="*50)

# Create and start worker
worker = Worker(
    [queue],
    connection=redis_conn,
    name=WORKER_NAME,
)

# This blocks and processes jobs forever
worker.work(with_scheduler=False)

---

## üîß Troubleshooting

### Connection Issues
```python
# Test Redis connection manually
redis_conn.ping()  # Should return True
```

### Queue Status
```python
# Check queue status
print(f"Queued: {len(queue)}")
print(f"Job IDs: {queue.job_ids}")
```

### Check Workers
```python
# See all connected workers
from rq import Worker
workers = Worker.all(connection=redis_conn)
for w in workers:
    print(f"{w.name}: {w.state}")
```