### https://docs.ray.io/en/latest/cluster/vms/user-guides/community/spark.html


Config:
- 2 Workers: 448 GB Memory, 24 Cores
  - 2 GPUs per Worker: V100
- 1 Driver: 224 GB Memory, 12 Cores
Runtime:
- 15.4.x-gpu-ml-scala2.12
Type:
- Standard_NC12s_v3

- Cost: 30 DBUs/h + Provider Cost


Notes: 

- We recommend setting the argument num_cpus_worker_node to the number of CPU cores per Apache Spark worker node. Similarly, setting num_gpus_worker_node to the number of GPUs per Apache Spark worker node is optimal. With this configuration, each Apache Spark worker node launches one Ray worker node that will fully utilize the resources of each Apache Spark worker node.
- Set the environment variable RAY_memory_monitor_refresh_ms to 0 within the Databricks cluster configuration when starting your Apache Spark cluster.


- In each spark worker node, we recommend making the sum of 'spark_executor_memory + num_Ray_worker_nodes_per_spark_worker * (memory_worker_node + object_store_memory_worker_node)' to be less than 'spark_worker_physical_memory * 0.8', otherwise it might lead to spark worker physical memory exhaustion and Ray task OOM errors.

In [0]:
%pip install gym torch numpy matplotlib

In [0]:
dbutils.library.restartPython()

In [0]:
DEBUG = True

We recommend setting the argument num_cpus_worker_node to the number of CPU cores per Apache Spark worker node. Similarly, setting num_gpus_worker_node to the number of GPUs per Apache Spark worker node is optimal. With this configuration, each Apache Spark worker node launches one Ray worker node that will fully utilize the resources of each Apache Spark worker node.

Set the environment variable RAY_memory_monitor_refresh_ms to 0 within the Databricks cluster configuration when starting your Apache Spark cluster.

In [0]:
# You configured 'spark.task.resource.gpu.amount' to 1.0, we recommend setting this value to 0 so that Spark jobs do not reserve GPU resources, preventing Ray-on-Spark workloads from having the maximum number of GPUs available. In each spark worker node, we recommend making the sum of 'spark_executor_memory + num_Ray_worker_nodes_per_spark_worker * (memory_worker_node + object_store_memory_worker_node)' to be less than 'spark_worker_physical_memory * 0.8', otherwise it might lead to spark worker physical memory exhaustion and Ray task OOM errors.

In [0]:
import os 

spark.conf.set("spark.task.resource.gpu.amount", "0")
os.environ["RAY_memory_monitor_refresh_ms"] = "0"

In [0]:
from mlflow.utils.databricks_utils import get_databricks_env_vars
import mlflow 

mlflow_db_creds = get_databricks_env_vars("databricks")

username = "will.smith@databricks.com" 
experiment_name = f"/Users/{username}/ray_dqn"

mlflow.set_experiment(experiment_name)

In [0]:
import ray
from ray.util.spark import setup_ray_cluster, shutdown_ray_cluster
import os 

# Call before setup_ray_cluster
os.environ["DATABRICKS_HOST"] = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get()
os.environ["DATABRICKS_TOKEN"] = dbutils.secrets.get(scope="william_smith_secrets", key="WS_PAT")

setup_ray_cluster(
  max_worker_nodes=1,
  num_cpus_per_node=12,
  num_gpus_per_node=2,
  num_cpus_head_node=12,
  num_gpus_head_node=2,
  collect_log_to_path="/dbfs/tmp/ws_ray_collected_logs"
)

# Pass any custom Ray configuration with ray.init
ray.init(ignore_reinit_error=True)

In [0]:
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
from collections import deque


# Create the CartPole environment
env = gym.make("CartPole-v1")

# Neural network model for approximating Q-values
class DQN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, output_dim)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)
    
