This repository has been archived by the owner on Feb 16, 2023. It is now read-only.
-
-
Notifications
You must be signed in to change notification settings - Fork 217
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Distributed PyTorch on Grid (via IPFS/PubSub) (#166)
* First round of torch hooks integration (#152) * Finished HookService, linked it with TorchService (#154) * feat: modify pubsub_peers to handle newer IPFS api. (#153) * finished minimal transfer of overloading code * found an untested bug * WIP for #130 and #132 (#155) * feat: modify pubsub_peers to handle newer IPFS api. (#153) * finished minimal transfer of overloading code * found an untested bug * adjust comments * this round of work sponsored by parallel jalebi * in the middle of fixing #130 and #132 * resolved #132, #130 will take a bit more effort than I'd planned for * completes #130, prepares #129 and #131; almost took care of #148 in the process * Worker side command processing and execution (#156) resolved #129 * Finished implementing IPFS into torch services (#161) * laptop sync * finished up ipfs integration, yet to test * syncing with colab notebooks * renamed channels.openmined to channels.om * found a worker node error * bug in Tensor.send_ * fixed two client side bugs * keyerror in receive_obj message * register tensors before sending * well that was rough * more bug fixes * premerge * fix utils import in hook_worker_service * fix return_result for worker * premerge * premerge2 * BOOM * multinode demo (#162) * lots o' comments (#164) * Reorganizing notebooks (#165) * lots o' comments * reorganize notebooks
- Loading branch information
Showing
27 changed files
with
3,475 additions
and
632 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,60 +1,58 @@ | ||
# Channels | ||
|
||
# Main channel that high level commands are broadcasted on | ||
openmined = 'openmined' | ||
om = 'openmined' | ||
|
||
list_tasks = 'openmined:list_tasks' | ||
add_task = 'openmined:add_task' | ||
list_models = 'openmined:list_models' | ||
list_workers = 'openmined:list_workers' | ||
list_tasks = f'{om}:list_tasks' | ||
add_task = f'{om}:task:add' | ||
list_models = f'{om}:list_models' | ||
list_workers = f'{om}:list_workers' | ||
|
||
|
||
def list_tasks_callback(id): | ||
return f'openmined:list_tasks:{id}' | ||
return f'{list_tasks}:{id}' | ||
|
||
|
||
def list_workers_callback(id): | ||
return f'openmined:list_workers:{id}' | ||
return f'{list_workers}:{id}' | ||
|
||
|
||
def add_model(name): | ||
return f'openmined:task:add:{name}' | ||
return f'{add_task}:{name}' | ||
|
||
|
||
# Whoami Channels | ||
|
||
whoami_listener = 'openmined:whoami' | ||
|
||
|
||
whoami_listener = f'{om}:whoami' | ||
def whoami_listener_callback(id): | ||
return f'{whoami_listener}:{id}' | ||
|
||
|
||
# Torch Channels | ||
|
||
torch_listen_for_obj = 'openmined:torch_listen_for_obj' | ||
|
||
|
||
torch_listen_for_obj = f'{om}:torch_listen_for_obj' | ||
def torch_listen_for_obj_callback(id): | ||
return f'openmined:torch_listen_for_obj:{id}' | ||
|
||
|
||
torch_listen_for_obj_response = 'openmined:torch_listen_for_obj_res' | ||
return f'{torch_listen_for_obj}:{id}' | ||
|
||
|
||
torch_listen_for_obj_response = f'{om}:torch_listen_for_obj_res' | ||
def torch_listen_for_obj_response_callback(id): | ||
return f'openmined:torch_listen_for_obj_res:{id}' | ||
|
||
|
||
torch_listen_for_obj_req = 'openmined:torch_listen_for_obj_req' | ||
return f'{torch_listen_for_obj_response}:{id}' | ||
|
||
|
||
torch_listen_for_obj_req = f'{om}:torch_listen_for_obj_req' | ||
def torch_listen_for_obj_req_callback(id): | ||
return f'openmined:torch_listen_for_obj_req:{id}' | ||
return f'{torch_listen_for_obj_req}:{id}' | ||
|
||
|
||
torch_listen_for_obj_req_response = 'openmined:torch_listen_for_obj_req_res' | ||
torch_listen_for_obj_req_response = f'{om}:torch_listen_for_obj_req_res' | ||
def torch_listen_for_obj_req_response_callback(id): | ||
return f'{torch_listen_for_obj_req_response}:{id}' | ||
|
||
torch_listen_for_command = f'{om}:torch_listen_for_command' | ||
def torch_listen_for_command_callback(id): | ||
return f'{torch_listen_for_command}:{id}' | ||
|
||
def torch_listen_for_obj_req_response_callback(id): | ||
return f'openmined:torch_listen_for_obj_req_res:{id}' | ||
torch_listen_for_command_response = f'{om}:torch_listen_for_command_response' | ||
def torch_listen_for_command_response_callback(id): | ||
return f'{torch_listen_for_command_response}:{id}' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,65 +1,158 @@ | ||
import os | ||
import json | ||
import re | ||
|
||
from pathlib import Path | ||
|
||
from . import utils | ||
|
||
|
||
def torch2ipfs(model): | ||
pass | ||
|
||
|
||
def ipfs2torch(model_addr): | ||
import torch | ||
|
||
|
||
# Helpers for HookService and TorchService | ||
def check_workers(self, workers): | ||
if type(workers) is str: | ||
workers = [workers] | ||
elif not hasattr(workers, '__iter__'): | ||
raise TypeError( | ||
"""Can only send {} to a string worker ID or an iterable of | ||
string worker IDs, not {}""".format(self.__name__, type(owners)) | ||
) | ||
return workers | ||
|
||
|
||
def get_tensorvars(self, command): | ||
args = command['args'] | ||
kwargs = command['kwargs'] | ||
arg_types = command['arg_types'] | ||
kwarg_types = command['kwarg_types'] | ||
tensorvar_args = [args[i] for i in range(len(args)) if arg_types[i] in self.tensorvar_types_strs] | ||
tensorvar_kwvals = [kwargs[i][1] for i in range(len(kwargs)) if kwarg_types[i] in self.tensorvar_types_strs] | ||
return tensorvar_args + tensorvar_kwvals | ||
|
||
|
||
def check_remote(tensorvars): | ||
return any([tensorvar.is_pointer for tensorvar in tensorvars]) | ||
|
||
|
||
def get_owners(tensorvars): | ||
owners = list(set([owner | ||
for tensorvar in tensorvars | ||
for owner in tensorvar.owners])) | ||
multiple_owners = len(owners) > 1 | ||
return multiple_owners, owners | ||
|
||
|
||
def replace_tensorvar(x): | ||
if hasattr(torch, 'old_is_tensor'): | ||
check = torch.old_is_tensor | ||
else: | ||
check = torch.is_tensor | ||
try: | ||
if check(x) or isinstance(x, torch.autograd.Variable): | ||
return '_fl.{}'.format(x.id) | ||
else: | ||
[replace_tensorvar(i) for i in x] | ||
except (AttributeError, TypeError): | ||
return x | ||
|
||
|
||
def replace_in_command(command_msg): | ||
command_msg['args'] = map_tuple( | ||
None, command_msg['args'], replace_tensorvar) | ||
command_msg['kwargs'] = map_dict( | ||
None, command_msg['kwargs'], replace_tensorvar) | ||
try: | ||
command_msg['self'] = replace_tensorvar(command_msg['self']) | ||
except KeyError: | ||
pass | ||
return command_msg | ||
|
||
# Client needs to identify a tensor before sending commands that use it | ||
def id_tensorvar(x): | ||
pat = re.compile('_fl.(.*)') | ||
try: | ||
if isinstance(x, str): | ||
return int(pat.search(x).group(1)) | ||
else: | ||
return [id_tensorvar(i) for i in x] | ||
except AttributeError: | ||
return x | ||
|
||
|
||
# Safety checks for serializing and deserializing torch objects | ||
# Desperately needs stress testing before going out in the wild | ||
map_torch_type = { | ||
'torch.FloatTensor':torch.FloatTensor, | ||
'torch.DoubleTensor':torch.DoubleTensor, | ||
'torch.HalfTensor':torch.HalfTensor, | ||
'torch.ByteTensor':torch.ByteTensor, | ||
'torch.CharTensor':torch.CharTensor, | ||
'torch.ShortTensor':torch.ShortTensor, | ||
'torch.IntTensor':torch.IntTensor, | ||
'torch.LongTensor':torch.LongTensor, | ||
'torch.autograd.variable.Variable':torch.autograd.variable.Variable, | ||
'torch.nn.parameter.Parameter':torch.nn.parameter.Parameter | ||
} | ||
|
||
|
||
def types_guard(obj_type): | ||
return map_torch_type[obj_type] | ||
|
||
|
||
def tensor_contents_guard(contents): | ||
# TODO: check to make sure the incoming list isn't dangerous to use for | ||
# constructing a tensor (likely non-trivial) | ||
return contents | ||
|
||
|
||
def command_guard(command, allowed): | ||
if command not in allowed: | ||
raise RuntimeError( | ||
'Command "{}" is not a supported Torch operation.'.format(command)) | ||
return command | ||
|
||
|
||
# Worker needs to retrieve tensor by ID before computing with it | ||
def retrieve_tensor(self, x): | ||
try: | ||
return self.worker.objects[id_tensorvar(x)] | ||
except TypeError: | ||
try: | ||
return [self.worker.objects[i] for i in id_tensorvar(x)] | ||
except TypeError: | ||
return x | ||
except KeyError: | ||
return x | ||
|
||
|
||
def map_tuple(service, args, func): | ||
if service: | ||
return tuple(func(service, x) for x in args) | ||
else: | ||
return tuple(func(x) for x in args) | ||
|
||
|
||
def map_dict(service, kwargs, func): | ||
if service: | ||
return {key:func(service, val) for key, val in kwargs.items()} | ||
else: | ||
return {key:func(val) for key, val in kwargs.items()} | ||
|
||
|
||
def hook_tensor_ser(service_self, tensor_type): | ||
def ser(self, include_data=True): | ||
"""Serializes a {} object to JSON.""".format(tensor_type) | ||
tensor_msg = {} | ||
tensor_msg['torch_type'] = self.type() | ||
if (include_data): | ||
tensor_msg['data'] = self.tolist() | ||
tensor_msg['id'] = self.id | ||
tensor_msg['owners'] = self.owners | ||
return json.dumps(tensor_msg) | ||
|
||
tensor_type.ser = ser | ||
|
||
|
||
def hook_var_ser(service_self): | ||
# TODO | ||
pass | ||
|
||
|
||
def serialize_torch_model(model, **kwargs): | ||
""" | ||
kwargs are the arguments needed to instantiate the model | ||
""" | ||
state = {'state_dict': model.state_dict(), 'kwargs': kwargs} | ||
torch.save(state, 'temp_model.pth.tar') | ||
with open('temp_model.pth.tar', 'rb') as f: | ||
model_bin = f.read() | ||
return model_bin | ||
|
||
|
||
def deserialize_torch_model(model_bin, model_class, **kwargs): | ||
""" | ||
model_class is needed since PyTorch uses pickle for serialization | ||
see https://discuss.pytorch.org/t/loading-pytorch-model-without-a-code/12469/2 for details | ||
kwargs are the arguments needed to instantiate the model from model_class | ||
""" | ||
with open('temp_model2.pth.tar', 'wb') as g: | ||
g.write(model_bin) | ||
state = torch.load() | ||
model = model_class(**state['kwargs']) | ||
model.load_state_dict(state['state_dict']) | ||
return model | ||
|
||
|
||
def save_best_torch_model_for_task(task, model): | ||
utils.ensure_exists(f'{Path.home()}/.openmined/models.json', {}) | ||
with open(f"{Path.home()}/.openmined/models.json", "r") as model_file: | ||
models = json.loads(model_file.read()) | ||
|
||
models[task] = torch2ipfs(model) | ||
|
||
with open(f"{Path.home()}/.openmined/models.json", "w") as model_file: | ||
json.dump(models, model_file) | ||
|
||
|
||
def best_torch_model_for_task(task, return_model=False): | ||
if not os.path.exists(f'{Path.home()}/.openmined/models.json'): | ||
return None | ||
|
||
with open(f'{Path.home()}/.openmined/models.json', 'r') as model_file: | ||
models = json.loads(model_file.read()) | ||
if task in models.keys(): | ||
if return_model: | ||
return ipfs2torch(models[task]) | ||
else: | ||
return models[task] | ||
|
||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.