# Distributed training

<div align="left">
<a target="_blank" href="https://console.anyscale.com/"><img src="https://img.shields.io/badge/🚀 Run_on-Anyscale-9hf"></a>&nbsp;
<a href="https://github.com/anyscale/foundational-ray-app" role="button"><img src="https://img.shields.io/static/v1?label=&amp;message=View%20On%20GitHub&amp;color=586069&amp;logo=github&amp;labelColor=2f363d"></a>&nbsp;
</div>

In this tutorial, we'll execute a distributed training workload that will connect the following heterogenous workloads:
- preprocess the dataset prior to training
- distributed training with Ray Train and PyTorch (with observability)
- evaluation (batch inference + eval logic)
- save model artifacts to a model registry (MLOps)

**Note**: we won't be tuning our model in this tutorial but be sure to check out [Ray Tune](https://docs.ray.io/en/latest/tune/index.html) for experiment execution and hyperparameter tuning at any scale.

<img src="https://raw.githubusercontent.com/anyscale/foundational-ray-app/refs/heads/main/images/distributed_training.png" width=800>

In [None]:
%load_ext autoreload
%autoreload all

In [None]:
import os
import ray
import sys
sys.path.append(os.path.abspath(".."))

In [None]:
# Enable Ray Train v2 (it's too good to wait for public release!)
ray.init(
    runtime_env={
        "env_vars": {"RAY_TRAIN_V2_ENABLED": "1"}, 
        "working_dir": "/home/ray/default",  # to import doggos (default working_dir=".")
    },
)

