# PyTorch DDP Fashion MNIST Training Example run with Local Process Backend

This example demonstrates training on your local machine using **native Python processes** (no containers required).

## Prerequisites

- Python 3.8+ installed
- No Docker or Podman required! 

The notebook demonstrates how to train a convolutional neural network (CNN) to classify images using the [Fashion MNIST](https://github.com/zalandoresearch/fashion-mnist) dataset and [PyTorch Distributed Data Parallel (DDP)](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html). 

## Install Kubeflow Trainer

In [None]:
%pip install kubeflow-trainer

## Define Training Function

This function trains a simple CNN on Fashion MNIST dataset using PyTorch.

**Note:** LocalProcessBackend runs in a **single process** (no distributed training), so we don't use `torch.distributed`.

In [None]:
def train_fashion_mnist():
    """Train a CNN on Fashion MNIST using PyTorch (single process)."""
    import torch
    import torch.nn.functional as F
    from torch import nn, optim
    from torch.utils.data import DataLoader
    from torchvision import datasets, transforms

    print("Starting training...")

    # Simple CNN model
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 20, 5, 1)
            self.conv2 = nn.Conv2d(20, 50, 5, 1)
            self.fc1 = nn.Linear(4 * 4 * 50, 500)
            self.fc2 = nn.Linear(500, 10)

        def forward(self, x):
            x = F.relu(self.conv1(x))
            x = F.max_pool2d(x, 2, 2)
            x = F.relu(self.conv2(x))
            x = F.max_pool2d(x, 2, 2)
            x = x.view(-1, 4 * 4 * 50)
            x = F.relu(self.fc1(x))
            x = self.fc2(x)
            return F.log_softmax(x, dim=1)

    # Create model
    model = Net()
    
    # Load Fashion MNIST dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    
    dataset = datasets.FashionMNIST(
        './data',
        train=True,
        download=True,
        transform=transform
    )
    
    train_loader = DataLoader(dataset, batch_size=64, shuffle=True)

    # Train for 2 epochs
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
    
    for epoch in range(1, 3):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, target)
            loss.backward()
            optimizer.step()
            
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}')

    # Save model
    torch.save(model.state_dict(), "fashion_mnist_cnn.pt")
    print("Model saved!")
    
    print("Training complete!")


## Initialize TrainerClient

Create a client with LocalProcessBackend configuration:

In [22]:
from kubeflow.trainer import CustomTrainer, TrainerClient, LocalProcessBackendConfig

# Create backend config
backend_config = LocalProcessBackendConfig(
    cleanup_venv=True  # Automatically clean up virtual environments after jobs complete
)

# Initialize client
client = TrainerClient(backend_config=backend_config)


## List the Training Runtimes

You can get the list of available Training Runtimes to start your TrainJob.

In [23]:
for runtime in client.list_runtimes():
    print(runtime)
    if runtime.name == "torch-distributed":
        torch_runtime = runtime

Runtime(name='torch-distributed', trainer=RuntimeTrainer(trainer_type=<TrainerType.CUSTOM_TRAINER: 'CustomTrainer'>, framework='torch', num_nodes=1, device='Unknown', device_count='Unknown'), pretrained_model=None)


## Start Training Job

Launch a training job as a local subprocess:

This will:
- Create a temporary virtual environment
- Install required packages (torch, torchvision)
- Execute your training function in a subprocess
- Clean up the venv automatically when done

In [24]:
job_name = client.train(
    trainer=CustomTrainer(
        func=train_fashion_mnist,
        packages_to_install=["torch", "torchvision"],  # Required packages
    ),
    runtime=torch_runtime,
)

## Monitor Job Status

Check the status of your training job:

In [None]:
# Get job status
job = client.get_job(job_name)

print(f"\n Job Status:")
print(f"   Name: {job.name}")
print(f"   Status: {job.status}")
print(f"   Created: {job.creation_timestamp}")
print(f"   Steps:")
for step in job.steps:
    print(f"     • {step.name}: {step.status}")


📊 Job Status:
   Name: zf5194b02611
   Status: Running
   Created: 2025-10-21 09:57:03.917775
   Steps:
     • train: Running


## Stream Training Logs

Watch the training progress in real-time:

In [None]:
print("Streaming logs (Ctrl+C to stop):\n")
print("="*80)

try:
    for log_line in client.get_job_logs(job_name, follow=True):
        print(log_line, end='')
except KeyboardInterrupt:
    print("\n\n  Log streaming stopped by user")

Streaming logs (Ctrl+C to stop):

Operating inside /var/folders/r3/kwn1z7n15nq3rh54ykdsy73r0000gn/T/zf5194b02611womr11v5
Looking in links: /tmp/tmpdcvyyznz
Processing /tmp/tmpdcvyyznz/pip-24.2-py3-none-any.whl
Installing collected packages: pip
Successfully installed pip-24.2
Collecting torch
  Using cached torch-2.9.0-cp313-none-macosx_11_0_arm64.whl.metadata (30 kB)
Collecting torchvision
  Using cached torchvision-0.24.0-cp313-cp313-macosx_12_0_arm64.whl.metadata (5.9 kB)
Collecting filelock (from torch)
  Using cached filelock-3.20.0-py3-none-any.whl.metadata (2.1 kB)
Collecting typing-extensions>=4.10.0 (from torch)
  Using cached typing_extensions-4.15.0-py3-none-any.whl.metadata (3.3 kB)
Collecting setuptools (from torch)
  Using cached setuptools-80.9.0-py3-none-any.whl.metadata (6.6 kB)
Collecting sympy>=1.13.3 (from torch)
  Using cached sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting networkx>=2.5.1 (from torch)
  Using cached networkx-3.5-py3-none-any.whl.metadata

## Wait for Completion

Wait for the training job to complete:

In [None]:

try:
    completed_job = client.wait_for_job_status(
        name=job_name,
        status={"Complete"},
        timeout=600,  # 10 minutes
        polling_interval=5  # Check every 5 seconds
    )
    
    print(f"\n Training job completed successfully!")
except TimeoutError:
    print(f"\n Job did not complete within timeout")
except RuntimeError as e:
    print(f"\n Job failed: {e}")


✅ Training job completed successfully!


## Clean Up

Delete the training job to free up resources:

In [None]:
# Delete the job (kills subprocess and cleans up venv)
client.delete_job(job_name)