In [None]:
%matplotlib inline


# Parameter Server

The parameter server is a framework for distributed machine learning training.

In the parameter server framework, a centralized server (or group of server nodes) maintains global shared arameters of a machine-learning model (e.g., a neural network) while the data and computation of calculating updates (i.e., gradient descent updates) are distributed over worker nodes.

<img src="https://docs.ray.io/en/master/_images/param_actor1.png" align="center">

Parameter servers are a core part of many machine learning applications. This document walks through how to implement simple synchronous and asynchronous parameter servers using Ray actors.

Let's first define some helper functions and import some dependencies.

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import numpy as np

from lithops import multiprocessing as mp
from lithops.multiprocessing.managers import SyncManager

In [None]:
def get_data_loader():
    """Safely downloads data. Returns training/validation set dataloader."""
    mnist_transforms = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])

    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            "~/data",
            train=True,
            download=True,
            transform=mnist_transforms),
        batch_size=128,
        shuffle=True)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST("~/data", train=False, transform=mnist_transforms),
        batch_size=128,
        shuffle=True)
    return train_loader, test_loader


def evaluate(model, test_loader):
    """Evaluates the accuracy of the model on a validation dataset."""
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader):
            # This is only set to finish evaluation faster.
            if batch_idx * len(data) > 1024:
                break
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    return 100. * correct / total

We define a small neural network to use in training. We provide some helper functions for obtaining data, including getter/setter methods for gradients and weights.

In [None]:
class ConvNet(nn.Module):
    """Small ConvNet for MNIST."""

    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
        self.fc = nn.Linear(192, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 3))
        x = x.view(-1, 192)
        x = self.fc(x)
        return F.log_softmax(x, dim=1)

    def get_weights(self):
        return {k: v.cpu() for k, v in self.state_dict().items()}

    def set_weights(self, weights):
        self.load_state_dict(weights)

    def get_gradients(self):
        grads = []
        for p in self.parameters():
            grad = None if p.grad is None else p.grad.data.cpu().numpy()
            grads.append(grad)
        return grads

    def set_gradients(self, gradients):
        for g, p in zip(gradients, self.parameters()):
            if g is not None:
                p.grad = torch.from_numpy(g)

## Parameter Server using Manager

We will define a Manager instance that will hold a copy of the model and contain the logic for, during training, receive gradients and apply them to its model and to send the updated model back to the workers.



In [None]:
class ParameterServer:
    def __init__(self, lr):
        self.model = ConvNet()
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=lr)

    def apply_gradients(self, *gradients):
        summed_gradients = [
            np.stack(gradient_zip).sum(axis=0)
            for gradient_zip in zip(*gradients)
        ]
        self.optimizer.zero_grad()
        self.model.set_gradients(summed_gradients)
        self.optimizer.step()
        return self.model.get_weights()

    def get_weights(self):
        return self.model.get_weights()

In [None]:
SyncManager.register('ParameterServer', ParameterServer)

man = SyncManager()
ps = man.ParameterServer(1e-2)
ps_lock = man.Lock()

man.start()

The worker will also hold a copy of the model. During training. it will continuously evaluate data and send gradients to the parameter server. The worker will synchronize its model with the Parameter Server model weights.



In [None]:
def compute_gradients(parameter_server, lock):
    model = ConvNet()
    data_iterator = iter(get_data_loader()[0])
    with lock:
        weights = parameter_server.get_weights()
    model.set_weights(weights)
    try:
        data, target = next(data_iterator)
    except StopIteration:  # When the epoch ends, start a new epoch.
        data_iterator = iter(get_data_loader()[0])
        data, target = next(data_iterator)
    model.zero_grad()
    output = model(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    gradients = model.get_gradients()
    with lock:
        parameter_server.apply_gradients(gradients)

We will now use a Pool to spawn workers that interact with the manager:

In [None]:
iterations = 5
num_workers = 25

In [None]:
print("Running synchronous parameter server training.")

pool = mp.Pool()
model = ConvNet()
test_loader = get_data_loader()[1]

for i in range(iterations):
    pool.starmap(compute_gradients, [(ps, ps_lock)] * num_workers)
    with ps_lock:
        current_weights = ps.get_weights()

    # Evaluate the current model.
    model.set_weights(current_weights)
    accuracy = evaluate(model, test_loader)
    print("Iter {}: \taccuracy is {:.1f}".format(i, accuracy))

print("Final accuracy is {:.1f}.".format(accuracy))

## Parameter Server using Processes and Queues

In [None]:
def parameter_server(lr, server_queue, worker_queues, iterations):
    model = ConvNet()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    test_loader = get_data_loader()[1]

    for i in range(iterations):
        print(f'Iteration {i}', flush=True)
        weights = model.get_weights()
        
        for q in worker_queues:
            print('put', flush=True)
            q.put(weights)
        
        gradients = []
        for _ in range(len(worker_queues)):
            print('Received gradient', flush=True)
            gradient = server_queue.get(timeout=15)
            gradients.append(gradient)
        
        summed_gradients = [
            np.stack(gradient_zip).sum(axis=0)
            for gradient_zip in zip(*gradients)
        ]
        
        optimizer.zero_grad()
        model.set_gradients(summed_gradients)
        optimizer.step()
        accuracy = evaluate(model, test_loader)
        print(f'Accuracy is {accuracy}', flush=True)
    
    for q in worker_queues:
        print('terminate', flush=True)
        q.put(None)

In [None]:
def worker(queue, param_server_queue):
    model = ConvNet()
    data_iterator = iter(get_data_loader()[0])
    
    while True:
        weights = queue.get(timeout=35)

        if weights is None:
            break
        
        model.set_weights(weights)
        try:
            data, target = next(data_iterator)
        except StopIteration:  # When the epoch ends, start a new epoch.
            data_iterator = iter(get_data_loader()[0])
            data, target = next(data_iterator)
        model.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        gradients = model.get_gradients()
        
        param_server_queue.put(gradients)   

In [None]:
num_workers = 25
iterations = 5
lr = 1e-2

In [None]:
mp.config.set_parameter(mp.config.STREAM_STDOUT, True)

In [None]:
server_q = mp.Queue()
worker_queues = [mp.Queue() for _ in range(num_workers)]

In [None]:
param_server = mp.Process(target=parameter_server, args=(lr, server_q, worker_queues, iterations))
param_server.start()

In [None]:
with mp.Pool() as p:
    p.starmap(worker, [(q, server_q) for q in worker_queues])

In [None]:
param_server.join()