In [1]:
#install all dependencies
!pip install filelock
!pip install torch --no-cache-dir
!pip install torchvision
!pip install tqdm



In [2]:
import os
from filelock import FileLock
from typing import Dict

In [3]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor, Normalize
from tqdm import tqdm

In [4]:
!pip install "ray==2.7.1"
!pip install ray[train]



In [5]:
import ray.train
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer

In [6]:
def get_dataloaders(batch_size):
    # Transform to normalize the input images
    transform = transforms.Compose([ToTensor(), Normalize((0.5,), (0.5,))])

    with FileLock(os.path.expanduser("~/data.lock")):
        # Download training data from open datasets.
        training_data = datasets.FashionMNIST(
            root="~/data",
            train=True,
            download=True,
            transform=transform,
        )

        # Download test data from open datasets.
        test_data = datasets.FashionMNIST(
            root="~/data",
            train=False,
            download=True,
            transform=transform,
        )

    # Create data loaders.
    train_dataloader = DataLoader(training_data, batch_size=batch_size)
    test_dataloader = DataLoader(test_data, batch_size=batch_size)

    return train_dataloader, test_dataloader

In [7]:
# Model Definition
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(512, 10),
            nn.ReLU(),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


In [8]:
def train_func_per_worker(config: Dict):
    lr = config["lr"]
    epochs = config["epochs"]
    batch_size = config["batch_size_per_worker"]

    # Get dataloaders inside worker training function
    train_dataloader, test_dataloader = get_dataloaders(batch_size=batch_size)

    # [1] Prepare Dataloader for distributed training
    # Shard the datasets among workers and move batches to the correct device
    # =======================================================================
    train_dataloader = ray.train.torch.prepare_data_loader(train_dataloader)
    test_dataloader = ray.train.torch.prepare_data_loader(test_dataloader)

    model = NeuralNetwork()

    # [2] Prepare and wrap your model with DistributedDataParallel
    # Move the model the correct GPU/CPU device
    # ============================================================
    model = ray.train.torch.prepare_model(model)

    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    # Model training loop
    for epoch in range(epochs):
        model.train()
        for X, y in tqdm(train_dataloader, desc=f"Train Epoch {epoch}"):
            pred = model(X)
            loss = loss_fn(pred, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        model.eval()
        test_loss, num_correct, num_total = 0, 0, 0
        with torch.no_grad():
            for X, y in tqdm(test_dataloader, desc=f"Test Epoch {epoch}"):
                pred = model(X)
                loss = loss_fn(pred, y)

                test_loss += loss.item()
                num_total += y.shape[0]
                num_correct += (pred.argmax(1) == y).sum().item()

        test_loss /= len(test_dataloader)
        accuracy = num_correct / num_total

        # [3] Report metrics to Ray Train
        # ===============================
        ray.train.report(metrics={"loss": test_loss, "accuracy": accuracy})

In [9]:
def train_fashion_mnist(num_workers=2, use_gpu=False):
    global_batch_size = 32

    train_config = {
        "lr": 1e-3,
        "epochs": 10,
        "batch_size_per_worker": global_batch_size // num_workers,
    }

    # Configure computation resources
    scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)

    # Initialize a Ray TorchTrainer
    trainer = TorchTrainer(
        train_loop_per_worker=train_func_per_worker,
        train_loop_config=train_config,
        scaling_config=scaling_config,
    )

    # [4] Start Distributed Training
    # Run `train_func_per_worker` on all workers
    # =============================================
    result = trainer.fit()
    print(f"Training result: {result}")


In [10]:
!pip install ray[client]



In [None]:
if __name__ == "__main__":
    train_fashion_mnist(num_workers=4, use_gpu=False)

0,1
Current time:,2023-10-12 19:34:13
Running for:,00:05:38.02
Memory:,7.3/31.3 GiB

Trial name,status,loc,iter,total time (s),loss,accuracy
TorchTrainer_898ce_00000,RUNNING,192.168.14.59:32528,9,308.181,0.374307,0.8656


