# Distributed training

<div align="left">
<a target="_blank" href="https://console.anyscale.com/"><img src="https://img.shields.io/badge/🚀 Run_on-Anyscale-9hf"></a>&nbsp;
<a href="https://github.com/anyscale/foundational-ray-app" role="button"><img src="https://img.shields.io/static/v1?label=&amp;message=View%20On%20GitHub&amp;color=586069&amp;logo=github&amp;labelColor=2f363d"></a>&nbsp;
</div>

In this tutorial, we'll execute a distributed training workload that will connect the following heterogenous workloads:
- preprocess the dataset prior to training
- distributed training with Ray Train and PyTorch (with observability)
- evaluation (batch inference + eval logic)
- save model artifacts to a model registry (MLOps)

**Note**: we won't be tuning our model in this tutorial but be sure to check out [Ray Tune](https://docs.ray.io/en/latest/tune/index.html) for experiment execution and hyperparameter tuning at any scale.

<img src="../images/distributed_training.png" width=800>

In [None]:
%load_ext autoreload
%autoreload all

In [None]:
import os
import ray
import sys
sys.path.append(os.path.abspath(".."))

In [None]:
# Enable Ray Train v2 (it's too good to wait for public release!)
ray.init(runtime_env={"env_vars": {"RAY_TRAIN_V2_ENABLED": "1"}})

