# Sales Forecasting - Distributed Training


In [None]:
%pip install -q kubeflow-training mlflow yamlmagic
%load_ext yamlmagic


## Configuration


In [None]:
%%yaml parameters

# =============================================================================
# Cluster Configuration
# =============================================================================
namespace: feast-trainer-demo
shared_pvc: feast-pvc
runtime: torch-distributed

# =============================================================================
# Training Hyperparameters
# =============================================================================
epochs: 50
batch_size: 256
learning_rate: 0.001

# =============================================================================
# Model Architecture
# =============================================================================
model:
  hidden_dims: [256, 128, 64]
  dropout: 0.2

# =============================================================================
# Feature Columns
# =============================================================================
features:
  - lag_1
  - lag_2
  - lag_4
  - lag_8
  - lag_52
  - rolling_mean_4w
  - store_size
  - temperature
  - fuel_price
  - cpi
  - unemployment

# =============================================================================
# Distributed Training
# =============================================================================
num_workers: 1
resources_per_worker:
  cpu: 4
  memory: 16Gi
  
# GPU Configuration: "none", "nvidia", or "amd"
gpu_type: nvidia
gpu_count: 1

# =============================================================================
# MLflow Tracking
# =============================================================================
mlflow:
  experiment_name: sales-forecasting

# =============================================================================
# Data Paths (must match PVC mount: /shared)
# =============================================================================
paths:
  data_dir: /shared/data
  model_dir: /shared/models

In [None]:
# Extract key parameters for convenience
NAMESPACE = parameters['namespace']
SHARED_PVC = parameters['shared_pvc']
RUNTIME = parameters['runtime']
MLFLOW_URI = f"http://mlflow.{NAMESPACE}.svc.cluster.local:5000"


## Authentication


In [None]:
import os

K8S_TOKEN = os.getenv("K8S_TOKEN", "<YOUR_TOKEN>")
K8S_API_SERVER = os.getenv("K8S_API_SERVER", "<YOUR_API_SERVER>")


In [None]:
from kubernetes import client as k8s
from kubeflow.training import TrainerClient, CustomTrainer
from kubeflow.training.types import KubernetesBackendConfig
from kubeflow.training.types import (
    PodTemplateOverrides, PodTemplateOverride,
    PodSpecOverride, ContainerOverride,
    Labels, Annotations
)

cfg = k8s.Configuration()
if K8S_TOKEN and K8S_API_SERVER:
    cfg.host = K8S_API_SERVER
    cfg.verify_ssl = False
    cfg.api_key = {"authorization": f"Bearer {K8S_TOKEN}"}

trainer_client = TrainerClient(
    KubernetesBackendConfig(
        namespace=NAMESPACE,
        client_configuration=cfg if K8S_TOKEN else None
    )
)


## Training Runtime


In [None]:
runtime = trainer_client.get_runtime(RUNTIME)

## Training Function


