Skip to content

Commit

Permalink
[partition] keep 'option' in properties (deepjavalibrary#819)
Browse files Browse the repository at this point in the history
* [partition] keep options in properties

* remove option for model_dir and intermediate
  • Loading branch information
sindhuvahinis authored and KexinFeng committed Aug 16, 2023
1 parent 827688c commit a622ffb
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 34 deletions.
30 changes: 17 additions & 13 deletions serving/docker/partition/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, props_manager):
self.download_model_from_s3()

def download_model_from_s3(self):
model_id = self.properties.get("model_id")
model_id = self.properties.get("option.model_id")
if not model_id or not model_id.startswith("s3://"):
return

Expand Down Expand Up @@ -71,7 +71,7 @@ def download_model_from_s3(self):
if not glob.glob(os.path.join(download_dir, '*')):
raise Exception('Model download from s3url failed')

self.properties['model_id'] = download_dir
self.properties['option.model_id'] = download_dir

def install_requirements_file(self):
req_file_dir = self.properties_manager.properties_dir
Expand Down Expand Up @@ -106,13 +106,13 @@ def set_environmental_vars(self):

def download_config_from_hf(self):
# checks if model_id is a path
if glob.glob(self.properties['model_id']):
return self.properties['model_id']
if glob.glob(self.properties['option.model_id']):
return self.properties['option.model_id']

download_dir = os.environ.get("SERVING_DOWNLOAD_DIR",
'/tmp/download/model/')

