# Training a Torch Image Classifier

Source: https://docs.ray.io/en/latest/ray-air/examples/torch_image_example.html

In [1]:
# Requirements
# !pip install 'ray[air]'
# !pip install requests torch torchvision

## Load and normalize CIFAR-10

In [1]:
import ray
import torchvision
import torchvision.transforms as transforms

if ray.is_initialized():
    ray.shutdown()
ray.init()

train_dataset = torchvision.datasets.CIFAR10("data", download=True, train=True)
test_dataset = torchvision.datasets.CIFAR10("data", download=True, train=False)

train_dataset: ray.data.Dataset = ray.data.from_torch(train_dataset)
test_dataset: ray.data.Dataset = ray.data.from_torch(test_dataset)

2023-09-05 21:53:50,487	INFO worker.py:1612 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m


Files already downloaded and verified
Files already downloaded and verified


Next, let’s represent our data using a dictionary of ndarrays instead of tuples. 

In [2]:
from typing import Dict, Tuple
import numpy as np
from PIL.Image import Image
import torch


def convert_batch_to_numpy(batch) -> Dict[str, np.ndarray]:
    images = np.stack([np.array(image) for image, _ in batch["item"]])
    labels = np.array([label for _, label in batch["item"]])
    return {"image": images, "label": labels}


train_dataset = train_dataset.map_batches(convert_batch_to_numpy).materialize()
test_dataset = test_dataset.map_batches(convert_batch_to_numpy).materialize()

2023-09-05 21:54:09,468	INFO streaming_executor.py:92 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(convert_batch_to_numpy)]
2023-09-05 21:54:09,469	INFO streaming_executor.py:93 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)
2023-09-05 21:54:09,470	INFO streaming_executor.py:95 -- Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`


Running 0:   0%|          | 0/200 [00:00<?, ?it/s]

2023-09-05 21:54:10,678	INFO streaming_executor.py:92 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(convert_batch_to_numpy)]
2023-09-05 21:54:10,678	INFO streaming_executor.py:93 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)
2023-09-05 21:54:10,678	INFO streaming_executor.py:95 -- Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`


Running 0:   0%|          | 0/200 [00:00<?, ?it/s]



In [3]:
train_dataset

MaterializedDataset(
   num_blocks=200,
   num_rows=50000,
   schema={image: numpy.ndarray(shape=(32, 32, 3), dtype=uint8), label: int64}
)

## Train a convolutional neural network

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [5]:
from ray import train
from ray.air import session, Checkpoint
from ray.train.torch import TorchCheckpoint
import torch.nn as nn
import torch.optim as optim
import torchvision


def train_loop_per_worker(config):
    model = train.torch.prepare_model(Net())

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    train_dataset_shard = session.get_dataset_shard("train")

    for epoch in range(2):
        running_loss = 0.0
        train_dataset_batches = train_dataset_shard.iter_torch_batches(
            batch_size=config["batch_size"],
        )
        for i, batch in enumerate(train_dataset_batches):
            # get the inputs and labels
            inputs, labels = batch["image"], batch["label"]

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 2000 == 1999:  # print every 2000 mini-batches
                print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
                running_loss = 0.0

        metrics = dict(running_loss=running_loss)
        checkpoint = TorchCheckpoint.from_state_dict(model.state_dict())
        session.report(metrics, checkpoint=checkpoint)

In [6]:
from ray.data.preprocessors import TorchVisionPreprocessor

transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
preprocessor = TorchVisionPreprocessor(columns=["image"], transform=transform)

In [7]:
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig

trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config={"batch_size": 2},
    datasets={"train": train_dataset},
    scaling_config=ScalingConfig(num_workers=2),
    preprocessor=preprocessor
)
result = trainer.fit()
latest_checkpoint = result.checkpoint

0,1
Current time:,2023-09-05 21:55:55
Running for:,00:00:57.55
Memory:,9.7/16.0 GiB

Trial name,status,loc,iter,total time (s),running_loss
TorchTrainer_2eba9_00000,TERMINATED,127.0.0.1:27979,2,53.3442,625.659


