Skip to content

Commit

Permalink
Add resource management to scheduler for multi-GPU
Browse files Browse the repository at this point in the history
See #104

Now you can select a specific GPU. The code is there for selecting n
available GPUs, but I need to build multi-GPU caffe to test it.
  • Loading branch information
lukeyeager committed May 18, 2015
1 parent 2e32ebe commit 0089c40
Show file tree
Hide file tree
Showing 10 changed files with 338 additions and 113 deletions.
17 changes: 13 additions & 4 deletions digits/config.py
Expand Up @@ -353,6 +353,16 @@ def validate(cls, value):
raise ValueError('caffe binary not found at "%s"' % value)
cls.validate_version(expected_path)

pythonpath = os.path.join(value, 'python')
sys.path.insert(0, pythonpath)
try:
imp.find_module('caffe')
except ImportError as e:
raise ValueError('Error while importing caffe from "%s": %s' % (pythonpath, e.message))
finally:
# Don't actually add this until apply() is called
sys.path.pop(0)

return value

# Used to validate the version
Expand Down Expand Up @@ -406,9 +416,9 @@ def apply(self):
# Add caffe/python to PATH
sys.path.insert(0, os.path.join(self.value, 'python'))
try:
imp.find_module('caffe')
except ImportError:
print 'ERROR: python module not found at "%s"' % sys.path.pop(0)
import caffe
except ImportError as e:
print 'Did you forget to "make pycaffe"?'
raise

