In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from filelock import FileLock
import numpy as np

import ray

def get_data_loader():
    
    """Safely downloads data. Returns training/validation set dataloader."""
    mnist_transforms = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )

    # We add FileLock here because multiple workers will want to
    # download data, and this may cause overwrites since
    # DataLoader is not threadsafe.
    
    class MNISTEvenOddDataset(torch.utils.data.Dataset):
        def __init__(self, ready_data):
            self.img_data = ready_data.data
            self.labels = ready_data.targets % 2
        
        def __len__(self):
            return len(self.labels)
    
        def __getitem__(self, ind):
            return torch.true_divide(self.img_data[ind].view(-1, 28 * 28).squeeze(), 255), torch.tensor([self.labels[ind]])


    
    with FileLock(os.path.expanduser("~/data.lock")):
        
        train_dataset = datasets.MNIST(
                "~/data", train=True, download=True, transform=mnist_transforms
            )
        
        test_dataset = datasets.MNIST("~/data", train=False, transform=mnist_transforms)
        
        train_loader = torch.utils.data.DataLoader(
            MNISTEvenOddDataset(train_dataset),
            batch_size=128,
            shuffle=True,
        )
        test_loader = torch.utils.data.DataLoader(
             MNISTEvenOddDataset(test_dataset),
            batch_size=128,
            shuffle=False,
        )
    return train_loader, test_loader


def evaluate(model, test_loader):
    """Evaluates the accuracy of the model on a validation dataset."""
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader):
            # This is only set to finish evaluation faster.
            if batch_idx * len(data) > 1024:
                break
            outputs = nn.Sigmoid()(model(data))
            #_, predicted = torch.max(outputs.data, 1)
            predicted = outputs > 0.5
            total += target.size(0)
            correct += (predicted == target).sum().item()
    return 100.0 * correct / total

In [2]:
class LinearNet(nn.Module):
    """Small Linear Network for MNIST."""

    def __init__(self):
        super(LinearNet, self).__init__()
        self.fc = nn.Linear(28*28, 1)

    def forward(self, x):
        x = self.fc(x)
        return x
    
    def get_weights(self):
        return {k: v.cpu() for k, v in self.state_dict().items()}

    def set_weights(self, keys, weights):   
        self.load_state_dict({keys[i]:weights[i] for i in range(len(keys))})
        
    def get_gradients(self):
        grads = []
        for p in self.parameters():
            grad = None if p.grad is None else p.grad.data.cpu().numpy()
            grads.append(grad)
        return grads

    def set_gradients(self, gradients):
        for g, p in zip(gradients, self.parameters()):
            if g is not None:
                p.grad = torch.from_numpy(g)

In [3]:
 @ray.remote  
class ParameterServer(object):
    def __init__(self, keys, values):
        values = [value.clone().detach() for value in values]
        self.weights = dict(zip(keys, values))

    def apply_gradients(self, keys, lr, *values):

        summed_gradients = [
            np.stack(gradient_zip).sum(axis=0) for gradient_zip in zip(*values)
        ]
            

        idx = 0
        for key, value in zip(keys, summed_gradients):
            self.weights[key] -= lr * torch.from_numpy(summed_gradients[idx])
            idx+=1
            
        return [self.weights[key] for key in keys]

    def get_weights(self, keys):
        return [self.weights[key] for key in keys]


In [4]:
@ray.remote
class DataWorker(object):
    def __init__(self):
        self.model = LinearNet()
        self.data_iterator = iter(get_data_loader()[0])

    def compute_gradients(self, keys, weights):
        self.model.set_weights(keys, weights)
        try:
            data, target = next(self.data_iterator)
        except StopIteration:  # When the epoch ends, start a new epoch.
            self.data_iterator = iter(get_data_loader()[0])
            data, target = next(self.data_iterator)
        
        self.model.zero_grad()
        output = self.model(data)
        loss = nn.BCEWithLogitsLoss()(output, target.float())
        loss.backward()
        return self.model.get_gradients()

In [5]:
iterations = 100
num_workers = 5

def Scheduler(num_servers):
    model = LinearNet()
    key_values = model.get_weights()
    keys = list(key_values.keys())
    values = [key_values[key] for key in keys]
    
    servers = ParameterServer.remote(keys, values)
    
    #print(keys)
    # assuming keys are ordered
    #approx_partition = len(keys)//num_servers
    #servers = [ParameterServer.remote(1e-2, keys[i*approx_partition,(i+1)*approx_partition], values[i*approx_partition,(i+1)*approx_partition]) for i in range(num_servers)]
    return servers, keys, model

ps, keys, model =  Scheduler(3)
ray.init(ignore_reinit_error=True)
workers = [DataWorker.remote() for i in range(num_workers)]

2022-04-29 11:46:00,575	INFO worker.py:963 -- Calling ray.init() again after it has already been called.


In [6]:
test_loader = get_data_loader()[1]

In [7]:
print("Running synchronous parameter server training.")
lr=0.01
current_weights = ps.get_weights.remote(keys)
for i in range(iterations):
    gradients = [worker.compute_gradients.remote(keys, current_weights) for worker in workers]
    
    # Calculate update after all gradients are available.

    current_weights = ps.apply_gradients.remote(keys, lr, *gradients)

    if i % 10 == 0:
        # Evaluate the current model.

        model.set_weights(keys, ray.get(current_weights))
        accuracy = evaluate(model, test_loader)
        print("Iter {}: \taccuracy is {:.1f}".format(i, accuracy))

print("Final accuracy is {:.1f}.".format(accuracy))
# Clean up Ray resources and processes before the next example.
ray.shutdown()

Running synchronous parameter server training.
Iter 0: 	accuracy is 62.2
Iter 10: 	accuracy is 80.9
Iter 20: 	accuracy is 82.1
Iter 30: 	accuracy is 82.9
Iter 40: 	accuracy is 83.2
Iter 50: 	accuracy is 83.7
Iter 60: 	accuracy is 84.0
Iter 70: 	accuracy is 84.8
Iter 80: 	accuracy is 85.0
Iter 90: 	accuracy is 85.3
Final accuracy is 85.3.