[2m[36m(TorchTrainer pid=27979)[0m The `preprocessor` arg to Trainer is deprecated. Apply preprocessor transformations ahead of time by calling `preprocessor.transform(ds)`. Support for the preprocessor arg will be dropped in a future release.
[2m[36m(TorchTrainer pid=27979)[0m Starting distributed worker processes: ['27981 (127.0.0.1)', '27982 (127.0.0.1)']
[2m[36m(RayTrainWorker pid=27981)[0m Setting up process group for: env:// [rank=0, world_size=2]
[2m[36m(TorchTrainer pid=27979)[0m Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(TorchVisionPreprocessor._transform_numpy)] -> AllToAllOperator[RandomizeBlockOrder]
[2m[36m(TorchTrainer pid=27979)[0m Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)
[2m[36m(TorchTrainer pid=27979)[0m Tip: For detailed progress reporting, run `r

(pid=27979) - RandomizeBlockOrder 1:   0%|          | 0/200 [00:00<?, ?it/s]

(pid=27979) Running 0:   0%|          | 0/200 [00:00<?, ?it/s]

[2m[36m(RayTrainWorker pid=27981)[0m Moving model to device: cpu
[2m[36m(RayTrainWorker pid=27981)[0m Wrapping provided model in DistributedDataParallel.


[2m[36m(RayTrainWorker pid=27982)[0m [1,  2000] loss: 2.244
[2m[36m(RayTrainWorker pid=27982)[0m [1,  6000] loss: 1.721[32m [repeated 4x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)[0m
[2m[36m(RayTrainWorker pid=27982)[0m [1, 10000] loss: 1.503[32m [repeated 4x across cluster][0m
[2m[36m(RayTrainWorker pid=27982)[0m [2,  2000] loss: 1.376[32m [repeated 4x across cluster][0m
[2m[36m(RayTrainWorker pid=27982)[0m [2,  6000] loss: 1.354[32m [repeated 4x across cluster][0m
[2m[36m(RayTrainWorker pid=27982)[0m [2, 10000] loss: 1.273[32m [repeated 4x across cluster][0m


2023-09-05 21:55:55,858	INFO tune.py:1148 -- Total run time: 57.58 seconds (57.55 seconds for the tuning loop).


In [8]:
from ray.train.torch import TorchPredictor
from ray.train.batch_predictor import BatchPredictor

batch_predictor = BatchPredictor.from_checkpoint(
    checkpoint=latest_checkpoint,
    predictor_cls=TorchPredictor,
    model=Net(),
)

outputs: ray.data.Dataset = batch_predictor.predict(
    data=test_dataset,
    dtype=torch.float,
    feature_columns=["image"],
    keep_columns=["label"],
)

In [9]:
import numpy as np

def convert_logits_to_classes(df):
    best_class = df["predictions"].map(lambda x: x.argmax())
    df["prediction"] = best_class
    return df[["prediction", "label"]]

predictions = outputs.map_batches(convert_logits_to_classes, batch_format="pandas")

predictions.show(1)

2023-09-05 22:00:18,508	INFO dataset.py:2180 -- Tip: Use `take_batch()` instead of `take() / show()` to return records in pandas or numpy batch format.
2023-09-05 22:00:18,511	INFO streaming_executor.py:92 -- Executing DAG InputDataBuffer[Input] -> ActorPoolMapOperator[MapBatches(TorchVisionPreprocessor._transform_numpy)->MapBatches(ScoringWrapper)] -> TaskPoolMapOperator[MapBatches(convert_logits_to_classes)]
2023-09-05 22:00:18,511	INFO streaming_executor.py:93 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)
2023-09-05 22:00:18,512	INFO streaming_executor.py:95 -- Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`
2023-09-05 22:00:18,531	INFO actor_pool_map_operator.py:117 -- MapBatches(TorchVisionPreprocessor._transform_numpy)->MapBatches(Sco

Running 0:   0%|          | 0/200 [00:00<?, ?it/s]

{'prediction': 3, 'label': 3}


In [10]:
def calculate_prediction_scores(df):
    df["correct"] = df["prediction"] == df["label"]
    return df


scores = predictions.map_batches(calculate_prediction_scores, batch_format="pandas")

scores.show(1)

2023-09-05 22:00:50,895	INFO streaming_executor.py:92 -- Executing DAG InputDataBuffer[Input] -> ActorPoolMapOperator[MapBatches(TorchVisionPreprocessor._transform_numpy)->MapBatches(ScoringWrapper)] -> TaskPoolMapOperator[MapBatches(convert_logits_to_classes)->MapBatches(calculate_prediction_scores)]
2023-09-05 22:00:50,900	INFO streaming_executor.py:93 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)
2023-09-05 22:00:50,903	INFO streaming_executor.py:95 -- Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`
2023-09-05 22:00:50,928	INFO actor_pool_map_operator.py:117 -- MapBatches(TorchVisionPreprocessor._transform_numpy)->MapBatches(ScoringWrapper): Waiting for 1 pool actors to start...


Running 0:   0%|          | 0/200 [00:00<?, ?it/s]



{'prediction': 3, 'label': 3, 'correct': True}


In [11]:
scores.sum(on="correct") / scores.count()

2023-09-05 22:01:14,592	INFO streaming_executor.py:92 -- Executing DAG InputDataBuffer[Input] -> ActorPoolMapOperator[MapBatches(TorchVisionPreprocessor._transform_numpy)->MapBatches(ScoringWrapper)] -> TaskPoolMapOperator[MapBatches(convert_logits_to_classes)->MapBatches(calculate_prediction_scores)] -> AllToAllOperator[Aggregate]
2023-09-05 22:01:14,592	INFO streaming_executor.py:93 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)
2023-09-05 22:01:14,593	INFO streaming_executor.py:95 -- Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`
2023-09-05 22:01:14,612	INFO actor_pool_map_operator.py:117 -- MapBatches(TorchVisionPreprocessor._transform_numpy)->MapBatches(ScoringWrapper): Waiting for 1 pool actors to start...


- Aggregate 1:   0%|          | 0/200 [00:00<?, ?it/s]

Shuffle Map 2:   0%|          | 0/200 [00:00<?, ?it/s]

Shuffle Reduce 3:   0%|          | 0/200 [00:00<?, ?it/s]

Running 0:   0%|          | 0/200 [00:00<?, ?it/s]

2023-09-05 22:01:17,571	INFO streaming_executor.py:92 -- Executing DAG InputDataBuffer[Input] -> ActorPoolMapOperator[MapBatches(TorchVisionPreprocessor._transform_numpy)->MapBatches(ScoringWrapper)] -> TaskPoolMapOperator[MapBatches(convert_logits_to_classes)->MapBatches(calculate_prediction_scores)]
2023-09-05 22:01:17,580	INFO streaming_executor.py:93 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)
2023-09-05 22:01:17,587	INFO streaming_executor.py:95 -- Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`
2023-09-05 22:01:17,606	INFO actor_pool_map_operator.py:117 -- MapBatches(TorchVisionPreprocessor._transform_numpy)->MapBatches(ScoringWrapper): Waiting for 1 pool actors to start...


Running 0:   0%|          | 0/200 [00:00<?, ?it/s]

0.5506

## Deploy the network and make a prediction

In [12]:
from ray import serve
from ray.serve import PredictorDeployment
from ray.serve.http_adapters import json_to_ndarray


serve.run(
    PredictorDeployment.bind(
        TorchPredictor,
        latest_checkpoint,
        model=Net(),
        http_adapter=json_to_ndarray,
    )
)

[2m[36m(HTTPProxyActor pid=29621)[0m INFO:     Started server process [29621]
[2m[36m(ServeController pid=29619)[0m INFO 2023-09-05 22:02:40,557 controller 29619 deployment_state.py:1308 - Deploying new version of deployment default_PredictorDeployment.
[2m[36m(ServeController pid=29619)[0m INFO 2023-09-05 22:02:40,660 controller 29619 deployment_state.py:1571 - Adding 1 replica to deployment default_PredictorDeployment.
2023-09-05 22:02:42,475	INFO router.py:853 -- Using PowerOfTwoChoicesReplicaScheduler.
2023-09-05 22:02:42,483	INFO router.py:329 -- Got updated replicas for deployment default_PredictorDeployment: {'default_PredictorDeployment#DUcTXi'}.


RayServeSyncHandle(deployment='default_PredictorDeployment')

In [13]:
image = test_dataset.take(1)[0]["image"]

In [14]:
import requests

payload = {"array": image.tolist(), "dtype": "float32"}
response = requests.post("http://localhost:8000/", json=payload)
response.json()

{'predictions': [-110.8222427368164,
  1.1426030397415161,
  -103.1728286743164,
  165.62962341308594,
  -172.87025451660156,
  196.0765380859375,
  28.147117614746094,
  -22.12794303894043,
  92.32986450195312,
  -194.28038024902344]}

[2m[36m(ServeReplica:default_PredictorDeployment pid=29622)[0m INFO 2023-09-05 22:04:48,846 default_PredictorDeployment default_PredictorDeployment#DUcTXi fvIHieyWEi / default replica.py:723 - __CALL__ OK 15.5ms
