### https://docs.ray.io/en/latest/cluster/vms/user-guides/community/spark.html


Config:
- 2 Workers: 448 GB Memory, 24 Cores
- 1 Driver: 224 GB Memory, 12 Cores
Runtime:
- 15.4.x-gpu-ml-scala2.12
Type:
- Standard_NC12s_v3


Notes: 

- We recommend setting the argument num_cpus_worker_node to the number of CPU cores per Apache Spark worker node. Similarly, setting num_gpus_worker_node to the number of GPUs per Apache Spark worker node is optimal. With this configuration, each Apache Spark worker node launches one Ray worker node that will fully utilize the resources of each Apache Spark worker node.
- Set the environment variable RAY_memory_monitor_refresh_ms to 0 within the Databricks cluster configuration when starting your Apache Spark cluster.


- In each spark worker node, we recommend making the sum of 'spark_executor_memory + num_Ray_worker_nodes_per_spark_worker * (memory_worker_node + object_store_memory_worker_node)' to be less than 'spark_worker_physical_memory * 0.8', otherwise it might lead to spark worker physical memory exhaustion and Ray task OOM errors.

In [0]:
# You configured 'spark.task.resource.gpu.amount' to 1.0,we recommend setting this value to 0 so that Spark jobs do not reserve GPU resources, preventing Ray-on-Spark workloads from having the maximum number of GPUs available. 

spark.conf.set("spark.task.resource.gpu.amount", "0")

In [0]:
import ray
from ray.util.spark import setup_ray_cluster, shutdown_ray_cluster

setup_ray_cluster(
  max_worker_nodes=1,
  num_cpus_per_node=12,
  num_gpus_per_node=2,
  num_cpus_head_node=12,
  num_gpus_head_node=2,
  collect_log_to_path="/dbfs/tmp/ws_ray_collected_logs"
)

# Pass any custom Ray configuration with ray.init
ray.init(ignore_reinit_error=True)

In [0]:
import os
import tempfile

import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose

import ray.train.torch

def train_func():
    # Model, Loss, Optimizer
    model = resnet18(num_classes=10)
    model.conv1 = torch.nn.Conv2d(
        1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    )
    # [1] Prepare model.
    model = ray.train.torch.prepare_model(model)
    # model.to("cuda")  # This is done by `prepare_model`
    criterion = CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=0.001)

    # Data
    transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
    data_dir = os.path.join(tempfile.gettempdir(), "data")
    train_data = FashionMNIST(root=data_dir, train=True, download=True, transform=transform)
    train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
    # [2] Prepare dataloader.
    train_loader = ray.train.torch.prepare_data_loader(train_loader)

    # Training
    for epoch in range(10):
        if ray.train.get_context().get_world_size() > 1:
            train_loader.sampler.set_epoch(epoch)

        for images, labels in train_loader:
            # This is done by `prepare_data_loader`!
            # images, labels = images.to("cuda"), labels.to("cuda")
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # [3] Report metrics and checkpoint.
        metrics = {"loss": loss.item(), "epoch": epoch}
        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
             # Save model state_dict based on whether it's wrapped in DataParallel or not
            if isinstance(model, torch.nn.DataParallel) or isinstance(model, torch.nn.parallel.DistributedDataParallel):
                torch.save(
                    model.module.state_dict(),
                    os.path.join(temp_checkpoint_dir, "model.pt")
                )
            else:
                torch.save(
                    model.state_dict(),
                    os.path.join(temp_checkpoint_dir, "model.pt")
                )

            ray.train.report(
                metrics,
                checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir),
            )
        if ray.train.get_context().get_world_rank() == 0:
            print(metrics)

In [0]:

from ray.train import RunConfig

#  [4] Configure scaling and resource requirements.
# Use GPU to allow cuda 
scaling_config = ray.train.ScalingConfig(num_workers=1, use_gpu=True)

# Local path (/some/local/path/unique_run_name)
run_config = RunConfig(storage_path="/dbfs/tmp/ray_ws_logs", name="local")

# [5] Launch distributed training job.
trainer = ray.train.torch.TorchTrainer(
    train_func,
    scaling_config=scaling_config,
    # [5a] If running in a multi-node cluster, this is where you
    # should configure the run's persistent storage that is accessible
    # across all worker nodes.
    run_config=run_config,
)

result = trainer.fit()

In [0]:
display(result.metrics)     # The metrics reported during training.
display(result.checkpoint)  # The latest checkpoint reported during training.
display(result.path)        # The path where logs are stored.
display(result.error)       # The exception that was raised, if training failed.

In [0]:
# [6] Load the trained model.
with result.checkpoint.as_directory() as checkpoint_dir:
    model_state_dict = torch.load(os.path.join(checkpoint_dir, "model.pt"))
    model = resnet18(num_classes=10)
    model.conv1 = torch.nn.Conv2d(
        1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    )
    model.load_state_dict(model_state_dict)

In [0]:
ray.util.spark.shutdown_ray_cluster()