-
Notifications
You must be signed in to change notification settings - Fork 861
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoMM] Add support for loading pre-trained weights in ft_transformer #3859
[AutoMM] Add support for loading pre-trained weights in ft_transformer #3859
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Job PR-3859-0b4ccee is done. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- How do users have a pretrained ft transformer that aligns with our implementation? We haven't provided the pretraining functionality.
- Consider adding a test.
# init transformer backbone from provided checkpoint | ||
if checkpoint_name: | ||
ckpt = torch.load(checkpoint_name) | ||
self.transformer.load_state_dict(ckpt["state_dict"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This only loads weights for self.transformer
. What if the saved weights also have self.categorical_adapter
, self.numerical_adapter
, and self.head
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, this is mostly for the XTab pretraining. And the use case is that users load XTab pre-trained ft-transformer weights and do finetuning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we benchmark the performance of pretrained weights and compare to random initialized ft transformer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added unit test. The XTab repo has the benchmarking results.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The readme in the XTab repo only shows the result on one toy dataset. Benchmark results on more tabular datasets are unclear. It's also unclear whether we can reproduce the results in the paper due to some details like light finetuning vs heavy finetuning.
Job PR-3859-9667b72 is done. |
@@ -677,6 +678,31 @@ def test_load_ckpt(): | |||
npt.assert_equal(predictions_prob, predictions2_prob) | |||
|
|||
|
|||
def test_fttransformer_load_ckpt(): | |||
download("https://automl-mm-bench.s3.amazonaws.com/ft_transformer_pretrained_ckpt/iter_2k.ckpt", "./") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no documentation for how to download a pretrained ft transformer checkpoint. If the benchmark results can show the pretrained is better than the random initialized, we can use the pretrained as the default. That is, we download checkpoint and initialize ft transformer internally. It would make it easier for users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you'd like to run a benchmark of pre-trained vs random, we can run it on AutoML Benchmark after the code is in a state where I can specify the pre-trained weights to load via the model hyperparameters through TabularPredictor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The codes should be able to specify the pre-trained weights in model hyperparameters by "model.ft_transformer.checkpoint_name": "path_to_checkpoint.ckpt"
. I am currently benchmarking on multimodal datasets. It would be great if you could run it on AutoML Benchmark.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we remove this line since we already support downloading it internally?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would really like being able to pass a URL path for the weights file as an ease of use improvement. Otherwise, looks good!
) | ||
hyperparameters = { | ||
"model.names": ["ft_transformer"], | ||
"model.ft_transformer.checkpoint_name": "./iter_2k.ckpt", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be a nice ease of use improvement if the user could specify a URL here:
instead of
download(...)
"model.ft_transformer.checkpoint_name": "./iter_2k.ckpt",
Just do:
"model.ft_transformer.checkpoint_name": "https://automl-mm-bench.s3.amazonaws.com/ft_transformer_pretrained_ckpt/iter_2k.ckpt"
Internally we can call download and save it to some predefined directory for model weights.
For example, if this was the case then I wouldn't need to edit my benchmarking code to include the download
call, which would simplify benchmarking a lot.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently I get the following error if I try to pass the URL:
FileNotFoundError: [Errno 2] No such file or directory: 'https://automl-mm-bench.s3.amazonaws.com/ft_transformer_pretrained_ckpt/iter_2k.ckpt'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This also would allow me to add the weights as part of the default hyperparameter configs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Upon further testing, I notice that the pretrained weights are re-downloaded when loading the fit model from disk, leading to substantial inference slowdown. Beyond re-downloading, I believe the weights are also loaded from disk even if they are downloaded, which shouldn't be happening after fit.
# init transformer backbone from provided checkpoint | ||
if checkpoint_name: | ||
if "https://" in checkpoint_name or is_s3_url(checkpoint_name): | ||
with tempfile.TemporaryDirectory() as tmpdirname: | ||
checkpoint_path = os.path.join(tmpdirname, "./ft_transformer_pretrained.ckpt") | ||
download(checkpoint_name, checkpoint_path) | ||
ckpt = torch.load(checkpoint_path) | ||
else: | ||
ckpt = torch.load(checkpoint_name) | ||
self.transformer.load_state_dict(ckpt["state_dict"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is called every time the model is loaded from disk, even after finetuning. We should ensure we are not using the pretrained weights in any way after we do training, as the pretrained weights are no longer necessary.
Minimal reproducible example:
from autogluon.tabular import TabularPredictor, TabularDataset
if __name__ == '__main__':
label = 'class'
train_data = TabularDataset('https://autogluon.s3.amazonaws.com/datasets/Inc/train.csv')
test_data = TabularDataset('https://autogluon.s3.amazonaws.com/datasets/Inc/test.csv')
subsample_size = 500 # subsample subset of data for faster demo, try setting this to much larger values
if subsample_size is not None and subsample_size < len(train_data):
train_data = train_data.sample(n=subsample_size, random_state=0)
hyperparameters = {
'FT_TRANSFORMER': [
{"model.ft_transformer.checkpoint_name": "https://automl-mm-bench.s3.amazonaws.com/ft_transformer_pretrained_ckpt/iter_2k.ckpt"},
],
}
predictor = TabularPredictor(
label=label,
eval_metric='roc_auc',
)
predictor: TabularPredictor = predictor.fit(train_data, hyperparameters=hyperparameters,)
# predictor.leaderboard(data=test_data, display=True)
print("###### Prior to predict ######")
predictor.predict(test_data)
print("###### After Predict (1) ######")
predictor.predict(test_data)
print("###### After Predict (2) ######")
predictor.predict(test_data)
print("###### After Predict (3) ######")
Output:
TabularPredictor saved. To load, use: predictor = TabularPredictor.load("AutogluonModels/ag-20240122_233921")
###### Prior to predict ######
Downloading /tmp/tmplmmfeb91/./ft_transformer_pretrained.ckpt from https://automl-mm-bench.s3.amazonaws.com/ft_transformer_pretrained_ckpt/iter_2k.ckpt...
100%|██████████| 3.13M/3.13M [00:00<00:00, 5.01MiB/s]
Load pretrained checkpoint: /home/ubuntu/workspace/code/scratch/AutogluonModels/ag-20240122_233921/models/FTTransformer/automm_model/model.ckpt
Predicting DataLoader 0: 100%|██████████| 20/20 [00:00<00:00, 47.68it/s]
###### After Predict (1) ######
Downloading /tmp/tmpf5zqen3i/./ft_transformer_pretrained.ckpt from https://automl-mm-bench.s3.amazonaws.com/ft_transformer_pretrained_ckpt/iter_2k.ckpt...
100%|██████████| 3.13M/3.13M [00:00<00:00, 6.26MiB/s]
Load pretrained checkpoint: /home/ubuntu/workspace/code/scratch/AutogluonModels/ag-20240122_233921/models/FTTransformer/automm_model/model.ckpt
Predicting DataLoader 0: 100%|██████████| 20/20 [00:00<00:00, 30.74it/s]
###### After Predict (2) ######
Downloading /tmp/tmp2479loia/./ft_transformer_pretrained.ckpt from https://automl-mm-bench.s3.amazonaws.com/ft_transformer_pretrained_ckpt/iter_2k.ckpt...
100%|██████████| 3.13M/3.13M [00:00<00:00, 4.46MiB/s]
Load pretrained checkpoint: /home/ubuntu/workspace/code/scratch/AutogluonModels/ag-20240122_233921/models/FTTransformer/automm_model/model.ckpt
Predicting DataLoader 0: 100%|██████████| 20/20 [00:00<00:00, 46.99it/s]
###### After Predict (3) ######
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix to the model load!
from autogluon.common.loaders._utils import download | ||
from autogluon.common.utils.s3_utils import is_s3_url |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if self.numerical_feature_tokenizer: | ||
self.numerical_adapter.apply(init_weights) | ||
if self.categorical_feature_tokenizer: | ||
self.categorical_adapter.apply(init_weights) | ||
self.head.apply(init_weights) | ||
# init transformer backbone from provided checkpoint | ||
if pretrained and checkpoint_name: | ||
if "https://" in checkpoint_name or is_s3_url(checkpoint_name): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be simplified if using https://github.com/autogluon/autogluon/blob/master/multimodal/src/autogluon/multimodal/utils/download.py#L31
@@ -677,6 +678,31 @@ def test_load_ckpt(): | |||
npt.assert_equal(predictions_prob, predictions2_prob) | |||
|
|||
|
|||
def test_fttransformer_load_ckpt(): | |||
download("https://automl-mm-bench.s3.amazonaws.com/ft_transformer_pretrained_ckpt/iter_2k.ckpt", "./") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we remove this line since we already support downloading it internally?
"source": [ | ||
"### model.ft_transformer.checkpoint_name\n", | ||
"\n", | ||
"Provide a pre-trained weights to initialize ft_transformer backbone." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Providing pre-trained weights is not accurate. Consider using local checkpoint path or a url?
"# by default, AutoMM doesn't use pre-trained weights\n", | ||
"predictor.fit(hyperparameters={\"model.ft_transformer.checkpoint_name\": None})\n", | ||
"# initialize the ft_transformer backbone from the give checkpoint\n", | ||
"predictor.fit(hyperparameters={\"model.ft_transformer.checkpoint_name\": 'my_checkpoint.ckpt'})\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider providing an example with our s3 url.
1a20279
to
6de89e4
Compare
Job PR-3859-6de89e4 is done. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Job PR-3859-d117ad4 is done. |
autogluon#3859) Co-authored-by: Zhiqiang Tang <zhiqiang.tang@rutgers.edu>
Issue #, if available:
Resolves #3847
Description of changes:
checkpoint_name
argument in ft_transformer so that users could load pre-trained ft_transformer weights.layer_id
whenmodel_prefix
isNone
.By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.