In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import numpy as np

import ray
import model
import random,time
from time import sleep
import copy 

if ray.is_initialized():
    ray.shutdown()
ray.init(memory=5000000000,object_store_memory=3000000000)

# @ray.remote(num_gpus=1)
@ray.remote
class ParameterServer():
    def __init__(self,learning_rate,num_workers,stalness_limit):
        self.net = model.SimpleCNN(learning_rate=learning_rate)
        self.stalness_table = [0] * num_workers
        self.stalness_limit = stalness_limit 
        self.global_step = 0
        self.eva_model = model.SimpleCNN()
        self.mnist = model.download_mnist_retry()
        
    def apply_gradients(self, gradients, wk_idx):
        self.net.apply_gradients(gradients)
        self.stalness_table[wk_idx] += 1
        self.global_step += 1
        if self.global_step % 100 == 0:
            print("global_step: ",self.global_step," and prepare evaluate")
            self.evaluate()
        
    def pull_weights(self):
        # return value of weights
        return self.net.get_weights()
    
    def check_stalness(self,wk_idx):
        min_iter = min(self.stalness_table)
        return self.stalness_table[wk_idx] - min_iter < self.stalness_limit
        
    def get_stalness(self):
        return min(self.stalness_table)
    
    def evaluate(self):
        cur_wei = self.net.get_weights()
        self.eva_model.set_weights(cur_wei[0],cur_wei[1])
        test_xs, test_ys = self.mnist.test.next_batch(1000)
        accuracy = self.eva_model.compute_accuracy(test_xs, test_ys)
        print("Iteration {}: accuracy is {}".format(self.global_step, accuracy))
        
@ray.remote
def worker_task(ps,worker_index,stale_limit,batch_size=50):
    mnist = model.download_mnist_retry(seed=worker_index)
    # Initialize the model.
    net = model.SimpleCNN()
    keys = net.get_weights()[0]
    local_step = 0
    
    while True:
        while(local_step - ray.get(ps.get_stalness.remote()) > stale_limit):
            print(worker_index," works too fast")
            sleep(1)
            
        # Get the current weights from the parameter server.
        init_wei = ray.get(ps.pull_weights.remote())
        net.set_weights(init_wei[0], init_wei[1])

        # Compute an update and push it to the parameter server.
        xs, ys = mnist.train.next_batch(batch_size)
        gradients = net.compute_gradients(xs, ys)
        local_step += 1
        ps.apply_gradients.remote(gradients,worker_index)
#         print(worker_index,"has finished update")

# @ray.remote
# class Worker():
#     def __init__(self,worker_index,init_wei, batch_size=50):
#         self.worker_index = worker_index
#         self.batch_size = batch_size
#         self.mnist = model.download_mnist_retry(seed=worker_index)
#         # the init weight is randomly assigned
#         self.net = model.SimpleCNN()
#         self.net.set_weights(init_wei[0],init_wei[1])
#         self.iter_count = 0

#     def iterater(self,ps_handler):
#         while not ray.get(ps_handler.check_stalness.remote(self.worker_index)):
# #             print(self.worker_index," works too fast")
#             sleep(1)
#         ps_handler.apply_gradients.remote(self._compute_gradients(),self.worker_index)
#         if self.iter_count % 5 == 0:
#             self.sync_wei(ps_handler)
    
#     def _compute_gradients(self):
#         # simulate network delay
# #         sleep(random.randint(0,3))
# #         print(self.worker_index,"is going to compute gradients")
#         xs,ys=self.mnist.train.next_batch(self.batch_size)
#         gradients = self.net.compute_gradients(xs,ys)
#         self.net.apply_gradients(gradients)
#         self.iter_count += 1 
#         return gradients
    
#     def set_weights(self,keys,weights):
#         self.net.set_weights(keys,weights)
    
#     def get_weights(self):
#         return self.net.get_weights()
    
#     def sync_wei(self,ps_handler):
#          # sync wei from ps 
#         sync_wei = ray.get(ps_handler.pull_weights.remote())
#         self.net.set_weights(sync_wei[0],sync_wei[1])
            

2019-10-28 20:19:17,493	INFO resource_spec.py:205 -- Starting Ray with 4.64 GiB memory available for workers and up to 2.79 GiB for objects. You can adjust these settings with ray.init(memory=<bytes>, object_store_memory=<bytes>).


In [2]:
num_worker = 3
stalness_table = [0] * num_worker
stalness_limit = 4

ps = ParameterServer.remote(1e-3,num_worker,stalness_limit)
# init_wei = ray.get(ps.pull_weights.remote())
# workers = [Worker.remote(index,init_wei) for index in range(num_worker)]
worker_tasks = [worker_task.remote(ps,i,stalness_limit) for i in range(num_worker)]

