Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoMM] Add support for loading pre-trained weights in ft_transformer #3859

Merged
merged 10 commits into from
Jan 26, 2024

Conversation

taoyang1122
Copy link
Contributor

@taoyang1122 taoyang1122 commented Jan 13, 2024

Issue #, if available:
Resolves #3847

Description of changes:

  • Add checkpoint_name argument in ft_transformer so that users could load pre-trained ft_transformer weights.
  • Fixed a bug in setting layer_id when model_prefix is None.

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@taoyang1122 taoyang1122 added model list checked You have updated the model list after modifying multimodal unit tests/docs run-multi-gpu Run multimodal multi-gpu tests labels Jan 13, 2024
Copy link

Job PR-3859-0b4ccee is done.
Docs are uploaded to http://autogluon-staging.s3-website-us-west-2.amazonaws.com/PR-3859/0b4ccee/index.html

Copy link
Contributor

@zhiqiangdon zhiqiangdon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • How do users have a pretrained ft transformer that aligns with our implementation? We haven't provided the pretraining functionality.
  • Consider adding a test.

# init transformer backbone from provided checkpoint
if checkpoint_name:
ckpt = torch.load(checkpoint_name)
self.transformer.load_state_dict(ckpt["state_dict"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only loads weights for self.transformer. What if the saved weights also have self.categorical_adapter, self.numerical_adapter, and self.head?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, this is mostly for the XTab pretraining. And the use case is that users load XTab pre-trained ft-transformer weights and do finetuning.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add tests?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we benchmark the performance of pretrained weights and compare to random initialized ft transformer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added unit test. The XTab repo has the benchmarking results.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The readme in the XTab repo only shows the result on one toy dataset. Benchmark results on more tabular datasets are unclear. It's also unclear whether we can reproduce the results in the paper due to some details like light finetuning vs heavy finetuning.

Copy link

Job PR-3859-9667b72 is done.
Docs are uploaded to http://autogluon-staging.s3-website-us-west-2.amazonaws.com/PR-3859/9667b72/index.html

@@ -677,6 +678,31 @@ def test_load_ckpt():
npt.assert_equal(predictions_prob, predictions2_prob)


def test_fttransformer_load_ckpt():
download("https://automl-mm-bench.s3.amazonaws.com/ft_transformer_pretrained_ckpt/iter_2k.ckpt", "./")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no documentation for how to download a pretrained ft transformer checkpoint. If the benchmark results can show the pretrained is better than the random initialized, we can use the pretrained as the default. That is, we download checkpoint and initialize ft transformer internally. It would make it easier for users.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you'd like to run a benchmark of pre-trained vs random, we can run it on AutoML Benchmark after the code is in a state where I can specify the pre-trained weights to load via the model hyperparameters through TabularPredictor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The codes should be able to specify the pre-trained weights in model hyperparameters by "model.ft_transformer.checkpoint_name": "path_to_checkpoint.ckpt". I am currently benchmarking on multimodal datasets. It would be great if you could run it on AutoML Benchmark.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove this line since we already support downloading it internally?

Copy link
Contributor

@Innixma Innixma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would really like being able to pass a URL path for the weights file as an ease of use improvement. Otherwise, looks good!

)
hyperparameters = {
"model.names": ["ft_transformer"],
"model.ft_transformer.checkpoint_name": "./iter_2k.ckpt",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be a nice ease of use improvement if the user could specify a URL here:

instead of

download(...)
"model.ft_transformer.checkpoint_name": "./iter_2k.ckpt",

Just do:

"model.ft_transformer.checkpoint_name": "https://automl-mm-bench.s3.amazonaws.com/ft_transformer_pretrained_ckpt/iter_2k.ckpt"

Internally we can call download and save it to some predefined directory for model weights.

For example, if this was the case then I wouldn't need to edit my benchmarking code to include the download call, which would simplify benchmarking a lot.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently I get the following error if I try to pass the URL:

FileNotFoundError: [Errno 2] No such file or directory: 'https://automl-mm-bench.s3.amazonaws.com/ft_transformer_pretrained_ckpt/iter_2k.ckpt'

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also would allow me to add the weights as part of the default hyperparameter configs.

Copy link
Contributor

@Innixma Innixma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link
Contributor

@Innixma Innixma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Upon further testing, I notice that the pretrained weights are re-downloaded when loading the fit model from disk, leading to substantial inference slowdown. Beyond re-downloading, I believe the weights are also loaded from disk even if they are downloaded, which shouldn't be happening after fit.

Comment on lines 609 to 618
# init transformer backbone from provided checkpoint
if checkpoint_name:
if "https://" in checkpoint_name or is_s3_url(checkpoint_name):
with tempfile.TemporaryDirectory() as tmpdirname:
checkpoint_path = os.path.join(tmpdirname, "./ft_transformer_pretrained.ckpt")
download(checkpoint_name, checkpoint_path)
ckpt = torch.load(checkpoint_path)
else:
ckpt = torch.load(checkpoint_name)
self.transformer.load_state_dict(ckpt["state_dict"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is called every time the model is loaded from disk, even after finetuning. We should ensure we are not using the pretrained weights in any way after we do training, as the pretrained weights are no longer necessary.

Minimal reproducible example:

from autogluon.tabular import TabularPredictor, TabularDataset


if __name__ == '__main__':
    label = 'class'
    train_data = TabularDataset('https://autogluon.s3.amazonaws.com/datasets/Inc/train.csv')
    test_data = TabularDataset('https://autogluon.s3.amazonaws.com/datasets/Inc/test.csv')
    subsample_size = 500  # subsample subset of data for faster demo, try setting this to much larger values
    if subsample_size is not None and subsample_size < len(train_data):
        train_data = train_data.sample(n=subsample_size, random_state=0)

    hyperparameters = {
        'FT_TRANSFORMER': [
            {"model.ft_transformer.checkpoint_name": "https://automl-mm-bench.s3.amazonaws.com/ft_transformer_pretrained_ckpt/iter_2k.ckpt"},
        ],
    }
    predictor = TabularPredictor(
        label=label,
        eval_metric='roc_auc',
    )

    predictor: TabularPredictor = predictor.fit(train_data, hyperparameters=hyperparameters,)

    # predictor.leaderboard(data=test_data, display=True)

    print("###### Prior to predict  ######")
    predictor.predict(test_data)
    print("###### After Predict (1) ######")
    predictor.predict(test_data)
    print("###### After Predict (2) ######")
    predictor.predict(test_data)
    print("###### After Predict (3) ######")

Output:

TabularPredictor saved. To load, use: predictor = TabularPredictor.load("AutogluonModels/ag-20240122_233921")
###### Prior to predict  ######
Downloading /tmp/tmplmmfeb91/./ft_transformer_pretrained.ckpt from https://automl-mm-bench.s3.amazonaws.com/ft_transformer_pretrained_ckpt/iter_2k.ckpt...
100%|██████████| 3.13M/3.13M [00:00<00:00, 5.01MiB/s]
Load pretrained checkpoint: /home/ubuntu/workspace/code/scratch/AutogluonModels/ag-20240122_233921/models/FTTransformer/automm_model/model.ckpt
Predicting DataLoader 0: 100%|██████████| 20/20 [00:00<00:00, 47.68it/s]
###### After Predict (1) ######
Downloading /tmp/tmpf5zqen3i/./ft_transformer_pretrained.ckpt from https://automl-mm-bench.s3.amazonaws.com/ft_transformer_pretrained_ckpt/iter_2k.ckpt...
100%|██████████| 3.13M/3.13M [00:00<00:00, 6.26MiB/s]
Load pretrained checkpoint: /home/ubuntu/workspace/code/scratch/AutogluonModels/ag-20240122_233921/models/FTTransformer/automm_model/model.ckpt
Predicting DataLoader 0: 100%|██████████| 20/20 [00:00<00:00, 30.74it/s]
###### After Predict (2) ######
Downloading /tmp/tmp2479loia/./ft_transformer_pretrained.ckpt from https://automl-mm-bench.s3.amazonaws.com/ft_transformer_pretrained_ckpt/iter_2k.ckpt...
100%|██████████| 3.13M/3.13M [00:00<00:00, 4.46MiB/s]
Load pretrained checkpoint: /home/ubuntu/workspace/code/scratch/AutogluonModels/ag-20240122_233921/models/FTTransformer/automm_model/model.ckpt
Predicting DataLoader 0: 100%|██████████| 20/20 [00:00<00:00, 46.99it/s]
###### After Predict (3) ######

Copy link
Contributor

@Innixma Innixma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix to the model load!

Comment on lines 8 to 9
from autogluon.common.loaders._utils import download
from autogluon.common.utils.s3_utils import is_s3_url
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if self.numerical_feature_tokenizer:
self.numerical_adapter.apply(init_weights)
if self.categorical_feature_tokenizer:
self.categorical_adapter.apply(init_weights)
self.head.apply(init_weights)
# init transformer backbone from provided checkpoint
if pretrained and checkpoint_name:
if "https://" in checkpoint_name or is_s3_url(checkpoint_name):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -677,6 +678,31 @@ def test_load_ckpt():
npt.assert_equal(predictions_prob, predictions2_prob)


def test_fttransformer_load_ckpt():
download("https://automl-mm-bench.s3.amazonaws.com/ft_transformer_pretrained_ckpt/iter_2k.ckpt", "./")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove this line since we already support downloading it internally?

"source": [
"### model.ft_transformer.checkpoint_name\n",
"\n",
"Provide a pre-trained weights to initialize ft_transformer backbone."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Providing pre-trained weights is not accurate. Consider using local checkpoint path or a url?

"# by default, AutoMM doesn't use pre-trained weights\n",
"predictor.fit(hyperparameters={\"model.ft_transformer.checkpoint_name\": None})\n",
"# initialize the ft_transformer backbone from the give checkpoint\n",
"predictor.fit(hyperparameters={\"model.ft_transformer.checkpoint_name\": 'my_checkpoint.ckpt'})\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider providing an example with our s3 url.

Copy link

Job PR-3859-6de89e4 is done.
Docs are uploaded to http://autogluon-staging.s3-website-us-west-2.amazonaws.com/PR-3859/6de89e4/index.html

Copy link
Contributor

@zhiqiangdon zhiqiangdon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhiqiangdon zhiqiangdon merged commit f9c64e5 into autogluon:master Jan 26, 2024
8 checks passed
Copy link

Job PR-3859-d117ad4 is done.
Docs are uploaded to http://autogluon-staging.s3-website-us-west-2.amazonaws.com/PR-3859/d117ad4/index.html

LennartPurucker pushed a commit to LennartPurucker/autogluon that referenced this pull request Jun 1, 2024
autogluon#3859)

Co-authored-by: Zhiqiang Tang <zhiqiang.tang@rutgers.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
model list checked You have updated the model list after modifying multimodal unit tests/docs run-multi-gpu Run multimodal multi-gpu tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

How to read pre v1 FTTransformer weights with v1 autogluon
3 participants