[2m[36m(TorchTrainer pid=32528)[0m Starting distributed worker processes: ['32592 (192.168.14.59)', '32593 (192.168.14.59)', '32594 (192.168.14.59)', '32595 (192.168.14.59)']
[2m[36m(RayTrainWorker pid=32592)[0m Setting up process group for: env:// [rank=0, world_size=4]


[2m[36m(RayTrainWorker pid=32593)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
[2m[36m(RayTrainWorker pid=32593)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to /usr/local/google/home/ryanaoleary/data/FashionMNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]
  0%|          | 32768/26421880 [00:00<01:51, 235700.28it/s]
  0%|          | 65536/26421880 [00:00<01:53, 232615.43it/s]
  0%|          | 131072/26421880 [00:00<01:17, 337292.21it/s]
  1%|          | 196608/26421880 [00:00<01:08, 384727.19it/s]
  1%|▏         | 393216/26421880 [00:00<00:34, 744231.59it/s]
  3%|▎         | 786432/26421880 [00:00<00:17, 1432676.55it/s]
  6%|▌         | 1572864/26421880 [00:00<00:08, 2773498.33it/s]
 12%|█▏        | 3145728/26421880 [00:01<00:04, 5411421.91it/s]
 24%|██▎       | 6258688/26421880 [00:01<00:01, 10568976.93it/s]
 35%|███▍      | 9175040/26421880 [00:01<00:01, 13593256.86it/s]
 46%|████▌     | 12124160/26421880 [00:01<00:00, 15776674.70it/s]
 57%|█████▋    | 14974976/26421880 [00:01<00:00, 17067137.08it/s]
 68%|██████▊   | 17924096/26421880 [00:01<00:00, 18168399.72it/s]
 79%|███████▉  | 20873216/26421880 [00:01<00:00, 18948893.35it/s]
100%|██████████| 26421880/26421880 [00:02<00:00, 12075432.14it

[2m[36m(RayTrainWorker pid=32593)[0m Extracting /usr/local/google/home/ryanaoleary/data/FashionMNIST/raw/train-images-idx3-ubyte.gz to /usr/local/google/home/ryanaoleary/data/FashionMNIST/raw
[2m[36m(RayTrainWorker pid=32593)[0m 
[2m[36m(RayTrainWorker pid=32593)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
[2m[36m(RayTrainWorker pid=32593)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to /usr/local/google/home/ryanaoleary/data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]
100%|██████████| 29515/29515 [00:00<00:00, 211426.89it/s]


[2m[36m(RayTrainWorker pid=32593)[0m Extracting /usr/local/google/home/ryanaoleary/data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to /usr/local/google/home/ryanaoleary/data/FashionMNIST/raw
[2m[36m(RayTrainWorker pid=32593)[0m 
[2m[36m(RayTrainWorker pid=32593)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
[2m[36m(RayTrainWorker pid=32593)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to /usr/local/google/home/ryanaoleary/data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]
  1%|          | 32768/4422102 [00:00<00:18, 232292.24it/s]
  1%|▏         | 65536/4422102 [00:00<00:18, 231932.98it/s]
  3%|▎         | 131072/4422102 [00:00<00:12, 337310.67it/s]
  5%|▌         | 229376/4422102 [00:00<00:08, 478954.65it/s]
 11%|█         | 491520/4422102 [00:00<00:04, 974134.95it/s]
 21%|██▏       | 950272/4422102 [00:00<00:01, 1745714.91it/s]
 43%|████▎     | 1900544/4422102 [00:00<00:00, 3369487.90it/s]


[2m[36m(RayTrainWorker pid=32593)[0m Extracting /usr/local/google/home/ryanaoleary/data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to /usr/local/google/home/ryanaoleary/data/FashionMNIST/raw
[2m[36m(RayTrainWorker pid=32593)[0m 
[2m[36m(RayTrainWorker pid=32593)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz


100%|██████████| 4422102/4422102 [00:01<00:00, 3902105.51it/s]


[2m[36m(RayTrainWorker pid=32593)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to /usr/local/google/home/ryanaoleary/data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
[2m[36m(RayTrainWorker pid=32593)[0m Extracting /usr/local/google/home/ryanaoleary/data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to /usr/local/google/home/ryanaoleary/data/FashionMNIST/raw
[2m[36m(RayTrainWorker pid=32593)[0m 


100%|██████████| 5148/5148 [00:00<00:00, 37617207.30it/s]
Train Epoch 0:   0%|          | 0/1875 [00:00<?, ?it/s]
[2m[36m(RayTrainWorker pid=32592)[0m Moving model to device: cpu
[2m[36m(RayTrainWorker pid=32592)[0m Wrapping provided model in DistributedDataParallel.
Train Epoch 0:   0%|          | 2/1875 [00:00<01:52, 16.60it/s]
Train Epoch 0:   0%|          | 0/1875 [00:00<?, ?it/s][32m [repeated 3x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)[0m
Train Epoch 0:  17%|█▋        | 316/1875 [00:05<00:25, 60.66it/s][32m [repeated 180x across cluster][0m
Train Epoch 0:  34%|███▍      | 637/1875 [00:10<00:19, 64.22it/s][32m [repeated 184x across cluster][0m
Train Epoch 0:  34%|███▍      | 645/1875 [00:10<00:18, 65.59it/s]
Train Epoch 0:  51%|█████     | 953/1875 [00:15<00:14, 63.03it/s][32m [repeated 178x across cl