def train_func():
    # Function to choose action using epsilon-greedy policy
    def select_action(state, epsilon):
        if random.random() < epsilon:
            return env.action_space.sample()  # Explore
        else:
            state = torch.FloatTensor(state).unsqueeze(0)
            q_values = policy_net(state)
            return torch.argmax(q_values).item()  # Exploit

    # Function to optimize the model using experience replay
    def optimize_model():
        if len(memory) < batch_size:
            return
        
        batch = random.sample(memory, batch_size)
        state_batch, action_batch, reward_batch, next_state_batch, done_batch = zip(*batch)

        state_batch = torch.FloatTensor(state_batch)
        action_batch = torch.LongTensor(action_batch).unsqueeze(1)
        reward_batch = torch.FloatTensor(reward_batch)
        next_state_batch = torch.FloatTensor(next_state_batch)
        done_batch = torch.FloatTensor(done_batch)

        # Compute Q-values for current states
        q_values = policy_net(state_batch).gather(1, action_batch).squeeze()

        # Compute target Q-values using the target network
        with torch.no_grad():
            max_next_q_values = target_net(next_state_batch).max(1)[0]
            target_q_values = reward_batch + gamma * max_next_q_values * (1 - done_batch)

        loss = nn.MSELoss()(q_values, target_q_values)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Hyperparameters
    learning_rate = 0.001
    gamma = 0.99
    epsilon = 1.0
    epsilon_min = 0.01
    epsilon_decay = 0.995
    batch_size = 64
    target_update_freq = 1000
    memory_size = 10000
    episodes = 1000

    # Initialize Q-networks
    input_dim = env.observation_space.shape[0]
    output_dim = env.action_space.n
    policy_net = DQN(input_dim, output_dim)
    target_net = DQN(input_dim, output_dim)
    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()
    terminated = False

    optimizer = optim.Adam(policy_net.parameters(), lr=learning_rate)
    memory = deque(maxlen=memory_size)

    # Main training loop
    rewards_per_episode = []
    steps_done = 0

    for episode in range(episodes):
        state = env.reset()
        episode_reward = 0
        done = False
        
        while not terminated:
            # Select action
            action = select_action(state, epsilon)
            if(DEBUG):
                print(env.step(action))
            next_state, reward, terminated, truncated, info = env.step(action)
            
            # Store transition in memory
            memory.append((state, action, reward, next_state, terminated))
            
            # Update state
            state = next_state
            episode_reward += reward
            
            # Optimize model
            optimize_model()

            # Update target network periodically
            if steps_done % target_update_freq == 0:
                target_net.load_state_dict(policy_net.state_dict())

            steps_done += 1

        # Decay epsilon
        epsilon = max(epsilon_min, epsilon_decay * epsilon)
        
        rewards_per_episode.append(episode_reward)

In [0]:

from ray.train import RunConfig
from ray.train.torch import TorchTrainer

#  [4] Configure scaling and resource requirements.
# Use GPU to allow cuda 
scaling_config = ray.train.ScalingConfig(num_workers=2, use_gpu=True)

# Local path (/some/local/path/unique_run_name)
run_config = RunConfig(storage_path="/dbfs/tmp/ray_ws_logs", name="local")

# [5] Launch distributed training job.
trainer = ray.train.torch.TorchTrainer(
    train_func,
    scaling_config=scaling_config,
    # [5a] If running in a multi-node cluster, this is where you
    # should configure the run's persistent storage that is accessible
    # across all worker nodes.
    run_config=run_config,
)

In [0]:
import mlflow 

try: 
  with mlflow.start_run() as run:
    results = trainer.fit() 
  for x in results:
    mlflow.log_metric("x", x)
except:
  results = trainer.fit() 

In [0]:
display(results.metrics)     # The metrics reported during training.
display(results.checkpoint)  # The latest checkpoint reported during training.
display(results.path)        # The path where logs are stored.
display(results.error)       # The exception that was raised, if training failed.

In [0]:
checkpoint = results.checkpoint

if(checkpoint is not None):
    with checkpoint.as_directory() as checkpoint_dir:
        # Change as needed for different DL frameworks
        checkpoint_path = f"{checkpoint_dir}/checkpoint.ckpt"
        # Load the model from the checkpoint
        model = DQN.load_from_checkpoint(checkpoint_path)

    with mlflow.start_run() as run:
        # Change the MLflow flavor as needed
        mlflow.pytorch.log_model(model, "model")
else:
    raise Exception("No checkpoint found")

#### When finished, be sure to shut down the ray_cluster

In [0]:
ray.util.spark.shutdown_ray_cluster()