In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from filelock import FileLock
import numpy as np
import ray
import sys
sys.path.append("../")
from consistent_hashing import ConsistentHash
import math 
from time import time
def get_data_loader():
    
    """Safely downloads data. Returns training/validation set dataloader."""
    mnist_transforms = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )

    # We add FileLock here because multiple workers will want to
    # download data, and this may cause overwrites since
    # DataLoader is not threadsafe.
    
    class MNISTEvenOddDataset(torch.utils.data.Dataset):
        def __init__(self, ready_data):
            self.img_data = ready_data.data
            self.labels = ready_data.targets % 2
        
        def __len__(self):
            return len(self.labels)
    
        def __getitem__(self, ind):
            return torch.true_divide(self.img_data[ind].view(-1, 28 * 28).squeeze(), 255), torch.tensor([self.labels[ind]])


    
    with FileLock(os.path.expanduser("~/data.lock")):
        
        train_dataset = datasets.MNIST(
                "~/data", train=True, download=True, transform=mnist_transforms
            )
        
        test_dataset = datasets.MNIST("~/data", train=False, transform=mnist_transforms)
        
        train_loader = torch.utils.data.DataLoader(
            MNISTEvenOddDataset(train_dataset),
            batch_size=128,
            shuffle=True,
        )
        test_loader = torch.utils.data.DataLoader(
             MNISTEvenOddDataset(test_dataset),
            batch_size=128,
            shuffle=False,
        )
    return train_loader, test_loader


def evaluate(model, test_loader):
    """Evaluates the accuracy of the model on a validation dataset."""
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader):
            # This is only set to finish evaluation faster.
            if batch_idx * len(data) > 1024:
                break
            outputs = nn.Sigmoid()(model(data))
            #_, predicted = torch.max(outputs.data, 1)
            predicted = outputs > 0.5
            total += target.size(0)
            correct += (predicted == target).sum().item()
    return 100.0 * correct / total

In [2]:
class LinearNet(nn.Module):
    """Small Linear Network for MNIST."""

    def __init__(self):
        super(LinearNet, self).__init__()
        self.fc_weights = nn.ParameterList([nn.Parameter(torch.empty(1)) for weight in range(784)])
        init_fc = [nn.init.uniform_(x) for x in self.fc_weights]
        
        self.fc_bias = nn.Parameter(torch.empty(1))
        nn.init.uniform_(self.fc_bias)
        
    #def __init__(self):
    #    super(LinearNet, self).__init__()
    #    self.fc = nn.Linear(28*28, 1)
    #    nn.init.normal(self.fc.weight)

    #def forward(self, x):
    #    x = self.fc(x)
    #    return x
    
    def forward(self, x):
        #fc_layer = torch.cat(tuple(self.fc_weights)).unsqueeze(0)
        #x = x @ fc_layer.T + self.fc_bias
        for i, param in enumerate(self.fc_weights):
            if i==0:
                p=x[:,i]*param
            else:
                p += x[:,i]*param
        x = p.unsqueeze(1) + self.fc_bias
        return x
    
    def get_weights(self):
        return {k: v.cpu() for k, v in self.state_dict().items()}

    def set_weights(self, keys, weights): 
        flatten_weights =  [item for sublist in weights for item in sublist]
        self.load_state_dict({keys[i]:flatten_weights[i] for i in range(len(keys))})
        
    def get_gradients(self, keys):
        grads = {}

        for name, p in self.named_parameters():
            if name in keys:
                grad = None if p.grad is None else p.grad.data.cpu().numpy()
                grads[name] = grad

        return [grads[key] for key in keys]

    def set_gradients(self, gradients):
        for g, p in zip(gradients, self.parameters()):
            if g is not None:
                p.grad = torch.from_numpy(g)

In [63]:
@ray.remote  
class ParameterServer(object):
    def __init__(self, keys, values):
        self.weights = dict(zip(keys, values))

    def apply_gradients(self, keys, lr, *values):
        summed_gradients = [
            np.stack(gradient_zip).sum(axis=0) for gradient_zip in zip(*values)
        ]
    
        idx = 0
        for key, value in zip(keys, summed_gradients):
            self.weights[key] -= lr * torch.from_numpy(summed_gradients[idx])
            idx+=1

        return [self.weights[key] for key in keys]
    
    def add_weight(self, key, value):
        self.weights[key] = value
    
    def get_weights(self, keys):
        return [self.weights[key] for key in keys]


