Skip to content

Commit

Permalink
Update on "[Unification][ALBEF] Replace ALBEF multimodal encoder layer"
Browse files Browse the repository at this point in the history
## Summary
Replace `ALBEFTransformerLayerWithCrossAttention` with generalized `TransformerCrossAttentionLayer`.

## Test plan
`pytest test -vv`
```
======================================================================================== test session starts ========================================================================================
platform linux -- Python 3.9.12, pytest-7.1.1, pluggy-1.0.0 -- /fsx/users/rafiayub/conda/envs/torchmm/bin/python
cachedir: .pytest_cache
rootdir: /data/home/rafiayub/torchmultimodal, configfile: pyproject.toml
plugins: hydra-core-1.1.2, cov-3.0.0, mock-3.8.2
collected 219 items                                                                                                                                                                                 

test/architectures/test_late_fusion.py::TestLateFusion::test_forward PASSED                                                                                                                   [  0%]
test/architectures/test_late_fusion.py::TestLateFusion::test_missing_key_in_modalities PASSED                                                                                                 [  0%]
test/architectures/test_late_fusion.py::TestLateFusion::test_script PASSED                                                                                                                    [  1%]
test/architectures/test_two_tower.py::TestTwoTower::test_shared_two_tower PASSED                                                                                                              [  1%]
test/architectures/test_two_tower.py::TestTwoTower::test_two_tower PASSED                                                                                                                     [  2%]
test/architectures/test_two_tower.py::TestTwoTower::test_two_tower_scripting PASSED                                                                                                           [  2%]
test/models/test_albef.py::test_albef_image_embeddings PASSED                                                                                                                                 [  3%]
test/models/test_albef.py::test_albef_image_embeddings_momentum PASSED                                                                                                                        [  3%]
test/models/test_albef.py::test_albef_text_embeddings PASSED                                                                                                                                  [  4%]
test/models/test_albef.py::test_albef_text_embeddings_momentum PASSED                                                                                                                         [  4%]
test/models/test_albef.py::test_albef_multimodal_embeddings PASSED                                                                                                                            [  5%]
test/models/test_albef.py::test_albef_multimodal_embeddings_momentum PASSED                                                                                                                   [  5%]
test/models/test_albef.py::test_copy_params_momentum_models PASSED                                                                                                                            [  5%]
test/models/test_albef.py::test_dequeue_and_enqueue PASSED                                                                                                                                    [  6%]
test/models/test_albef.py::test_momentum_update PASSED                                                                                                                                        [  6%]
test/models/test_albef.py::test_similarity PASSED                                                                                                                                             [  7%]
test/models/test_albef.py::test_neg_embeddings PASSED                                                                                                                                         [  7%]
test/models/test_clip.py::TestCLIP::test_clip_forward PASSED                                                                                                                                  [  8%]
test/models/test_clip.py::TestCLIP::test_clip_resnet_forward PASSED                                                                                                                           [  8%]
test/models/test_clip.py::TestCLIP::test_clip_vit_forward PASSED                                                                                                                              [  9%]
test/models/test_gpt.py::TestMultimodalTransformerDecoder::test_bad_input PASSED                                                                                                              [  9%]
test/models/test_gpt.py::TestMultimodalTransformerDecoder::test_forward_in_modality PASSED                                                                                                    [ 10%]
test/models/test_gpt.py::TestMultimodalTransformerDecoder::test_forward_out_modality PASSED                                                                                                   [ 10%]
test/models/test_gpt.py::TestMultimodalTransformerDecoder::test_forward_two_modality PASSED                                                                                                   [ 10%]
test/models/test_gpt.py::TestMultimodalTransformerDecoder::test_bad_pos_ids PASSED                                                                                                            [ 11%]
test/models/test_gpt.py::TestMultimodalTransformerDecoder::test_optional_pos_ids PASSED                                                                                                       [ 11%]
test/models/test_gpt.py::TestTransformerDecoder::test_forward PASSED                                                                                                                          [ 12%]
test/models/test_gpt.py::TestTransformerDecoder::test_forward_additional_output PASSED                                                                                                        [ 12%]
test/models/test_gpt.py::TestTransformerDecoderLayer::test_forward PASSED                                                                                                                     [ 13%]
test/models/test_gpt.py::TestTransformerDecoderLayer::test_forward_masked PASSED                                                                                                              [ 13%]
test/models/test_gpt.py::TestTransformerDecoderLayer::test_forward_additional_output PASSED                                                                                                   [ 14%]
test/models/test_gpt.py::test_sigmoid_linear_unit PASSED                                                                                                                                      [ 14%]
test/models/test_gpt.py::test_right_shift PASSED                                                                                                                                              [ 15%]
test/models/test_mdetr.py::TestMDETR::test_transformer_encoder PASSED                                                                                                                         [ 15%]
test/models/test_mdetr.py::TestMDETR::test_transformer_decoder PASSED                                                                                                                         [ 15%]
test/models/test_mdetr.py::TestMDETR::test_full_mdetr_model PASSED                                                                                                                            [ 16%]
test/models/test_omnivore.py::test_omnivore_swin_t_forward PASSED                                                                                                                             [ 16%]
test/models/test_omnivore.py::test_omnivore_swin_s_forward PASSED                                                                                                                             [ 17%]
test/models/test_omnivore.py::test_omnivore_swin_b_forward PASSED                                                                                                                             [ 17%]
test/models/test_omnivore.py::test_omnivore_forward_wrong_input_type PASSED                                                                                                                   [ 18%]
test/models/test_video_vqvae.py::TestAttentionResidualBlock::test_hidden_dim_assertion PASSED                                                                                                 [ 18%]
test/models/test_video_vqvae.py::TestAttentionResidualBlock::test_forward PASSED                                                                                                              [ 19%]
test/models/test_video_vqvae.py::TestVideoEncoder::test_forward PASSED                                                                                                                        [ 19%]
test/models/test_video_vqvae.py::TestVideoDecoder::test_forward PASSED                                                                                                                        [ 20%]
test/models/test_video_vqvae.py::TestVideoVQVAE::test_encode PASSED                                                                                                                           [ 20%]
test/models/test_video_vqvae.py::TestVideoVQVAE::test_decode PASSED                                                                                                                           [ 21%]
test/models/test_video_vqvae.py::TestVideoVQVAE::test_tokenize PASSED                                                                                                                         [ 21%]
test/models/test_video_vqvae.py::TestVideoVQVAE::test_forward PASSED                                                                                                                          [ 21%]
test/models/test_vqvae.py::TestVQVAE::test_encode PASSED                                                                                                                                      [ 22%]
test/models/test_vqvae.py::TestVQVAE::test_decode PASSED                                                                                                                                      [ 22%]
test/models/test_vqvae.py::TestVQVAE::test_tokenize PASSED                                                                                                                                    [ 23%]
test/models/test_vqvae.py::TestVQVAE::test_forward PASSED                                                                                                                                     [ 23%]
test/models/flava/test_flava.py::TestFLAVA::test_forward_classification PASSED                                                                                                                [ 24%]
test/models/flava/test_flava.py::TestFLAVA::test_forward_pretraining PASSED                                                                                                                   [ 24%]
test/models/flava/test_flava.py::TestFLAVAModel::test_forward_image PASSED                                                                                                                    [ 25%]
test/models/flava/test_flava.py::TestFLAVAModel::test_forward_image_text PASSED                                                                                                               [ 25%]
test/models/flava/test_flava.py::TestFLAVAModel::test_forward_masked_image PASSED                                                                                                             [ 26%]
test/models/flava/test_flava.py::TestFLAVAModel::test_forward_masked_image_and_text PASSED                                                                                                    [ 26%]
test/models/flava/test_flava.py::TestFLAVAModel::test_forward_masked_text PASSED                                                                                                              [ 26%]
test/models/flava/test_flava.py::TestFLAVAModel::test_forward_text PASSED                                                                                                                     [ 27%]
test/models/flava/test_flava_checkpoint.py::TestFLAVACheckpoint::test_flava_model_for_classification PASSED                                                                                   [ 27%]
test/models/flava/test_flava_checkpoint.py::TestFLAVACheckpoint::test_flava_model_for_pretraining PASSED                                                                                      [ 28%]
test/models/flava/test_flava_image_encoder.py::TestFlavaImageEncoder::test_embedding PASSED                                                                                                   [ 28%]
test/models/flava/test_flava_image_encoder.py::TestFlavaImageEncoder::test_image_encoder PASSED                                                                                               [ 29%]
test/models/flava/test_flava_text_encoder.py::TestFlavaTextEncoder::test_embedding PASSED                                                                                                     [ 29%]
test/models/flava/test_flava_text_encoder.py::TestFlavaTextEncoder::test_text_transformer PASSED                                                                                              [ 30%]
test/models/flava/test_flava_text_encoder.py::TestFlavaTextEncoder::test_text_transformer_attn_mask PASSED                                                                                    [ 30%]
test/modules/encoders/test_albef_multimodal_encoder.py::test_multimodal_encoder PASSED                                                                                                        [ 31%]
test/modules/encoders/test_albef_multimodal_encoder.py::test_invalid_image_hidden_size PASSED                                                                                                 [ 31%]
test/modules/encoders/test_albef_multimodal_encoder.py::test_invalid_text_hidden_size PASSED                                                                                                  [ 31%]
test/modules/encoders/test_albef_multimodal_encoder.py::test_not_matching_input_batch_size PASSED                                                                                             [ 32%]
test/modules/encoders/test_albef_text_encoder.py::test_text_encoder PASSED                                                                                                                    [ 32%]
test/modules/encoders/test_albef_text_encoder.py::test_invalid_input_length PASSED                                                                                                            [ 33%]
test/modules/encoders/test_albef_text_encoder.py::test_not_matching_attention_mask_shape PASSED                                                                                               [ 33%]
test/modules/encoders/test_albef_vision_encoder.py::TestALBEFVisionEncoder::test_vision_transformer PASSED                                                                                    [ 34%]
test/modules/encoders/test_albef_vision_encoder.py::TestALBEFVisionEncoder::test_invalid_input_length PASSED                                                                                  [ 34%]
test/modules/encoders/test_albef_vision_encoder.py::TestALBEFVisionEncoder::test_invalid_image_channel_dim PASSED                                                                             [ 35%]
test/modules/encoders/test_albef_vision_encoder.py::TestALBEFVisionEncoder::test_invalid_image_height PASSED                                                                                  [ 35%]
test/modules/encoders/test_albef_vision_encoder.py::TestALBEFVisionEncoder::test_invalid_image_width PASSED                                                                                   [ 36%]
test/modules/encoders/test_clip_resnet_encoder.py::TestCLIPModule::test_resnet PASSED                                                                                                         [ 36%]
test/modules/encoders/test_clip_text_encoder.py::TestCLIPTextEncoder::test_clip_parameters PASSED                                                                                             [ 36%]
test/modules/encoders/test_clip_text_encoder.py::TestCLIPTextEncoder::test_attention_mask PASSED                                                                                              [ 37%]
test/modules/encoders/test_clip_text_encoder.py::TestCLIPTextEncoder::test_forward PASSED                                                                                                     [ 37%]
test/modules/encoders/test_clip_text_encoder.py::TestCLIPTextEncoder::test_forward_over_context_length PASSED                                                                                 [ 38%]
test/modules/encoders/test_clip_text_encoder.py::TestCLIPTextEncoder::test_scripting PASSED                                                                                                   [ 38%]
test/modules/encoders/test_embedding_encoder.py::TestEmbeddingEncoder::test_embedding_encoder_hash PASSED                                                                                     [ 39%]
test/modules/encoders/test_embedding_encoder.py::TestEmbeddingEncoder::test_embedding_encoder_invalid_pooling PASSED                                                                          [ 39%]
test/modules/encoders/test_embedding_encoder.py::TestEmbeddingEncoder::test_embedding_encoder_max PASSED                                                                                      [ 40%]
test/modules/encoders/test_embedding_encoder.py::TestEmbeddingEncoder::test_embedding_encoder_mean PASSED                                                                                     [ 40%]
test/modules/encoders/test_embedding_encoder.py::TestEmbeddingEncoder::test_embedding_encoder_sum PASSED                                                                                      [ 41%]
test/modules/encoders/test_mdetr_image_encoder.py::TestMDETRImageEncoder::test_resnet_101_forward PASSED                                                                                      [ 41%]
test/modules/encoders/test_mdetr_text_encoder.py::TestMDETRTextEncoder::test_mdetr_modified_transformer PASSED                                                                                [ 42%]
test/modules/encoders/test_mdetr_text_encoder.py::TestMDETRTextEncoder::test_mdetr_text_embeddings PASSED                                                                                     [ 42%]
test/modules/encoders/test_mdetr_text_encoder.py::TestMDETRTextEncoder::test_mdetr_text_encoder PASSED                                                                                        [ 42%]
test/modules/encoders/test_mil_encoder.py::TestMILEncoder::test_forward PASSED                                                                                                                [ 43%]
test/modules/encoders/test_mil_encoder.py::TestMILEncoder::test_invalid_partitioning PASSED                                                                                                   [ 43%]
test/modules/encoders/test_mil_encoder.py::TestMILEncoder::test_scripting PASSED                                                                                                              [ 44%]
test/modules/encoders/test_mil_encoder.py::TestMILEncoder::test_transformer_pooling PASSED                                                                                                    [ 44%]
test/modules/encoders/test_swin_transformer_3d_encoder.py::TestSwinTransformer3d::test_swin_transformer_3d_encoder PASSED                                                                     [ 45%]
test/modules/encoders/test_swin_transformer_3d_encoder.py::TestSwinTransformer3d::test_swin_transformer_3d_scripting PASSED                                                                   [ 45%]
test/modules/encoders/test_swin_transformer_3d_encoder.py::TestSwinTransformer3dComponents::test_patch_merging_3d PASSED                                                                      [ 46%]
test/modules/encoders/test_swin_transformer_3d_encoder.py::TestSwinTransformer3dComponents::test_shifted_window_attention_3d PASSED                                                           [ 46%]
test/modules/encoders/test_swin_transformer_3d_encoder.py::TestSwinTransformer3dComponents::test_shifted_window_attention_3d_zero_shift PASSED                                                [ 47%]
test/modules/encoders/test_weighted_embedding_encoder.py::TestEmbeddingEncoder::test_forward_max_pooling PASSED                                                                               [ 47%]
test/modules/encoders/test_weighted_embedding_encoder.py::TestEmbeddingEncoder::test_forward_mean_pooling PASSED                                                                              [ 47%]
test/modules/encoders/test_weighted_embedding_encoder.py::TestEmbeddingEncoder::test_forward_sum_pooling PASSED                                                                               [ 48%]
test/modules/encoders/test_weighted_embedding_encoder.py::TestEmbeddingEncoder::test_scripting PASSED                                                                                         [ 48%]
test/modules/fusions/test_attention_fusion.py::TestAttentionFusionModule::test_input_projection_dim PASSED                                                                                    [ 49%]
test/modules/fusions/test_attention_fusion.py::TestAttentionFusionModule::test_no_projection_dim PASSED                                                                                       [ 49%]
test/modules/fusions/test_attention_fusion.py::TestAttentionFusionModule::test_scripted_model PASSED                                                                                          [ 50%]
test/modules/fusions/test_deepset_fusion.py::TestDeepSetFusionModule::test_deepset_apply_attention PASSED                                                                                     [ 50%]
test/modules/fusions/test_deepset_fusion.py::TestDeepSetFusionModule::test_deepset_auto_mapping PASSED                                                                                        [ 51%]
test/modules/fusions/test_deepset_fusion.py::TestDeepSetFusionModule::test_deepset_invalid_pooling PASSED                                                                                     [ 51%]
test/modules/fusions/test_deepset_fusion.py::TestDeepSetFusionModule::test_deepset_max PASSED                                                                                                 [ 52%]
test/modules/fusions/test_deepset_fusion.py::TestDeepSetFusionModule::test_deepset_mean PASSED                                                                                                [ 52%]
test/modules/fusions/test_deepset_fusion.py::TestDeepSetFusionModule::test_deepset_median PASSED                                                                                              [ 52%]
test/modules/fusions/test_deepset_fusion.py::TestDeepSetFusionModule::test_deepset_min PASSED                                                                                                 [ 53%]
test/modules/fusions/test_deepset_fusion.py::TestDeepSetFusionModule::test_deepset_modality_normalize PASSED                                                                                  [ 53%]
test/modules/fusions/test_deepset_fusion.py::TestDeepSetFusionModule::test_deepset_sum PASSED                                                                                                 [ 54%]
test/modules/fusions/test_deepset_fusion.py::TestDeepSetFusionModule::test_deepset_transformer PASSED                                                                                         [ 54%]
test/modules/fusions/test_deepset_fusion.py::TestDeepSetFusionModule::test_get_deepset_transformer PASSED                                                                                     [ 55%]
test/modules/fusions/test_deepset_fusion.py::TestDeepSetFusionModule::test_torchscript PASSED                                                                                                 [ 55%]
test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_self_attention PASSED                                                                                          [ 56%]
test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_cross_attention PASSED                                                                                         [ 56%]
test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_attention_use_cache PASSED                                                                                     [ 57%]
test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_attention_causal_use_cache PASSED                                                                              [ 57%]
test/modules/layers/test_attention.py::TestScaledDotProductAttention::test_scaled_dot_product_attention PASSED                                                                                [ 57%]
test/modules/layers/test_attention.py::TestScaledDotProductAttention::test_scaled_dot_product_attention_with_attention_mask PASSED                                                            [ 58%]
test/modules/layers/test_attention.py::TestScaledDotProductAttention::test_scaled_dot_product_attention_with_head_mask PASSED                                                                 [ 58%]
test/modules/layers/test_attention.py::TestScaledDotProductAttention::test_scaled_dot_product_attention_with_dropout PASSED                                                                   [ 59%]
test/modules/layers/test_attention.py::test_self_attention PASSED                                                                                                                             [ 59%]
test/modules/layers/test_attention.py::test_axial_attention PASSED                                                                                                                            [ 60%]
test/modules/layers/test_attention.py::test_split_multihead PASSED                                                                                                                            [ 60%]
test/modules/layers/test_attention.py::test_merge_multihead PASSED                                                                                                                            [ 61%]
test/modules/layers/test_attention.py::TestAxialBlock::test_axial_block_forward PASSED                                                                                                        [ 61%]
test/modules/layers/test_attention.py::TestAxialBlock::test_axial_block_channel_dim PASSED                                                                                                    [ 62%]
test/modules/layers/test_codebook.py::TestCodebook::test_codebook_restart PASSED                                                                                                              [ 62%]
test/modules/layers/test_codebook.py::TestCodebook::test_ema_update_embedding PASSED                                                                                                          [ 63%]
test/modules/layers/test_codebook.py::TestCodebook::test_init_embedding_and_preprocess PASSED                                                                                                 [ 63%]
test/modules/layers/test_codebook.py::TestCodebook::test_init_embedding_smaller_encoded PASSED                                                                                                [ 63%]
test/modules/layers/test_codebook.py::TestCodebook::test_postprocess PASSED                                                                                                                   [ 64%]
test/modules/layers/test_codebook.py::TestCodebook::test_preprocess PASSED                                                                                                                    [ 64%]
test/modules/layers/test_codebook.py::TestCodebook::test_preprocess_channel_dim_assertion PASSED                                                                                              [ 65%]
test/modules/layers/test_codebook.py::TestCodebook::test_quantized_output PASSED                                                                                                              [ 65%]
test/modules/layers/test_codebook.py::TestCodebook::test_register_buffer_tensors PASSED                                                                                                       [ 66%]
test/modules/layers/test_conv.py::TestSamePadConv3d::test_calculate_same_padding_assert PASSED                                                                                                [ 66%]
test/modules/layers/test_conv.py::TestSamePadConv3d::test_calculate_same_padding_output PASSED                                                                                                [ 67%]
test/modules/layers/test_conv.py::TestSamePadConv3d::test_calculate_transpose_padding_assert PASSED                                                                                           [ 67%]
test/modules/layers/test_conv.py::TestSamePadConv3d::test_calculate_transpose_padding_output PASSED                                                                                           [ 68%]
test/modules/layers/test_conv.py::TestSamePadConv3d::test_samepadconv3d_forward PASSED                                                                                                        [ 68%]
test/modules/layers/test_conv.py::TestSamePadConv3d::test_samepadconvtranspose3d_forward PASSED                                                                                               [ 68%]
test/modules/layers/test_mlp.py::TestMLP::test_activation_and_normalization PASSED                                                                                                            [ 69%]
test/modules/layers/test_mlp.py::TestMLP::test_dropout_default PASSED                                                                                                                         [ 69%]
test/modules/layers/test_mlp.py::TestMLP::test_no_dropout PASSED                                                                                                                              [ 70%]
test/modules/layers/test_mlp.py::TestMLP::test_no_hidden_layers PASSED                                                                                                                        [ 70%]
test/modules/layers/test_mlp.py::TestMLP::test_pass_hidden_dims PASSED                                                                                                                        [ 71%]
test/modules/layers/test_mlp.py::TestMLP::test_torchscript PASSED                                                                                                                             [ 71%]
test/modules/layers/test_position_embedding.py::TestBroadcastedPositionEmbedding::test_init_sets_embedding PASSED                                                                             [ 72%]
test/modules/layers/test_position_embedding.py::TestBroadcastedPositionEmbedding::test_init_bad_embedding_dim PASSED                                                                          [ 72%]
test/modules/layers/test_position_embedding.py::TestBroadcastedPositionEmbedding::test_broadcast PASSED                                                                                       [ 73%]
test/modules/layers/test_position_embedding.py::TestBroadcastedPositionEmbedding::test_forward PASSED                                                                                         [ 73%]
test/modules/layers/test_position_embedding.py::TestBroadcastedPositionEmbedding::test_forward_invalid_input PASSED                                                                           [ 73%]
test/modules/layers/test_transformer.py::TestFLAVATransformerEncoder::test_flava_encoder_forward PASSED                                                                                       [ 74%]
test/modules/layers/test_transformer.py::TestTransformerEncoderLayer::test_attention_block PASSED                                                                                             [ 74%]
test/modules/layers/test_transformer.py::TestTransformerEncoderLayer::test_feedforward_block PASSED                                                                                           [ 75%]
test/modules/layers/test_transformer.py::TestTransformerEncoderLayer::test_forward_prenorm PASSED                                                                                             [ 75%]
test/modules/layers/test_transformer.py::TestTransformerEncoderLayer::test_forward_postnorm PASSED                                                                                            [ 76%]
test/modules/layers/test_transformer.py::TestTransformerCrossAttentionLayer::test_self_attention_block PASSED                                                                                 [ 76%]
test/modules/layers/test_transformer.py::TestTransformerCrossAttentionLayer::test_cross_attention_block PASSED                                                                                [ 77%]
test/modules/layers/test_transformer.py::TestTransformerCrossAttentionLayer::test_feedforward_block PASSED                                                                                    [ 77%]
test/modules/layers/test_transformer.py::TestTransformerCrossAttentionLayer::test_forward_prenorm PASSED                                                                                      [ 78%]
test/modules/layers/test_transformer.py::TestTransformerCrossAttentionLayer::test_forward_postnorm PASSED                                                                                     [ 78%]
test/modules/layers/test_transformer.py::test_apply_layernorm PASSED                                                                                                                          [ 78%]
test/modules/losses/test_albef.py::TestImageTextContrastiveLoss::test_itc_loss_invalid_sim PASSED                                                                                             [ 79%]
test/modules/losses/test_albef.py::TestImageTextContrastiveLoss::test_itc_loss_missing_sim_m PASSED                                                                                           [ 79%]
test/modules/losses/test_albef.py::TestImageTextContrastiveLoss::test_itc_loss_invalid_sim_m PASSED                                                                                           [ 80%]
test/modules/losses/test_albef.py::TestImageTextContrastiveLoss::test_itc_loss_invalid_sim_target PASSED                                                                                      [ 80%]
test/modules/losses/test_albef.py::TestImageTextContrastiveLoss::test_itc_loss_without_distillation PASSED                                                                                    [ 81%]
test/modules/losses/test_albef.py::TestImageTextContrastiveLoss::test_itc_loss_with_distillation PASSED                                                                                       [ 81%]
test/modules/losses/test_albef.py::TestImageTextContrastiveLoss::test_itc_loss_with_sim_targets PASSED                                                                                        [ 82%]
test/modules/losses/test_albef.py::TestImageTextMatchingLoss::test_itm_loss_invalid_input_hidden_size PASSED                                                                                  [ 82%]
test/modules/losses/test_albef.py::TestImageTextMatchingLoss::test_itm_loss PASSED                                                                                                            [ 83%]
test/modules/losses/test_albef.py::TestMaskedLanguageModelingLoss::test_mlm_loss_invalid_labels PASSED                                                                                        [ 83%]
test/modules/losses/test_albef.py::TestMaskedLanguageModelingLoss::test_mlm_loss_invalid_embeddings PASSED                                                                                    [ 84%]
test/modules/losses/test_albef.py::TestMaskedLanguageModelingLoss::test_mlm_loss_missing_momentum_embeddings PASSED                                                                           [ 84%]
test/modules/losses/test_albef.py::TestMaskedLanguageModelingLoss::test_mlm_loss PASSED                                                                                                       [ 84%]
test/modules/losses/test_albef.py::TestMaskedLanguageModelingLoss::test_mlm_loss_with_distillation PASSED                                                                                     [ 85%]
test/modules/losses/test_commitment.py::TestCommitment::test_loss_value PASSED                                                                                                                [ 85%]
test/modules/losses/test_contrastive_loss_with_temperature.py::TestContrastiveLossWithTemperature::test_local_loss PASSED                                                                     [ 86%]
test/modules/losses/test_contrastive_loss_with_temperature.py::TestContrastiveLossWithTemperature::test_multi_gpu_loss SKIPPED (Not enough GPUs to run the test: required 2)                  [ 86%]
test/modules/losses/test_contrastive_loss_with_temperature.py::TestContrastiveLossWithTemperature::test_single_gpu_loss SKIPPED (Not enough GPUs to run the test: required 1)                 [ 87%]
test/modules/losses/test_contrastive_loss_with_temperature.py::TestContrastiveLossWithTemperature::test_temperature_clamp_invalid PASSED                                                      [ 87%]
test/modules/losses/test_contrastive_loss_with_temperature.py::TestContrastiveLossWithTemperature::test_temperature_clamp_max PASSED                                                          [ 88%]
test/modules/losses/test_contrastive_loss_with_temperature.py::TestContrastiveLossWithTemperature::test_temperature_clamp_min PASSED                                                          [ 88%]
test/modules/losses/test_mdetr_losses.py::TestMDETRLosses::test_soft_token_prediction_loss PASSED                                                                                             [ 89%]
test/modules/losses/test_mdetr_losses.py::TestMDETRLosses::test_box_losses PASSED                                                                                                             [ 89%]
test/transforms/test_bert_text_transform.py::TestBertTextTransform::test_single_transform PASSED                                                                                              [ 89%]
test/transforms/test_bert_text_transform.py::TestBertTextTransform::test_multi_transform PASSED                                                                                               [ 90%]
test/transforms/test_clip_transform.py::TestCLIPTransform::test_clip_multi_transform PASSED                                                                                                   [ 90%]
test/transforms/test_clip_transform.py::TestCLIPTransform::test_clip_single_transform PASSED                                                                                                  [ 91%]
test/transforms/test_video_transform.py::TestVideoTransform::test_call PASSED                                                                                                                 [ 91%]
test/transforms/test_video_transform.py::TestVideoTransform::test_wrong_channels PASSED                                                                                                       [ 92%]
test/transforms/test_video_transform.py::TestVideoTransform::test_sample_frames PASSED                                                                                                        [ 92%]
test/transforms/test_video_transform.py::TestVideoTransform::test_resize_hw PASSED                                                                                                            [ 93%]
test/transforms/test_video_transform.py::TestVideoTransform::test_normalize PASSED                                                                                                            [ 93%]
test/utils/test_assertion.py::TestAssertEqualLengths::test_different_lengths PASSED                                                                                                           [ 94%]
test/utils/test_assertion.py::TestAssertEqualLengths::test_same_lengths PASSED                                                                                                                [ 94%]
test/utils/test_attention_utils.py::test_get_causal_attention_masks PASSED                                                                                                                    [ 94%]
test/utils/test_ckpt_load.py::test_load_module_from_url PASSED                                                                                                                                [ 95%]
test/utils/test_common.py::test_shift_dim PASSED                                                                                                                                              [ 95%]
test/utils/test_common.py::TestTensorSlice::test_default PASSED                                                                                                                               [ 96%]
test/utils/test_common.py::TestTensorSlice::test_size_minus_one PASSED                                                                                                                        [ 96%]
test/utils/test_common.py::TestTensorSlice::test_uneven_begin_size PASSED                                                                                                                     [ 97%]
test/utils/test_common.py::TestTensorSlice::test_invalid_begin XFAIL (Invalid begin)                                                                                                          [ 97%]
test/utils/test_common.py::TestTensorSlice::test_invalid_size XFAIL (Invalid size)                                                                                                            [ 98%]
test/utils/test_common.py::TestToTupleTuple::test_int PASSED                                                                                                                                  [ 98%]
test/utils/test_common.py::TestToTupleTuple::test_tuple PASSED                                                                                                                                [ 99%]
test/utils/test_common.py::TestCheckpointWrapper::test_training_mode PASSED                                                                                                                   [ 99%]
test/utils/test_common.py::TestCheckpointWrapper::test_eval_model PASSED                                                                                                                      [100%]

========================================================================================= warnings summary ==========================================================================================
test/models/flava/test_flava_checkpoint.py::TestFLAVACheckpoint::test_flava_model_for_pretraining
  /data/home/rafiayub/torchmultimodal/test/models/flava/test_flava_checkpoint.py:81: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
    else torch.tensor(dict_actual[key])

test/modules/encoders/test_clip_text_encoder.py::TestCLIPTextEncoder::test_scripting
test/modules/fusions/test_deepset_fusion.py::TestDeepSetFusionModule::test_torchscript
  /fsx/users/rafiayub/conda/envs/torchmm/lib/python3.9/site-packages/torch/jit/_recursive.py:246: UserWarning: 'batch_first' was found in ScriptModule constants, but was not actually set in __init__. Consider removing it.
    warnings.warn("'{}' was found in ScriptModule constants, "

test/modules/fusions/test_deepset_fusion.py::TestDeepSetFusionModule::test_torchscript
  /fsx/users/rafiayub/conda/envs/torchmm/lib/python3.9/site-packages/torch/jit/_recursive.py:240: UserWarning: 'norm' was found in ScriptModule constants,  but it is a non-constant submodule. Consider removing it.
    warnings.warn("'{}' was found in ScriptModule constants, "

test/modules/layers/test_codebook.py::TestCodebook::test_register_buffer_tensors
  /data/home/rafiayub/torchmultimodal/test/modules/layers/test_codebook.py:169: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at /opt/conda/conda-bld/pytorch_1656918644659/work/build/aten/src/ATen/core/TensorBody.h:478.)
    assert not self.vq.code_avg.grad, msg_has_grad

test/modules/layers/test_codebook.py::TestCodebook::test_register_buffer_tensors
  /data/home/rafiayub/torchmultimodal/test/modules/layers/test_codebook.py:171: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at /opt/conda/conda-bld/pytorch_1656918644659/work/build/aten/src/ATen/core/TensorBody.h:478.)
    assert not self.vq.embedding.grad, msg_has_grad

test/modules/layers/test_conv.py::TestSamePadConv3d::test_samepadconv3d_forward
  /data/home/rafiayub/torchmultimodal/torchmultimodal/modules/layers/conv.py:49: UserWarning: Padding was specified but will not be used in favor of same padding,                 use Conv3d directly for custom padding
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
====================================================================== 215 passed, 2 skipped, 2 xfailed, 7 warnings in 55.69s =======================================================================
```