2025-04-03 12:05:01,633	INFO worker.py:1665 -- Connecting to existing Ray cluster at address: 10.0.60.75:6379...
2025-04-03 12:05:01,644	INFO worker.py:1849 -- Connected to Ray cluster. View the dashboard at [1m[32mhttps://session-fyhrc759flh928h7czptpn79mb.i.anyscaleuserdata.com [39m[22m
2025-04-03 12:05:01,690	INFO packaging.py:575 -- Creating a file package for local module '/home/ray/default'.
2025-04-03 12:05:01,747	INFO packaging.py:367 -- Pushing file package 'gcs://_ray_pkg_8d4e1742b75472ad.zip' (8.50MiB) to Ray cluster...
2025-04-03 12:05:01,786	INFO packaging.py:380 -- Successfully pushed file package 'gcs://_ray_pkg_8d4e1742b75472ad.zip'.


0,1
Python version:,3.12.9
Ray version:,3.0.0.dev0
Dashboard:,http://session-fyhrc759flh928h7czptpn79mb.i.anyscaleuserdata.com


In [None]:
%%bash
# This will be removed once Ray Train v2 is part of latest Ray version
echo "RAY_TRAIN_V2_ENABLED=1" > /home/ray/default/.env

In [None]:
# Load env vars in notebooks
from dotenv import load_dotenv
load_dotenv()

True

### Preprocess

We need to convert our classes to labels (unique integers) so that we can train a classifier that can correctly predict the class given an input image. But before we do this, we'll quickly apply the same data ingestion and preprocessing as the previous notebook.

In [None]:
def add_class(row):
    row["class"] = row["path"].rsplit("/", 3)[-2]
    return row

In [None]:
# Preprocess data splits
train_ds = ray.data.read_images("s3://doggos-dataset/train", include_paths=True, shuffle="files")
train_ds = train_ds.map(add_class)
val_ds = ray.data.read_images("s3://doggos-dataset/val", include_paths=True)
val_ds = val_ds.map(add_class)

We'll define a `Preprocessor` class that will:
- create an embedding, we will move the embedding layer outside of the model since we will freeze the embedding layer's weights and so we don't have to do it repeatedly as part of the model's forward pass (unecessary compute)
- convert our classes into labels for the classifier. 

While we could've just done this as a simple operation, we're taking the time to organize it as a class so that we can save and load for inference later.

In [None]:
from doggos.embed import EmbeddingGenerator

In [None]:
class Preprocessor:
    """Preprocessor class."""
    def __init__(self, class_to_label=None):
        self.class_to_label = class_to_label or {}  # mutable defaults
        self.label_to_class = {v: k for k, v in self.class_to_label.items()}
        
    def fit(self, ds, column):
        self.classes = ds.unique(column=column)
        self.class_to_label = {tag: i for i, tag in enumerate(self.classes)}
        self.label_to_class = {v: k for k, v in self.class_to_label.items()}
        return self

    def convert_to_label(self, row, class_to_label):
        if "class" in row:
            row["label"] = class_to_label[row["class"]]
        return row
    
    def transform(self, ds, concurrency=4, batch_size=64, num_gpus=1):
        ds = ds.map(
            self.convert_to_label, 
            fn_kwargs={"class_to_label": self.class_to_label},
        )
        ds = ds.map_batches(
            EmbeddingGenerator,
            fn_constructor_kwargs={"model_id": "openai/clip-vit-base-patch32"},
            fn_kwargs={"device": "cuda"}, 
            concurrency=concurrency, 
            batch_size=batch_size,
            num_gpus=num_gpus,
        )
        ds = ds.drop_columns(["image"])
        return ds

    def save(self, fp):
        with open(fp, "w") as f:
            json.dump(self.class_to_label, f)

In [None]:
# Preprocess
preprocessor = Preprocessor()
preprocessor = preprocessor.fit(train_ds, column="class")
train_ds = preprocessor.transform(ds=train_ds)
val_ds = preprocessor.transform(ds=val_ds)
train_ds.take(1)

2025-04-03 12:05:09,005	INFO dataset.py:2798 -- Tip: Use `take_batch()` instead of `take() / show()` to return records in pandas or numpy batch format.


2025-04-03 12:05:09,029	INFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-04-03_11-37-08_695372_142551/logs/ray-data
2025-04-03 12:05:09,030	INFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadImage->Map(add_class)] -> AllToAllOperator[Aggregate] -> LimitOperator[limit=1]


Running 0: 0.00 row [00:00, ? row/s]

- ReadImage->Map(add_class) 1: 0.00 row [00:00, ? row/s]

- Aggregate 2: 0.00 row [00:00, ? row/s]

Sort Sample 3:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Map 4:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 5:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- limit=1 6: 0.00 row [00:00, ? row/s]

[36m(autoscaler +15s)[0m Tip: use `ray status` to view detailed cluster status. To disable these messages, set RAY_SCHEDULER_EVENTS=0.
[36m(autoscaler +15s)[0m [autoscaler] [4xT4:48CPU-192GB] Attempting to add 1 node(s) to the cluster (increasing from 3 to 4).


2025-04-03 12:05:19,100	INFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-04-03_11-37-08_695372_142551/logs/ray-data
2025-04-03 12:05:19,101	INFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadImage->Map(add_class)->Map(Preprocessor.convert_to_label)] -> ActorPoolMapOperator[MapBatches(EmbeddingGenerator)] -> TaskPoolMapOperator[MapBatches(drop_columns)] -> LimitOperator[limit=1]


Running 0: 0.00 row [00:00, ? row/s]

[36m(autoscaler +20s)[0m [autoscaler] [4xT4:48CPU-192GB] Launched 1 instances.


- ReadImage->Map(add_class)->Map(Preprocessor.convert_to_label) 1: 0.00 row [00:00, ? row/s]

- MapBatches(EmbeddingGenerator) 2: 0.00 row [00:00, ? row/s]

- MapBatches(drop_columns) 3: 0.00 row [00:00, ? row/s]

- limit=1 4: 0.00 row [00:00, ? row/s]



[{'path': 'doggos-dataset/train/basset/basset_9900.jpg',
  'class': 'basset',
  'label': 23,
  'embedding': array([-2.07168400e-01,  1.72517151e-01,  1.06291771e-02,  3.91304970e-01,
          4.22165036e-01, -2.29526266e-01,  4.86278594e-01, -7.37224400e-01,
         -2.12697357e-01,  9.64940041e-02, -4.45656598e-01,  1.91988591e-02,
          1.59223646e-01,  4.70588207e-02,  5.04390359e-01,  4.46884111e-02,
          8.77077401e-01, -1.18519142e-01, -2.73600221e-02,  1.21952325e-01,
         -1.81658298e-01, -8.20441172e-02,  4.58503455e-01, -2.75700241e-01,
         -1.21452257e-01, -1.05347462e-01,  5.08147657e-01,  8.92426074e-02,
         -8.58309567e-02,  1.97094947e-01,  2.54393816e-01,  2.57087588e-01,
         -7.34195113e-04,  7.72692822e-03,  3.71548086e-01,  1.72115996e-01,
          4.44463849e-01, -3.11355114e-01, -1.99511334e-01,  1.66943169e+00,
         -6.59029603e-01, -3.57044078e-02,  1.50953978e-01,  3.47819507e-01,
          2.39400923e-01,  4.90695834e-01,  2.7

<div class="alert alert-block alert"> <b> Data Processing</b> 

Be sure to checkout this extensive guide on [data loading and preprocessing](https://docs.ray.io/en/latest/train/user-guides/data-loading-preprocessing.html) for the last-mile preprocessing we'll need to do prior to training our models. However, Ray Data does support performant joins, filters, aggregations, etc. for the more structure data processing your workloads may need.

<div class="alert alert-block alert"> <b> Store often, Save compute</b> 

We're going to now store our preprocessed data into shared cloud storage because we want to:
- save a record of what this preprocessed data looks like
- avoid triggering the entire preprocessing for each batch our model will process
- don't want to [`materialize`](https://docs.ray.io/en/latest/data/api/doc/ray.data.Dataset.materialize.html) the preprocessed data either (shouldn't force large data to fit in memory)

In [None]:
from doggos.utils import delete_s3_objects

In [None]:
# Write processed data to cloud storage
preprocessed_data_path = os.path.join(
    os.getenv("ANYSCALE_ARTIFACT_STORAGE", ""), 
    os.getenv("ANYSCALE_USERNAME", "").replace(" ", "_"), 
    "doggos/preprocessed_data",
)
delete_s3_objects(s3_path=preprocessed_data_path)
preprocessed_train_path = os.path.join(preprocessed_data_path, "preprocessed_train")
preprocessed_val_path = os.path.join(preprocessed_data_path, "preprocessed_val")
train_ds.write_parquet(preprocessed_train_path)
val_ds.write_parquet(preprocessed_val_path)

Deleted 65 objects from s3://anyscale-test-data-cld-i2w99rzq8b6lbjkke9y94vi5/org_7c1Kalm9WcX2bNIjW53GUT/cld_kvedZWag2qA8i5BjxUevf5i7/artifact_storage/goku_mohandas/doggos/preprocessed_data


2025-04-03 12:05:42,432	INFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-04-03_11-37-08_695372_142551/logs/ray-data
2025-04-03 12:05:42,433	INFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadImage->Map(add_class)->Map(Preprocessor.convert_to_label)] -> ActorPoolMapOperator[MapBatches(EmbeddingGenerator)] -> TaskPoolMapOperator[MapBatches(drop_columns)->Write]


Running 0: 0.00 row [00:00, ? row/s]

- ReadImage->Map(add_class)->Map(Preprocessor.convert_to_label) 1: 0.00 row [00:00, ? row/s]

- MapBatches(EmbeddingGenerator) 2: 0.00 row [00:00, ? row/s]

- MapBatches(drop_columns)->Write 3: 0.00 row [00:00, ? row/s]

2025-04-03 12:06:00,911	INFO dataset.py:4167 -- Data sink Parquet finished. 2880 rows and 5.9MB data written.
2025-04-03 12:06:01,145	INFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-04-03_11-37-08_695372_142551/logs/ray-data
2025-04-03 12:06:01,146	INFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadImage->Map(add_class)->Map(Preprocessor.convert_to_label)] -> ActorPoolMapOperator[MapBatches(EmbeddingGenerator)] -> TaskPoolMapOperator[MapBatches(drop_columns)->Write]


Running 0: 0.00 row [00:00, ? row/s]

- ReadImage->Map(add_class)->Map(Preprocessor.convert_to_label) 1: 0.00 row [00:00, ? row/s]

- MapBatches(EmbeddingGenerator) 2: 0.00 row [00:00, ? row/s]

- MapBatches(drop_columns)->Write 3: 0.00 row [00:00, ? row/s]

[36m(autoscaler +1m15s)[0m [autoscaler] Cluster upscaled to {192 CPU, 16 GPU}.


2025-04-03 12:06:29,894	INFO dataset.py:4167 -- Data sink Parquet finished. 720 rows and 1.5MB data written.


### Model

Let's define our model -- a simple two layer neural net with softmax layer to predict class probabilities. You'll notice that it's all just base PyTorch and nothing else.

In [None]:
import json
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class ClassificationModel(torch.nn.Module):
    def __init__(self, embedding_dim, hidden_dim, dropout_p, num_classes):
        super().__init__()
        # Hyperparameters
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.dropout_p = dropout_p
        self.num_classes = num_classes

        # Define layers
        self.fc1 = nn.Linear(embedding_dim, hidden_dim)
        self.batch_norm = nn.BatchNorm1d(hidden_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout_p)
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, batch):
        z = self.fc1(batch["embedding"])
        z = self.batch_norm(z)
        z = self.relu(z)
        z = self.dropout(z)
        z = self.fc2(z)
        return z

    @torch.inference_mode()
    def predict(self, batch):
        z = self(batch)
        y_pred = torch.argmax(z, dim=1).cpu().numpy()
        return y_pred

    @torch.inference_mode()
    def predict_probabilities(self, batch):
        z = self(batch)
        y_probs = F.softmax(z, dim=1).cpu().numpy()
        return y_probs

    def save(self, dp):
        Path(dp).mkdir(parents=True, exist_ok=True)
        with open(Path(dp, "args.json"), "w") as fp:
            json.dump({
                "embedding_dim": self.embedding_dim,
                "hidden_dim": self.hidden_dim,
                "dropout_p": self.dropout_p,
                "num_classes": self.num_classes,
            }, fp, indent=4)
        torch.save(self.state_dict(), Path(dp, "model.pt"))

    @classmethod
    def load(cls, args_fp, state_dict_fp, device="cpu"):
        with open(args_fp, "r") as fp:
            model = cls(**json.load(fp))
        model.load_state_dict(torch.load(state_dict_fp, map_location=device))
        return model

In [None]:
# Initialize model
num_classes = len(preprocessor.classes)
model = ClassificationModel(
    embedding_dim=512, 
    hidden_dim=256, 
    dropout_p=0.3, 
    num_classes=num_classes,
)
print (model)

ClassificationModel(
  (fc1): Linear(in_features=512, out_features=256, bias=True)
  (batch_norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (fc2): Linear(in_features=256, out_features=36, bias=True)
)


### Batching

Let's take a look at a sample batch of data and ensure that tensors of the proper data type.

In [None]:
from ray.train.torch import get_device

In [None]:
def collate_fn(batch):
    dtypes = {"embedding": torch.float32, "label": torch.int64}
    tensor_batch = {}
    for key in dtypes.keys():
        if key in batch:
            tensor_batch[key] = torch.as_tensor(
                batch[key],
                dtype=dtypes[key],
                device=get_device(),
            )
    return tensor_batch

In [None]:
# Sample batch
sample_batch = train_ds.take_batch(batch_size=3)
collate_fn(batch=sample_batch)

2025-04-03 12:06:30,677	INFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-04-03_11-37-08_695372_142551/logs/ray-data
2025-04-03 12:06:30,678	INFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadImage->Map(add_class)->Map(Preprocessor.convert_to_label)] -> ActorPoolMapOperator[MapBatches(EmbeddingGenerator)] -> TaskPoolMapOperator[MapBatches(drop_columns)] -> LimitOperator[limit=3]


Running 0: 0.00 row [00:00, ? row/s]

- ReadImage->Map(add_class)->Map(Preprocessor.convert_to_label) 1: 0.00 row [00:00, ? row/s]

- MapBatches(EmbeddingGenerator) 2: 0.00 row [00:00, ? row/s]

- MapBatches(drop_columns) 3: 0.00 row [00:00, ? row/s]

- limit=3 4: 0.00 row [00:00, ? row/s]

  tensor_batch[key] = torch.as_tensor(


{'embedding': tensor([[-0.0119,  0.3011,  0.1426,  ...,  0.5759, -0.0689,  0.1088],
         [-0.2397,  0.1221,  0.3727,  ...,  0.5081, -0.1322,  0.3646],
         [ 0.2526,  0.2565, -0.4290,  ...,  0.4673,  0.4384,  0.2572]]),
 'label': tensor([10, 33, 26])}

### Model registry

We'll be creating a model registry in our [Anyscale user storage](https://docs.anyscale.com/configuration/storage/#user-storage) to save our model checkpoints to. We'll be using OSS mlflow but we can easily [set up other experiment trackers](https://docs.ray.io/en/latest/train/user-guides/experiment-tracking.html) with Ray.

In [None]:
import shutil

In [None]:
model_registry = "/mnt/user_storage/mlflow/doggos"
os.path.isdir(model_registry) and shutil.rmtree(model_registry)  # clean up
os.makedirs(model_registry, exist_ok=True)

### Training

We'll define our training workload by specifying our:
- experiment and model parameters
- compute scaling configuration
- forward pass for batches of training and validation data
- train loop for each epoch of data (and checkpointing)

<img src="https://raw.githubusercontent.com/anyscale/foundational-ray-app/refs/heads/main/images/trainer.png" width=500>

In [None]:
# Train loop config
experiment_name = "doggos"
train_loop_config = {
    "model_registry": model_registry,
    "experiment_name": experiment_name,
    "embedding_dim": 512,
    "hidden_dim": 256,
    "dropout_p": 0.3,
    "lr": 1e-3,
    "lr_factor": 0.8,
    "lr_patience": 3,
    "num_epochs": 20,
    "batch_size": 256,
}

In [None]:
# Scaling config
num_workers = 2
scaling_config = ray.train.ScalingConfig(
    num_workers=num_workers,
    use_gpu=True,
    resources_per_worker={"CPU": 8, "GPU": 2})

In [None]:
import tempfile
import mlflow
import numpy as np
from ray.train.torch import TorchTrainer

In [None]:
def train_epoch(ds, batch_size, model, num_classes, loss_fn, optimizer):
    model.train()
    loss = 0.0
    ds_generator = ds.iter_torch_batches(batch_size=batch_size, collate_fn=collate_fn)
    for i, batch in enumerate(ds_generator):
        optimizer.zero_grad()  # reset gradients
        z = model(batch)  # forward pass
        targets = F.one_hot(batch["label"], num_classes=num_classes).float()
        J = loss_fn(z, targets)  # define loss
        J.backward()  # backward pass
        optimizer.step()  # update weights
        loss += (J.detach().item() - loss) / (i + 1)  # cumulative loss
    return loss

In [None]:
def eval_epoch(ds, batch_size, model, num_classes, loss_fn):
    model.eval()
    loss = 0.0
    y_trues, y_preds = [], []
    ds_generator = ds.iter_torch_batches(batch_size=batch_size, collate_fn=collate_fn)
    with torch.inference_mode():
        for i, batch in enumerate(ds_generator):
            z = model(batch)
            targets = F.one_hot(batch["label"], num_classes=num_classes).float()  # one-hot (for loss_fn)
            J = loss_fn(z, targets).item()
            loss += (J - loss) / (i + 1)
            y_trues.extend(batch["label"].cpu().numpy())
            y_preds.extend(torch.argmax(z, dim=1).cpu().numpy())
    return loss, np.vstack(y_trues), np.vstack(y_preds)

In [None]:
def train_loop_per_worker(config):
    # Hyperparameters
    model_registry = config["model_registry"]
    experiment_name = config["experiment_name"]
    embedding_dim = config["embedding_dim"]
    hidden_dim = config["hidden_dim"]
    dropout_p = config["dropout_p"]
    lr = config["lr"]
    lr_factor = config["lr_factor"]
    lr_patience = config["lr_patience"]
    num_epochs = config["num_epochs"]
    batch_size = config["batch_size"]
    num_classes = config["num_classes"]

    # Experiment tracking
    if ray.train.get_context().get_world_rank() == 0:
        mlflow.set_tracking_uri(f"file:{model_registry}")
        mlflow.set_experiment(experiment_name)
        mlflow.start_run()
        mlflow.log_params(config)

    # Datasets
    train_ds = ray.train.get_dataset_shard("train")
    val_ds = ray.train.get_dataset_shard("val")

    # Model
    model = ClassificationModel(
        embedding_dim=embedding_dim, 
        hidden_dim=hidden_dim, 
        dropout_p=dropout_p, 
        num_classes=num_classes,
    )
    model = ray.train.torch.prepare_model(model)

    # Training components
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        mode="min", 
        factor=lr_factor, 
        patience=lr_patience,
    )

    # Training
    best_val_loss = float("inf")
    for epoch in range(num_epochs):
        # Steps
        train_loss = train_epoch(train_ds, batch_size, model, num_classes, loss_fn, optimizer)
        val_loss, _, _ = eval_epoch(val_ds, batch_size, model, num_classes, loss_fn)
        scheduler.step(val_loss)

        # Checkpoint (metrics, preprocessor and model artifacts)
        with tempfile.TemporaryDirectory() as dp:
            model.module.save(dp=dp)
            metrics = dict(lr=optimizer.param_groups[0]["lr"], train_loss=train_loss, val_loss=val_loss)
            with open(os.path.join(dp, "class_to_label.json"), "w") as fp:
                json.dump(config["class_to_label"], fp, indent=4)
            if ray.train.get_context().get_world_rank() == 0:  # only on main worker 0
                mlflow.log_metrics(metrics, step=epoch)
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    mlflow.log_artifacts(dp)

    # End experiment tracking
    if ray.train.get_context().get_world_rank() == 0:
        mlflow.end_run()

<div class="alert alert-block alert"> <b> Minimal change to your training code</b> 

You'll notice that there isn't much new Ray Train code on top of our base PyTorch code. We specified how we want to scale out our training workload, load the Ray datasets and then checkpoint on our main worker node... and that's it! Check out these guides ([PyTorch](https://docs.ray.io/en/latest/train/getting-started-pytorch.html), [PyTorch Lightning](https://docs.ray.io/en/latest/train/getting-started-pytorch-lightning.html), [HuggingFace Transformers](https://docs.ray.io/en/latest/train/getting-started-transformers.html)) to see the minimal delta code needed to distribute our training workloads and check out this extensive list of [Ray Train user guides](https://docs.ray.io/en/latest/train/user-guides.html).

In [None]:
# Load preprocessed datasets
preprocessed_train_ds = ray.data.read_parquet(preprocessed_train_path)
preprocessed_val_ds = ray.data.read_parquet(preprocessed_val_path)

Metadata Fetch Progress 0:   0%|          | 0.00/8.00 [00:00<?, ? task/s]

Parquet Files Sample 0:   0%|          | 0.00/2.00 [00:00<?, ? file/s]

Parquet Files Sample 0:   0%|          | 0.00/2.00 [00:00<?, ? file/s]

In [None]:
# Trainer
train_loop_config["class_to_label"] = preprocessor.class_to_label
train_loop_config["num_classes"] = len(preprocessor.class_to_label)
trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config=train_loop_config,
    scaling_config=scaling_config,
    datasets={"train": preprocessed_train_ds, "val": preprocessed_val_ds},
)

<div class="alert alert-block alert"> <b> Ray Train</b> 

**🎛️ Multi-node orchestration made easy**

- Ray Train automatically handles multi-node, multi-GPU setup with no manual SSH setup or hostfile configs. 
- And it also integrates with Ray's cluster launcher for cloud (AWS, GCP, K8s) and on-prem clusters. 
- Solutions like PyTorch DDP require manually setting up your own process group, ranks, networking, etc.

**🩹 2. Built-in fault tolerance**
- Ray Train supports automatic retry of failed workers.
- and can continue training from the last checkpoint in case of failure.


**✂️ 3. Flexible training strategies** (not just DDP)
- Ray Train supports Data Parallel, Model Parallel, Parameter Server, and even custom strategies.
- You can also use Torch DDP, FSPD, DeepSpeed, etc. under the hood if you want.
- [Ray Compiled graphs](https://docs.ray.io/en/latest/ray-core/compiled-graph/ray-compiled-graph.html) allow us to even define different parallelism for jointly optimizing multipe models (Megatron, Deepspeed, etc. only allow for one global setting).

**🔥 Better support for heterogeneous clusters**
- Ray Train lets you define per-worker resource requirements (e.g., 2 CPUs and 1 GPU per worker).
- and can run on heterogeneous machines and scale flexibly (e.g., CPU for preprocessing and GPU for training)

**🌍 Integrations**

<img src="https://raw.githubusercontent.com/anyscale/foundational-ray-app/refs/heads/main/images/train_integrations.png" width=500>

[RayTurbo Train](https://docs.anyscale.com/rayturbo/rayturbo-train) offers even more improvement to the price-performance ratio, performance monitoring and more:
- **elastic training** to scale to a dynamic number of workers, continue training on fewer resources (even on spot instances).
- **purpose-built dashboard** designed to streamline the debugging of Ray Train workloads
    - Monitoring: View the status of training runs and train workers.
    - Metrics: See insights on training throughput, training system operation time.
    - Profiling: Investigate bottlenecks, hangs, or errors from individual training worker processes.

<img src="https://raw.githubusercontent.com/anyscale/foundational-ray-app/refs/heads/main/images/train_dashboard.png" width=700>

In [None]:
# Train
results = trainer.fit()

[36m(TrainController pid=157502)[0m Attempting to start training worker group of size 4 with the following resources: [{'CPU': 8, 'GPU': 1}] * 4




[36m(RayTrainWorker pid=34863, ip=10.0.25.230)[0m Setting up process group for: env:// [rank=0, world_size=4]
[36m(RayTrainWorker pid=34863, ip=10.0.25.230)[0m 2025/04/03 12:07:08 INFO mlflow.tracking.fluent: Experiment with name 'doggos' does not exist. Creating a new experiment.
[36m(RayTrainWorker pid=34863, ip=10.0.25.230)[0m Moving model to device: cuda:0
[36m(RayTrainWorker pid=34863, ip=10.0.25.230)[0m Wrapping provided model in DistributedDataParallel.
[36m(SplitCoordinator pid=157639)[0m Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-04-03_11-37-08_695372_142551/logs/ray-data
[36m(SplitCoordinator pid=157639)[0m Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadParquet] -> OutputSplitter[split(4, equal=True)]


(pid=157639) Running 0: 0.00 row [00:00, ? row/s]

(pid=157639) - ReadParquet->SplitBlocks(8) 1: 0.00 row [00:00, ? row/s]

(pid=157639) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]





(pid=157681) Running 0: 0.00 row [00:00, ? row/s]

(pid=157681) - ReadParquet->SplitBlocks(32) 1: 0.00 row [00:00, ? row/s]

(pid=157681) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

[36m(SplitCoordinator pid=157639)[0m Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-04-03_11-37-08_695372_142551/logs/ray-data[32m [repeated 2x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)[0m
[36m(SplitCoordinator pid=157639)[0m Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadParquet] -> OutputSplitter[split(4, equal=True)][32m [repeated 2x across cluster][0m


(pid=157639) Running 0: 0.00 row [00:00, ? row/s]

(pid=157639) - ReadParquet->SplitBlocks(8) 1: 0.00 row [00:00, ? row/s]

(pid=157639) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

(pid=157681) Running 0: 0.00 row [00:00, ? row/s]





(pid=157681) - ReadParquet->SplitBlocks(32) 1: 0.00 row [00:00, ? row/s]

(pid=157681) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

(pid=157639) Running 0: 0.00 row [00:00, ? row/s]

(pid=157639) - ReadParquet->SplitBlocks(8) 1: 0.00 row [00:00, ? row/s]

(pid=157639) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

(pid=157681) Running 0: 0.00 row [00:00, ? row/s]

(pid=157681) - ReadParquet->SplitBlocks(32) 1: 0.00 row [00:00, ? row/s]

(pid=157681) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

[36m(SplitCoordinator pid=157681)[0m Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-04-03_11-37-08_695372_142551/logs/ray-data[32m [repeated 3x across cluster][0m
[36m(SplitCoordinator pid=157681)[0m Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadParquet] -> OutputSplitter[split(4, equal=True)][32m [repeated 3x across cluster][0m


(pid=157639) Running 0: 0.00 row [00:00, ? row/s]

(pid=157639) - ReadParquet->SplitBlocks(8) 1: 0.00 row [00:00, ? row/s]

(pid=157639) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

(pid=157681) Running 0: 0.00 row [00:00, ? row/s]

(pid=157681) - ReadParquet->SplitBlocks(32) 1: 0.00 row [00:00, ? row/s]

(pid=157681) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

(pid=157639) Running 0: 0.00 row [00:00, ? row/s]

(pid=157639) - ReadParquet->SplitBlocks(8) 1: 0.00 row [00:00, ? row/s]

(pid=157639) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

(pid=157681) Running 0: 0.00 row [00:00, ? row/s]

[36m(SplitCoordinator pid=157681)[0m Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-04-03_11-37-08_695372_142551/logs/ray-data[32m [repeated 4x across cluster][0m
[36m(SplitCoordinator pid=157681)[0m Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadParquet] -> OutputSplitter[split(4, equal=True)][32m [repeated 4x across cluster][0m


(pid=157681) - ReadParquet->SplitBlocks(32) 1: 0.00 row [00:00, ? row/s]

(pid=157681) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

(pid=157639) Running 0: 0.00 row [00:00, ? row/s]

(pid=157639) - ReadParquet->SplitBlocks(8) 1: 0.00 row [00:00, ? row/s]

(pid=157639) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

(pid=157681) Running 0: 0.00 row [00:00, ? row/s]

(pid=157681) - ReadParquet->SplitBlocks(32) 1: 0.00 row [00:00, ? row/s]

(pid=157681) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

[36m(SplitCoordinator pid=157639)[0m Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-04-03_11-37-08_695372_142551/logs/ray-data[32m [repeated 3x across cluster][0m
[36m(SplitCoordinator pid=157639)[0m Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadParquet] -> OutputSplitter[split(4, equal=True)][32m [repeated 3x across cluster][0m


(pid=157639) Running 0: 0.00 row [00:00, ? row/s]

(pid=157639) - ReadParquet->SplitBlocks(8) 1: 0.00 row [00:00, ? row/s]

(pid=157639) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

(pid=157681) Running 0: 0.00 row [00:00, ? row/s]

(pid=157681) - ReadParquet->SplitBlocks(32) 1: 0.00 row [00:00, ? row/s]

(pid=157681) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

(pid=157639) Running 0: 0.00 row [00:00, ? row/s]

(pid=157639) - ReadParquet->SplitBlocks(8) 1: 0.00 row [00:00, ? row/s]

(pid=157639) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

(pid=157681) Running 0: 0.00 row [00:00, ? row/s]

(pid=157681) - ReadParquet->SplitBlocks(32) 1: 0.00 row [00:00, ? row/s]

(pid=157681) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

[36m(SplitCoordinator pid=157681)[0m Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-04-03_11-37-08_695372_142551/logs/ray-data[32m [repeated 3x across cluster][0m
[36m(SplitCoordinator pid=157681)[0m Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadParquet] -> OutputSplitter[split(4, equal=True)][32m [repeated 3x across cluster][0m


(pid=157639) Running 0: 0.00 row [00:00, ? row/s]

(pid=157639) - ReadParquet->SplitBlocks(8) 1: 0.00 row [00:00, ? row/s]

(pid=157639) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

(pid=157681) Running 0: 0.00 row [00:00, ? row/s]

(pid=157681) - ReadParquet->SplitBlocks(32) 1: 0.00 row [00:00, ? row/s]

(pid=157681) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

(pid=157639) Running 0: 0.00 row [00:00, ? row/s]

(pid=157639) - ReadParquet->SplitBlocks(8) 1: 0.00 row [00:00, ? row/s]

(pid=157639) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

(pid=157681) Running 0: 0.00 row [00:00, ? row/s]

(pid=157681) - ReadParquet->SplitBlocks(32) 1: 0.00 row [00:00, ? row/s]

(pid=157681) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

[36m(SplitCoordinator pid=157681)[0m Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-04-03_11-37-08_695372_142551/logs/ray-data[32m [repeated 4x across cluster][0m
[36m(SplitCoordinator pid=157681)[0m Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadParquet] -> OutputSplitter[split(4, equal=True)][32m [repeated 4x across cluster][0m


(pid=157639) Running 0: 0.00 row [00:00, ? row/s]

(pid=157639) - ReadParquet->SplitBlocks(8) 1: 0.00 row [00:00, ? row/s]

(pid=157639) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

(pid=157681) Running 0: 0.00 row [00:00, ? row/s]

(pid=157681) - ReadParquet->SplitBlocks(32) 1: 0.00 row [00:00, ? row/s]

(pid=157681) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

(pid=157639) Running 0: 0.00 row [00:00, ? row/s]

(pid=157639) - ReadParquet->SplitBlocks(8) 1: 0.00 row [00:00, ? row/s]

(pid=157639) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

(pid=157681) Running 0: 0.00 row [00:00, ? row/s]

[36m(SplitCoordinator pid=157681)[0m Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-04-03_11-37-08_695372_142551/logs/ray-data[32m [repeated 4x across cluster][0m
[36m(SplitCoordinator pid=157681)[0m Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadParquet] -> OutputSplitter[split(4, equal=True)][32m [repeated 4x across cluster][0m


(pid=157681) - ReadParquet->SplitBlocks(32) 1: 0.00 row [00:00, ? row/s]

(pid=157681) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

(pid=157639) Running 0: 0.00 row [00:00, ? row/s]

(pid=157639) - ReadParquet->SplitBlocks(8) 1: 0.00 row [00:00, ? row/s]

(pid=157639) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

(pid=157681) Running 0: 0.00 row [00:00, ? row/s]

(pid=157681) - ReadParquet->SplitBlocks(32) 1: 0.00 row [00:00, ? row/s]

(pid=157681) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

[36m(SplitCoordinator pid=157639)[0m Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-04-03_11-37-08_695372_142551/logs/ray-data[32m [repeated 3x across cluster][0m
[36m(SplitCoordinator pid=157639)[0m Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadParquet] -> OutputSplitter[split(4, equal=True)][32m [repeated 3x across cluster][0m


(pid=157639) Running 0: 0.00 row [00:00, ? row/s]

(pid=157639) - ReadParquet->SplitBlocks(8) 1: 0.00 row [00:00, ? row/s]

(pid=157639) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

(pid=157681) Running 0: 0.00 row [00:00, ? row/s]

(pid=157681) - ReadParquet->SplitBlocks(32) 1: 0.00 row [00:00, ? row/s]

(pid=157681) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

(pid=157639) Running 0: 0.00 row [00:00, ? row/s]

(pid=157639) - ReadParquet->SplitBlocks(8) 1: 0.00 row [00:00, ? row/s]

(pid=157639) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

(pid=157681) Running 0: 0.00 row [00:00, ? row/s]

(pid=157681) - ReadParquet->SplitBlocks(32) 1: 0.00 row [00:00, ? row/s]

(pid=157681) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

[36m(SplitCoordinator pid=157639)[0m Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-04-03_11-37-08_695372_142551/logs/ray-data[32m [repeated 4x across cluster][0m
[36m(SplitCoordinator pid=157639)[0m Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadParquet] -> OutputSplitter[split(4, equal=True)][32m [repeated 4x across cluster][0m


(pid=157639) Running 0: 0.00 row [00:00, ? row/s]

(pid=157639) - ReadParquet->SplitBlocks(8) 1: 0.00 row [00:00, ? row/s]

(pid=157639) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

(pid=157681) Running 0: 0.00 row [00:00, ? row/s]

(pid=157681) - ReadParquet->SplitBlocks(32) 1: 0.00 row [00:00, ? row/s]

(pid=157681) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

(pid=157639) Running 0: 0.00 row [00:00, ? row/s]

(pid=157639) - ReadParquet->SplitBlocks(8) 1: 0.00 row [00:00, ? row/s]

(pid=157639) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

(pid=157681) Running 0: 0.00 row [00:00, ? row/s]

(pid=157681) - ReadParquet->SplitBlocks(32) 1: 0.00 row [00:00, ? row/s]

(pid=157681) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

[36m(SplitCoordinator pid=157639)[0m Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-04-03_11-37-08_695372_142551/logs/ray-data[32m [repeated 4x across cluster][0m
[36m(SplitCoordinator pid=157639)[0m Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadParquet] -> OutputSplitter[split(4, equal=True)][32m [repeated 4x across cluster][0m


(pid=157639) Running 0: 0.00 row [00:00, ? row/s]

(pid=157639) - ReadParquet->SplitBlocks(8) 1: 0.00 row [00:00, ? row/s]

(pid=157639) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

(pid=157681) Running 0: 0.00 row [00:00, ? row/s]

(pid=157681) - ReadParquet->SplitBlocks(32) 1: 0.00 row [00:00, ? row/s]

(pid=157681) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

(pid=157639) Running 0: 0.00 row [00:00, ? row/s]

(pid=157639) - ReadParquet->SplitBlocks(8) 1: 0.00 row [00:00, ? row/s]

(pid=157639) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

(pid=157681) Running 0: 0.00 row [00:00, ? row/s]

(pid=157681) - ReadParquet->SplitBlocks(32) 1: 0.00 row [00:00, ? row/s]

(pid=157681) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

[36m(SplitCoordinator pid=157639)[0m Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-04-03_11-37-08_695372_142551/logs/ray-data[32m [repeated 4x across cluster][0m
[36m(SplitCoordinator pid=157639)[0m Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadParquet] -> OutputSplitter[split(4, equal=True)][32m [repeated 4x across cluster][0m


(pid=157639) Running 0: 0.00 row [00:00, ? row/s]

(pid=157639) - ReadParquet->SplitBlocks(8) 1: 0.00 row [00:00, ? row/s]

(pid=157639) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

(pid=157681) Running 0: 0.00 row [00:00, ? row/s]

(pid=157681) - ReadParquet->SplitBlocks(32) 1: 0.00 row [00:00, ? row/s]

(pid=157681) - split(4, equal=True) 2: 0.00 row [00:00, ? row/s]

We can view our experiment metrics and model artifacts in our model registry. We're using OSS mlflow so we can run the server by pointing to our model registry location:

```bash
mlflow server -h 0.0.0.0 -p 8080 --backend-store-uri /mnt/user_storage/mlflow/doggos
```

We can view the dashboard by going to the **Overview tab** up top → **Open Ports**. 

<img src="https://raw.githubusercontent.com/anyscale/foundational-ray-app/refs/heads/main/images/mlflow.png" width=685>

We also have our Ray Dashboard and Train workfload specific dashboards above. 

<img src="https://raw.githubusercontent.com/anyscale/foundational-ray-app/refs/heads/main/images/train_metrics.png" width=700>


In [None]:
# Sorted runs
mlflow.set_tracking_uri(f"file:{model_registry}")
sorted_runs = mlflow.search_runs(
    experiment_names=[experiment_name], 
    order_by=["metrics.val_loss ASC"])
best_run = sorted_runs.iloc[0]
best_run

run_id                                      7a95842997b04a59ba9b7336fcae611d
experiment_id                                             986917759896566964
status                                                              FINISHED
artifact_uri               file:///mnt/user_storage/mlflow/doggos/9869177...
start_time                                  2025-04-03 19:07:08.787000+00:00
end_time                                    2025-04-03 19:08:19.551000+00:00
metrics.lr                                                             0.001
metrics.val_loss                                                    0.599019
metrics.train_loss                                                  0.387353
params.embedding_dim                                                     512
params.experiment_name                                                doggos
params.num_epochs                                                         20
params.model_registry                        /mnt/user_storage/mlflow/doggos

And we can easily wrap our training workload as a production grade [Anyscale Job](https://docs.anyscale.com/platform/jobs/) ([API ref](https://docs.anyscale.com/reference/job-api/))

**Note**: 
- we're using a `containerfile` to define our dependencies, but we could easily use a pre-built image as well.
- we can specify the compute as a [compute config](https://docs.anyscale.com/configuration/compute-configuration/) or inline in a [job config](https://docs.anyscale.com/reference/job-api#job-cli) file.
- when we don't specify compute and when launching from a workspace, this defaults to the compute configuration of the Workspace.

In [None]:
%%bash
# Production batch job
anyscale job submit --name=train-doggos-model \
  --containerfile="/home/ray/default/containerfile" \
  --working-dir="/home/ray/default" \
  --exclude="" \
  --max-retries=0 \
  -- python doggos/train.py

Output
(anyscale +2.9s) Submitting job with config JobConfig(name='train-doggos-model', image_uri=None, compute_config=None, env_vars=None, py_modules=None, py_executable=None, cloud=None, project=None, ray_version=None, job_queue_config=None).
(anyscale +4.3s) Building image. View it in the UI: https://console.anyscale.com/v2/container-images/apt_udgl8twk396mta6cawdqsr8ds7/versions/bld_jxtzi5qtafpn445rld8cydln88
(anyscale +48.3s) Waiting for image build to complete. Elapsed time: 42 seconds.
(anyscale +48.3s) Image build succeeded.
(anyscale +48.5s) Uploading local dir '/home/ray/default' to cloud storage.
(anyscale +50.3s) Including workspace-managed pip dependencies.
(anyscale +50.6s) Job 'train-doggos-model' submitted, ID: 'prodjob_t66b5774u83lwcugkw4999prj4'.
(anyscale +50.6s) View the job in the UI: https://console.anyscale.com/jobs/prodjob_t66b5774u83lwcugkw4999prj4
(anyscale +50.6s) Use `--wait` to wait for the job to run and stream logs.


<img src="https://raw.githubusercontent.com/anyscale/foundational-ray-app/refs/heads/main/images/train_job.png" width=700>

### Evaluation

We'll conclude by evaluating our trained model on our test dataset. Evaluation is essentially just the same as our batch inference workload -- where we'll apply the model on batches of data and then calculate metrics using the predictions vs.true labels. Ray data is hyper optimized for throughput so preserving order is not a priority. But for evaluation, this is crucial! So we'll achieve this by preserving the entire row and adding the predicted label as another column to each row.

In [None]:
from urllib.parse import urlparse
from sklearn.metrics import multilabel_confusion_matrix

In [None]:
class TorchPredictor:
    def __init__(self, preprocessor, model):
        self.preprocessor = preprocessor
        self.model = model
        self.model.eval()

    def __call__(self, batch, device="cuda"):
        self.model.to(device)
        batch["prediction"] = self.model.predict(collate_fn(batch))
        return batch

    def predict_probabilities(self, batch, device="cuda"):
        self.model.to(device)
        predicted_probabilities = self.model.predict_probabilities(collate_fn(batch))
        batch["probabilities"] = [
            {self.preprocessor.label_to_class[i]: prob for i, prob in enumerate(probabilities)}
            for probabilities in predicted_probabilities
        ]
        return batch
    
    @classmethod
    def from_artifacts_dir(cls, artifacts_dir):
        with open(os.path.join(artifacts_dir, "class_to_label.json"), "r") as fp:
            class_to_label = json.load(fp)
        preprocessor = Preprocessor(class_to_label=class_to_label)
        model = ClassificationModel.load(
            args_fp=os.path.join(artifacts_dir, "args.json"), 
            state_dict_fp=os.path.join(artifacts_dir, "model.pt"),
        )
        return cls(preprocessor=preprocessor, model=model)

In [None]:
# Load and preproces eval dataset
artifacts_dir = urlparse(best_run.artifact_uri).path
predictor = TorchPredictor.from_artifacts_dir(artifacts_dir=artifacts_dir)
test_ds = ray.data.read_images("s3://doggos-dataset/test", include_paths=True)
test_ds = test_ds.map(add_class)
test_ds = predictor.preprocessor.transform(ds=test_ds)

  model.load_state_dict(torch.load(state_dict_fp, map_location=device))


In [None]:
# y_pred (batch inference)
pred_ds = test_ds.map_batches(
    predictor,
    fn_kwargs={"device": "cuda"},
    concurrency=4,
    batch_size=64,
    num_gpus=1,
)
pred_ds.take(1)

2025-04-03 12:09:36,476	INFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-04-03_11-37-08_695372_142551/logs/ray-data
2025-04-03 12:09:36,477	INFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadImage->Map(add_class)->Map(Preprocessor.convert_to_label)] -> ActorPoolMapOperator[MapBatches(EmbeddingGenerator)] -> TaskPoolMapOperator[MapBatches(drop_columns)] -> TaskPoolMapOperator[MapBatches(TorchPredictor)] -> LimitOperator[limit=1]


Running 0: 0.00 row [00:00, ? row/s]

- ReadImage->Map(add_class)->Map(Preprocessor.convert_to_label) 1: 0.00 row [00:00, ? row/s]

- MapBatches(EmbeddingGenerator) 2: 0.00 row [00:00, ? row/s]

- MapBatches(drop_columns) 3: 0.00 row [00:00, ? row/s]

- MapBatches(TorchPredictor) 4: 0.00 row [00:00, ? row/s]

- limit=1 5: 0.00 row [00:00, ? row/s]



[{'path': 'doggos-dataset/test/bernese_mountain_dog/bernese_mountain_dog_4200.jpg',
  'class': 'bernese_mountain_dog',
  'label': 14,
  'embedding': array([ 1.21616945e-01,  2.49908268e-01, -9.42120180e-02, -1.01999938e-01,
          2.50217021e-01, -7.53532588e-01,  1.37271792e-01, -2.63006270e-01,
          4.47464064e-02,  4.74804267e-02, -6.01408541e-01, -1.59543782e-01,
          4.15230095e-01, -2.00223476e-02,  8.17479253e-01,  2.42886394e-01,
          3.10201585e-01, -1.92650370e-02,  3.53697807e-01, -1.75084591e-01,
         -7.65018463e-01,  1.78362057e-01,  4.84163761e-01, -5.49955904e-01,
          2.23341316e-01,  7.51222894e-02,  3.73767614e-01, -1.23567298e-01,
         -2.52771586e-01,  1.15096115e-01,  1.27928227e-01,  3.01279724e-02,
          8.95194411e-02, -6.99338242e-02,  5.82968295e-01,  1.54279739e-01,
          1.18630141e-01,  4.64450747e-01, -1.26711771e-01,  1.43133807e+00,
         -7.22994626e-01, -1.34238645e-01,  1.28221631e-01,  7.86256194e-02,
      

In [None]:
def batch_metric(batch):
    labels = batch["label"]
    preds = batch["prediction"]
    mcm = multilabel_confusion_matrix(labels, preds)
    tn, fp, fn, tp = [], [], [], []
    for i in range(mcm.shape[0]):
        tn.append(mcm[i, 0, 0])  # True negatives
        fp.append(mcm[i, 0, 1])  # False positives
        fn.append(mcm[i, 1, 0])  # False negatives
        tp.append(mcm[i, 1, 1])  # True positives
    return {"TN": tn, "FP": fp, "FN": fn, "TP": tp}


In [None]:
# Aggregated metrics after processing all batches
metrics_ds = pred_ds.map_batches(batch_metric)
aggregate_metrics = metrics_ds.sum(["TN", "FP", "FN", "TP"])

# Aggregate the confusion matrix components across all batches
tn = aggregate_metrics["sum(TN)"]
fp = aggregate_metrics["sum(FP)"]
fn = aggregate_metrics["sum(FN)"]
tp = aggregate_metrics["sum(TP)"]

# Calculate metrics
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
accuracy = (tp + tn) / (tp + tn + fp + fn)

2025-04-03 12:09:53,530	INFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-04-03_11-37-08_695372_142551/logs/ray-data
2025-04-03 12:09:53,531	INFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadImage->Map(add_class)->Map(Preprocessor.convert_to_label)] -> ActorPoolMapOperator[MapBatches(EmbeddingGenerator)] -> TaskPoolMapOperator[MapBatches(drop_columns)] -> TaskPoolMapOperator[MapBatches(TorchPredictor)] -> TaskPoolMapOperator[MapBatches(batch_metric)] -> AllToAllOperator[Aggregate] -> LimitOperator[limit=1]


Running 0: 0.00 row [00:00, ? row/s]

- ReadImage->Map(add_class)->Map(Preprocessor.convert_to_label) 1: 0.00 row [00:00, ? row/s]

- MapBatches(EmbeddingGenerator) 2: 0.00 row [00:00, ? row/s]

- MapBatches(drop_columns) 3: 0.00 row [00:00, ? row/s]

- MapBatches(TorchPredictor) 4: 0.00 row [00:00, ? row/s]

- MapBatches(batch_metric) 5: 0.00 row [00:00, ? row/s]

- Aggregate 6: 0.00 row [00:00, ? row/s]

Sort Sample 7:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Map 8:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 9:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- limit=1 10: 0.00 row [00:00, ? row/s]



In [None]:
print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")
print(f"F1: {f1:.2f}")
print(f"Accuracy: {accuracy:.2f}")

Precision: 0.84
Recall: 0.84
F1: 0.84
Accuracy: 0.96


In [None]:
import IPython
IPython.get_ipython().kernel.do_shutdown(restart=True)

{'status': 'ok', 'restart': True}

: 