In [97]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from filelock import FileLock
import numpy as np
import ray
import sys
sys.path.append("../")
from consistent_hashing import ConsistentHash

def get_data_loader():
    
    """Safely downloads data. Returns training/validation set dataloader."""
    mnist_transforms = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )

    # We add FileLock here because multiple workers will want to
    # download data, and this may cause overwrites since
    # DataLoader is not threadsafe.
    
    class MNISTEvenOddDataset(torch.utils.data.Dataset):
        def __init__(self, ready_data):
            self.img_data = ready_data.data
            self.labels = ready_data.targets % 2
        
        def __len__(self):
            return len(self.labels)
    
        def __getitem__(self, ind):
            return torch.true_divide(self.img_data[ind].view(-1, 28 * 28).squeeze(), 255), torch.tensor([self.labels[ind]])


    
    with FileLock(os.path.expanduser("~/data.lock")):
        
        train_dataset = datasets.MNIST(
                "~/data", train=True, download=True, transform=mnist_transforms
            )
        
        test_dataset = datasets.MNIST("~/data", train=False, transform=mnist_transforms)
        
        train_loader = torch.utils.data.DataLoader(
            MNISTEvenOddDataset(train_dataset),
            batch_size=128,
            shuffle=True,
        )
        test_loader = torch.utils.data.DataLoader(
             MNISTEvenOddDataset(test_dataset),
            batch_size=128,
            shuffle=False,
        )
    return train_loader, test_loader


def evaluate(model, test_loader):
    """Evaluates the accuracy of the model on a validation dataset."""
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader):
            # This is only set to finish evaluation faster.
            if batch_idx * len(data) > 1024:
                break
            outputs = nn.Sigmoid()(model(data))
            #_, predicted = torch.max(outputs.data, 1)
            predicted = outputs > 0.5
            total += target.size(0)
            correct += (predicted == target).sum().item()
    return 100.0 * correct / total

In [98]:
class LinearNet(nn.Module):
    """Small Linear Network for MNIST."""

    def __init__(self):
        super(LinearNet, self).__init__()
#         self.fc = nn.Linear(28*28, 1)
        self.fc_weights = nn.ParameterList([nn.Parameter(torch.empty(1)) for weight in range(784)])
        init_fc = [nn.init.normal_(x) for x in self.fc_weights]
#         self.fc_weights = torch.cat([nn.Parameter])
        self.fc_bias = nn.Parameter(torch.empty(1))
        nn.init.normal_(self.fc_bias)

    def forward(self, x):
        fc_layer = torch.cat(tuple(self.fc_weights)).unsqueeze(0)
        x = x @ fc_layer.T + self.fc_bias
        return x
    
    def get_weights(self):
        return {k: v.cpu() for k, v in self.state_dict().items()}

    def set_weights(self, keys, weights):  
        self.load_state_dict({keys[i]:weights[i] for i in range(len(keys))})
        
    def get_gradients(self, keys):
        grads = []
        for name, p in self.named_parameters():
            if name in keys: # this will return everything in the order of named_params we want this in order of keys
                grad = None if p.grad is None else p.grad.data.cpu().numpy()
                grads.append(grad)
        return grads

    def set_gradients(self, gradients):
        for g, p in zip(gradients, self.parameters()):
            if g is not None:
                p.grad = torch.from_numpy(g)

In [99]:
@ray.remote  
class ParameterServer(object):
    def __init__(self, keys, values):
        values = [value.clone().detach() for value in values]
        self.weights = dict(zip(keys, values))

    def apply_gradients(self, keys, lr, *values):

        summed_gradients = [
            np.stack(gradient_zip).sum(axis=0) for gradient_zip in zip(*values)
        ]
    
        idx = 0
        for key, value in zip(keys, summed_gradients):
            self.weights[key] -= lr * torch.from_numpy(summed_gradients[idx])
            idx+=1
            
        return [self.weights[key] for key in keys]

    def get_weights(self, keys):
        return [self.weights[key] for key in keys]