In [68]:
@ray.remote
class DataWorker(object):
    def __init__(self, keys):
        self.model = LinearNet()
        self.data_iterator = iter(get_data_loader()[0])
        self.keys = keys
        self.key_set = set(self.keys)
        for key, value in dict(self.model.named_parameters()).items():
            if key not in self.key_set:
                value.requires_grad=False

        
    def update_weights(self, keys, *weights):
        self.model.set_weights(keys, weights)
        
    def update_trainable(self, keys):
        self.keys = keys
        self.key_set = set(self.keys)
        for key, value in dict(self.model.named_parameters()).items():
            if key in self.key_set:
                value.requires_grad = True
            else:
                value.requires_grad = False
       

    def compute_gradients(self):
        #self.model.set_weights(keys, weights)
        try:
            data, target = next(self.data_iterator)
        except StopIteration:  # When the epoch ends, start a new epoch.
            self.data_iterator = iter(get_data_loader()[0])
            data, target = next(self.data_iterator)
        
        self.model.zero_grad()
        output = self.model(data)
        loss = nn.BCEWithLogitsLoss()(output, target.float())
        loss.backward()
        
        return self.model.get_gradients(self.keys)

In [114]:
iterations = 500
num_workers = 1 # number of workers per server
num_servers = 5 # number of servers
hashes_per_server = 100

def Scheduler(num_servers, hashes_per_server=50):
    
    model = LinearNet()
    key_values = model.get_weights()
    keys = np.array(list(key_values.keys()))
    #print(keys)
    #print(key_values) z
    values = [key_values[key] for key in keys]
    #values = [key_values[key] for key in keys]
    
    key_indices = {key: x for x, key in enumerate(keys)}
   
    # distributing weights across servers - do this using consistency hashing
    server_ids = ["server" + str(ind) for ind in range(num_servers)]
    hasher = ConsistentHash(keys, server_ids, hashes_per_server)
    servers = [ParameterServer.remote(keys[[key_indices[key] for key in hasher.get_keys_per_node()[serv]]], 
                                      [values[key_indices[key]] for key in hasher.get_keys_per_node()[serv]]) for serv in server_ids]
    # servers = [ParameterServer.remote(keys[0:1], values[0:1]), ParameterServer.remote(keys[1:2], values[1:2])]
    
    return hasher, servers, keys, model, hasher.get_keys_per_node(), server_ids.copy()

hasher, servers, keys, model, weight_assignments, server_ids =  Scheduler(num_servers, hashes_per_server)
ray.init(ignore_reinit_error=True)

# creating equal workers per server

workers = [[DataWorker.remote(weight_assignments["server" + str(j)]) for i in range(num_workers)] for j in range(num_servers)]





2022-05-20 19:22:53,475	INFO worker.py:963 -- Calling ray.init() again after it has already been called.


In [75]:
len(weight_assignments["server0"])

785

In [115]:
test_loader = get_data_loader()[1]


In [116]:
print("Running synchronous parameter server training.")
lr=0.1
failure_iter=60
failure_server="server4"

# we need to get a new keys order because we are not assuming a ordering in keys
current_weights = []
keys_order = []

for j in range(num_servers):
    keys_order.extend(weight_assignments["server" + str(j)])
    current_weights.extend(ray.get(servers[j].get_weights.remote(weight_assignments["server" + str(j)]))) 
curr_weights_ckpt = current_weights.copy()

time_per_iteration = []
for i in range(iterations):
 
    #start = time()
    
    if i == failure_iter:
        #Define parameters that will need to be moved
        failure_params = weight_assignments[failure_server]
        #Delete server from hash ring and reassign params
        hasher.delete_node_and_reassign_to_others(failure_server)
        weight_assignments = hasher.get_keys_per_node()
        #Update servers and workers
        num_servers -= 1
        server_ind = server_ids.index(failure_server)
        server_ids = server_ids[0 : server_ind] + server_ids[server_ind + 1 : ]
        servers = servers[0 : server_ind] + servers[server_ind + 1 : ]
        workers = workers[0 : server_ind] + workers[server_ind + 1 : ]
        #Add each relevant parameter to its new server
        server_dict = {server_ids[x]:servers[x] for x in range(len(server_ids))}
        for ind, param in enumerate(failure_params):
            server_dict[hasher.get_key_to_node_map()[param]].add_weight.remote(param, curr_weights_ckpt[server_ind][ind])
        #Update these parameters for each worker to make them trainable
        [workers[j][idx].update_trainable.remote(weight_assignments["server" + str(j)]) for  idx  in range(num_workers) for j in range(num_servers)]
        keys_order = []
        for j in range(num_servers):
            keys_order.extend(weight_assignments["server" + str(j)])

    
    # sync all weights on workers
    if i % 10 == 0:
        curr_weights_ckpt = current_weights.copy()
        # get weights from server
        #current_weights = [servers[j].get_weights.remote(weight_assignments["server" + str(j)]) for j in range(num_servers)] 

        # update weights on all workers
        [workers[j][idx].update_weights.remote(keys_order, *current_weights) for  idx  in range(num_workers) for j in range(num_servers)]
    
        
    # use local cache of weights and get gradients from workers
    gradients = [[workers[j][idx].compute_gradients.remote() for  idx  in range(num_workers)] for j in range(num_servers)]

    start = time()
    # Updates gradients to specfic parameter servers
    current_weights_t = [servers[j].apply_gradients.remote(weight_assignments["server" + str(j)], lr, *gradients[j]) for j in range(num_servers)]
    current_weights = ray.get(current_weights_t)
    
    end = time()
    time_per_iteration.append(end-start)

    if i % 10 == 0:
        # Evaluate the current model.
        # current_weights = [servers[j].get_weights.remote(weight_assignments["server" + str(j)]) for j in range(num_servers)] 
      
        # we are once again using the server to key mapping to set the weight back
        model.set_weights(keys_order, current_weights)
        accuracy = evaluate(model, test_loader)
        print("Iter {}: \taccuracy is {:.1f}".format(i, accuracy))
    #rint("\n")

