-
Notifications
You must be signed in to change notification settings - Fork 861
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support dump and load timm and hf_text models on MultiModalPredictor #2682
Conversation
408839d
to
4ea1a19
Compare
4ea1a19
to
1adf0a6
Compare
Job PR-2682-1adf0a6 is done. |
if isinstance(self._model, MultimodalFusionMLP) and isinstance( | ||
self._model.model, torch.nn.modules.container.ModuleList | ||
): | ||
for per_model in self._model.model: | ||
if isinstance(per_model, TimmAutoModelForImagePrediction): | ||
model = per_model | ||
break |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Saving a timm image backbone from a fusion model seems not useful since it can't work individually. We can consider only deal with single timm model for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We considered the scenario for fusion model because there might be a use case that a user wants to save specifically timm_image or hf_text from the fusion model trained to use on other downstream tasks which only use a single model. Do you feel that would be useful?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example, people trained on the mixture of image, text, tabular data, and like to extract the image part from the fusion model. The API “predictor.dump_timm_image” should still work in this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic has the limitation that, If multiple timm image models are available, it would only dump the first one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've updated the logic to support saving multiple models under timm_image and hf_text.
model = None | ||
if isinstance(self._model, MultimodalFusionMLP) and isinstance( | ||
self._model.model, torch.nn.modules.container.ModuleList | ||
): | ||
for per_model in self._model.model: | ||
if isinstance(per_model, HFAutoModelForTextPrediction): | ||
model = per_model | ||
break |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto. We can consider only the single huggingface text model for now.
os.makedirs(path) | ||
model.model.save_pretrained(path) | ||
logger.info(f"Model saved to {path}.") | ||
if TEXT in self._data_processors.keys(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For hf_text model, we can assert self._data_processors have only one text processor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean there is no need to check the prefix of the data processor to get the tokenizer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we consider dumping a hf model from a fusion model, this logic is OK.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need raise error if there is no hf model available?
@@ -2812,6 +2817,79 @@ def load( | |||
|
|||
return predictor | |||
|
|||
def dump_timm_image( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There would be too many APIs if each model has a dump function. How about using only one API dump_model()
? Inside the function, we can check the model type is timm image or hf text.
return filtered_cfg | ||
|
||
|
||
def save_timm_config( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstrings are missing.
1adf0a6
to
ad79e84
Compare
|
||
model = self._model if model is None else model | ||
if isinstance(model, HFAutoModelForTextPrediction) and model.model is not None: | ||
os.makedirs(path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Try to use “os.makedirs(path, exist_ok=True)”.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I originally had this flag set, but after discussing with Zhiqiang offline, he suggested that we might not want to save to an existing (or non-empty) directory to avoid accidentally overwriting the model.
ce8ebb7
to
303d91d
Compare
303d91d
to
b289c4a
Compare
Job PR-2682-303d91d is done. |
Job PR-2682-b289c4a is done. |
@@ -523,3 +524,47 @@ def modify_duplicate_model_names( | |||
|
|||
def list_timm_models(pretrained=True): | |||
return timm.list_models(pretrained=pretrained) | |||
|
|||
|
|||
def _filter_timm_pretrained_cfg(cfg, remove_source=False, remove_null=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since _filter_timm_pretrained_cfg
and save_timm_config
are about config, is it better to put them into utils/config.py
?
timm_image_dir = f"{model_dump_path}/timm_image" | ||
assert os.path.exists(hf_text_dir) and (len(os.listdir(hf_text_dir)) > 2) == True | ||
assert os.path.exists(timm_image_dir) and (len(os.listdir(timm_image_dir)) == 2) == True | ||
print("done") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove print
?
path : str | ||
Path to directory where models and configs should be saved. | ||
""" | ||
models = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
models
are list and dict in dump_timm_image
and dump_hf_text
, respectively. Can we use list for both of them?
Path to directory where models and configs should be saved. | ||
""" | ||
models = [] | ||
if isinstance(self._model, MultimodalFusionMLP) and isinstance( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A fusion model may also be MultimodalFusionTransformer
.
Job PR-2682-f7801b7 is done. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Awesome feature! Consider unifying the dumping functions into one API in follow-up PRs.
Path to directory where models and configs should be saved. | ||
""" | ||
models = [] | ||
if isinstance(self._model, (MultimodalFusionMLP, MultimodalFusionTransformer)) and isinstance( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhiqiangdon We can later add a BaseMultimodalFusionModel
class, and ensure that MultimodalFusionMLP
and MultimodalFusionTransformer
inherit from this class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@suzhoum May be we can add a TODO item here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. We can add a base class.
@@ -679,3 +681,104 @@ def test_image_bytearray(): | |||
npt.assert_array_equal( | |||
[prediction_prob_1, prediction_prob_2, prediction_prob_3, prediction_prob_4], [prediction_prob_1] * 4 | |||
) | |||
|
|||
|
|||
def test_dump_timm_image(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we separate these tests to another file? The reason is that test_predictor.py
is growing to be too huge and too slow.
We can consider to add it under the following name
unittests/predictor/test_predictor_dump_third_party.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. Done!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two minor comments!! Overall LGTM!!!! Thanks!!!
Job PR-2682-771fb04 is done. |
Issue #, if available:
Description of changes:
timm_image
andhf_text
modelstimm_image
Usage:
To dump models from fine-tuned MultiModalPredictor:
To load models from saved model path:
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.