[2m[36m(pid=13201)[0m Instructions for updating:
[2m[36m(pid=13201)[0m Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
[2m[36m(pid=13201)[0m Instructions for updating:
[2m[36m(pid=13201)[0m Please write your own downloading logic.
[2m[36m(pid=13201)[0m Instructions for updating:
[2m[36m(pid=13201)[0m Please use tf.data to implement this functionality.
[2m[36m(pid=13201)[0m Extracting MNIST_data/train-images-idx3-ubyte.gz
[2m[36m(pid=13207)[0m Instructions for updating:
[2m[36m(pid=13207)[0m Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
[2m[36m(pid=13207)[0m Instructions for updating:
[2m[36m(pid=13207)[0m Please write your own downloading logic.
[2m[36m(pid=13207)[0m Instructions for updating:
[2m[36m(pid=13207)[0m Please use tf.data to implement this functionality.
[2m[36m(pid=13207)[0m Extracting MNIST_data/train-images-idx3-ubyte.gz
[2m[36m(pid=13201)[0m Instructions fo

[2m[36m(pid=13207)[0m Instructions for updating:
[2m[36m(pid=13207)[0m 
[2m[36m(pid=13207)[0m Future major versions of TensorFlow will allow gradients to flow
[2m[36m(pid=13207)[0m into the labels input on backprop by default.
[2m[36m(pid=13207)[0m 
[2m[36m(pid=13207)[0m See @{tf.nn.softmax_cross_entropy_with_logits_v2}.
[2m[36m(pid=13207)[0m 
[2m[36m(pid=13200)[0m Instructions for updating:
[2m[36m(pid=13200)[0m Please use tf.data to implement this functionality.
[2m[36m(pid=13200)[0m Instructions for updating:
[2m[36m(pid=13200)[0m Please use tf.one_hot on tensors.
[2m[36m(pid=13200)[0m Extracting MNIST_data/train-labels-idx1-ubyte.gz
[2m[36m(pid=13200)[0m Extracting MNIST_data/t10k-images-idx3-ubyte.gz
[2m[36m(pid=13200)[0m Instructions for updating:
[2m[36m(pid=13200)[0m Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
[2m[36m(pid=13200)[0m Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
[2m[36m(pid=

[2m[36m(pid=13202)[0m global_step:  1900  and prepare evaluate
[2m[36m(pid=13202)[0m Iteration 1900: accuracy is 0.9810000061988831
[2m[36m(pid=13202)[0m global_step:  2000  and prepare evaluate
[2m[36m(pid=13202)[0m Iteration 2000: accuracy is 0.9800000190734863
[2m[36m(pid=13202)[0m global_step:  2100  and prepare evaluate
[2m[36m(pid=13202)[0m Iteration 2100: accuracy is 0.9800000190734863
[2m[36m(pid=13202)[0m global_step:  2200  and prepare evaluate
[2m[36m(pid=13202)[0m Iteration 2200: accuracy is 0.9890000224113464
[2m[36m(pid=13201)[0m 0  works too fast
[2m[36m(pid=13207)[0m 2  works too fast
[2m[36m(pid=13202)[0m global_step:  2300  and prepare evaluate
[2m[36m(pid=13202)[0m Iteration 2300: accuracy is 0.9789999723434448
[2m[36m(pid=13202)[0m global_step:  2400  and prepare evaluate
[2m[36m(pid=13202)[0m Iteration 2400: accuracy is 0.9789999723434448
[2m[36m(pid=13202)[0m global_step:  2500  and prepare evaluate
[2m[36m(pid=13202)

In [None]:
# for _ in range(2000):
#     [worker.iterater.remote(ps) for idx,worker in enumerate(workers)]

In [None]:
# def compute_grad_apply(worker,worker_index,ps):
#     sleep(random.randint(0, 3))
#     grad = ray.get(worker.compute_gradients.remote())
#     ps.apply_gradients.remote(grad)
#     stalness_table[worker_index] += 1

In [None]:
# def check_stalness(stalness_table,worker_index,stalness_limit):
#     max_iter = max(stalness_table.values())
#     return max_iter - stalness_table[worker_index] < stalness_limit


In [None]:
# def worker_iter(idx,worker):
#     print(idx,"starts to update")
#     while not check_stalness(stalness_table,idx,stalness_limit):
#         print("meet stalness upper limit")
#         sleep(1)
#     compute_grad_apply(worker,idx,ps)
#     print(idx,"ends update")

In [None]:
# for(_ in range(100)):
#     for wk_idx in len(workers):
#         while not check_stalness(stalness_table,wk_idx,stalness_limit=10):
#             print("meet stalness limit")
#             sleep(1)
#         worker.compute_gradients.remote()
        

In [None]:
# # 检查 并行？ 
# from time import sleep
# import random

# @ray.remote
# class test():
#     def __init__(self):
#         self.count = 0
#     def sleep_inc(self,idx):
#         sleep(random.randint(0,3))
#         self.count += 1
#         print(idx,"has incresed")
#         self.help_fun()
#         return self.count
    
#     def help_fun(self):
#         print('helping')
    
# tests = [test.remote() for _ in range(5)]
# [t.sleep_inc.remote(idx) for idx,t in enumerate(tests)]