## MNIST Distributed Training

### Overview
Distributed training of a CNN model on MNIST handwritten digit dataset using PyTorch's DistributedDataParallel (DDP).
### Training Setup
- **Backend**: NCCL (GPU) / Gloo (CPU)
- **Data**: MNIST test set (train=False)
- **Distributed**: Multi-Node Multi-CPU/GPU with DistributedSampler
- **Device Flexibility**: Auto-fallback GPU→CPU
### Key Parameters
- `epochs`: Training iterations
- `batch_size`: Samples per device
- `lr`: Learning rate
- `save_every`: Checkpoint frequency
- `backend`: Communication backend

In [None]:
%pip install kubeflow
%pip install -U kubeflow-training

In [2]:
%pip show kubeflow-training

Name: kubeflow-training
Version: 1.9.3
Summary: Training Operator Python SDK
Home-page: https://github.com/kubeflow/training-operator/tree/master/sdk/python
Author: Kubeflow Authors
Author-email: hejinchi@cn.ibm.com
License: Apache License Version 2.0
Location: /opt/app-root/lib64/python3.12/site-packages
Requires: certifi, kubernetes, retrying, setuptools, six, urllib3
Required-by: 
Note: you may need to restart the kernel to use updated packages.


### Initialise Training Client

In [3]:
from kubernetes import client
from kubeflow.training import TrainingClient

api_server = ""
token = ""

# Configure the API client with the server and token
configuration = client.Configuration()
configuration.host = api_server
configuration.api_key = {"authorization": f"Bearer {token}"}
configuration.verify_ssl = False  # Disable SSL verification

# Initialize API client and TrainingClient with the configuration
api_client = client.ApiClient(configuration)
client = TrainingClient(client_configuration=api_client.configuration)

print("successfully authenticated!")

successfully authenticated!


### Submit PytorchJob using Kubeflow-Training SDK to be managed by Kubeflow Trainer V1

In [4]:
from kfto_mnist import main
from kubeflow.training.models import V1Volume, V1VolumeMount, V1PersistentVolumeClaimVolumeSource

# Start PyTorchJob with 2 Workers and 2 GPU per Worker (e.g. multi-node, multi-worker job).
client.create_job(
    name="pytorch-ddp",
    train_func=main,
    base_image="quay.io/modh/training:py311-cuda121-torch241",
    num_workers=2,
    resources_per_worker={"gpu": "1"},
    packages_to_install=["torchvision==0.19.0"],
    parameters={
       "epochs": 5, 
       "save_every": 2, 
       "batch_size": 2, 
       "backend": "gloo",
       "lr" : 0.001, 
       "dataset_path": "/shared/data", 
       'snapshot_path': "/shared/checkpoints/snapshot_mnist.pt"
    },
    env_vars={
        "NCCL_DEBUG": "INFO", 
        "TORCH_DISTRIBUTED_DEBUG": "DETAIL",
    },
    volumes=[
        V1Volume(
            name="shared",
            persistent_volume_claim=V1PersistentVolumeClaimVolumeSource(claim_name="shared")
        ),
    ],
    volume_mounts=[
        V1VolumeMount(name="shared", mount_path="/shared"),
    ],
)
print("Training Client Initialised !")

Training Client Initialised !


### Get PytorchJob Logs

In [7]:
logs= client.get_job_logs(name="pytorch-ddp")
print("pytorch-ddp-master-0:\n\n"+logs[0]['pytorch-ddp-master-0'])
if client.is_job_succeeded(name="pytorch-ddp"):
    print("PytorchJob succeeded!")

pytorch-ddp-master-0:

2025-10-05T14:31:33Z INFO     No GPU available, falling back to CPU.
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to /shared/data/MNIST/raw/train-images-idx3-ubyte.gz
100%|██████████| 9912422/9912422 [00:00<00:00, 128286441.04it/s]
Extracting /shared/data/MNIST/raw/train-images-idx3-ubyte.gz to /shared/data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to /shared/data/MNIST/raw/train-labels-idx1-ubyte.gz
100%|██████████| 28881/28881 [00:00<00:00, 25566841.25it/s]
Extracting /shared/data/MNIST/raw/train-labels-idx1-ubyte.gz to /shared/data/MNIST/raw

Downloading htt

In [6]:
import time
print("waiting.")
while not client.is_job_succeeded(name="pytorch-ddp"):
    print(".", end="")
    time.sleep(1)
print("\nPytorchJob succeeded!")

waiting.
................................................................................................................................................................................................................................................................................................................................................................................................
PytorchJob succeeded!


### Cleanup resources created

In [8]:
client.delete_job(name="pytorch-ddp", namespace="abdhumal-test")
print("PytorchJob deleted gracefully!")

PytorchJob deleted gracefully!
