# 02 - Distributed Training with Kubeflow & Feast

![Workflow](../docs/02-training-workflow.png)

## What This Notebook Does

| Step | Component | Action |
|------|-----------|--------|
| 1 | Kubeflow SDK | Submit TrainJob to cluster |
| 2 | Feast + Ray | Distributed feature retrieval (65K rows) |
| 3 | PyTorch DDP | Multi-GPU distributed training |
| 4 | MLflow | Track experiments, log model |

## Architecture

```
┌─────────────┐     ┌─────────────────┐     ┌─────────────┐
│  Kubeflow   │────▶│   TrainJob      │────▶│   MLflow    │
│    SDK      │     │  (2 GPU nodes)  │     │  Tracking   │
└─────────────┘     └────────┬────────┘     └─────────────┘
                             │
                    ┌────────▼────────┐
                    │   Feast + Ray   │
                    │  (distributed   │
                    │   PIT join)     │
                    └─────────────────┘
```

**Prerequisites:** `01-feast-features.ipynb` completed, MLflow & RayCluster running.

In [None]:
# Run this cell first, then restart kernel if needed
%pip install -q kubeflow kubernetes "mlflow>=3.0"

In [None]:
%pip show kubeflow

## Configuration

| Parameter | Value | Purpose |
|-----------|-------|----------|
| `NAMESPACE` | `feast-trainer-demo` | K8s namespace |
| `PVC` | `shared` | Persistent storage for model artifacts |
| `RUNTIME` | `torch-distributed` | ClusterTrainingRuntime for DDP |
| `USE_RAY` | `True` | Enable Ray for distributed feature retrieval |
| `EPOCHS` | 50 | Training iterations |

In [None]:
NAMESPACE = "feast-trainer-demo"
PVC = "shared"
RUNTIME = "torch-distributed"
MLFLOW_URI = f"http://mlflow.{NAMESPACE}.svc.cluster.local:5000"
EPOCHS = 50
USE_RAY = True
RAY_CLUSTER = "feast-ray"

## Kubernetes Authentication

Connect to the cluster using service account token:

In [None]:
import os
K8S_TOKEN = os.getenv("K8S_TOKEN", "sha256~JcXJKSaqqrzvpE0OIAeOLhDkKofm9kQqRLOjkgHJMsY")
K8S_API = os.getenv("K8S_API_SERVER", "https://api.oai-kft-ibm.ibm.rh-ods.com:6443")

## Initialize Kubeflow TrainerClient

The `TrainerClient` manages TrainJob lifecycle:
- `train()` → Submit job
- `get_job()` → Check status
- `wait_for_job_status()` → Block until complete
- `delete_job()` → Cleanup

In [None]:
from kubernetes import client as k8s
from kubeflow.trainer import TrainerClient, CustomTrainer
from kubeflow.common.types import KubernetesBackendConfig
from kubeflow.trainer.options import Labels, PodTemplateOverrides, PodTemplateOverride, PodSpecOverride, ContainerOverride

cfg = k8s.Configuration()
if K8S_TOKEN:
    cfg.host = K8S_API
    cfg.verify_ssl = False
    cfg.api_key = {"authorization": f"Bearer {K8S_TOKEN}"}

trainer = TrainerClient(KubernetesBackendConfig(namespace=NAMESPACE, client_configuration=cfg if K8S_TOKEN else None))
runtime = trainer.get_runtime(RUNTIME)
print(f"Runtime: {RUNTIME}")

## Training Function

This function runs **inside the TrainJob pods**. Key components:

### 1. Feast Feature Retrieval (Rank 0 only)
```python
entity_df = DataFrame([{store_id, dept_id, event_timestamp} × 65K])
df = store.get_historical_features(entity_df, 'training_features').to_df()
```
Uses Ray cluster for distributed point-in-time joins.

### 2. PyTorch DDP Training (All ranks)
```python
model = DDP(MLP(...).to(device))  # Wrap model for distributed
sampler = DistributedSampler(...)  # Shard data across GPUs
```

### 3. MLflow Logging (Rank 0 only)
```python
mlflow.log_metrics({'mape': mape, 'loss': loss})
mlflow.pytorch.log_model(model, 'model')
```

