On utilise ce notebook pour présenter le comportement du controleur en communication avec les _workers_.

Le format de communication est __REQUEST - REPLY__, où chaque worker établit une socket REPLY et se connecte à la socket REQUEST du controleur. Le controleur quant à lui se connecte à tous les workers en REQUEST, et établit :  
+ en mode synchrone, une socket par worker,
+ en mode asynchrone, une seule socket.

In [3]:
import zmq
import time
import multiprocessing
from multiprocessing import Process

In [4]:
import numpy as np

# Approche synchrone

Dans le cas d'application considéré, on transfèrera les poids des neurones au moyen de `zmq.Socket.send_pyobj()` et `zmq.Socket.recv_pyobj()`.

(Voir les fichiers `./small_mnist_multigpu.py` et `./cifar_multigpu.py` pour l'application à Keras.)

In [22]:
def worker_process(port=9000, w_id=0, eta=0.9, rho=0.5):
    print("Initializing worker on port {}".format(port))
    context = zmq.Context()
    socket = context.socket(zmq.REP)
    socket.bind("tcp://*:{}".format(port))
    
    np.random.seed(w_id*7) 
    ##careful, seeds should be different 
    ## if different behaviors are expected between workers
    
    while True:
        # wait for controller request
        req = socket.recv_json()
        
        if 'stop' in req:
            break
        
        center = req['center']
        if 'init' in req:
            x = center
        
        # the gradient update step would go here
        
        grad = np.random.rand()
        x -= eta * (grad + rho*(x - center))
        
        time.sleep(1+grad)
        
        socket.send_json({'w_id':w_id, 'value':x, 'time':1+grad})

In [23]:
def controller(port=9000, n_workers=2):
    print("Initializing controller...")
    
    # initialize center variable
    x = np.random.rand()
    print("(init) center var = {:.2f}".format(x))
    
    # parameters
    eta, rho = 0.9, 0.5
    
    context = zmq.Context()
    
    ports = [port+k for k in range(n_workers)]
    
    # init workers and corresponding sockets (one per worker)
    workers = [context.socket(zmq.REQ) for k in range(n_workers)]
    
    for w_id in range(n_workers):
        w_port = ports[w_id]
        Process(target=worker_process, args=(w_port,w_id,eta,rho,)).start()
        w_sock = workers[w_id]
        w_sock.connect('tcp://localhost:{}'.format(w_port))
        
    # begin steps
    for step in range(5):
        values = np.zeros(n_workers)
        print('\n-- Step {} --\n'.format(step))
        
        for w_id in range(n_workers):
            # send request with current center to each worker
            worker = workers[w_id]
            to_send = {'center': x}
            
            # additionnal data to control workers init
            if step==0:
                to_send['init'] = True
            worker.send_json(to_send)
            
        for w_id in range(n_workers):
            # wait for each answer to perform update
            worker = workers[w_id]
            x_worker = worker.recv_json()
            
            print("(worker {w_id}) updated var = {value:.2f}".format(**x_worker))
            print("\ttook {time:.2f}".format(**x_worker))
            values[w_id] = x_worker['value']
        
        # compute new center value
        diffs = values - x
        x += eta * rho * np.sum(diffs)
        print("(update) center var = {:.2f}".format(x))
        
        # if last step, stop workers
        if step==4:
            for worker in workers:
                worker.send_json({'stop':True})

In [24]:
controller()

Initializing controller...
(init) center var = 0.11
Initializing worker on port 9000
Initializing worker on port 9001

-- Step 0 --

(worker 0) updated var = -0.38
	took 1.55
(worker 1) updated var = 0.04
	took 1.08
(update) center var = -0.14

-- Step 1 --

(worker 0) updated var = -0.92
	took 1.72
(worker 1) updated var = -0.74
	took 1.78
(update) center var = -0.76

-- Step 2 --

(worker 0) updated var = -1.39
	took 1.60
(worker 1) updated var = -1.15
	took 1.44
(update) center var = -1.22

-- Step 3 --

(worker 0) updated var = -1.80
	took 1.54
(worker 1) updated var = -1.83
	took 1.72
(update) center var = -1.76

-- Step 4 --

(worker 0) updated var = -2.16
	took 1.42
(worker 1) updated var = -2.68
	took 1.98
(update) center var = -2.35


# Approche asynchrone

In [None]:
# TODO

***

## Remarques

Lorsque l'on utilise `multiprocessing` et qu'un processus tourne encore en tâche de fond, il suffit d'utiliser cette boucle pour nettoyer tous les processus encore actifs :

In [20]:
for child in multiprocessing.active_children():
    child.terminate()