# Get Started with Distributed Training using TensorFlow/Keras
Ray Train’s TensorFlow integration enables you to scale your TensorFlow and Keras training functions to many machines and GPUs by configuring `TF_CONFIG` and managing worker processes for you.

On a technical level, Ray Train schedules your training workers and configures `TF_CONFIG` for you, allowing you to run your `MultiWorkerMirroredStrategy` training script. See Distributed training with TensorFlow for more information.

Most of the examples in this guide use TensorFlow with Keras, but Ray Train also works with vanilla TensorFlow.

## Quickstart

In [0]:
import ray
import tensorflow as tf

from ray import train
from ray.train import ScalingConfig
from ray.train.tensorflow import TensorflowTrainer
from ray.train.tensorflow.keras import ReportCheckpointCallback


# If using GPUs, set this to True.
use_gpu = False

a = 5
b = 10
size = 100


def build_model() -> tf.keras.Model:
    model = tf.keras.Sequential(
        [
            tf.keras.layers.InputLayer(input_shape=()),
            # Add feature dimension, expanding (batch_size,) to (batch_size, 1).
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(10),
            tf.keras.layers.Dense(1),
        ]
    )
    return model


def train_func(config: dict):
    batch_size = config.get("batch_size", 64)
    epochs = config.get("epochs", 3)

    strategy = tf.distribute.MultiWorkerMirroredStrategy()
    with strategy.scope():
        # Model building/compiling need to be within `strategy.scope()`.
        multi_worker_model = build_model()
        multi_worker_model.compile(
            optimizer=tf.keras.optimizers.SGD(learning_rate=config.get("lr", 1e-3)),
            loss=tf.keras.losses.mean_squared_error,
            metrics=[tf.keras.metrics.mean_squared_error],
        )

    dataset = train.get_dataset_shard("train")

    results = []
    for _ in range(epochs):
        tf_dataset = dataset.to_tf(
            feature_columns="x", label_columns="y", batch_size=batch_size
        )
        history = multi_worker_model.fit(
            tf_dataset, callbacks=[ReportCheckpointCallback()]
        )
        results.append(history.history)
    return results


config = {"lr": 1e-3, "batch_size": 32, "epochs": 4}

train_dataset = ray.data.from_items(
    [{"x": x / 200, "y": 2 * x / 200} for x in range(200)]
)
scaling_config = ScalingConfig(num_workers=1, use_gpu=use_gpu)
trainer = TensorflowTrainer(
    train_loop_per_worker=train_func,
    train_loop_config=config,
    scaling_config=scaling_config,
    datasets={"train": train_dataset},
)
result = trainer.fit()
print(result.metrics)

## Update your training function
Wrap your model building and compilation in a `MultiWorkerMirroredStrategy` scope:
```python
with tf.distribute.MultiWorkerMirroredStrategy().scope():
    model = build_model()
    model.compile(...)
```
Adjust your batch size to global batch size:
```diff
- batch_size = worker_batch_size
+ batch_size = worker_batch_size * train.get_context().get_world_size()
```

## Create a TensorflowTrainer
Instantiate a `TensorflowTrainer` with the desired scaling configuration:

In [0]:
from ray.train import ScalingConfig
from ray.train.tensorflow import TensorflowTrainer, TensorflowConfig

# For GPU training, set use_gpu=True
trainer = TensorflowTrainer(
    train_loop_per_worker=train_func,
    scaling_config=ScalingConfig(use_gpu=False, num_workers=1),
    tensorflow_backend=TensorflowConfig(),  # optional custom backend config
)

## Run a training function

In [0]:
trainer.fit()

## Load and preprocess data
Convert a Ray Dataset shard into a TensorFlow dataset:
```python
from ray import train
from ray.train.tensorflow import prepare_dataset_shard

def train_func(config: dict):
    dataset_shard = train.get_context().get_dataset_shard('train')
    def to_tf_dataset(ds, batch_size):
        tf_ds = ds.to_tf(feature_columns='image', label_columns='label', batch_size=batch_size)
        return prepare_dataset_shard(tf_ds)
    # ... use tf_ds ...
```

## Report results
Use `ReportCheckpointCallback` to automatically report metrics and checkpoints:

In [0]:
from ray.train.tensorflow.keras import ReportCheckpointCallback

def train_func(config: dict):
    # ...
    for epoch in range(config['epochs']):
        history = model.fit(dataset, callbacks=[ReportCheckpointCallback()])
        # ...

## Save and load checkpoints

In [0]:
import tempfile
from ray import train
from ray.train import Checkpoint, ScalingConfig
from ray.train.tensorflow import TensorflowTrainer

def train_func(config):
    # ...
    for epoch in range(config['num_epochs']):
        # Save model and epoch metadata
        with tempfile.TemporaryDirectory() as tmp:
            model.save(f"{tmp}/model.keras")
            checkpoint = Checkpoint.from_directory(tmp)
            train.report({'loss': history.history['loss'][0]}, checkpoint=checkpoint)

In [0]:
# Loading from checkpoint
def train_func(config):
    checkpoint = train.get_checkpoint()
    if checkpoint:
        with checkpoint.as_directory() as d:
            model = tf.keras.models.load_model(f"{d}/model.keras")
    else:
        model = build_model()
    model.compile(...)
    # ...

## Further reading
- [Experiment tracking]
- [Fault tolerance and spot instances]
- [Hyperparameter optimization]