In [None]:
def train_fn(epochs=50, use_ray=True, ray_cluster='feast-ray', namespace='feast-trainer-demo', output_dir='/shared/models', feature_repo='/shared/feature_repo'):
    """Training function - SDK passes func_args as **kwargs"""
    import os, json, numpy as np, pandas as pd, torch, torch.nn as nn, torch.distributed as dist, joblib, mlflow, shutil, time
    from torch.nn.parallel import DistributedDataParallel as DDP
    from torch.utils.data import DataLoader, Dataset, DistributedSampler
    from sklearn.preprocessing import StandardScaler
    from datetime import datetime, timezone, timedelta
    
    dist.init_process_group(backend='nccl' if torch.cuda.is_available() else 'gloo')
    rank, world = dist.get_rank(), dist.get_world_size()
    device = torch.device(f"cuda:{int(os.environ.get('LOCAL_RANK', 0))}" if torch.cuda.is_available() else "cpu")
    print(f"[Rank {rank}] Device: {device}")
    
    OUT = output_dir
    REPO = feature_repo
    EPOCHS = epochs
    
    class MLP(nn.Module):
        def __init__(self, inp, hidden=[512, 256, 128, 64], drop=0.3):
            super().__init__()
            layers = []
            for h in hidden:
                layers.extend([nn.Linear(inp, h), nn.BatchNorm1d(h), nn.ReLU(), nn.Dropout(drop)])
                inp = h
            layers.append(nn.Linear(inp, 1))
            self.net = nn.Sequential(*layers)
        def forward(self, x):
            return self.net(x).squeeze(-1)
    
    class DS(Dataset):
        def __init__(self, X, y):
            self.X = torch.tensor(X, dtype=torch.float32)
            self.y = torch.tensor(y, dtype=torch.float32)
        def __len__(self): return len(self.X)
        def __getitem__(self, i): return self.X[i], self.y[i]
    
    if rank == 0:
        os.makedirs(OUT, exist_ok=True)
        mlflow.set_tracking_uri(os.getenv('MLFLOW_TRACKING_URI', 'http://mlflow:5000'))
        mlflow.set_experiment('sales-forecasting')
        mlflow.start_run(run_name=f"train-{datetime.now().strftime('%Y%m%d-%H%M%S')}")
        
        # Use Ray config for Feast
        if use_ray:
            os.environ['FEAST_RAY_USE_KUBERAY'] = 'true'
            os.environ['FEAST_RAY_CLUSTER_NAME'] = ray_cluster
            os.environ['FEAST_RAY_NAMESPACE'] = namespace
            os.environ['FEAST_RAY_SKIP_TLS'] = 'true'
            token_path = '/var/run/secrets/kubernetes.io/serviceaccount/token'
            if os.path.exists(token_path):
                with open(token_path) as f:
                    os.environ['FEAST_RAY_AUTH_TOKEN'] = f.read()
                os.environ['FEAST_RAY_AUTH_SERVER'] = f"https://{os.environ.get('KUBERNETES_SERVICE_HOST')}:{os.environ.get('KUBERNETES_SERVICE_PORT')}"
            ray_cfg = f"{REPO}/feature_store_ray.yaml"
            if os.path.exists(ray_cfg):
                shutil.copy(ray_cfg, f"{REPO}/feature_store.yaml")
        
        from feast import FeatureStore
        store = FeatureStore(repo_path=REPO)
        
        # Entity DF
        entity_rows = [{'store_id': s, 'dept_id': d, 'event_timestamp': datetime(2022,1,1,tzinfo=timezone.utc) + timedelta(weeks=w)}
            for w in range(104) for s in range(1, 46) for d in range(1, 15)]
        entity_df = pd.DataFrame(entity_rows)
        print(f"Entity DF: {len(entity_df):,} rows")
        
        t0 = time.time()
        df = store.get_historical_features(entity_df=entity_df, features=store.get_feature_service('training_features')).to_df()
        print(f"✅ Feast: {len(df):,} rows in {time.time()-t0:.1f}s")
        
        df = df.dropna(subset=['weekly_sales']).sort_values('event_timestamp')
        split = int(len(df) * 0.8)
        train_df, val_df = df.iloc[:split], df.iloc[split:]
        
        exclude = ['store_id', 'dept_id', 'event_timestamp', 'date', 'weekly_sales']
        feat_cols = [c for c in df.columns if c not in exclude and df[c].dtype in [np.float64, np.int64, np.float32, np.int32]]
        print(f"Features ({len(feat_cols)}): {feat_cols[:5]}...")
        
        X_train, y_train = train_df[feat_cols].fillna(0).values, train_df['weekly_sales'].values
        X_val, y_val = val_df[feat_cols].fillna(0).values, val_df['weekly_sales'].values
        
        scaler, y_scaler = StandardScaler(), StandardScaler()
        X_train = scaler.fit_transform(X_train)
        X_val = scaler.transform(X_val)
        y_train_s = y_scaler.fit_transform(np.log1p(y_train).reshape(-1,1)).flatten()
        y_val_s = y_scaler.transform(np.log1p(y_val).reshape(-1,1)).flatten()
        
        joblib.dump({'scaler_X': scaler, 'scaler_y': y_scaler, 'use_log_transform': True}, f'{OUT}/scalers.joblib')
        joblib.dump(feat_cols, f'{OUT}/feature_cols.pkl')
        np.savez(f'{OUT}/.data.npz', X_train=X_train, y_train=y_train_s, X_val=X_val, y_val=y_val_s, y_val_orig=y_val)
        np.save(f'{OUT}/.dim.npy', [X_train.shape[1]])
        mlflow.log_params({'epochs': EPOCHS, 'train_rows': len(train_df), 'val_rows': len(val_df), 'features': len(feat_cols)})
    
    dist.barrier()
    data = np.load(f'{OUT}/.data.npz')
    X_train, y_train, X_val, y_val, y_val_orig = data['X_train'], data['y_train'], data['X_val'], data['y_val'], data['y_val_orig']
    inp_dim = int(np.load(f'{OUT}/.dim.npy')[0])
    dist.barrier()
    
    train_ds = DS(X_train, y_train)
    val_ds = DS(X_val, y_val)
    sampler = DistributedSampler(train_ds, num_replicas=world, rank=rank)
    train_loader = DataLoader(train_ds, batch_size=64, sampler=sampler)
    val_loader = DataLoader(val_ds, batch_size=64)
    
    model = DDP(MLP(inp_dim).to(device), device_ids=[int(os.environ.get('LOCAL_RANK',0))] if torch.cuda.is_available() else None)
    opt = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-3)
    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, factor=0.5, patience=3)
    crit = nn.MSELoss()
    
    best_loss, best_mape = float('inf'), float('inf')
    for ep in range(EPOCHS):
        sampler.set_epoch(ep)
        model.train()
        train_loss = 0.0
        for X_b, y_b in train_loader:
            X_b, y_b = X_b.to(device), y_b.to(device)
            opt.zero_grad()
            loss = crit(model(X_b), y_b)
            loss.backward()
            opt.step()
            train_loss += loss.item()
        train_loss /= len(train_loader)
        
        model.eval()
        with torch.no_grad():
            preds = np.concatenate([model(X.to(device)).cpu().numpy() for X, _ in val_loader])
        val_loss = np.mean((preds - y_val)**2)
        sched.step(val_loss)
        
        if rank == 0:
            y_sc = joblib.load(f'{OUT}/scalers.joblib')['scaler_y']
            pred_orig = np.expm1(y_sc.inverse_transform(preds.reshape(-1,1)).flatten())
            mask = y_val_orig > 1000
            mape = np.mean(np.abs((y_val_orig[mask] - pred_orig[mask]) / y_val_orig[mask])) * 100
            mlflow.log_metrics({'train_loss': train_loss, 'val_loss': val_loss, 'mape': mape, 'lr': opt.param_groups[0]['lr']}, step=ep)
            if (ep + 1) % 10 == 0:
                print(f"Ep {ep+1}/{EPOCHS} | Train: {train_loss:.4f} | Val: {val_loss:.4f} | MAPE: {mape:.1f}%")
            if val_loss < best_loss:
                best_loss, best_mape = val_loss, mape
                torch.save(model.module.state_dict(), f'{OUT}/best_model.pt')
        dist.barrier()
    
    if rank == 0:
        print(f"\n✅ DONE: MAPE {best_mape:.1f}%")
        mlflow.log_metrics({'best_mape': best_mape, 'best_val_loss': best_loss})
        m = MLP(inp_dim)
        m.load_state_dict(torch.load(f'{OUT}/best_model.pt'))
        m.eval()
        mlflow.pytorch.log_model(m, 'model')
        feat_cols = joblib.load(f'{OUT}/feature_cols.pkl')
        json.dump({'model_type': 'SalesMLP', 'input_dim': inp_dim, 'hidden_dims': [512,256,128,64], 'dropout': 0.3, 'best_mape': float(best_mape), 'feature_columns': feat_cols}, open(f'{OUT}/model_metadata.json', 'w'))
        mlflow.end_run()
        for f in ['.data.npz', '.dim.npy']:
            try: os.remove(f'{OUT}/{f}')
            except: pass
    dist.destroy_process_group()