Differential Revision: [D38392153](https://our.internmc.facebook.com/intern/diff/D38392153)

[ghstack-poisoned]
  • Loading branch information
RdoubleA committed Aug 4, 2022
2 parents a0a1f7f + 68223a4 commit cdd699d
Show file tree
Hide file tree
Showing 4 changed files with 390 additions and 48 deletions.
368 changes: 368 additions & 0 deletions examples/mugen/retrieval/evaluation.ipynb

Large diffs are not rendered by default.

22 changes: 5 additions & 17 deletions examples/mugen/retrieval/video_clip.py
Expand Up @@ -12,7 +12,7 @@
from torch import nn

from torchmultimodal.models.clip import CLIP
from torchmultimodal.utils.common import PretrainedMixin
from torchmultimodal.utils.common import load_module_from_url
from transformers import DistilBertConfig, DistilBertModel


Expand All @@ -21,19 +21,7 @@
)


class HuggingFaceMixin(PretrainedMixin):
"""Interface to loading pretrained model from HuggingFace.
Inputs:
model_name (str): name of pretrained model.
"""

def load_model(self, model_name: str):
self.model = DistilBertModel.from_pretrained(model_name)


class TextEncoder(nn.Module, HuggingFaceMixin):
class TextEncoder(nn.Module):
"""Encode tokenized text to the last hidden state representation of the CLS token using
DistilBERT. DistilBERT prepends a CLS (classification) token to every text so the
token's hidden state represents the entire text.
Expand Down Expand Up @@ -81,7 +69,7 @@ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
return last_hidden_state[:, self.target_token_idx, :]


class VideoEncoder(nn.Module, PretrainedMixin):
class VideoEncoder(nn.Module):
"""Encode videos to the last layer before the fully-connected layer of S3D.
Adapted from MUGEN's video encoder
Expand Down Expand Up @@ -192,7 +180,7 @@ def videoclip(
)
if text_pretrained:
print(f"Loading pretrained DistilBERT from {text_model_name}.")
text_model.load_model(text_model_name)
text_model.model = DistilBertModel.from_pretrained(text_model_name)
if text_pretrained and not text_trainable:
# check `text_pretrained` because if model isn't pretrained, then it should be trainable
for p in text_model.model.parameters():
Expand All @@ -208,7 +196,7 @@ def videoclip(
video_model = VideoEncoder()
if video_pretrained:
print(f"Loading pretrained video encoder from {video_pretrain_path}.")
video_model.load_model(video_pretrain_path)
load_module_from_url(video_model, video_pretrain_path)
if video_pretrained and not video_trainable:
# check `video_pretrained` because if model isn't pretrained, then it should be trainable
for p in video_model.model.parameters():
Expand Down
46 changes: 16 additions & 30 deletions examples/mugen/test/retrieval/test_video_clip.py
Expand Up @@ -4,9 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
from typing import Optional

import pytest
import torch
from examples.mugen.retrieval.video_clip import (
Expand All @@ -17,39 +14,28 @@
)

from test.test_utils import assert_expected, get_asset_path, set_rng_seed
from torchmultimodal import _PATH_MANAGER
from torchmultimodal.utils.common import shift_dim


def patch_load_model(mocker):
"""Mock the ``load_model`` function of ``VideoEncoder`` to allow loading truncated
state dicts with ``strict=False``.
def patch_load_module_from_url(mocker):
"""Mock the ``load_module_from_url`` utility function used in ``videoclip()`` to allow
loading truncated state dicts with ``strict=False``.
"""

def patched_load_model(
cls,
pretrained_url: Optional[str],
load_state_dict: bool = True,
state_dict_key: Optional[str] = None,
):
assert isinstance(
cls, torch.nn.Module
), "load_model can only be called on an nn.Module instance"
if os.path.exists(pretrained_url):
state_dict = torch.load(pretrained_url)
def patched_load_module_from_url(
model: torch.nn.Module, url: str, strict: bool = True, progress: bool = True
) -> None:
local_path = _PATH_MANAGER.get_local_path(url)
if not torch.cuda.is_available():
state_dict = torch.load(local_path, map_location=torch.device("cpu"))
else:
state_dict = torch.hub.load_state_dict_from_url(
pretrained_url, model_dir=cls.get_model_dir(pretrained_url)
)
if state_dict_key:
state_dict = state_dict[state_dict_key]

if load_state_dict:
cls.load_state_dict(state_dict, strict=False)
return state_dict
state_dict = torch.load(local_path)
model.load_state_dict(state_dict, strict=False)

return mocker.patch(
"examples.mugen.retrieval.video_clip.VideoEncoder.load_model",
new=patched_load_model,
"examples.mugen.retrieval.video_clip.load_module_from_url",
new=patched_load_module_from_url,
)


Expand Down Expand Up @@ -166,7 +152,7 @@ def utils(self, set_seed):

def test_forward_pretrained_trainable(self, utils, mocker):
input_text, input_video = utils
patch_load_model(mocker)
patch_load_module_from_url(mocker)
model = videoclip(
video_pretrain_path=get_asset_path("S3D_sample.pt"), proj_out_dim=3
)
Expand All @@ -193,7 +179,7 @@ def test_forward_pretrained_trainable(self, utils, mocker):
)

def test_pretrained_untrainable(self, mocker):
patch_load_model(mocker)
patch_load_module_from_url(mocker)
model = videoclip(
text_trainable=False,
video_trainable=False,
Expand Down
2 changes: 1 addition & 1 deletion torchmultimodal/models/flava/flava_model.py
Expand Up @@ -62,7 +62,7 @@
FLAVA_FOR_PRETRAINED_MAPPING = {
# This will no longer load with the updated model, but keeping here just in case
# "flava_full": "https://huggingface.co/aps/flava_full_pretrained_encoders_torchmm/resolve/main/pytorch_model.bin",
"flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining_encoder_layer.pt",
"flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining_unified.pt",
}


Expand Down

0 comments on commit cdd699d

Please sign in to comment.