In [None]:
import ray
from ray.train import ScalingConfig, RunConfig, Checkpoint
import numpy as np
import os
import tempfile
import torch
from torch.optim import Adam
from torch.nn import BCEWithLogitsLoss
from ray.train.torch import TorchTrainer

# Train a Basic PyTorch Model with Ray Train + Ray Data

We'll start with the NYC taxicab data

In [None]:
dataset = ray.data.read_parquet("s3://anonymous@anyscale-training-data/intro-to-ray-air/nyc_taxi_2021.parquet", 
                                columns=['passenger_count', 'trip_distance', 'fare_amount', 'trip_duration','hour','day_of_week','is_big_tip'])

dataset.schema()

In order to feed this data to typical model fitting APIs, we need to collect the predictors into a vector. We may want to stack these into batches represented by matrices.

> Note: we could do this within the per-worker training function itself, converting each batch of data close to where we consume it for training. In many cases, though, that would violate separation of concerns and make it hard to change and reuse the data processing logic.

We're going to demonstrate how to do this with Ray Data because
1. we might want to do other transformations or feature preprocessing in a data-only pipeline
1. it provides better separation of concerns and reuse
1. it allows us to better benefit from Ray's granular resource and scheduling capabilities

Map batches receives and returns a dict by default.

An easy way to design the batch logic is to take a batch from the dataset and write code that works correctly for that batch; then we'll wrap it in a function.

In [None]:
b = dataset.take_batch(5)

b

Let's move our logic into a function and check the output

In [None]:
def batch_vectorize(batch):
    return { 'vectors' : np.vstack(list(batch.values()), dtype=np.float32).transpose() }

In [None]:
train_ds = dataset.map_batches(batch_vectorize)

In [None]:
train_ds.take_batch(4)

Looks good. Next we'll prepare a per-worker training function
* Start with standard PyTorch code (model, loss, optimizer, loop)
* use `ray.train.torch.prepare_model` to wrap the model for distributed training
* use `get_dataset_shard` to obtain a source for data batches
    * if we want to implement validation (or anything else) we can pass additional datasets in the same way
    * the batchs from `iter_torch_batches` will be dictionaries ... with Torch Tensors as the values
* supply arbitrary small values via `train_loop_config`
* create checkpoints and call `train.report`
    * we can use the logic we like to decide on checkpointing (typically only from one worker and only every *n* epochs)
    * we must call `train.report` from each worker (even if that worker is not checkpointing or reporting any stats)

In [None]:
def train_func(config):
    from ray.train import get_dataset_shard

    # simple multilayer perceptron model
    
    D_in = 6
    H1 = config["H1_width"] # example config params
    H2 = config["H2_width"]
    D_out = 1
    
    model = torch.nn.Sequential(
      torch.nn.Linear(D_in, H1), 
      torch.nn.ReLU(),
      torch.nn.Linear(H1, H2), 
      torch.nn.ReLU(),
      torch.nn.Linear(H2, D_out)
    )
    
    criterion = BCEWithLogitsLoss()
    optimizer = Adam(model.parameters())

    # Prepare model

    model = ray.train.torch.prepare_model(model)

    # Get and local data shard stream and generate batch iterator
    
    train_sh = get_dataset_shard("train")
    training = train_sh.iter_torch_batches(batch_size=65536, dtypes=torch.float)
    
    # Simple torch training loop
    
    for epoch in range(5):
        for batch in training:
            features = batch['vectors'][:,:-1]
            targets = batch['vectors'][:,-1,None]
            outputs = model(features)
            loss = criterion(outputs, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Report (with metrics) and checkpoint

        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
            checkpoint = None

            # In standard DDP training, where the model is the same across all ranks,
            # only the global rank 0 worker needs to save and report the checkpoint
            if ray.train.get_context().get_world_rank() == 0:
                torch.save(
                    model.module.state_dict(),  # NOTE: Unwrap the model.
                    os.path.join(temp_checkpoint_dir, "model.pt"),
                )
                checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)

            ray.train.report({'loss': loss.item()}, checkpoint=checkpoint)    

In [None]:
scaling_config = ScalingConfig(num_workers=2, use_gpu=True)

trainer = TorchTrainer(train_func, 
                       scaling_config=scaling_config, 
                       run_config=ray.train.RunConfig(storage_path='/mnt/cluster_storage'),
                       datasets={"train": train_ds},
                       train_loop_config={"H1_width": 10, "H2_width": 5})

result = trainer.fit()

In [None]:
result