2025-04-02 07:27:45,966	INFO worker.py:1660 -- Connecting to existing Ray cluster at address: 10.0.20.47:6379...
2025-04-02 07:27:45,976	INFO worker.py:1843 -- Connected to Ray cluster. View the dashboard at [1m[32mhttps://session-fyhrc759flh928h7czptpn79mb.i.anyscaleuserdata.com [39m[22m
2025-04-02 07:27:45,989	INFO packaging.py:367 -- Pushing file package 'gcs://_ray_pkg_0b08aeed72810f81292cb115f1a529e07b7fdfd5.zip' (1.82MiB) to Ray cluster...
2025-04-02 07:27:45,996	INFO packaging.py:380 -- Successfully pushed file package 'gcs://_ray_pkg_0b08aeed72810f81292cb115f1a529e07b7fdfd5.zip'.


0,1
Python version:,3.12.9
Ray version:,2.44.1
Dashboard:,http://session-fyhrc759flh928h7czptpn79mb.i.anyscaleuserdata.com


In [None]:
%%bash
# This will be removed once Ray Train v2 is part of latest Ray version
echo "RAY_TRAIN_V2_ENABLED=1" > /home/ray/default/.env

In [None]:
# Load env vars in notebooks
from dotenv import load_dotenv
load_dotenv()

True

### Preprocess

We need to convert our classes to labels (unique integers) so that we can train a classifier that can correctly predict the class given an input image. But before we do this, we'll quickly apply the same data ingestion and preprocessing as the previous notebook.

In [None]:
def add_class(row):
    row["class"] = row["path"].rsplit("/", 3)[-2]
    return row

In [None]:
# Preprocess data splits
train_ds = ray.data.read_images("s3://doggos-dataset/train", include_paths=True, shuffle="files")
train_ds = train_ds.map(add_class)
val_ds = ray.data.read_images("s3://doggos-dataset/val", include_paths=True)
val_ds = val_ds.map(add_class)

We'll define a `Preprocessor` class that will:
- create an embedding, since we will not want to change the embedding layer's weights and so we don't have to do it repeatedly as part of the model's forward pass (unecessary compute)
- convert our classes into labels for the classifier. 

While we could've just done this as a simple operation, we're taking the time to organize it as a class so that we can save and load for inference later.

In [None]:
import numpy as np
from PIL import Image
import torch
from transformers import CLIPModel, CLIPProcessor

In [None]:
class EmbeddingGenerator(object):
    def __init__(self, model_id):
        # Load CLIP model and processor
        self.model = CLIPModel.from_pretrained(model_id)
        self.processor = CLIPProcessor.from_pretrained(model_id)

    def __call__(self, batch, device="cpu"):
        # Load and preprocess images
        images = [Image.fromarray(np.uint8(img)).convert("RGB") for img in batch["image"]]
        inputs = self.processor(images=images, return_tensors="pt", padding=True).to(device)

        # Generate embeddings
        self.model.to(device)
        with torch.inference_mode():
            batch["embedding"] = self.model.get_image_features(**inputs).cpu().numpy()

        return batch

In [None]:
class Preprocessor:
    """Preprocessor class."""
    def __init__(self, class_to_label={}):
        self.class_to_label = class_to_label or {}  # mutable defaults
        self.label_to_class = {v: k for k, v in self.class_to_label.items()}
        
    def fit(self, ds, column):
        self.classes = ds.unique(column=column)
        self.class_to_label = {tag: i for i, tag in enumerate(self.classes)}
        self.label_to_class = {v: k for k, v in self.class_to_label.items()}
        return self

    def convert_to_label(self, row, class_to_label):
        if "class" in row:
            row["label"] = class_to_label[row["class"]]
        return row
    
    def transform(self, ds, concurrency=4, batch_size=64, num_gpus=1):
        ds = ds.map(
            self.convert_to_label, 
            fn_kwargs={"class_to_label": self.class_to_label},
        )
        ds = ds.map_batches(
            EmbeddingGenerator,
            fn_constructor_kwargs={"model_id": "openai/clip-vit-base-patch32"},
            fn_kwargs={"device": "cuda"}, 
            concurrency=concurrency, 
            batch_size=batch_size,
            num_gpus=num_gpus,
        )
        ds = ds.drop_columns(["image"])
        return ds

    def save(self, fp):
        with open(fp, "w") as f:
            json.dump(self.class_to_label, f)

In [None]:
# Preprocess
preprocessor = Preprocessor()
preprocessor = preprocessor.fit(train_ds, column="class")
train_ds = preprocessor.transform(ds=train_ds)
val_ds = preprocessor.transform(ds=val_ds)
train_ds.take(1)

2025-04-02 07:27:54,232	INFO dataset.py:2809 -- Tip: Use `take_batch()` instead of `take() / show()` to return records in pandas or numpy batch format.
2025-04-02 07:27:54,243	INFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-04-02_01-36-15_100194_2258/logs/ray-data
2025-04-02 07:27:54,244	INFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ListFiles] -> TaskPoolMapOperator[ReadFiles] -> TaskPoolMapOperator[Map(add_class)] -> AllToAllOperator[Aggregate] -> LimitOperator[limit=1]


Running 0: 0.00 row [00:00, ? row/s]

- ListFiles 1: 0.00 row [00:00, ? row/s]

- ReadFiles 2: 0.00 row [00:00, ? row/s]

- Map(add_class) 3: 0.00 row [00:00, ? row/s]

- Aggregate 4: 0.00 row [00:00, ? row/s]

Sort Sample 5:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Map 6:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 7:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- limit=1 8: 0.00 row [00:00, ? row/s]

[36m(autoscaler +15s)[0m Tip: use `ray status` to view detailed cluster status. To disable these messages, set RAY_SCHEDULER_EVENTS=0.
[36m(autoscaler +15s)[0m [autoscaler] [4xT4:48CPU-192GB] Attempting to add 1 node(s) to the cluster (increasing from 0 to 1).
[36m(autoscaler +15s)[0m [autoscaler] [4xT4:48CPU-192GB] Launched 1 instances.
[36m(ListFiles pid=3033, ip=10.0.117.255)[0m >>> [DBG] partition_files: before: pyarrow.Table
[36m(ListFiles pid=3033, ip=10.0.117.255)[0m __path: string
[36m(ListFiles pid=3033, ip=10.0.117.255)[0m __file_size: int64
[36m(ListFiles pid=3033, ip=10.0.117.255)[0m ----
[36m(ListFiles pid=3033, ip=10.0.117.255)[0m __path: [["doggos-dataset/train/basset/basset_10028.jpg","doggos-dataset/train/basset/basset_10054.jpg","doggos-dataset/train/basset/basset_10072.jpg","doggos-dataset/train/basset/basset_10095.jpg","doggos-dataset/train/basset/basset_10110.jpg",...,"doggos-dataset/train/yorkshire_terrier/yorkshire_terrier_889.jpg","doggos-dataset

2025-04-02 07:29:26,022	INFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-04-02_01-36-15_100194_2258/logs/ray-data
2025-04-02 07:29:26,022	INFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ListFiles] -> TaskPoolMapOperator[ReadFiles] -> TaskPoolMapOperator[Map(add_class)->Map(Preprocessor.convert_to_label)] -> ActorPoolMapOperator[MapBatches(EmbeddingGenerator)] -> TaskPoolMapOperator[MapBatches(drop_columns)] -> LimitOperator[limit=1]


Running 0: 0.00 row [00:00, ? row/s]

- ListFiles 1: 0.00 row [00:00, ? row/s]

- ReadFiles 2: 0.00 row [00:00, ? row/s]

- Map(add_class)->Map(Preprocessor.convert_to_label) 3: 0.00 row [00:00, ? row/s]

- MapBatches(EmbeddingGenerator) 4: 0.00 row [00:00, ? row/s]

- MapBatches(drop_columns) 5: 0.00 row [00:00, ? row/s]

- limit=1 6: 0.00 row [00:00, ? row/s]

[36m(ListFiles pid=4290, ip=10.0.117.255)[0m >>> [DBG] partition_files: before: pyarrow.Table
[36m(ListFiles pid=4290, ip=10.0.117.255)[0m __path: string[32m [repeated 2x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)[0m
[36m(ListFiles pid=4290, ip=10.0.117.255)[0m __file_size: [[72246,38195,396642,16829,41808,...,7636,37138,18289,152413,24018]][32m [repeated 4x across cluster][0m
[36m(ListFiles pid=4290, ip=10.0.117.255)[0m ----[32m [repeated 2x across cluster][0m
[36m(ListFiles pid=4290, ip=10.0.117.255)[0m __path: [["doggos-dataset/train/miniature_schnauzer/miniature_schnauzer_2178.jpg","doggos-dataset/train/malamute/malamute_6518.jpg","doggos-dataset/train/malinois/malinois_2093.jpg","doggos-dataset/train/golden_retriever/golden_retriever_3508.jpg","doggos-dataset/train/border_collie/bor

[{'path': 'doggos-dataset/train/malamute/malamute_12269.jpg',
  'class': 'malamute',
  'label': 26,
  'embedding': array([-3.20934802e-02,  2.38295987e-01, -2.82098830e-01,  3.15648079e-01,
          4.98129278e-02, -3.21801215e-01,  1.72749728e-01,  2.98801422e-01,
         -1.46350980e-01,  3.36772591e-01, -2.07972020e-01,  8.03879276e-02,
          3.31074893e-01,  2.07380429e-01,  3.22806120e-01, -2.45789498e-01,
          4.45166469e-01,  7.23792464e-02, -2.30011567e-01, -3.13075364e-01,
         -4.44811493e-01,  7.88827240e-03,  2.61921108e-01, -5.02690732e-01,
         -7.88383186e-03, -1.80014670e-01,  3.15743566e-01, -4.99399513e-01,
          1.14164159e-01,  3.44607234e-01,  1.55417800e-01, -1.19305663e-01,
          1.82588041e-01, -8.12833011e-02,  1.02748406e+00,  1.45590588e-01,
          1.71805650e-01,  2.43361276e-02, -4.99892712e-01,  1.35278070e+00,
         -4.73846406e-01, -6.83804229e-02,  4.67592418e-01, -4.83125627e-01,
         -2.01296866e-01, -5.84190071e-0

<div class="alert alert-block alert"> <b> Data Processing</b> 

Be sure to checkout this extensive guide on [data loading and preprocessing](https://docs.ray.io/en/latest/train/user-guides/data-loading-preprocessing.html) for the last-mile preprocessing we'll need to do prior to training our models. However, Ray Data does support performant joins, filters, aggregations, etc. for the more structure data processing your workloads may need.

<div class="alert alert-block alert"> <b> Store often, Save compute</b> 

We're going to now store our preprocessed data into shared cloud storage because we want to:
- save a record of what this preprocessed data looks like
- avoid triggering the entire preprocessing for each batch our model will process
- don't want to [`materialize`](https://docs.ray.io/en/latest/data/api/doc/ray.data.Dataset.materialize.html) the preprocessed data either (shouldn't force large data to fit in memory)

In [None]:
from doggos.utils import delete_s3_objects

In [None]:
# Write processed data to cloud storage
preprocessed_data_path = os.path.join(
    os.getenv("ANYSCALE_ARTIFACT_STORAGE", ""), 
    os.getenv("ANYSCALE_USERNAME", "").replace(" ", "_"), 
    "doggos/preprocessed_data",
)
delete_s3_objects(s3_path=preprocessed_data_path)
preprocessed_train_path = os.path.join(preprocessed_data_path, "preprocessed_train")
preprocessed_val_path = os.path.join(preprocessed_data_path, "preprocessed_val")
train_ds.write_parquet(preprocessed_train_path)
val_ds.write_parquet(preprocessed_val_path)

Deleted 65 objects from s3://anyscale-test-data-cld-i2w99rzq8b6lbjkke9y94vi5/org_7c1Kalm9WcX2bNIjW53GUT/cld_kvedZWag2qA8i5BjxUevf5i7/artifact_storage/goku_mohandas/doggos/preprocessed_data


2025-04-02 07:29:49,790	INFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-04-02_01-36-15_100194_2258/logs/ray-data
2025-04-02 07:29:49,792	INFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ListFiles] -> TaskPoolMapOperator[ReadFiles] -> TaskPoolMapOperator[Map(add_class)->Map(Preprocessor.convert_to_label)] -> ActorPoolMapOperator[MapBatches(EmbeddingGenerator)] -> TaskPoolMapOperator[MapBatches(drop_columns)->Write]


Running 0: 0.00 row [00:00, ? row/s]

- ListFiles 1: 0.00 row [00:00, ? row/s]

- ReadFiles 2: 0.00 row [00:00, ? row/s]

- Map(add_class)->Map(Preprocessor.convert_to_label) 3: 0.00 row [00:00, ? row/s]

- MapBatches(EmbeddingGenerator) 4: 0.00 row [00:00, ? row/s]

- MapBatches(drop_columns)->Write 5: 0.00 row [00:00, ? row/s]

[36m(ListFiles pid=4265, ip=10.0.117.255)[0m >>> [DBG] partition_files: before: pyarrow.Table
[36m(ListFiles pid=4265, ip=10.0.117.255)[0m __path: string
[36m(ListFiles pid=4265, ip=10.0.117.255)[0m __file_size: int64
[36m(ListFiles pid=4265, ip=10.0.117.255)[0m ----
[36m(ListFiles pid=4265, ip=10.0.117.255)[0m __path: [["doggos-dataset/train/basset/basset_10028.jpg","doggos-dataset/train/basset/basset_10054.jpg","doggos-dataset/train/basset/basset_10072.jpg","doggos-dataset/train/basset/basset_10095.jpg","doggos-dataset/train/basset/basset_10110.jpg",...,"doggos-dataset/train/yorkshire_terrier/yorkshire_terrier_889.jpg","doggos-dataset/train/yorkshire_terrier/yorkshire_terrier_9618.jpg","doggos-dataset/train/yorkshire_terrier/yorkshire_terrier_962.jpg","doggos-dataset/train/yorkshire_terrier/yorkshire_terrier_967.jpg","doggos-dataset/train/yorkshire_terrier/yorkshire_terrier_9739.jpg"]]
[36m(ListFiles pid=4265, ip=10.0.117.255)[0m __file_size: [[56919,36417,21093,23721,125

2025-04-02 07:30:09,377	INFO dataset.py:4178 -- Data sink Parquet finished. 2880 rows and 5.9MB data written.
2025-04-02 07:30:09,602	INFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-04-02_01-36-15_100194_2258/logs/ray-data
2025-04-02 07:30:09,603	INFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ListFiles] -> TaskPoolMapOperator[ReadFiles] -> TaskPoolMapOperator[Map(add_class)->Map(Preprocessor.convert_to_label)] -> ActorPoolMapOperator[MapBatches(EmbeddingGenerator)] -> TaskPoolMapOperator[MapBatches(drop_columns)->Write]


Running 0: 0.00 row [00:00, ? row/s]

- ListFiles 1: 0.00 row [00:00, ? row/s]

- ReadFiles 2: 0.00 row [00:00, ? row/s]

- Map(add_class)->Map(Preprocessor.convert_to_label) 3: 0.00 row [00:00, ? row/s]

- MapBatches(EmbeddingGenerator) 4: 0.00 row [00:00, ? row/s]

- MapBatches(drop_columns)->Write 5: 0.00 row [00:00, ? row/s]

2025-04-02 07:30:28,105	INFO dataset.py:4178 -- Data sink Parquet finished. 720 rows and 1.5MB data written.


### Model

Let's define our model -- a simple two layer neural net with softmax layer to predict class probabilities. You'll notice that it's all just base PyTorch and nothing else.

In [None]:
import json
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class ClassificationModel(torch.nn.Module):
    def __init__(self, embedding_dim, hidden_dim, dropout_p, num_classes):
        super().__init__()
        # Hyperparameters
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.dropout_p = dropout_p
        self.num_classes = num_classes

        # Define layers
        self.fc1 = nn.Linear(embedding_dim, hidden_dim)
        self.batch_norm = nn.BatchNorm1d(hidden_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout_p)
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, batch):
        z = self.fc1(batch["embedding"])
        z = self.batch_norm(z)
        z = self.relu(z)
        z = self.dropout(z)
        z = self.fc2(z)
        return z

    @torch.inference_mode()
    def predict(self, batch):
        z = self(batch)
        y_pred = torch.argmax(z, dim=1).cpu().numpy()
        return y_pred

    @torch.inference_mode()
    def predict_probabilities(self, batch):
        z = self(batch)
        y_probs = F.softmax(z, dim=1).cpu().numpy()
        return y_probs

    def save(self, dp):
        Path(dp).mkdir(parents=True, exist_ok=True)
        with open(Path(dp, "args.json"), "w") as fp:
            json.dump({
                "embedding_dim": self.embedding_dim,
                "hidden_dim": self.hidden_dim,
                "dropout_p": self.dropout_p,
                "num_classes": self.num_classes,
            }, fp, indent=4)
        torch.save(self.state_dict(), Path(dp, "model.pt"))

    @classmethod
    def load(cls, args_fp, state_dict_fp, device="cpu"):
        with open(args_fp, "r") as fp:
            model = cls(**json.load(fp))
        model.load_state_dict(torch.load(state_dict_fp, map_location=device))
        return model

In [None]:
# Initialize model
num_classes = len(preprocessor.classes)
model = ClassificationModel(
    embedding_dim=512, 
    hidden_dim=256, 
    dropout_p=0.3, 
    num_classes=num_classes,
)
print (model.named_parameters)

<bound method Module.named_parameters of ClassificationModel(
  (fc1): Linear(in_features=512, out_features=256, bias=True)
  (batch_norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (fc2): Linear(in_features=256, out_features=36, bias=True)
)>


### Batching

Let's take a look at a sample batch of data and ensure that tensors of the proper data type.

In [None]:
from ray.train.torch import get_device

In [None]:
def collate_fn(batch):
    dtypes = {"embedding": torch.float32, "label": torch.int64}
    tensor_batch = {}
    for key in dtypes.keys():
        if key in batch:
            tensor_batch[key] = torch.as_tensor(
                batch[key],
                dtype=dtypes[key],
                device=get_device(),
            )
    return tensor_batch

In [None]:
# Sample batch
sample_batch = train_ds.take_batch(batch_size=3)
collate_fn(batch=sample_batch)

2025-04-02 07:30:31,488	INFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-04-02_01-36-15_100194_2258/logs/ray-data
2025-04-02 07:30:31,489	INFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ListFiles] -> TaskPoolMapOperator[ReadFiles] -> TaskPoolMapOperator[Map(add_class)->Map(Preprocessor.convert_to_label)] -> ActorPoolMapOperator[MapBatches(EmbeddingGenerator)] -> TaskPoolMapOperator[MapBatches(drop_columns)] -> LimitOperator[limit=3]


Running 0: 0.00 row [00:00, ? row/s]

- ListFiles 1: 0.00 row [00:00, ? row/s]

- ReadFiles 2: 0.00 row [00:00, ? row/s]

- Map(add_class)->Map(Preprocessor.convert_to_label) 3: 0.00 row [00:00, ? row/s]

- MapBatches(EmbeddingGenerator) 4: 0.00 row [00:00, ? row/s]

- MapBatches(drop_columns) 5: 0.00 row [00:00, ? row/s]

- limit=3 6: 0.00 row [00:00, ? row/s]

[36m(ListFiles pid=3033, ip=10.0.117.255)[0m >>> [DBG] partition_files: before: pyarrow.Table
[36m(ListFiles pid=3033, ip=10.0.117.255)[0m __path: string
[36m(ListFiles pid=3033, ip=10.0.117.255)[0m __file_size: int64
[36m(ListFiles pid=3033, ip=10.0.117.255)[0m ----
[36m(ListFiles pid=3033, ip=10.0.117.255)[0m __path: [["doggos-dataset/train/basset/basset_10028.jpg","doggos-dataset/train/basset/basset_10054.jpg","doggos-dataset/train/basset/basset_10072.jpg","doggos-dataset/train/basset/basset_10095.jpg","doggos-dataset/train/basset/basset_10110.jpg",...,"doggos-dataset/train/yorkshire_terrier/yorkshire_terrier_889.jpg","doggos-dataset/train/yorkshire_terrier/yorkshire_terrier_9618.jpg","doggos-dataset/train/yorkshire_terrier/yorkshire_terrier_962.jpg","doggos-dataset/train/yorkshire_terrier/yorkshire_terrier_967.jpg","doggos-dataset/train/yorkshire_terrier/yorkshire_terrier_9739.jpg"]]
[36m(ListFiles pid=3033, ip=10.0.117.255)[0m __file_size: [[56919,36417,21093,23721,125

  tensor_batch[key] = torch.as_tensor(


{'embedding': tensor([[-0.2528,  0.3721, -0.3783,  ...,  1.1563,  0.2734, -0.0609],
         [-0.0251, -0.2169, -0.0802,  ...,  0.5675,  0.2192, -0.0173],
         [-0.0288, -0.0551,  0.0742,  ...,  0.5049, -0.2767,  0.0743]]),
 'label': tensor([29,  5, 21])}

### Model registry

We'll be creating a model registry in our [Anyscale user storage](https://docs.anyscale.com/configuration/storage/#user-storage) to save our model checkpoints to. We'll be using OSS mlflow but we can easily [set up other experiment trackers](https://docs.ray.io/en/latest/train/user-guides/experiment-tracking.html) with Ray.

In [None]:
import shutil

In [None]:
model_registry = "/mnt/user_storage/mlflow/doggos"
os.path.isdir(model_registry) and shutil.rmtree(model_registry)  # clean up
os.makedirs(model_registry, exist_ok=True)

### Training

We'll define our training workload by specifying our:
- experiment and model parameters
- compute scaling configuration
- forward pass for batches of training and validation data
- train loop for each epoch of data (and checkpointing)

<img src="../images/trainer.png" width=500>

In [None]:
# Train loop config
experiment_name = "doggos"
train_loop_config = {
    "model_registry": model_registry,
    "experiment_name": experiment_name,
    "embedding_dim": 512,
    "hidden_dim": 256,
    "dropout_p": 0.3,
    "lr": 1e-3,
    "lr_factor": 0.8,
    "lr_patience": 3,
    "num_epochs": 20,
    "batch_size": 256,
}

In [None]:
# Scaling config
num_workers = 4
scaling_config = ray.train.ScalingConfig(
    num_workers=num_workers,
    use_gpu=True,
    resources_per_worker={"CPU": 8, "GPU": 1})

In [None]:
import tempfile
import mlflow
from ray.train.torch import TorchTrainer

In [None]:
def train_step(ds, batch_size, model, num_classes, loss_fn, optimizer):
    model.train()
    loss = 0.0
    ds_generator = ds.iter_torch_batches(batch_size=batch_size, collate_fn=collate_fn)
    for i, batch in enumerate(ds_generator):
        optimizer.zero_grad()  # reset gradients
        z = model(batch)  # forward pass
        targets = F.one_hot(batch["label"], num_classes=num_classes).float()
        J = loss_fn(z, targets)  # define loss
        J.backward()  # backward pass
        optimizer.step()  # update weights
        loss += (J.detach().item() - loss) / (i + 1)  # cumulative loss
    return loss

In [None]:
def eval_step(ds, batch_size, model, num_classes, loss_fn):
    model.eval()
    loss = 0.0
    y_trues, y_preds = [], []
    ds_generator = ds.iter_torch_batches(batch_size=batch_size, collate_fn=collate_fn)
    with torch.inference_mode():
        for i, batch in enumerate(ds_generator):
            z = model(batch)
            targets = F.one_hot(batch["label"], num_classes=num_classes).float()  # one-hot (for loss_fn)
            J = loss_fn(z, targets).item()
            loss += (J - loss) / (i + 1)
            y_trues.extend(batch["label"].cpu().numpy())
            y_preds.extend(torch.argmax(z, dim=1).cpu().numpy())
    return loss, np.vstack(y_trues), np.vstack(y_preds)

In [None]:
def train_loop_per_worker(config):
    # Hyperparameters
    model_registry = config["model_registry"]
    experiment_name = config["experiment_name"]
    embedding_dim = config["embedding_dim"]
    hidden_dim = config["hidden_dim"]
    dropout_p = config["dropout_p"]
    lr = config["lr"]
    lr_factor = config["lr_factor"]
    lr_patience = config["lr_patience"]
    num_epochs = config["num_epochs"]
    batch_size = config["batch_size"]
    num_classes = config["num_classes"]

    # Experiment tracking
    if ray.train.get_context().get_world_rank() == 0:
        mlflow.set_tracking_uri(f"file:{model_registry}")
        mlflow.set_experiment(experiment_name)
        mlflow.start_run()
        mlflow.log_params(config)

    # Datasets
    train_ds = ray.train.get_dataset_shard("train")
    val_ds = ray.train.get_dataset_shard("val")

    # Model
    model = ClassificationModel(
        embedding_dim=embedding_dim, 
        hidden_dim=hidden_dim, 
        dropout_p=dropout_p, 
        num_classes=num_classes,
    )
    model = ray.train.torch.prepare_model(model)

    # Training components
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        mode="min", 
        factor=lr_factor, 
        patience=lr_patience,
    )

    # Training
    best_val_loss = float("inf")
    for epoch in range(num_epochs):
        # Steps
        train_loss = train_step(train_ds, batch_size, model, num_classes, loss_fn, optimizer)
        val_loss, _, _ = eval_step(val_ds, batch_size, model, num_classes, loss_fn)
        scheduler.step(val_loss)

        # Checkpoint (metrics, preprocessor and model artifacts)
        with tempfile.TemporaryDirectory() as dp:
            model.module.save(dp=dp)
            metrics = dict(lr=optimizer.param_groups[0]["lr"], train_loss=train_loss, val_loss=val_loss)
            with open(os.path.join(dp, "class_to_label.json"), "w") as fp:
                json.dump(config["class_to_label"], fp, indent=4)
            if ray.train.get_context().get_world_rank() == 0:  # only on main worker 0
                mlflow.log_metrics(metrics, step=epoch)
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    mlflow.log_artifacts(dp)

    # End experiment tracking
    if ray.train.get_context().get_world_rank() == 0:
        mlflow.end_run()

<div class="alert alert-block alert"> <b> Minimal change to your training code</b> 

You'll notice that there isn't much new Ray Train code on top of our base PyTorch code. We specified how we want to scale out our training workload, load the Ray datasets and then checkpoint on our main worker node... and that's it! Check out these guides ([PyTorch](https://docs.ray.io/en/latest/train/getting-started-pytorch.html), [PyTorch Lightning](https://docs.ray.io/en/latest/train/getting-started-pytorch-lightning.html), [HuggingFace Transformers](https://docs.ray.io/en/latest/train/getting-started-transformers.html)) to see the minimal delta code needed to distribute our training workloads and check out this extensive list of [Ray Train user guides](https://docs.ray.io/en/latest/train/user-guides.html).

In [None]:
# Load preprocessed datasets
preprocessed_train_ds = ray.data.read_parquet(preprocessed_train_path)
preprocessed_val_ds = ray.data.read_parquet(preprocessed_val_path)

In [None]:
# Trainer
train_loop_config["class_to_label"] = preprocessor.class_to_label
train_loop_config["num_classes"] = len(preprocessor.class_to_label)
trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config=train_loop_config,
    scaling_config=scaling_config,
    datasets={"train": preprocessed_train_ds, "val": preprocessed_val_ds},
)

<div class="alert alert-block alert"> <b> Ray Train</b> 

**🎛️ Multi-node orchestration made easy**

- Ray Train automatically handles multi-node, multi-GPU setup with no manual SSH setup or hostfile configs. 
- And it also integrates with Ray's cluster launcher for cloud (AWS, GCP, K8s) and on-prem clusters. 
- Solutions like PyTorch DDP require manually setting up your own process group, ranks, networking, etc.

**🩹 2. Built-in fault tolerance**
- Ray Train supports automatic retry of failed workers.
- and can continue training from the last checkpoint in case of failure.


**✂️ 3. Flexible training strategies** (not just DDP)
- Ray Train supports Data Parallel, Model Parallel, Parameter Server, and even custom strategies.
- You can also use Torch DDP, FSPD, DeepSpeed, etc. under the hood if you want.
- [Ray Compiled graphs](https://docs.ray.io/en/latest/ray-core/compiled-graph/ray-compiled-graph.html) allow us to even define different parallelism for jointly optimizing multipe models (Megatron, Deepspeed, etc. only allow for one global setting).

**🔥 Better support for heterogeneous clusters**
- Ray Train lets you define per-worker resource requirements (e.g., 2 CPUs, 1 GPU per worker).
- and can run on heterogeneous machines and scale flexibly.

**🌍 Integrations**

<img src="../images/train_integrations.png" width=500>

[RayTurbo Train]() offers even more improvemen to the price-performance ratio, performance monitoring and more:
- **elastic training** to enable jobs to seamlessly adapt to changes in resource availability
- **purpose-built dashboard** designed to streamline the debugging of Ray Train workloads
    - Monitoring: View the status of training runs and train workers.
    - Metrics: See insights on training throughput, training system operation time.
    - Profiling: Investigate bottlenecks, hangs, or errors from individual training worker processes.

<img src="../images/train_dashboard.png" width=700>

In [None]:
# Train
results = trainer.fit()

[36m(autoscaler +9m6s)[0m [autoscaler] Downscaling node i-008ce49342ee21cd1 (node IP: 10.0.117.255) due to node idle termination.


We can view our experiment metrics and model artifacts in our model registry. We're using OSS mlflow so we can run the server by pointing to our model registry location:

```bash
mlflow server -h 0.0.0.0 -p 8080 --backend-store-uri /mnt/user_storage/mlflow/doggos
```

We can view the dashboard by going to the **Overview tab** up top → **Open Ports**. 

<img src="../images/mlflow.png" width=685>

We also have our Ray Dashboard and Train workfload specific dashboards above. 

<img src="../images/train_metrics.png" width=700>


In [None]:
# Sorted runs
mlflow.set_tracking_uri(f"file:{model_registry}")
sorted_runs = mlflow.search_runs(
    experiment_names=[experiment_name], 
    order_by=["metrics.val_loss ASC"])
best_run = sorted_runs.iloc[0]
best_run

run_id                                      0f1ed542edda4cff8ecf731efe23893b
experiment_id                                             415276455561301916
status                                                               RUNNING
artifact_uri               file:///mnt/user_storage/mlflow/doggos/4152764...
start_time                                  2025-04-02 05:28:46.736000+00:00
end_time                                                                None
metrics.val_loss                                                    0.886611
metrics.lr                                                             0.001
metrics.train_loss                                                  0.620973
params.lr_factor                                                         0.8
params.embedding_dim                                                     512
params.num_classes                                                        36
params.dropout_p                                                         0.3

And we can easily wrap our training workload as a production grade Job:

In [None]:
%%bash
anyscale job submit --name=train-doggos-model \
  --containerfile="/home/ray/default/containerfile" \
  --config-file="/home/ray/default/compute.yaml" \
  --working-dir="/home/ray/default" \
  --exclude="" \
  --max-retries=0 \
  -- python doggos/train.py

Output
(anyscale +0.8s) Submitting job with config JobConfig(name='train-doggos-model', image_uri=None, compute_config=ComputeConfig(cloud='anyscale_v2_default_cloud'), env_vars=None, py_modules=None, py_executable=None, cloud=None, project=None, ray_version=None, job_queue_config=None).
(anyscale +1.4s) Created compute config: 'compute-v1-038176bbd05fe5f40d47c502a03ae789:1'
(anyscale +1.4s) View the compute config in the UI: 'https://console.anyscale.com/v2/cld_kvedZWag2qA8i5BjxUevf5i7/compute-configs/cpt_wmtzjwgcli4ftbuveyuujqtg4b'
(anyscale +2.4s) Uploading local dir '/home/ray/default' to cloud storage.
(anyscale +3.3s) Including workspace-managed pip dependencies.
(anyscale +3.7s) Job 'train-doggos-model' submitted, ID: 'prodjob_a9sg2ernkhnr62vklzdx39unm8'.
(anyscale +3.7s) View the job in the UI: https://console.anyscale.com/jobs/prodjob_a9sg2ernkhnr62vklzdx39unm8
(anyscale +3.7s) Use `--wait` to wait for the job to run and stream logs.


<img src="../images/train_job.png" width=700>

### Evaluation

We'll conclude by evaluating our trained model on our test dataset. Evaluation is essentially just the same as our batch inference workload -- where we'll apply the model on batches of data and then calculate metrics using the predictions vs.true labels. Ray data is hyper optimized for throughput so preserving order is not a priority. But for evaluation, this is crucial! So we'll achieve this by preserving the entire row and adding the predicted label as another column to each row.

In [None]:
from urllib.parse import urlparse

In [None]:
class TorchPredictor:
    def __init__(self, preprocessor, model):
        self.preprocessor = preprocessor
        self.model = model
        self.model.eval()

    def __call__(self, batch, device="cuda"):
        self.model.to(device)
        batch["prediction"] = self.model.predict(collate_fn(batch))
        return batch

    def predict_probabilities(self, batch, device="cuda"):
        self.model.to(device)
        predicted_probabilities = self.model.predict_probabilities(collate_fn(batch))
        batch["probabilities"] = [
            {self.preprocessor.label_to_class[i]: prob for i, prob in enumerate(probabilities)}
            for probabilities in predicted_probabilities
        ]
        return batch
    
    @classmethod
    def from_artifacts_dir(cls, artifacts_dir):
        with open(os.path.join(artifacts_dir, "class_to_label.json"), "r") as fp:
            class_to_label = json.load(fp)
        preprocessor = Preprocessor(class_to_label=class_to_label)
        model = ClassificationModel.load(
            args_fp=os.path.join(artifacts_dir, "args.json"), 
            state_dict_fp=os.path.join(artifacts_dir, "model.pt"),
        )
        return cls(preprocessor=preprocessor, model=model)

In [None]:
# Load and preproces eval dataset
artifacts_dir = urlparse(best_run.artifact_uri).path
predictor = TorchPredictor.from_artifacts_dir(artifacts_dir=artifacts_dir)
test_ds = ray.data.read_images("s3://doggos-dataset/test", include_paths=True)
test_ds = test_ds.map(add_class)
test_ds = predictor.preprocessor.transform(ds=test_ds)

  model.load_state_dict(torch.load(state_dict_fp, map_location=device))


In [None]:
# y_pred (batch inference)
pred_ds = test_ds.map_batches(
    predictor,
    fn_kwargs={"device": "cuda"},
    concurrency=4,
    batch_size=64,
    num_gpus=1,
)
pred_ds.take(1)

2025-04-02 05:38:34,078	INFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-04-02_01-36-15_100194_2258/logs/ray-data
2025-04-02 05:38:34,078	INFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ListFiles] -> TaskPoolMapOperator[ReadFiles] -> TaskPoolMapOperator[Map(add_class)->Map(Preprocessor.convert_to_label)] -> ActorPoolMapOperator[MapBatches(EmbeddingGenerator)] -> TaskPoolMapOperator[MapBatches(drop_columns)] -> TaskPoolMapOperator[MapBatches(TorchPredictor)] -> LimitOperator[limit=1]


Running 0: 0.00 row [00:00, ? row/s]

[36m(autoscaler +43m42s)[0m [autoscaler] [4xT4:48CPU-192GB] Attempting to add 1 node(s) to the cluster (increasing from 0 to 1).
[36m(autoscaler +43m42s)[0m [autoscaler] [4xT4:48CPU-192GB] Launched 1 instances.


- ListFiles 1: 0.00 row [00:00, ? row/s]

- ReadFiles 2: 0.00 row [00:00, ? row/s]

- Map(add_class)->Map(Preprocessor.convert_to_label) 3: 0.00 row [00:00, ? row/s]

- MapBatches(EmbeddingGenerator) 4: 0.00 row [00:00, ? row/s]

- MapBatches(drop_columns) 5: 0.00 row [00:00, ? row/s]

- MapBatches(TorchPredictor) 6: 0.00 row [00:00, ? row/s]

- limit=1 7: 0.00 row [00:00, ? row/s]

[36m(autoscaler +45m2s)[0m [autoscaler] [4xT4:48CPU-192GB] Attempting to add 1 node(s) to the cluster (increasing from 1 to 2).
[36m(autoscaler +45m2s)[0m [autoscaler] [4xT4:48CPU-192GB] Launched 1 instances.




[{'path': 'doggos-dataset/test/basset/basset_3579.jpg',
  'class': 'basset',
  'label': 20,
  'embedding': array([ 2.20798776e-02, -1.71993062e-01, -9.32805091e-02,  2.52691805e-01,
         -2.29116872e-01, -7.94908166e-01,  4.60279137e-01,  2.67838597e-01,
         -1.26107395e-01, -2.50836074e-01, -1.45977423e-01, -1.36666849e-01,
          2.27473646e-01, -2.00646162e-01,  6.81822419e-01,  9.98634398e-02,
          4.29645002e-01,  6.83942139e-02, -4.37991731e-02, -1.69624895e-01,
         -5.87508202e-01, -1.87354431e-01,  9.97562930e-02, -3.60309660e-01,
         -4.44779992e-01,  4.32465971e-02,  4.66175020e-01,  1.49571478e-01,
         -1.19229004e-01, -8.18487182e-02,  2.32132941e-01,  1.93196714e-01,
          2.08487064e-01, -2.43074715e-01, -9.16214734e-02,  1.96457177e-01,
         -1.50599152e-01, -2.74054796e-01, -6.15260229e-02,  2.02351093e+00,
         -4.40336466e-01,  2.25521371e-01,  1.24245100e-02,  2.49176830e-01,
         -5.37308306e-02, -8.19697261e-01,  2.31

In [None]:
from sklearn.metrics import precision_recall_fscore_support

In [None]:
def batch_metric(batch):
    labels = batch["label"]
    preds = batch["prediction"]
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="weighted", zero_division=0)
    return {"precision": [precision], "recall": [recall], "f1": [f1], "count": [1]}

In [None]:
# Calculate metrics
metrics_ds = pred_ds.map_batches(batch_metric)
aggregate_metrics = metrics_ds.sum(["precision", "recall", "f1", "count"])
precision = aggregate_metrics["sum(precision)"] / aggregate_metrics["sum(count)"]
recall = aggregate_metrics["sum(recall)"] / aggregate_metrics["sum(count)"]
f1 = aggregate_metrics["sum(f1)"] / aggregate_metrics["sum(count)"]

2025-04-02 05:49:35,580	INFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-04-02_01-36-15_100194_2258/logs/ray-data
2025-04-02 05:49:35,581	INFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ListFiles] -> TaskPoolMapOperator[ReadFiles] -> TaskPoolMapOperator[Map(add_class)->Map(Preprocessor.convert_to_label)] -> ActorPoolMapOperator[MapBatches(EmbeddingGenerator)] -> TaskPoolMapOperator[MapBatches(drop_columns)] -> TaskPoolMapOperator[MapBatches(TorchPredictor)] -> TaskPoolMapOperator[MapBatches(batch_metric)] -> AllToAllOperator[Aggregate] -> LimitOperator[limit=1]


Running 0: 0.00 row [00:00, ? row/s]

- ListFiles 1: 0.00 row [00:00, ? row/s]

- ReadFiles 2: 0.00 row [00:00, ? row/s]

- Map(add_class)->Map(Preprocessor.convert_to_label) 3: 0.00 row [00:00, ? row/s]

- MapBatches(EmbeddingGenerator) 4: 0.00 row [00:00, ? row/s]

- MapBatches(drop_columns) 5: 0.00 row [00:00, ? row/s]

- MapBatches(TorchPredictor) 6: 0.00 row [00:00, ? row/s]

- MapBatches(batch_metric) 7: 0.00 row [00:00, ? row/s]

- Aggregate 8: 0.00 row [00:00, ? row/s]

Sort Sample 9:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Map 10:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 11:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- limit=1 12: 0.00 row [00:00, ? row/s]



[36m(autoscaler +57m2s)[0m [autoscaler] Downscaling node i-0fbb05b1d1e6d2559 (node IP: 10.0.107.124) due to node idle termination.
[36m(autoscaler +57m2s)[0m [autoscaler] Downscaling node i-0155b391f7ed6818b (node IP: 10.0.127.226) due to node idle termination.
[36m(autoscaler +57m2s)[0m [autoscaler] Cluster resized to {48 CPU, 4 GPU}.


In [None]:
print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")
print(f"F1: {f1:.2f}")

Precision: 0.91
Recall: 0.83
F1: 0.85


[36m(autoscaler +1h11m57s)[0m [autoscaler] Downscaling node i-0a05c37a45a557c0d (node IP: 10.0.99.208) due to node idle termination.


In [None]:
import IPython
IPython.get_ipython().kernel.do_shutdown(restart=True)