In [100]:
@ray.remote
class DataWorker(object):
    def __init__(self):
        self.model = LinearNet()
        self.data_iterator = iter(get_data_loader()[0])
        
    def update_weights(self, keys, weights):
        self.model.set_weights(keys, weights)

    def compute_gradients(self, keys):
        #self.model.set_weights(keys, weights)
        try:
            data, target = next(self.data_iterator)
        except StopIteration:  # When the epoch ends, start a new epoch.
            self.data_iterator = iter(get_data_loader()[0])
            data, target = next(self.data_iterator)
        
        self.model.zero_grad()
        output = self.model(data)
        loss = nn.BCEWithLogitsLoss()(output, target.float())
        loss.backward()
        
        return self.model.get_gradients(keys)

In [103]:
iterations = 500
num_workers = 1 # number of workers per server
num_servers = 1 # number of servers
hashes_per_server = 50

def Scheduler(num_servers, hashes_per_server=50):
    
    model = LinearNet()
    key_values = model.get_weights()
    keys = np.array(list(key_values.keys()))
    values = np.array([key_values[key] for key in keys])
    key_indices = {key: x for x, key in enumerate(keys)}
    # assuming keys are ordered
    #approx_partition = len(keys)//num_servers
    #servers = [ParameterServer.remote(1e-2, keys[i*approx_partition,(i+1)*approx_partition], values[i*approx_partition,(i+1)*approx_partition]) for i in range(num_servers)]

    # distributing weights across servers - do this using consistency hashing
    server_ids = ["server" + str(ind) for ind in range(num_servers)]
    hasher = ConsistentHash(keys, server_ids, hashes_per_server)
    servers = [ParameterServer.remote(keys[[key_indices[key] for key in hasher.get_keys_per_node()[serv]]], 
                                      values[[key_indices[key] for key in hasher.get_keys_per_node()[serv]]]) for serv in server_ids]
#     servers = [ParameterServer.remote(keys[0:1], values[0:1]), ParameterServer.remote(keys[1:2], values[1:2])]
    
    return servers, keys, model, hasher.get_keys_per_node()

servers, keys, model, weight_assignments =  Scheduler(num_servers, hashes_per_server)
ray.init(ignore_reinit_error=True)

# creating equal workers per server

workers = [[DataWorker.remote() for i in range(num_workers)] for j in range(num_servers)]





  # This is added back by InteractiveShellApp.init_path()
  # This is added back by InteractiveShellApp.init_path()
2022-05-06 12:37:50,250	INFO worker.py:963 -- Calling ray.init() again after it has already been called.


In [104]:
len(weight_assignments["server0"])

785

In [105]:
test_loader = get_data_loader()[1]

In [None]:
print("Running synchronous parameter server training.")
lr=0.01

# we need to get a new keys order because we are not assuming a ordering in keys
current_weights = []
keys_order = []

for j in range(num_servers):
    keys_order.extend(weight_assignments["server" + str(j)])
    current_weights.extend(ray.get(servers[j].get_weights.remote(weight_assignments["server" + str(j)]))) 
                       
for i in range(iterations):
    
    # sync all weights on workers
    if i % 20 == 0:
        current_weights = []
        keys_order = []

        # get weights from server
        for j in range(num_servers):
            keys_order.extend(weight_assignments["server" + str(j)])
            current_weights.extend(ray.get(servers[j].get_weights.remote(weight_assignments["server" + str(j)]))) 
   
        # update weights on all workers
        for j in range(num_servers):
            for  idx  in range(num_workers):
                workers[j][idx].update_weights.remote(keys_order, current_weights)
    
    
    # use local cache of weights and get gradients from workers
    gradients = [[workers[j][idx].compute_gradients.remote(weight_assignments["server" + str(j)]) for  idx  in range(num_workers)] for j in range(num_servers)]
    
    # Updates gradients to specfic parameter servers
    [servers[j].apply_gradients.remote(weight_assignments["server" + str(j)], lr, *gradients[j]) for j in range(num_servers)]
           
    if i % 10 == 0:
        # Evaluate the current model.

        for j in range(num_servers):
            keys_order.extend(weight_assignments["server" + str(j)])
            current_weights.extend(ray.get(servers[j].get_weights.remote(weight_assignments["server" + str(j)]))) 
   
        # we are once again using the server to key mapping to set the weight back
        model.set_weights(keys_order, current_weights)
        accuracy = evaluate(model, test_loader)
        print("Iter {}: \taccuracy is {:.1f}".format(i, accuracy))

