From 2667aa009d1cb228fa589d79c5615d0b66af1118 Mon Sep 17 00:00:00 2001 From: Prabhu Teja Date: Sun, 10 Sep 2023 16:30:46 +0000 Subject: [PATCH 1/3] abstracting prompting transformer --- src/renate/benchmark/models/l2p.py | 220 +++++++++++++++-------- src/renate/updaters/experimental/l2p.py | 2 +- test/renate/benchmark/models/test_l2p.py | 26 ++- 3 files changed, 169 insertions(+), 79 deletions(-) diff --git a/src/renate/benchmark/models/l2p.py b/src/renate/benchmark/models/l2p.py index cd0f8379..0c2ac001 100644 --- a/src/renate/benchmark/models/l2p.py +++ b/src/renate/benchmark/models/l2p.py @@ -108,6 +108,123 @@ def forward(self, x: torch.Tensor, manual_prompt_indices: Optional[torch.LongTen return selected_prompts, loss_value +class PromptedTransformer(nn.Module): + """This generic module is the basic prompted transformer. It takes in a model string and creates + the appropriate transformer. If not prompted, it returns features, and if prompted, it returns + the full feature using those prompts and the input image/text. + + Args: + pretrained_model_name_or_path: A string that denotes which pretrained model from the HF hub + to use. + num_outputs: Size of the output. + prediction_strategy: Continual learning strategies may alter the prediction at train or test + time. + add_icarl_class_means: If ``True``, additional parameters used only by the + ``ICaRLModelUpdater`` are added. Only required when using that updater. + """ + + def __init__( + self, + pretrained_model_name_or_path="google/vit-base-patch16-224", + image_size: int = 32, + patch_size: int = 4, + num_layers: int = 12, + num_heads: int = 12, + hidden_dim: int = 768, + mlp_dim: int = 3072, + dropout: float = 0.1, + attention_dropout: float = 0.1, + num_outputs: int = 10, + prediction_strategy: Optional[PredictionStrategy] = None, + add_icarl_class_means: bool = True, + ) -> None: + super().__init__() + if "vit" in pretrained_model_name_or_path: + self.transformer = VisionTransformer( + pretrained_model_name_or_path=pretrained_model_name_or_path, + image_size=image_size, + patch_size=patch_size, + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + dropout=dropout, + attention_dropout=attention_dropout, + num_outputs=num_outputs, + prediction_strategy=prediction_strategy, + add_icarl_class_means=add_icarl_class_means, + ) + self._is_text_transformer = False + else: + self.transformer = HuggingFaceSequenceClassificationTransformer( + pretrained_model_name_or_path=pretrained_model_name_or_path, + num_outputs=num_outputs, + prediction_strategy=prediction_strategy, + add_icarl_class_means=add_icarl_class_means, + ) + for named_param, value in self.transformer.named_parameters(): + if value.shape[0] == self.transformer._backbone.config.vocab_size: + self.word_embeddings = self.transformer.get_submodule( + named_param.replace(".weight", "") + ) + break + + self._is_text_transformer = True + + self.transformer._tasks_params.clear() + self.transformer.eval() + for p in self.transformer.parameters(): + p.requires_grad_(False) + + def forward( + self, x: torch.Tensor, prompt: Optional[torch.Tensor] = None, cls_feat: bool = True + ) -> torch.Tensor: + """ + Args: + x: Input torch tensor. + prompt: Prompt tensor. Defaults to None. + cls_feat: Whether to extract [CLS] token or to return full feature tensor. + Ignored for text transformer. Defaults to True. + """ + if prompt is None: + return ( + self.transformer.get_features(x) + if self._is_text_transformer + else self.transformer.get_features(x, cls_feat=cls_feat) + ) + # text transformers dont support cls_feat. + else: + if self._is_text_transformer: + # The implicit assumption here is that x for text transformers is the input_ids. + # This simplified forward pass has 4 steps: + # 1. Get prompts + # 2. Get embeddings from inputs. + # 3. Concat prompt and inputs + # 4. Forward prop inputs_embeds to get the features. + inputs_embeds = self.word_embeddings(x["input_ids"]) + if prompt.size(0) != inputs_embeds.size(0): + prompt = prompt.unsqueeze(0).expand( + inputs_embeds.size(0), -1, -1 + ) # Expand one prompt to batch size + inputs_embeds = torch.cat((prompt, inputs_embeds), dim=1) + return self.transformer.get_features({"inputs_embeds": inputs_embeds}) + else: + patch_embeddings = self.transformer.get_submodule("_backbone.embeddings")(x) + if prompt.size(0) != x.size(0): + prompt = prompt.unsqueeze(0).expand( + x.size(0), -1, -1 + ) # Expand one prompt to batch size# Expand one prompt to batch size + input_concat_prompt = torch.cat([patch_embeddings, prompt], dim=1) + + encoded_features = self.transformer.get_submodule("_backbone.encoder")( + input_concat_prompt, return_dict=False + )[0] + encoded_features = self.transformer.get_submodule("_backbone.layernorm")( + encoded_features + ) + return encoded_features[:, 0, :] if cls_feat else encoded_features + + class LearningToPromptTransformer(RenateBenchmarkingModule): """ Implements the vision transformer with prompt pool described in @@ -166,34 +283,22 @@ def __init__( prompt_embedding_features: str = "cls", patch_pooler: str = "prompt_mean", ) -> None: - if "vit" in pretrained_model_name_or_path: - transformer = VisionTransformer( - pretrained_model_name_or_path=pretrained_model_name_or_path, - image_size=image_size, - patch_size=patch_size, - num_layers=num_layers, - num_heads=num_heads, - hidden_dim=hidden_dim, - mlp_dim=mlp_dim, - dropout=dropout, - attention_dropout=attention_dropout, - prediction_strategy=prediction_strategy, - add_icarl_class_means=add_icarl_class_means, - num_outputs=num_outputs, - ) - self._is_text_transformer = False - else: - transformer = HuggingFaceSequenceClassificationTransformer( - pretrained_model_name_or_path=pretrained_model_name_or_path, - prediction_strategy=prediction_strategy, - add_icarl_class_means=add_icarl_class_means, - num_outputs=num_outputs, - ) - - self._is_text_transformer = True - transformer._tasks_params.clear() + transformer = PromptedTransformer( + pretrained_model_name_or_path=pretrained_model_name_or_path, + image_size=image_size, + patch_size=patch_size, + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + dropout=dropout, + attention_dropout=attention_dropout, + num_outputs=num_outputs, + add_icarl_class_means=add_icarl_class_means, + prediction_strategy=prediction_strategy, + ) prompter = PromptPool( - embedding_dim=transformer._embedding_size, + embedding_dim=transformer.transformer._embedding_size, pool_size=pool_size, pool_selection_size=pool_selection_size, prompt_size=prompt_size, @@ -204,10 +309,10 @@ def __init__( ) super().__init__( - embedding_size=transformer._embedding_size, + embedding_size=transformer.transformer._embedding_size, num_outputs=num_outputs, constructor_arguments=dict( - **transformer._constructor_arguments, + **transformer.transformer._constructor_arguments, pool_size=pool_size, pool_selection_size=pool_selection_size, prompt_size=prompt_size, @@ -221,6 +326,7 @@ def __init__( ) self._backbone = nn.ModuleDict({"transformer": transformer, "prompter": prompter}) + self._is_text_transformer = transformer._is_text_transformer self.prompt_embedding_features = prompt_embedding_features self.patch_pooler = patch_pooler self.similarity_score: Optional[torch.Tensor] = None @@ -236,20 +342,9 @@ def __init__( "prompt_mean", ], f"Invalid method to extract prompt embedding features. Got {patch_pooler}" - for n, p in self._backbone["transformer"].named_parameters(): - p.requires_grad = False - self._backbone["transformer"].eval() for p in self._backbone["prompter"].parameters(): p.requires_grad = True - if self._is_text_transformer: - ## This is to find the Embedding layer. - for named_param, value in self._backbone["transformer"].named_parameters(): - if value.shape[0] == self._backbone["transformer"]._backbone.config.vocab_size: - self.word_embeddings = self._backbone["transformer"].get_submodule( - named_param.replace(".weight", "") - ) - break # The backbone's forward is monkey-patched to allow the parent class' forward to work # without any manual management. self._backbone.forward = self.forward_for_monkey_patching @@ -257,10 +352,9 @@ def __init__( def forward_for_monkey_patching( self, x: torch.Tensor, task_id: str = defaults.TASK_ID ) -> torch.Tensor: + with torch.no_grad(): + prompt_pool_input = self._backbone["transformer"](x, cls_feat=False) if not self._is_text_transformer: - # The vision transformer code is manual strapping in. - with torch.no_grad(): - prompt_pool_input = self._backbone["transformer"].get_features(x, cls_feat=False) if self.prompt_embedding_features == "cls": # retrieve cls token features. This is used in L2P paper. prompt_pool_input = prompt_pool_input[:, 0, :] @@ -268,24 +362,12 @@ def forward_for_monkey_patching( # compute mean patch features. prompt_pool_input = prompt_pool_input[:, 1:, :].mean(1) # Compute the prompts to be stacked - prompts, prompt_similarity = self._backbone["prompter"](prompt_pool_input) - # compute patch embeddings - patch_embeddings = self._backbone["transformer"].get_submodule("_backbone.embeddings")( - x - ) - # concatenate both. - input_concat_prompt = torch.cat([patch_embeddings, prompts], dim=1) - ## rest of processing. this code is part of the ViTModel class in HF Transformers. - encoded_features = self._backbone["transformer"].get_submodule("_backbone.encoder")( - input_concat_prompt, return_dict=False - )[0] - encoded_features = self._backbone["transformer"].get_submodule("_backbone.layernorm")( - encoded_features - ) - - ## Save similarity - self.similarity_score = prompt_similarity - + prompts, prompt_similarity = self._backbone["prompter"](prompt_pool_input) + self.similarity_score = prompt_similarity + encoded_features = self._backbone["transformer"](x, prompts, cls_feat=False) + if self._is_text_transformer: + return encoded_features + else: if self.patch_pooler == "cls": seq_cls_token = encoded_features[:, 0, :] elif self.patch_pooler == "mean": @@ -294,19 +376,3 @@ def forward_for_monkey_patching( num_prompts = prompts.size(1) seq_cls_token = encoded_features[:, -num_prompts:, :].mean(1) return seq_cls_token - - else: - ## The implicit assumption here is that x for text transformers is the input_ids. - # This simplified forward pass has 4 steps: - # 1. Get prompts - # 2. Get embeddings from inputs. - # 3. Concat prompt and inputs - # 4. Forward prop inputs_embeds to get the features. The forward of the RenateBM applies - # the classifier and gets logits. - with torch.no_grad(): - prompt_pool_input = self._backbone["transformer"].get_features(x) - prompts, prompt_similarity = self._backbone["prompter"](prompt_pool_input) # 1 - self.similarity_score = prompt_similarity - inputs_embeds = self.word_embeddings(x["input_ids"]) # 2 - inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1) # 3 - return self._backbone["transformer"].get_features({"inputs_embeds": inputs_embeds}) # 4 diff --git a/src/renate/updaters/experimental/l2p.py b/src/renate/updaters/experimental/l2p.py index 92044813..ceed6ecc 100644 --- a/src/renate/updaters/experimental/l2p.py +++ b/src/renate/updaters/experimental/l2p.py @@ -48,7 +48,7 @@ def __init__( **kwargs, ) self.prompt_sim_loss_weight = prompt_sim_loss_weight - self._loss_collections["train_losses"].update({"key_sim_loss": torchmetrics.MeanMetric()}) + self._loss_collections["train_losses"].update({"key_sim": torchmetrics.MeanMetric()}) def training_step( self, batch: Tuple[NestedTensors, torch.Tensor], batch_idx: int diff --git a/test/renate/benchmark/models/test_l2p.py b/test/renate/benchmark/models/test_l2p.py index a52b0a3b..44e0be0b 100644 --- a/test/renate/benchmark/models/test_l2p.py +++ b/test/renate/benchmark/models/test_l2p.py @@ -3,7 +3,7 @@ import pytest import torch -from renate.benchmark.models.l2p import LearningToPromptTransformer, PromptPool +from renate.benchmark.models.l2p import LearningToPromptTransformer, PromptPool, PromptedTransformer def test_prompt_pool(): @@ -62,3 +62,27 @@ def test_prompt_vision_transformer_trainable_parameters(backbone, num_trainable_ model = LearningToPromptTransformer(pretrained_model_name_or_path=backbone) n = sum(1 for x in model.parameters() if x.requires_grad) assert n == num_trainable_params + + +@pytest.mark.parametrize("backbone", ["google/vit-base-patch16-224", "bert-base-uncased"]) +@pytest.mark.parametrize("prompt", [None, torch.rand(3, 10, 768)]) +@pytest.mark.parametrize("cls_feat", [True, False]) +def test_prompted_transformer(backbone, prompt, cls_feat): + model = PromptedTransformer( + pretrained_model_name_or_path=backbone, + num_outputs=10, + prediction_strategy=None, + add_icarl_class_means=False, + ) + + B, P_len, _ = prompt.shape if prompt is not None else (5, 0, 0) + if "vit" in backbone: + # we are doing ViT. + inputs = torch.rand(B, 3, 224, 224) + expected_output_size = (B, 197 + P_len, 768) if not cls_feat else (B, 768) + else: + inputs = {"input_ids": torch.randint(0, 10000, (B, 128))} + expected_output_size = (B, 768) + + out = model(inputs, prompt, cls_feat=cls_feat) + assert out.shape == expected_output_size From d8ef1c81da1e7e1e714a7ef96ab8beb775f0bad4 Mon Sep 17 00:00:00 2001 From: Prabhu Teja Date: Tue, 19 Sep 2023 08:57:19 +0000 Subject: [PATCH 2/3] un-nesting if --- src/renate/benchmark/models/l2p.py | 57 +++++++++++++++--------------- 1 file changed, 28 insertions(+), 29 deletions(-) diff --git a/src/renate/benchmark/models/l2p.py b/src/renate/benchmark/models/l2p.py index 0c2ac001..f3a4588a 100644 --- a/src/renate/benchmark/models/l2p.py +++ b/src/renate/benchmark/models/l2p.py @@ -193,36 +193,35 @@ def forward( else self.transformer.get_features(x, cls_feat=cls_feat) ) # text transformers dont support cls_feat. + elif self._is_text_transformer: + # The implicit assumption here is that x for text transformers is the input_ids. + # This simplified forward pass has 4 steps: + # 1. Get prompts + # 2. Get embeddings from inputs. + # 3. Concat prompt and inputs + # 4. Forward prop inputs_embeds to get the features. + inputs_embeds = self.word_embeddings(x["input_ids"]) + if prompt.size(0) != inputs_embeds.size(0): + prompt = prompt.unsqueeze(0).expand( + inputs_embeds.size(0), -1, -1 + ) # Expand one prompt to batch size + inputs_embeds = torch.cat((prompt, inputs_embeds), dim=1) + return self.transformer.get_features({"inputs_embeds": inputs_embeds}) else: - if self._is_text_transformer: - # The implicit assumption here is that x for text transformers is the input_ids. - # This simplified forward pass has 4 steps: - # 1. Get prompts - # 2. Get embeddings from inputs. - # 3. Concat prompt and inputs - # 4. Forward prop inputs_embeds to get the features. - inputs_embeds = self.word_embeddings(x["input_ids"]) - if prompt.size(0) != inputs_embeds.size(0): - prompt = prompt.unsqueeze(0).expand( - inputs_embeds.size(0), -1, -1 - ) # Expand one prompt to batch size - inputs_embeds = torch.cat((prompt, inputs_embeds), dim=1) - return self.transformer.get_features({"inputs_embeds": inputs_embeds}) - else: - patch_embeddings = self.transformer.get_submodule("_backbone.embeddings")(x) - if prompt.size(0) != x.size(0): - prompt = prompt.unsqueeze(0).expand( - x.size(0), -1, -1 - ) # Expand one prompt to batch size# Expand one prompt to batch size - input_concat_prompt = torch.cat([patch_embeddings, prompt], dim=1) - - encoded_features = self.transformer.get_submodule("_backbone.encoder")( - input_concat_prompt, return_dict=False - )[0] - encoded_features = self.transformer.get_submodule("_backbone.layernorm")( - encoded_features - ) - return encoded_features[:, 0, :] if cls_feat else encoded_features + patch_embeddings = self.transformer.get_submodule("_backbone.embeddings")(x) + if prompt.size(0) != x.size(0): + prompt = prompt.unsqueeze(0).expand( + x.size(0), -1, -1 + ) # Expand one prompt to batch size# Expand one prompt to batch size + input_concat_prompt = torch.cat([patch_embeddings, prompt], dim=1) + + encoded_features = self.transformer.get_submodule("_backbone.encoder")( + input_concat_prompt, return_dict=False + )[0] + encoded_features = self.transformer.get_submodule("_backbone.layernorm")( + encoded_features + ) + return encoded_features[:, 0, :] if cls_feat else encoded_features class LearningToPromptTransformer(RenateBenchmarkingModule): From 3896135f37f3e71f424265c36d75025819986c8b Mon Sep 17 00:00:00 2001 From: Prabhu Teja Date: Tue, 19 Sep 2023 09:23:02 +0000 Subject: [PATCH 3/3] comments, cleanup --- src/renate/benchmark/models/l2p.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/renate/benchmark/models/l2p.py b/src/renate/benchmark/models/l2p.py index f3a4588a..175eb3e8 100644 --- a/src/renate/benchmark/models/l2p.py +++ b/src/renate/benchmark/models/l2p.py @@ -110,8 +110,9 @@ def forward(self, x: torch.Tensor, manual_prompt_indices: Optional[torch.LongTen class PromptedTransformer(nn.Module): """This generic module is the basic prompted transformer. It takes in a model string and creates - the appropriate transformer. If not prompted, it returns features, and if prompted, it returns - the full feature using those prompts and the input image/text. + the appropriate transformer (ViT or Text transformer). If no prompts are provided in the forward + call, image/text features are returned. If a prompt is provided, it is concatenated to the + embedding layer output and the resultant features are returned. Args: pretrained_model_name_or_path: A string that denotes which pretrained model from the HF hub @@ -154,7 +155,7 @@ def __init__( prediction_strategy=prediction_strategy, add_icarl_class_means=add_icarl_class_means, ) - self._is_text_transformer = False + self.is_text_transformer = False else: self.transformer = HuggingFaceSequenceClassificationTransformer( pretrained_model_name_or_path=pretrained_model_name_or_path, @@ -169,7 +170,7 @@ def __init__( ) break - self._is_text_transformer = True + self.is_text_transformer = True self.transformer._tasks_params.clear() self.transformer.eval() @@ -189,11 +190,11 @@ def forward( if prompt is None: return ( self.transformer.get_features(x) - if self._is_text_transformer + if self.is_text_transformer else self.transformer.get_features(x, cls_feat=cls_feat) ) # text transformers dont support cls_feat. - elif self._is_text_transformer: + elif self.is_text_transformer: # The implicit assumption here is that x for text transformers is the input_ids. # This simplified forward pass has 4 steps: # 1. Get prompts @@ -325,7 +326,7 @@ def __init__( ) self._backbone = nn.ModuleDict({"transformer": transformer, "prompter": prompter}) - self._is_text_transformer = transformer._is_text_transformer + self._is_text_transformer = transformer.is_text_transformer self.prompt_embedding_features = prompt_embedding_features self.patch_pooler = patch_pooler self.similarity_score: Optional[torch.Tensor] = None