In [None]:
def train_sales_model(parameters):
    """
    Distributed training with MLflow best practices:
    - Tags, comprehensive params, per-epoch metrics
    - MAPE, R², MAE, RMSE tracking
    - Training plots and organized artifacts
    """
    import os, json, time
    import torch, torch.nn as nn, torch.distributed as dist
    import pandas as pd, numpy as np
    from sklearn.preprocessing import StandardScaler
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
    from pathlib import Path
    from datetime import datetime

    start_time = time.time()

    # =========================================================================
    # Device Detection
    # =========================================================================
    if torch.cuda.is_available():
        is_rocm = hasattr(torch.version, 'hip') and torch.version.hip is not None
        device_type = "rocm" if is_rocm else "cuda"
        backend = "nccl"
        local_rank = int(os.getenv("LOCAL_RANK", 0))
        torch.cuda.set_device(local_rank)
        device = torch.device("cuda", local_rank)
    else:
        device, backend, device_type = torch.device("cpu"), "gloo", "cpu"

    # =========================================================================
    # Distributed Setup
    # =========================================================================
    dist.init_process_group(backend=backend)
    rank, world_size = dist.get_rank(), dist.get_world_size()
    
    gpu_name, gpu_mem_gb = "", 0
    if rank == 0:
        if device_type != "cpu":
            gpu_name = torch.cuda.get_device_name(0)
            gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
            print(f"Device: {device_type.upper()} ({gpu_name}, {gpu_mem_gb:.1f}GB) | Workers: {world_size}")
        else:
            print(f"Device: CPU | Workers: {world_size}")

    # =========================================================================
    # Configuration
    # =========================================================================
    mlflow_uri = os.getenv("MLFLOW_TRACKING_URI", "http://mlflow:5000")
    data_dir = parameters.get('paths', {}).get('data_dir', '/shared/data')
    model_dir = parameters.get('paths', {}).get('model_dir', '/shared/models')
    epochs = parameters.get('epochs', 50)
    batch_size = parameters.get('batch_size', 256)
    lr = parameters.get('learning_rate', 0.001)
    feature_cols = parameters.get('features', [])
    hidden_dims = parameters.get('model', {}).get('hidden_dims', [256, 128, 64])
    dropout = parameters.get('model', {}).get('dropout', 0.2)

    # =========================================================================
    # Model Definition
    # =========================================================================
    class SalesMLP(nn.Module):
        def __init__(self, input_dim, hidden_dims, dropout):
            super().__init__()
            layers = []
            prev = input_dim
            for h in hidden_dims:
                layers.extend([nn.Linear(prev, h), nn.BatchNorm1d(h), nn.ReLU(), nn.Dropout(dropout)])
                prev = h
            layers.append(nn.Linear(prev, 1))
            self.net = nn.Sequential(*layers)
        def forward(self, x): return self.net(x).squeeze(-1)

    # =========================================================================
    # Load Data
    # =========================================================================
    df = pd.read_parquet(f"{data_dir}/features.parquet")
    cols = [c for c in feature_cols if c in df.columns]
    df = df.dropna(subset=cols + ["weekly_sales"])
    X, y = df[cols].values, df["weekly_sales"].values
    data_hash = f"{len(df)}-{len(cols)}f"
    
    if rank == 0:
        print(f"Data: {len(df):,} samples, {len(cols)} features")

    # =========================================================================
    # Preprocessing
    # =========================================================================
    scaler_X, scaler_y = StandardScaler(), StandardScaler()
    X_scaled = scaler_X.fit_transform(X)
    y_scaled = scaler_y.fit_transform(y.reshape(-1, 1)).flatten()
    X_train, X_test, y_train, y_test = train_test_split(X_scaled, y_scaled, test_size=0.2, random_state=42)

    # =========================================================================
    # Model + DDP
    # =========================================================================
    model = SalesMLP(len(cols), hidden_dims, dropout).to(device)
    if world_size > 1:
        model = nn.parallel.DistributedDataParallel(model, device_ids=[device.index] if device_type != "cpu" else None)
    
    train_dataset = torch.utils.data.TensorDataset(torch.FloatTensor(X_train), torch.FloatTensor(y_train))
    sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
    X_test_t = torch.FloatTensor(X_test).to(device)

    criterion = nn.MSELoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    best_val_loss, best_mape, best_epoch = float('inf'), float('inf'), 0
    best_state = None
    history = {"train_loss": [], "val_loss": [], "mae": [], "rmse": [], "mape": [], "r2": [], "lr": []}

    # =========================================================================
    # MLflow Setup (rank 0)
    # =========================================================================
    mlflow_active = False
    if rank == 0:
        try:
            import mlflow
            mlflow.set_tracking_uri(mlflow_uri)
            exp_name = parameters.get('mlflow', {}).get('experiment_name', 'sales-forecasting')
            run_name = os.getenv("RUN_NAME", f"train-{datetime.now().strftime('%m%d-%H%M%S')}")
            mlflow.set_experiment(exp_name)
            mlflow.start_run(run_name=run_name, description=f"Sales forecasting with {len(cols)} features")
            
            # Tags
            mlflow.set_tags({"model_type": "SalesMLP", "framework": "pytorch", "task": "regression",
                "device": device_type, "gpu_name": gpu_name or "N/A", "data_hash": data_hash, "environment": "kubeflow"})
            
            # Parameters
            mlflow.log_params({"epochs": epochs, "batch_size": batch_size, "learning_rate": lr,
                "optimizer": "AdamW", "weight_decay": 1e-5, "scheduler": "CosineAnnealingLR",
                "hidden_dims": str(hidden_dims), "dropout": dropout, "num_features": len(cols),
                "train_samples": len(X_train), "test_samples": len(X_test), "world_size": world_size,
                "torch_version": torch.__version__, "gpu_memory_gb": f"{gpu_mem_gb:.1f}" if gpu_mem_gb else "N/A"})
            mlflow_active = True
            print(f"MLflow: tracking at {mlflow_uri}")
        except Exception as e:
            print(f"MLflow init: {e}")

    # =========================================================================
    # Training Loop
    # =========================================================================
    for epoch in range(epochs):
        sampler.set_epoch(epoch)
        model.train()
        train_losses = []
        
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            optimizer.zero_grad()
            loss = criterion(model(xb), yb)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())
        
        current_lr = scheduler.get_last_lr()[0]
        scheduler.step()
        
        # Evaluation
        model.eval()
        with torch.no_grad():
            y_pred_scaled = model(X_test_t).cpu().numpy()
            val_loss = criterion(torch.FloatTensor(y_pred_scaled), torch.FloatTensor(y_test)).item()
            
            y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten()
            y_true = scaler_y.inverse_transform(y_test.reshape(-1, 1)).flatten()
            
            mae = mean_absolute_error(y_true, y_pred)
            rmse = np.sqrt(mean_squared_error(y_true, y_pred))
            r2 = r2_score(y_true, y_pred)
            mask = np.abs(y_true) > 100  # Avoid division by near-zero
            mape = np.mean(np.abs((y_true[mask] - y_pred[mask]) / y_true[mask])) * 100 if mask.sum() > 0 else 0.0
        
        train_loss = np.mean(train_losses)
        for k, v in [("train_loss", train_loss), ("val_loss", val_loss), ("mae", mae), ("rmse", rmse), ("mape", mape), ("r2", r2), ("lr", current_lr)]:
            history[k].append(v)
        
        if val_loss < best_val_loss:
            best_val_loss, best_mape, best_epoch = val_loss, mape, epoch
            if rank == 0:
                base_model = model.module if hasattr(model, 'module') else model
                best_state = base_model.state_dict().copy()
        
        if rank == 0:
            if mlflow_active:
                try: mlflow.log_metrics({"train_loss": train_loss, "val_loss": val_loss, "mae": mae, "rmse": rmse, "mape": mape, "r2": r2, "learning_rate": current_lr}, step=epoch)
                except: pass
            if (epoch + 1) % 10 == 0:
                print(f"Epoch {epoch+1}/{epochs} | Val: {val_loss:.6f} | MAPE: {mape:.2f}% | R²: {r2:.4f}")

    # =========================================================================
    # Save Model & Artifacts (rank 0)
    # =========================================================================
    dist.barrier()
    training_duration = time.time() - start_time
    
    if rank == 0:
        import joblib
        Path(model_dir).mkdir(parents=True, exist_ok=True)
        
        torch.save(best_state, f"{model_dir}/best_model.pt")
        joblib.dump({"scaler_X": scaler_X, "scaler_y": scaler_y, "feature_cols": cols}, f"{model_dir}/scalers.joblib")
        
        metadata = {"model_type": "SalesMLP", "input_dim": len(cols), "hidden_dims": hidden_dims, "dropout": dropout,
            "feature_columns": cols, "best_val_loss": float(best_val_loss), "best_mape": float(best_mape),
            "best_epoch": best_epoch, "total_epochs": epochs, "device_type": device_type,
            "training_duration_sec": training_duration, "data_hash": data_hash, "created_at": datetime.now().isoformat()}
        with open(f"{model_dir}/model_metadata.json", "w") as f: json.dump(metadata, f, indent=2)
        with open(f"{model_dir}/training_history.json", "w") as f: json.dump(history, f)
        
        # Training plots
        try:
            import matplotlib
            matplotlib.use('Agg')
            import matplotlib.pyplot as plt
            fig, axes = plt.subplots(2, 2, figsize=(12, 10))
            axes[0, 0].plot(history["train_loss"], label="Train"); axes[0, 0].plot(history["val_loss"], label="Val")
            axes[0, 0].axvline(best_epoch, color='r', linestyle='--', alpha=0.5); axes[0, 0].set_title("Loss"); axes[0, 0].legend(); axes[0, 0].grid(True, alpha=0.3)
            axes[0, 1].plot(history["mape"], color="green"); axes[0, 1].axhline(best_mape, color='r', linestyle='--', alpha=0.5)
            axes[0, 1].set_title(f"MAPE (Best: {best_mape:.2f}%)"); axes[0, 1].grid(True, alpha=0.3)
            axes[1, 0].plot(history["r2"], color="purple"); axes[1, 0].set_title("R² Score"); axes[1, 0].grid(True, alpha=0.3)
            axes[1, 1].plot(history["lr"], color="orange"); axes[1, 1].set_title("Learning Rate"); axes[1, 1].grid(True, alpha=0.3)
            plt.tight_layout(); plt.savefig(f"{model_dir}/training_curves.png", dpi=150); plt.close()
        except: pass
        
        # MLflow artifacts
        if mlflow_active:
            try:
                mlflow.log_metrics({"best_val_loss": best_val_loss, "best_mape": best_mape, "best_epoch": best_epoch,
                    "final_r2": history["r2"][-1], "training_duration_min": training_duration / 60})
                mlflow.log_artifact(f"{model_dir}/best_model.pt", "model")
                mlflow.log_artifact(f"{model_dir}/scalers.joblib", "preprocessing")
                mlflow.log_artifact(f"{model_dir}/model_metadata.json", "metadata")
                mlflow.log_artifact(f"{model_dir}/training_history.json", "metrics")
                if Path(f"{model_dir}/training_curves.png").exists(): mlflow.log_artifact(f"{model_dir}/training_curves.png", "plots")
                mlflow.end_run()
            except Exception as e:
                print(f"MLflow artifacts: {e}")
                try: mlflow.end_run()
                except: pass
        
        print(f"Done! Best MAPE: {best_mape:.2f}% @ epoch {best_epoch} | Duration: {training_duration/60:.1f}min")
    
    dist.barrier()
    dist.destroy_process_group()