## Submit TrainJob

The `CustomTrainer` packages the function for distributed execution:

| Parameter | Value | Purpose |
|-----------|-------|----------|
| `func` | `train_fn` | Python function to run |
| `func_args` | `{epochs, use_ray, ...}` | Arguments passed as kwargs |
| `num_nodes` | 1 | Number of worker pods |
| `resources_per_node` | `{gpu:1, cpu:4}` | Resources per pod |
| `packages_to_install` | `[feast, mlflow, ...]` | Pip packages |

**Environment Variables:**
- `RDZV_TIMEOUT=1800` → DDP rendezvous timeout (30 min for Feast retrieval)
- `FEAST_DATA_ROOT=/shared/data` → Path to parquet files

In [None]:
from datetime import datetime
from kubeflow.trainer.options import Name, Labels, PodTemplateOverrides, PodTemplateOverride, PodSpecOverride, ContainerOverride

job_id = datetime.now().strftime('%m%d-%H%M')
func_args = {'epochs': EPOCHS, 'use_ray': USE_RAY, 'ray_cluster': RAY_CLUSTER, 'namespace': NAMESPACE, 'output_dir': '/shared/models', 'feature_repo': '/shared/feature_repo'}

job = trainer.train(
    trainer=CustomTrainer(
        func=train_fn,
        func_args=func_args,  # Pass args here, not to train()
        num_nodes=1,
        resources_per_node={'gpu': 1, 'cpu': 4, 'memory': '8Gi'},  # Use 'gpu' not 'nvidia.com/gpu'
        packages_to_install=['feast[ray,postgres]==0.59.0', 'codeflare-sdk', 'psycopg2-binary', 'scikit-learn', 'pandas', 'pyarrow', 'joblib', 'mlflow>=3.0'],
        env={
            'MLFLOW_TRACKING_URI': MLFLOW_URI,
            'FEAST_DATA_ROOT': '/shared/data',  # Path to parquet files
            'RDZV_TIMEOUT': '1800',  # 30 min for DDP rendezvous (feature retrieval takes time)
            'NCCL_IB_DISABLE': '1'   # Disable InfiniBand if not available
        }
    ),
    runtime=runtime,
    options=[
        Name("sales-forecast"),
        Labels({'app': 'sales-forecasting'}),
        PodTemplateOverrides(PodTemplateOverride(
            target_jobs=['node'],
            spec=PodSpecOverride(
                volumes=[{'name': 'shared', 'persistentVolumeClaim': {'claimName': PVC}}],
                containers=[ContainerOverride(name='node', volume_mounts=[{'name': 'shared', 'mountPath': '/shared'}])]
            )
        ))
    ]
)
print(f"✅ Submitted: {job}")

In [None]:
# trainer.delete_job(job)

## Monitor Training

Wait for job completion (max 1 hour):

In [None]:
trainer.wait_for_job_status(name=job, status={'Complete', 'Failed'}, timeout=3600)

## View MLflow Results

Check training metrics and artifacts:

| Metric | Meaning |
|--------|----------|
| `best_mape` | Mean Absolute Percentage Error (lower is better) |
| `train_loss` | MSE on training set |
| `val_loss` | MSE on validation set |

In [None]:
import mlflow
mlflow.set_tracking_uri(MLFLOW_URI)
exp = mlflow.get_experiment_by_name('sales-forecasting')
if exp:
    runs = mlflow.search_runs([exp.experiment_id], max_results=5, order_by=['start_time DESC'])
    display(runs[['tags.mlflow.runName', 'metrics.best_mape']])
else:
    print("Experiment not found yet")

---
## ✅ Training Complete!

**Artifacts saved to `/shared/models/`:**
- `best_model.pt` → PyTorch model weights
- `scalers.joblib` → Feature scalers
- `feature_cols.pkl` → Feature column names
- `model_metadata.json` → Architecture info

**Next:** `03-inference.ipynb` → Deploy model with KServe