Skip to content

Commit

Permalink
Wz/fix/handle gpu setting (#29)
Browse files Browse the repository at this point in the history
* fix: updated model_args handling of gpu setting variable

* added related tests

* updated setup.py

* refined GPU status message

* Update huggingface.py
  • Loading branch information
wanoz committed May 25, 2023
1 parent e0c7f27 commit 22f7423
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 6 deletions.
15 changes: 10 additions & 5 deletions panml/core/llm/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,20 @@ def __init__(self, model: str, input_block_size: int, padding_length: int, token
self.input_block_size = input_block_size
self.tokenizer_batch = tokenizer_batch
self.device = 'cpu'
if 'gpu' in model_args: # set model processing on GPU else defaults on CPU
if not isinstance(model_args['gpu'], bool):
raise TypeError('Input model args, gpu needs to be of type: boolean')
set_gpu = model_args.pop('gpu')
if torch.cuda.is_available() and set_gpu:
self.device = 'cuda'
else:
print('CUDA is not available')
print(f'Model processing is set on {self.device.upper()}')
self.train_default_args = ['title', 'num_train_epochs', 'optimizer', 'mlm',
'per_device_train_batch_size', 'per_device_eval_batch_size',
'warmup_steps', 'weight_decay', 'logging_steps',
'output_dir', 'logging_dir', 'save_model']

if source == 'huggingface':
if 'flan' in self.model_name:
self.model_hf = AutoModelForSeq2SeqLM.from_pretrained(self.model_name, **model_args)
Expand All @@ -42,11 +52,6 @@ def __init__(self, model: str, input_block_size: int, padding_length: int, token
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, mirror='https://huggingface.co')

# Set model on GPU if available and specified
if 'gpu' in model_args:
self.device = 'cuda' if torch.cuda.is_available() and model_args['gpu'] else 'cpu'
print('Model processing is set on GPU')
else:
print('Model processing is set on CPU')
self.model_hf.to(torch.device(self.device))

# Embed text
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
setup(
name = 'panml', # package name
packages = find_packages(exclude=['test']), # package name
version = '0.0.16', # version
version = '0.0.17', # version
license = 'MIT', # license
description = 'PanML is a high level generative AI/ML development library designed for ease of use and fast experimentation.', # short description about the package
long_description = 'PanML aims to make analysis and experimentation of generative AI/ML models more accessible, by providing a simple and consistent interface to foundation models for Data Scientists, Machine Learning Engineers and Software Developers. \
Expand Down
13 changes: 13 additions & 0 deletions test/test_ModelPack.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,17 @@ def test_modelpack_correct_model_source_flan_input(self):
# test valid model and source match combo 2
m = ModelPack(model='google/flan-t5-small', source='huggingface')

# Test case: handle model GPU invalid type
print('Setup test_modelpack_incorrect_model_gpu_input')
def test_modelpack_incorrect_model_gpu_input(self):
# test invalid GPU setting input
with self.assertRaises(TypeError):
m = ModelPack(model='google/flan-t5-small', source='huggingface', model_args={'gpu': 1})

# Test case: handle model GPU correct input
print('Setup test_modelpack_correct_model_gpu_input')
def test_modelpack_correct_model_gpu_input(self):
# test valid GPU setting input
m = ModelPack(model='google/flan-t5-small', source='huggingface', model_args={'gpu': True})


0 comments on commit 22f7423

Please sign in to comment.