## Submit Training Job


In [None]:
from datetime import datetime

job_id = datetime.now().strftime("%m%d-%H%M")
job_name = f"sales-training-{job_id}"


In [None]:
# Build resource spec based on GPU type
resources = {
    "cpu": parameters['resources_per_worker']['cpu'],
    "memory": parameters['resources_per_worker']['memory'],
}

gpu_type = parameters.get('gpu_type', 'none')
gpu_count = parameters.get('gpu_count', 0)

if gpu_type == 'nvidia' and gpu_count > 0:
    resources["nvidia.com/gpu"] = gpu_count
elif gpu_type == 'amd' and gpu_count > 0:
    resources["amd.com/gpu"] = gpu_count

# Submit job
job = trainer_client.train(
    trainer=CustomTrainer(
        func=train_sales_model,
        num_nodes=parameters['num_workers'],
        resources_per_node=resources,
        packages_to_install=["scikit-learn", "pandas", "pyarrow", "joblib", "mlflow", "matplotlib"],
        env={"MLFLOW_TRACKING_URI": MLFLOW_URI, "RUN_NAME": f"train-{job_id}"},
    ),
    runtime=runtime,
    parameters=parameters,
    options=[
        Labels({"app": "sales-forecasting", "job-type": "training", "run-id": job_id}),
        Annotations({"description": f"Sales forecasting - {job_id}"}),
        PodTemplateOverrides(
            PodTemplateOverride(
                target_jobs=["node"],
                spec=PodSpecOverride(
                    volumes=[{"name": "shared", "persistentVolumeClaim": {"claimName": SHARED_PVC}}],
                    containers=[ContainerOverride(name="node", volume_mounts=[{"name": "shared", "mountPath": "/shared"}])]
                )
            )
        ),
    ],
)

## Monitor Progress


In [None]:
trainer_client.wait_for_job_status(name=job, status={"Running"}, timeout=300)


In [None]:
_ = trainer_client.get_job_logs(name=job, follow=True)


In [None]:
trainer_client.wait_for_job_status(name=job, status={"Complete", "Failed"}, timeout=3600)


## MLflow


In [None]:
import mlflow

mlflow.set_tracking_uri(MLFLOW_URI)
experiment = mlflow.get_experiment_by_name(parameters['mlflow']['experiment_name'])
runs = mlflow.search_runs(
    experiment_ids=[experiment.experiment_id],
    max_results=5,
    order_by=["start_time DESC"]
)
cols = ["tags.mlflow.runName", "metrics.best_mape", "metrics.final_r2", "metrics.best_epoch", "metrics.training_duration_min", "tags.device"]
display_cols = [c for c in cols if c in runs.columns]
runs[display_cols]


## Cleanup


In [None]:
# trainer_client.delete_job(name=job)
