Skip to content
This repository has been archived by the owner on Feb 16, 2023. It is now read-only.

Commit

Permalink
Distributed PyTorch on Grid (via IPFS/PubSub) (#166)
Browse files Browse the repository at this point in the history
* First round of torch hooks integration (#152)

* Finished HookService, linked it with TorchService (#154)

* feat: modify pubsub_peers to handle newer IPFS api. (#153)

* finished minimal transfer of overloading code

* found an untested bug

* WIP for #130 and #132 (#155)

* feat: modify pubsub_peers to handle newer IPFS api. (#153)

* finished minimal transfer of overloading code

* found an untested bug

* adjust comments

* this round of work sponsored by parallel jalebi

* in the middle of fixing #130 and #132

* resolved #132, #130 will take a bit more effort than I'd planned for

* completes #130, prepares #129 and #131; almost took care of #148 in the process

* Worker side command processing and execution (#156)

resolved #129

* Finished implementing IPFS into torch services (#161)

* laptop sync

* finished up ipfs integration, yet to test

* syncing with colab notebooks

* renamed channels.openmined to channels.om

* found a worker node error

* bug in Tensor.send_

* fixed two client side bugs

* keyerror in receive_obj message

* register tensors before sending

* well that was rough

* more bug fixes

* premerge

* fix utils import in hook_worker_service

* fix return_result for worker

* premerge

* premerge2

* BOOM

* multinode demo (#162)

* lots o' comments (#164)

* Reorganizing notebooks (#165)

* lots o' comments

* reorganize notebooks
  • Loading branch information
jvmancuso authored and iamtrask committed Mar 31, 2018
1 parent 05b5c11 commit d2aa4b4
Show file tree
Hide file tree
Showing 27 changed files with 3,475 additions and 632 deletions.
50 changes: 24 additions & 26 deletions grid/channels.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,58 @@
# Channels

# Main channel that high level commands are broadcasted on
openmined = 'openmined'
om = 'openmined'

list_tasks = 'openmined:list_tasks'
add_task = 'openmined:add_task'
list_models = 'openmined:list_models'
list_workers = 'openmined:list_workers'
list_tasks = f'{om}:list_tasks'
add_task = f'{om}:task:add'
list_models = f'{om}:list_models'
list_workers = f'{om}:list_workers'


def list_tasks_callback(id):
return f'openmined:list_tasks:{id}'
return f'{list_tasks}:{id}'


def list_workers_callback(id):
return f'openmined:list_workers:{id}'
return f'{list_workers}:{id}'


def add_model(name):
return f'openmined:task:add:{name}'
return f'{add_task}:{name}'


# Whoami Channels

whoami_listener = 'openmined:whoami'


whoami_listener = f'{om}:whoami'
def whoami_listener_callback(id):
return f'{whoami_listener}:{id}'


# Torch Channels

torch_listen_for_obj = 'openmined:torch_listen_for_obj'


torch_listen_for_obj = f'{om}:torch_listen_for_obj'
def torch_listen_for_obj_callback(id):
return f'openmined:torch_listen_for_obj:{id}'


torch_listen_for_obj_response = 'openmined:torch_listen_for_obj_res'
return f'{torch_listen_for_obj}:{id}'


torch_listen_for_obj_response = f'{om}:torch_listen_for_obj_res'
def torch_listen_for_obj_response_callback(id):
return f'openmined:torch_listen_for_obj_res:{id}'


torch_listen_for_obj_req = 'openmined:torch_listen_for_obj_req'
return f'{torch_listen_for_obj_response}:{id}'


torch_listen_for_obj_req = f'{om}:torch_listen_for_obj_req'
def torch_listen_for_obj_req_callback(id):
return f'openmined:torch_listen_for_obj_req:{id}'
return f'{torch_listen_for_obj_req}:{id}'


torch_listen_for_obj_req_response = 'openmined:torch_listen_for_obj_req_res'
torch_listen_for_obj_req_response = f'{om}:torch_listen_for_obj_req_res'
def torch_listen_for_obj_req_response_callback(id):
return f'{torch_listen_for_obj_req_response}:{id}'

torch_listen_for_command = f'{om}:torch_listen_for_command'
def torch_listen_for_command_callback(id):
return f'{torch_listen_for_command}:{id}'

def torch_listen_for_obj_req_response_callback(id):
return f'openmined:torch_listen_for_obj_req_res:{id}'
torch_listen_for_command_response = f'{om}:torch_listen_for_command_response'
def torch_listen_for_command_response_callback(id):
return f'{torch_listen_for_command_response}:{id}'
7 changes: 2 additions & 5 deletions grid/clients/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,10 @@ def refresh(self, refresh_known_nodes=True, refresh_network_stats=True):
self.stats = self.refresh_network_stats()

def get_stats(self, worker_id, timeout=10):
def ret(msg):
return json.loads(msg['data'])

return self.request_response(
channel=channels.whoami_listener_callback(worker_id),
message=[],
response_handler=ret,
response_handler=utils.unpack,
timeout=10)

def print_network_stats(self):
Expand Down Expand Up @@ -174,7 +171,7 @@ def __len__(self):
"""

def found_task(self, message):
tasks = json.loads(message['data'])
tasks = utils.unpack(message)
for task in tasks:
# utils.store_task(task['name'], task['address'])
name = task['name']
Expand Down
2 changes: 1 addition & 1 deletion grid/clients/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def load_model(self, addr):
return keras_utils.ipfs2keras(self.api, addr)

def receive_model(self, message, verbose=True):
msg = json.loads(message['data'])
msg = utils.unpack(message)

if (msg is not None):
if (msg['type'] == 'transact'):
Expand Down
2 changes: 2 additions & 0 deletions grid/clients/torch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from . import base
from ..services.torch.torch_service import TorchService
from ..services.torch.hook_service import HookService


class TorchClient(base.BaseClient):
Expand All @@ -14,4 +15,5 @@ def __init__(self,
include_github_known_workers=include_github_known_workers,
verbose=verbose)

self.services['hook_service'] = HookService(self)
self.services['torch_service'] = TorchService(self)
209 changes: 151 additions & 58 deletions grid/lib/torch_utils.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,158 @@
import os
import json
import re

from pathlib import Path

from . import utils


def torch2ipfs(model):
pass


def ipfs2torch(model_addr):
import torch


# Helpers for HookService and TorchService
def check_workers(self, workers):
if type(workers) is str:
workers = [workers]
elif not hasattr(workers, '__iter__'):
raise TypeError(
"""Can only send {} to a string worker ID or an iterable of
string worker IDs, not {}""".format(self.__name__, type(owners))
)
return workers


def get_tensorvars(self, command):
args = command['args']
kwargs = command['kwargs']
arg_types = command['arg_types']
kwarg_types = command['kwarg_types']
tensorvar_args = [args[i] for i in range(len(args)) if arg_types[i] in self.tensorvar_types_strs]
tensorvar_kwvals = [kwargs[i][1] for i in range(len(kwargs)) if kwarg_types[i] in self.tensorvar_types_strs]
return tensorvar_args + tensorvar_kwvals


def check_remote(tensorvars):
return any([tensorvar.is_pointer for tensorvar in tensorvars])


def get_owners(tensorvars):
owners = list(set([owner
for tensorvar in tensorvars
for owner in tensorvar.owners]))
multiple_owners = len(owners) > 1
return multiple_owners, owners


def replace_tensorvar(x):
if hasattr(torch, 'old_is_tensor'):
check = torch.old_is_tensor
else:
check = torch.is_tensor
try:
if check(x) or isinstance(x, torch.autograd.Variable):
return '_fl.{}'.format(x.id)
else:
[replace_tensorvar(i) for i in x]
except (AttributeError, TypeError):
return x


def replace_in_command(command_msg):
command_msg['args'] = map_tuple(
None, command_msg['args'], replace_tensorvar)
command_msg['kwargs'] = map_dict(
None, command_msg['kwargs'], replace_tensorvar)
try:
command_msg['self'] = replace_tensorvar(command_msg['self'])
except KeyError:
pass
return command_msg

# Client needs to identify a tensor before sending commands that use it
def id_tensorvar(x):
pat = re.compile('_fl.(.*)')
try:
if isinstance(x, str):
return int(pat.search(x).group(1))
else:
return [id_tensorvar(i) for i in x]
except AttributeError:
return x


# Safety checks for serializing and deserializing torch objects
# Desperately needs stress testing before going out in the wild
map_torch_type = {
'torch.FloatTensor':torch.FloatTensor,
'torch.DoubleTensor':torch.DoubleTensor,
'torch.HalfTensor':torch.HalfTensor,
'torch.ByteTensor':torch.ByteTensor,
'torch.CharTensor':torch.CharTensor,
'torch.ShortTensor':torch.ShortTensor,
'torch.IntTensor':torch.IntTensor,
'torch.LongTensor':torch.LongTensor,
'torch.autograd.variable.Variable':torch.autograd.variable.Variable,
'torch.nn.parameter.Parameter':torch.nn.parameter.Parameter
}


def types_guard(obj_type):
return map_torch_type[obj_type]


def tensor_contents_guard(contents):
# TODO: check to make sure the incoming list isn't dangerous to use for
# constructing a tensor (likely non-trivial)
return contents


def command_guard(command, allowed):
if command not in allowed:
raise RuntimeError(
'Command "{}" is not a supported Torch operation.'.format(command))
return command


# Worker needs to retrieve tensor by ID before computing with it
def retrieve_tensor(self, x):
try:
return self.worker.objects[id_tensorvar(x)]
except TypeError:
try:
return [self.worker.objects[i] for i in id_tensorvar(x)]
except TypeError:
return x
except KeyError:
return x


def map_tuple(service, args, func):
if service:
return tuple(func(service, x) for x in args)
else:
return tuple(func(x) for x in args)


def map_dict(service, kwargs, func):
if service:
return {key:func(service, val) for key, val in kwargs.items()}
else:
return {key:func(val) for key, val in kwargs.items()}


def hook_tensor_ser(service_self, tensor_type):
def ser(self, include_data=True):
"""Serializes a {} object to JSON.""".format(tensor_type)
tensor_msg = {}
tensor_msg['torch_type'] = self.type()
if (include_data):
tensor_msg['data'] = self.tolist()
tensor_msg['id'] = self.id
tensor_msg['owners'] = self.owners
return json.dumps(tensor_msg)

tensor_type.ser = ser


def hook_var_ser(service_self):
# TODO
pass


def serialize_torch_model(model, **kwargs):
"""
kwargs are the arguments needed to instantiate the model
"""
state = {'state_dict': model.state_dict(), 'kwargs': kwargs}
torch.save(state, 'temp_model.pth.tar')
with open('temp_model.pth.tar', 'rb') as f:
model_bin = f.read()
return model_bin


def deserialize_torch_model(model_bin, model_class, **kwargs):
"""
model_class is needed since PyTorch uses pickle for serialization
see https://discuss.pytorch.org/t/loading-pytorch-model-without-a-code/12469/2 for details
kwargs are the arguments needed to instantiate the model from model_class
"""
with open('temp_model2.pth.tar', 'wb') as g:
g.write(model_bin)
state = torch.load()
model = model_class(**state['kwargs'])
model.load_state_dict(state['state_dict'])
return model


def save_best_torch_model_for_task(task, model):
utils.ensure_exists(f'{Path.home()}/.openmined/models.json', {})
with open(f"{Path.home()}/.openmined/models.json", "r") as model_file:
models = json.loads(model_file.read())

models[task] = torch2ipfs(model)

with open(f"{Path.home()}/.openmined/models.json", "w") as model_file:
json.dump(models, model_file)


def best_torch_model_for_task(task, return_model=False):
if not os.path.exists(f'{Path.home()}/.openmined/models.json'):
return None

with open(f'{Path.home()}/.openmined/models.json', 'r') as model_file:
models = json.loads(model_file.read())
if task in models.keys():
if return_model:
return ipfs2torch(models[task])
else:
return models[task]

return None
4 changes: 4 additions & 0 deletions grid/lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
import numpy as np


def unpack(message):
return json.loads(message['data'])


def get_ipfs_api(mode, ipfs_addr='127.0.0.1', port=5001, max_tries=25):
print(
f'\n{Fore.BLUE}UPDATE: {Style.RESET_ALL}Connecting to IPFS... this can take a few seconds...'
Expand Down
Loading

0 comments on commit d2aa4b4

Please sign in to comment.