#print("Final accuracy is {:.1f}.".format(accuracy))
# Clean up Ray resources and processes before the next example.
ray.shutdown()


Running synchronous parameter server training.
Iter 0: 	accuracy is 51.9
Iter 10: 	accuracy is 51.9
Iter 20: 	accuracy is 52.4
Iter 30: 	accuracy is 48.2
Iter 40: 	accuracy is 65.4
Iter 50: 	accuracy is 51.6
Iter 60: 	accuracy is 75.2
Iter 70: 	accuracy is 48.6
Iter 80: 	accuracy is 77.4
Iter 90: 	accuracy is 78.1
Iter 100: 	accuracy is 78.6
Iter 110: 	accuracy is 79.4
Iter 120: 	accuracy is 80.2
Iter 130: 	accuracy is 81.3
Iter 140: 	accuracy is 80.6
Iter 150: 	accuracy is 81.9
Iter 160: 	accuracy is 81.8
Iter 170: 	accuracy is 82.2
Iter 180: 	accuracy is 81.8
Iter 190: 	accuracy is 82.0
Iter 200: 	accuracy is 81.8
Iter 210: 	accuracy is 82.3
Iter 220: 	accuracy is 82.9
Iter 230: 	accuracy is 82.9
Iter 240: 	accuracy is 83.1
Iter 250: 	accuracy is 83.2
Iter 260: 	accuracy is 83.2
Iter 270: 	accuracy is 83.2
Iter 280: 	accuracy is 82.2
Iter 290: 	accuracy is 81.7
Iter 300: 	accuracy is 81.0
Iter 310: 	accuracy is 80.2
Iter 320: 	accuracy is 79.3
Iter 330: 	accuracy is 79.9
Iter 340: 	a

KeyboardInterrupt: 

In [78]:
np.mean(time_per_iteration[1:])

0.28572969966464573

In [79]:
np.std(time_per_iteration[1:])

0.014149284267918704

In [None]:
time_per_iteration

In [13]:
server_ids

['server0', 'server1', 'server2', 'server3', 'server4']

In [14]:
server_ids.index("server4")

4

In [40]:
hasher.get_key_to_node_map()['fc_bias']

'server4'

In [46]:
testmodel=LinearNet()

In [48]:
testmodel.get_weights()

{'fc_bias': tensor([0.6657]),
 'fc_weights.0': tensor([0.8101]),
 'fc_weights.1': tensor([0.9490]),
 'fc_weights.2': tensor([0.1687]),
 'fc_weights.3': tensor([0.3932]),
 'fc_weights.4': tensor([0.7757]),
 'fc_weights.5': tensor([0.7350]),
 'fc_weights.6': tensor([0.7638]),
 'fc_weights.7': tensor([0.0539]),
 'fc_weights.8': tensor([0.7448]),
 'fc_weights.9': tensor([0.7037]),
 'fc_weights.10': tensor([0.9956]),
 'fc_weights.11': tensor([0.3894]),
 'fc_weights.12': tensor([0.9956]),
 'fc_weights.13': tensor([0.0574]),
 'fc_weights.14': tensor([0.7612]),
 'fc_weights.15': tensor([0.6381]),
 'fc_weights.16': tensor([0.7507]),
 'fc_weights.17': tensor([0.7583]),
 'fc_weights.18': tensor([0.9456]),
 'fc_weights.19': tensor([0.5211]),
 'fc_weights.20': tensor([0.3994]),
 'fc_weights.21': tensor([0.7367]),
 'fc_weights.22': tensor([0.7428]),
 'fc_weights.23': tensor([0.2396]),
 'fc_weights.24': tensor([0.4197]),
 'fc_weights.25': tensor([0.5078]),
 'fc_weights.26': tensor([0.0744]),
 'fc_wei

In [51]:
len(current_weights[0])

160

In [81]:
servers[0].weights

AttributeError: 'ActorHandle' object has no attribute 'weights'

In [96]:
sum([184, 198, 191, 212])

785