model_name = self.properties['model_id']
model_name = self.properties['option.model_id']
downloaded_dir = snapshot_download(
repo_id=model_name,
cache_dir=download_dir,
Expand All @@ -122,22 +122,24 @@ def download_config_from_hf(self):

def copy_config_files(self):
model_dir = self.properties['model_dir']
if 'model_id' in self.properties:
if 'option.model_id' in self.properties:
model_dir = self.download_config_from_hf()

config_files = []
for pattern in CONFIG_FILES_PATTERNS:
config_files += glob.glob(os.path.join(model_dir, pattern))

for file in config_files:
shutil.copy(file, dst=self.properties['save_mp_checkpoint_path'])
shutil.copy(file,
dst=self.properties['option.save_mp_checkpoint_path'])

def upload_checkpoints_to_s3(self):
if 'upload_checkpoints_s3url' not in self.properties:
return

s3url = self.properties['upload_checkpoints_s3url']
saved_checkpoints_dir = self.properties["save_mp_checkpoint_path"]
saved_checkpoints_dir = self.properties[
"option.save_mp_checkpoint_path"]

if not saved_checkpoints_dir.endswith('/'):
saved_checkpoints_dir = saved_checkpoints_dir + '/'
Expand All @@ -154,7 +156,7 @@ def upload_checkpoints_to_s3(self):
commands = ["aws", "s3", "sync", saved_checkpoints_dir, s3url]

subprocess.run(commands)
shutil.rmtree(self.properties["save_mp_checkpoint_path"])
shutil.rmtree(self.properties["option.save_mp_checkpoint_path"])

def cleanup(self):
"""
Expand Down Expand Up @@ -184,16 +186,18 @@ def run_partition(self):

def load_the_generated_checkpoints(self):
if self.properties['engine'] == 'DeepSpeed':
saved_checkpoints_dir = self.properties["save_mp_checkpoint_path"]
saved_checkpoints_dir = self.properties[
"option.save_mp_checkpoint_path"]
properties = utils.load_properties(saved_checkpoints_dir)
properties['model_dir'] = saved_checkpoints_dir
properties['entryPoint'] = self.properties['entryPoint']
properties['option.entryPoint'] = self.properties[
'option.entryPoint']
properties['partition_handler'] = 'handle'

entry_point_file = None
if properties['entryPoint'] == 'model.py':
if properties['option.entryPoint'] == 'model.py':
entry_point_file = os.path.join(
self.properties['properties_dir'], 'model.py')
self.properties_manager.properties_dir, 'model.py')
shutil.copy(entry_point_file, saved_checkpoints_dir)

commands = get_partition_cmd(True, properties)
Expand Down
36 changes: 16 additions & 20 deletions serving/docker/partition/properties_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,25 @@
from utils import is_engine_mpi_mode, get_engine_configs, get_download_dir, load_properties

EXCLUDE_PROPERTIES = [
'option.model_id', 'save_mp_checkpoint_path', 'model_dir',
'option.model_id', 'option.save_mp_checkpoint_path', 'model_dir',
'upload_checkpoints_s3url', 'properties_dir'
]

PARTITION_SUPPORTED_ENGINES = ['DeepSpeed', 'FasterTransformer']

CHUNK_SIZE = 4096 # 4MB chunk size


class PropertiesManager(object):

def __init__(self, args):
self.entry_point_url = None
self.properties = {}
self.properties_dir = args.properties_dir
load_properties(self.properties_dir)
self.properties = load_properties(self.properties_dir)

if args.model_id:
self.properties['option.model_id'] = args.model_id
if args.engine:
self.properties['engine'] = args.engine
if args.save_mp_checkpoint_path:
self.properties[
'save_mp_checkpoint_path'] = args.save_mp_checkpoint_path
'option.save_mp_checkpoint_path'] = args.save_mp_checkpoint_path
if args.tensor_parallel_degree:
self.properties[
'option.tensor_parallel_degree'] = args.tensor_parallel_degree
Expand Down Expand Up @@ -86,7 +81,7 @@ def validate_and_correct_checkpoints_json(self):
"""
if self.properties.get('engine') == 'DeepSpeed':
config_file = os.path.join(
self.properties['save_mp_checkpoint_path'],
self.properties['option.save_mp_checkpoint_path'],
'ds_inference_config.json')
if not os.path.exists(config_file):
raise ValueError("Checkpoints json file was not generated."
Expand All @@ -103,17 +98,17 @@ def validate_and_correct_checkpoints_json(self):
json.dump(configs, f)

def generate_properties_file(self):
checkpoint_path = self.properties.get('save_mp_checkpoint_path')
checkpoint_path = self.properties.get('option.save_mp_checkpoint_path')
configs = get_engine_configs(self.properties)

for key, value in self.properties.items():
if key not in EXCLUDE_PROPERTIES:
if key == "entryPoint":
entry_point = self.properties.get("entryPoint")
if key == "option.entryPoint":
entry_point = self.properties.get("option.entryPoint")
if entry_point == "model.py":
continue
elif self.entry_point_url:
configs["entryPoint"] = self.entry_point_url
configs["option.entryPoint"] = self.entry_point_url
else:
configs[key] = value

Expand All @@ -135,14 +130,14 @@ def validate_tp_degree(self):
)

def set_and_validate_entry_point(self):
entry_point = self.properties.get('entryPoint')
entry_point = self.properties.get('option.entryPoint')
if entry_point is None:
entry_point = os.environ.get("DJL_ENTRY_POINT")
if entry_point is None:
entry_point_file = glob.glob(
os.path.join(self.properties_dir, 'model.py'))
if entry_point_file:
self.properties['entryPoint'] = 'model.py'
self.properties['option.entryPoint'] = 'model.py'
else:
engine = self.properties.get('engine')
if engine == "DeepSpeed":
Expand All @@ -151,7 +146,7 @@ def set_and_validate_entry_point(self):
entry_point = "djl_python.fastertransformer"
else:
raise ValueError("Please specify engine")
self.properties['entryPoint'] = entry_point
self.properties['option.entryPoint'] = entry_point
elif entry_point.lower().startswith('http'):
logging.info(f'Downloading entrypoint file.')
self.entry_point_url = entry_point
Expand All @@ -161,16 +156,17 @@ def set_and_validate_entry_point(self):
with requests.get(entry_point) as r:
with open(model_file, 'wb') as f:
f.write(r.content)
self.properties['entryPoint'] = model_file
self.properties['option.entryPoint'] = model_file
logging.info(f'Entrypoint file downloaded successfully')

def set_and_validate_save_mp_checkpoint_path(self):
save_mp_checkpoint_path = self.properties.get(
"save_mp_checkpoint_path")
"option.save_mp_checkpoint_path")
if not save_mp_checkpoint_path:
raise ValueError("Please specify save_mp_checkpoint_path")
if save_mp_checkpoint_path.startswith("s3://"):
self.properties[
"upload_checkpoints_s3url"] = save_mp_checkpoint_path
self.properties["save_mp_checkpoint_path"] = get_download_dir(
self.properties_dir, "partition-model")
self.properties[
"option.save_mp_checkpoint_path"] = get_download_dir(
self.properties_dir, "partition-model")
2 changes: 1 addition & 1 deletion serving/docker/partition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,5 +90,5 @@ def load_properties(properties_dir):
if line.startswith("#") or not line.strip():
continue
key, value = line.strip().split('=', 1)
properties[key.split(".", 1)[-1]] = value
properties[key] = value
return properties

0 comments on commit a622ffb

Please sign in to comment.