class GpuListOption(ConfigOption):
Expand Down Expand Up @@ -977,7 +987,6 @@ def load_option(option, mode, newConfig,
systemConfig -- the current SystemConfigFile
"""
if 'DIGITS_MODE_TEST' in os.environ and option.has_test_value():
print 'Setting %s to test value ...' % option.name()
option.value = option.test_value()
return option.value

Expand Down
12 changes: 11 additions & 1 deletion digits/dataset/tasks/create_db.py
Expand Up @@ -109,7 +109,17 @@ def html_id(self):
return super(CreateDbTask, self).html_id()

@override
def task_arguments(self, **kwargs):
def offer_resources(self, resources):
key = 'create_db_task_pool'
if key not in resources:
return None
for resource in resources[key]:
if resource.remaining() >= 1:
return {key: [(resource.identifier, 1)]}
return None

@override
def task_arguments(self, resources):
args = [sys.executable, os.path.join(os.path.dirname(os.path.dirname(digits.__file__)), 'tools', 'create_db.py'),
self.path(self.input_file),
self.path(self.db_name),
Expand Down
12 changes: 11 additions & 1 deletion digits/dataset/tasks/parse_folder.py
Expand Up @@ -105,7 +105,17 @@ def html_id(self):
return 'task-parse-folder-%s' % ('-'.join(sets))

@override
def task_arguments(self, **kwargs):
def offer_resources(self, resources):
key = 'parse_folder_task_pool'
if key not in resources:
return None
for resource in resources[key]:
if resource.remaining() >= 1:
return {key: [(resource.identifier, 1)]}
return None

@override
def task_arguments(self, resources):
args = [sys.executable, os.path.join(os.path.dirname(os.path.dirname(digits.__file__)), 'tools', 'parse_folder.py'),
self.folder,
self.path(utils.constants.LABELS_FILE),
Expand Down
36 changes: 34 additions & 2 deletions digits/model/forms.py
Expand Up @@ -8,7 +8,8 @@
from wtforms import validators
from caffe.proto import caffe_pb2

from digits import utils
from digits import utils, config
from digits.device_query import get_devices

class ModelForm(Form):

Expand Down Expand Up @@ -195,11 +196,42 @@ def validate_custom_network_snapshot(form, field):
if not os.path.exists(snapshot):
raise validators.ValidationError('File does not exist')

# Select one of several GPUs
select_gpu = wtforms.RadioField('Select which GPU you would like to use',
choices = [('next', 'Next available')] + [(
index,
'#%s - %s' % (index, get_devices()[int(index)].name),
) for index in config.config_option('gpu_list').split(',') if index],
default = 'next',
)

# Select N of several GPUs
select_gpus = wtforms.SelectMultipleField('Select which GPU[s] you would like to use',
choices = [(
index,
'#%s - %s' % (index, get_devices()[int(index)].name),
) for index in config.config_option('gpu_list').split(',') if index]
)

# Use next available N GPUs
select_gpu_count = wtforms.IntegerField('Use this many GPUs (next available)',
validators = [
validators.NumberRange(min=1, max=len(config.config_option('gpu_list').split(',')))
],
default = 1,
)

def validate_select_gpu_count(form, field):
if field.data is None:
if form.select_gpus.data:
# Make this field optional
field.errors[:] = []
raise validators.StopValidation()

model_name = wtforms.StringField('Model Name',
validators = [
validators.DataRequired()
]
)



32 changes: 29 additions & 3 deletions digits/model/images/classification/views.py
Expand Up @@ -22,7 +22,8 @@
from job import ImageClassificationModelJob
from digits.status import Status

NAMESPACE = '/models/images/classification'
NAMESPACE = '/models/images/classification'
MULTI_GPU = False

@app.route(NAMESPACE + '/new', methods=['GET'])
def image_classification_model_new():
Expand All @@ -34,7 +35,11 @@ def image_classification_model_new():

prev_network_snapshots = get_previous_network_snapshots()

return render_template('models/images/classification/new.html', form=form, previous_network_snapshots=prev_network_snapshots, has_datasets=(len(get_datasets())==0))
return render_template('models/images/classification/new.html',
form = form,
previous_network_snapshots = prev_network_snapshots,
multi_gpu = MULTI_GPU,
)

@app.route(NAMESPACE, methods=['POST'])
def image_classification_model_create():
Expand All @@ -47,7 +52,11 @@ def image_classification_model_create():
prev_network_snapshots = get_previous_network_snapshots()

if not form.validate_on_submit():
return render_template('models/images/classification/new.html', form=form, previous_network_snapshots=prev_network_snapshots), 400
return render_template('models/images/classification/new.html',
form = form,
previous_network_snapshots=prev_network_snapshots,
multi_gpu = MULTI_GPU,
), 400

datasetJob = scheduler.get_job(form.dataset.data)
if not datasetJob:
Expand Down Expand Up @@ -131,6 +140,21 @@ def image_classification_model_create():
else:
return 'Invalid policy', 404

if MULTI_GPU:
if form.select_gpu_count.data:
gpu_count = form.select_gpu_count.data
selected_gpus = None
else:
selected_gpus = [str(gpu) for gpu in form.select_gpus.data]
gpu_count = None
else:
if form.select_gpu.data == 'next':
gpu_count = 1
selected_gpus = None
else:
selected_gpus = [str(form.select_gpu.data)]
gpu_count = None

job.tasks.append(
tasks.CaffeTrainTask(
job_dir = job.dir(),
Expand All @@ -139,6 +163,8 @@ def image_classification_model_create():
snapshot_interval = form.snapshot_interval.data,
learning_rate = form.learning_rate.data,
lr_policy = policy,
gpu_count = gpu_count,
selected_gpus = selected_gpus,
batch_size = form.batch_size.data,
val_interval = form.val_interval.data,
pretrained_model= pretrained_model,
Expand Down
14 changes: 9 additions & 5 deletions digits/model/tasks/caffe_train.py
Expand Up @@ -368,9 +368,7 @@ def iteration_to_epoch(self, it):
return float(it * self.train_epochs) / self.solver.max_iter

@override
def task_arguments(self, **kwargs):
gpu_id = kwargs.pop('gpu_id', None)

def task_arguments(self, resources):
if config_option('caffe_root') == '<PATHS>':
caffe_bin = 'caffe'
else:
Expand All @@ -381,8 +379,14 @@ def task_arguments(self, **kwargs):
'--solver=%s' % self.path(self.solver_file),
]

if gpu_id:
args.append('--gpu=%d' % gpu_id)
if 'gpus' in resources:
identifiers = []
for identifier, value in resources['gpus']:
identifiers.append(identifier)
if len(identifiers) == 1:
args.append('--gpu=%s' % identifiers[0])
elif len(identifiers) > 1:
args.append('--gpus=%s' % ','.join(identifiers))
if self.pretrained_model:
args.append('--weights=%s' % self.path(self.pretrained_model))

Expand Down
39 changes: 39 additions & 0 deletions digits/model/tasks/train.py
Expand Up @@ -6,6 +6,7 @@

from digits import utils
from digits.task import Task
from digits.utils import override

# NOTE: Increment this everytime the picked object changes
PICKLE_VERSION = 2
Expand All @@ -28,13 +29,17 @@ def __init__(self, dataset, train_epochs, snapshot_interval, learning_rate, lr_p
lr_policy -- a hash of options to be used for the learning rate policy
Keyword arguments:
gpu_count -- how many GPUs to use for training (integer)
selected_gpus -- a list of GPU indexes to be used for training
batch_size -- if set, override any network specific batch_size with this value
val_interval -- how many epochs between validating the model with an epoch of validation data
pretrained_model -- filename for a model to use for fine-tuning
crop_size -- crop each image down to a square of this size
use_mean -- subtract the dataset's mean file
random_seed -- optional random seed
"""
self.gpu_count = kwargs.pop('gpu_count', None)
self.selected_gpus = kwargs.pop('selected_gpus', None)
self.batch_size = kwargs.pop('batch_size', None)
self.val_interval = kwargs.pop('val_interval', None)
self.pretrained_model = kwargs.pop('pretrained_model', None)
Expand Down Expand Up @@ -94,6 +99,40 @@ def __setstate__(self, state):
self.snapshots = []
self.dataset = None

@override
def offer_resources(self, resources):
if 'gpus' not in resources:
return None
if not resources['gpus']:
return {} # don't use a GPU at all
if self.gpu_count is not None:
identifiers = []
for resource in resources['gpus']:
if resource.remaining() >= 1:
identifiers.append(resource.identifier)
if len(identifiers) == self.gpu_count:
break
if len(identifiers) == self.gpu_count:
return {'gpus': [(i, 1) for i in identifiers]}
else:
return None
elif self.selected_gpus is not None:
found_all = True
for i in self.selected_gpus:
found = False
for gpu in resources['gpus']:
if i == gpu.identifier:
found = True
break
if not found:
found_all = False
break
if found_all:
return {'gpus': [(i, 1) for i in self.selected_gpus]}
else:
return None
return None

def send_progress_update(self, epoch):
"""
Sends socketio message about the current progress
Expand Down

0 comments on commit 0089c40

Please sign in to comment.