print("Final accuracy is {:.1f}.".format(accuracy))
# Clean up Ray resources and processes before the next example.
ray.shutdown()


Running synchronous parameter server training.
Iter 0: 	accuracy is 50.0
Iter 10: 	accuracy is 49.5
Iter 20: 	accuracy is 48.5
Iter 30: 	accuracy is 49.5
Iter 40: 	accuracy is 50.3
Iter 50: 	accuracy is 51.0
Iter 60: 	accuracy is 52.0
Iter 70: 	accuracy is 52.8
Iter 80: 	accuracy is 53.3
Iter 90: 	accuracy is 54.0
Iter 100: 	accuracy is 55.0
Iter 110: 	accuracy is 55.9
Iter 120: 	accuracy is 56.9
Iter 130: 	accuracy is 57.0
Iter 140: 	accuracy is 58.1
Iter 150: 	accuracy is 58.2
Iter 160: 	accuracy is 58.8
Iter 170: 	accuracy is 59.0
Iter 180: 	accuracy is 60.0
Iter 190: 	accuracy is 60.4
Iter 200: 	accuracy is 60.9
Iter 210: 	accuracy is 61.3
Iter 220: 	accuracy is 61.9
Iter 230: 	accuracy is 62.7
Iter 240: 	accuracy is 62.9
Iter 250: 	accuracy is 63.3
Iter 260: 	accuracy is 63.5
Iter 270: 	accuracy is 63.9
Iter 280: 	accuracy is 64.4
Iter 290: 	accuracy is 64.6
Iter 300: 	accuracy is 64.8
Iter 310: 	accuracy is 64.9
Iter 320: 	accuracy is 65.0
Iter 330: 	accuracy is 65.2
Iter 340: 	a

In [38]:
ray.get(ps[0].get_weights.remote(keys))

RayTaskError(KeyError): [36mray::ParameterServer.get_weights()[39m (pid=37281, ip=10.112.2.8, repr=<__main__.ParameterServer object at 0x7efed8235c90>)
  File "/tmp/ipykernel_38318/1549726583.py", line 21, in get_weights
  File "/tmp/ipykernel_38318/1549726583.py", line 21, in <listcomp>
KeyError: 0

In [72]:
model = LinearNet()

In [83]:
torch.cat(tuple(model.fc_weights)).unsqueeze(0).shape

torch.Size([1, 784])

In [50]:
for x in test_loader:
    print(x[0].shape)

torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])
torch.Size(

In [25]:
weight_assignments

{'server0': ['fc_bias',
  'fc_weights.1',
  'fc_weights.2',
  'fc_weights.3',
  'fc_weights.6',
  'fc_weights.7',
  'fc_weights.9',
  'fc_weights.14',
  'fc_weights.15',
  'fc_weights.16',
  'fc_weights.18',
  'fc_weights.20',
  'fc_weights.21',
  'fc_weights.25',
  'fc_weights.26',
  'fc_weights.28',
  'fc_weights.30',
  'fc_weights.32',
  'fc_weights.38',
  'fc_weights.39',
  'fc_weights.40',
  'fc_weights.41',
  'fc_weights.43',
  'fc_weights.45',
  'fc_weights.46',
  'fc_weights.47',
  'fc_weights.48',
  'fc_weights.49',
  'fc_weights.50',
  'fc_weights.51',
  'fc_weights.52',
  'fc_weights.53',
  'fc_weights.54',
  'fc_weights.56',
  'fc_weights.57',
  'fc_weights.60',
  'fc_weights.62',
  'fc_weights.63',
  'fc_weights.64',
  'fc_weights.65',
  'fc_weights.66',
  'fc_weights.72',
  'fc_weights.73',
  'fc_weights.74',
  'fc_weights.75',
  'fc_weights.76',
  'fc_weights.77',
  'fc_weights.80',
  'fc_weights.81',
  'fc_weights.83',
  'fc_weights.84',
  'fc_weights.85',
  'fc_weights