## MNIST Distributed Training

### Overview
Distributed training of a CNN model on MNIST handwritten digit dataset using PyTorch's DistributedDataParallel (DDP).
### Training Setup
- **Backend**: NCCL (GPU) / Gloo (CPU)
- **Data**: MNIST test set (train=False)
- **Distributed**: Multi-Node Multi-CPU/GPU with DistributedSampler
- **Device Flexibility**: Auto-fallback GPU→CPU
### Key Parameters
- `epochs`: Training iterations
- `batch_size`: Samples per device
- `lr`: Learning rate
- `save_every`: Checkpoint frequency
- `backend`: Communication backend

In [None]:
%pip install kubeflow
%pip install -U kubeflow-training

In [None]:
%pip show kubeflow-training

### Initialise Training Client

In [None]:
from kubernetes import client
from kubeflow.training import TrainingClient

api_server = ""
token = ""

# Configure the API client with the server and token
configuration = client.Configuration()
configuration.host = api_server
configuration.api_key = {"authorization": f"Bearer {token}"}
configuration.verify_ssl = False  # Disable SSL verification

# Initialize API client and TrainingClient with the configuration
api_client = client.ApiClient(configuration)
client = TrainingClient(client_configuration=api_client.configuration)

print("successfully authenticated!")

### Submit PytorchJob using Kubeflow-Training SDK to be managed by Kubeflow Trainer V1

In [None]:
from kfto_mnist import main
from kubeflow.training.models import V1Volume, V1VolumeMount, V1PersistentVolumeClaimVolumeSource

# Start PyTorchJob with 2 Workers and 2 GPU per Worker (e.g. multi-node, multi-worker job).
client.create_job(
    name="pytorch-ddp",
    train_func=main,
    base_image="quay.io/modh/training:py311-cuda121-torch241",
    num_workers=2,
    resources_per_worker={"gpu": "1"},
    packages_to_install=["torchvision==0.19.0","minio"],
    parameters={
       "epochs": 5, 
       "save_every": 2, 
       "batch_size": 2, 
       "backend": "gloo",
       "lr" : 0.001, 
       "dataset_path": "/shared/data", 
       'snapshot_path': "/shared/checkpoints/snapshot_mnist.pt"
    },
    env_vars={
        "NCCL_DEBUG": "INFO", 
        "TORCH_DISTRIBUTED_DEBUG": "DETAIL",
    },
    volumes=[
        V1Volume(
            name="shared",
            persistent_volume_claim=V1PersistentVolumeClaimVolumeSource(claim_name="shared")
        ),
    ],
    volume_mounts=[
        V1VolumeMount(name="shared", mount_path="/shared"),
    ],
)

### Get PytorchJob Logs

In [None]:
logs= client.get_job_logs(name="pytorch-ddp")
print("pytorch-ddp-master-0:\n\n"+logs[0]['pytorch-ddp-master-0'])
if client.is_job_succeeded(name="pytorch-ddp"):
    print("PytorchJob succeeded!")

In [None]:
import time
print("waiting.")
while not client.is_job_succeeded(name="pytorch-ddp"):
    print(".", end="")
    time.sleep(1)
print("\nPytorchJob succeeded!")

### Cleanup resources created

In [None]:
client.delete_job(name="pytorch-ddp", namespace="abdhumal-test")