Skip to content

Commit

Permalink
Abstracting prompting transformer for use in L2P and S-Prompt (#420)
Browse files Browse the repository at this point in the history
  • Loading branch information
prabhuteja12 committed Sep 19, 2023
1 parent 5bc82b9 commit 864c1a4
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 79 deletions.
220 changes: 143 additions & 77 deletions src/renate/benchmark/models/l2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,123 @@ def forward(self, x: torch.Tensor, manual_prompt_indices: Optional[torch.LongTen
return selected_prompts, loss_value


class PromptedTransformer(nn.Module):
"""This generic module is the basic prompted transformer. It takes in a model string and creates
the appropriate transformer (ViT or Text transformer). If no prompts are provided in the forward
call, image/text features are returned. If a prompt is provided, it is concatenated to the
embedding layer output and the resultant features are returned.
Args:
pretrained_model_name_or_path: A string that denotes which pretrained model from the HF hub
to use.
num_outputs: Size of the output.
prediction_strategy: Continual learning strategies may alter the prediction at train or test
time.
add_icarl_class_means: If ``True``, additional parameters used only by the
``ICaRLModelUpdater`` are added. Only required when using that updater.
"""

def __init__(
self,
pretrained_model_name_or_path="google/vit-base-patch16-224",
image_size: int = 32,
patch_size: int = 4,
num_layers: int = 12,
num_heads: int = 12,
hidden_dim: int = 768,
mlp_dim: int = 3072,
dropout: float = 0.1,
attention_dropout: float = 0.1,
num_outputs: int = 10,
prediction_strategy: Optional[PredictionStrategy] = None,
add_icarl_class_means: bool = True,
) -> None:
super().__init__()
if "vit" in pretrained_model_name_or_path:
self.transformer = VisionTransformer(
pretrained_model_name_or_path=pretrained_model_name_or_path,
image_size=image_size,
patch_size=patch_size,
num_layers=num_layers,
num_heads=num_heads,
hidden_dim=hidden_dim,
mlp_dim=mlp_dim,
dropout=dropout,
attention_dropout=attention_dropout,
num_outputs=num_outputs,
prediction_strategy=prediction_strategy,
add_icarl_class_means=add_icarl_class_means,
)
self.is_text_transformer = False
else:
self.transformer = HuggingFaceSequenceClassificationTransformer(
pretrained_model_name_or_path=pretrained_model_name_or_path,
num_outputs=num_outputs,
prediction_strategy=prediction_strategy,
add_icarl_class_means=add_icarl_class_means,
)
for named_param, value in self.transformer.named_parameters():
if value.shape[0] == self.transformer._backbone.config.vocab_size:
self.word_embeddings = self.transformer.get_submodule(
named_param.replace(".weight", "")
)
break

self.is_text_transformer = True

self.transformer._tasks_params.clear()
self.transformer.eval()
for p in self.transformer.parameters():
p.requires_grad_(False)

def forward(
self, x: torch.Tensor, prompt: Optional[torch.Tensor] = None, cls_feat: bool = True
) -> torch.Tensor:
"""
Args:
x: Input torch tensor.
prompt: Prompt tensor. Defaults to None.
cls_feat: Whether to extract [CLS] token or to return full feature tensor.
Ignored for text transformer. Defaults to True.
"""
if prompt is None:
return (
self.transformer.get_features(x)
if self.is_text_transformer
else self.transformer.get_features(x, cls_feat=cls_feat)
)
# text transformers dont support cls_feat.
elif self.is_text_transformer:
# The implicit assumption here is that x for text transformers is the input_ids.
# This simplified forward pass has 4 steps:
# 1. Get prompts
# 2. Get embeddings from inputs.
# 3. Concat prompt and inputs
# 4. Forward prop inputs_embeds to get the features.
inputs_embeds = self.word_embeddings(x["input_ids"])
if prompt.size(0) != inputs_embeds.size(0):
prompt = prompt.unsqueeze(0).expand(
inputs_embeds.size(0), -1, -1
) # Expand one prompt to batch size
inputs_embeds = torch.cat((prompt, inputs_embeds), dim=1)
return self.transformer.get_features({"inputs_embeds": inputs_embeds})
else:
patch_embeddings = self.transformer.get_submodule("_backbone.embeddings")(x)
if prompt.size(0) != x.size(0):
prompt = prompt.unsqueeze(0).expand(
x.size(0), -1, -1
) # Expand one prompt to batch size# Expand one prompt to batch size
input_concat_prompt = torch.cat([patch_embeddings, prompt], dim=1)

encoded_features = self.transformer.get_submodule("_backbone.encoder")(
input_concat_prompt, return_dict=False
)[0]
encoded_features = self.transformer.get_submodule("_backbone.layernorm")(
encoded_features
)
return encoded_features[:, 0, :] if cls_feat else encoded_features


class LearningToPromptTransformer(RenateBenchmarkingModule):
"""
Implements the vision transformer with prompt pool described in
Expand Down Expand Up @@ -166,34 +283,22 @@ def __init__(
prompt_embedding_features: str = "cls",
patch_pooler: str = "prompt_mean",
) -> None:
if "vit" in pretrained_model_name_or_path:
transformer = VisionTransformer(
pretrained_model_name_or_path=pretrained_model_name_or_path,
image_size=image_size,
patch_size=patch_size,
num_layers=num_layers,
num_heads=num_heads,
hidden_dim=hidden_dim,
mlp_dim=mlp_dim,
dropout=dropout,
attention_dropout=attention_dropout,
prediction_strategy=prediction_strategy,
add_icarl_class_means=add_icarl_class_means,
num_outputs=num_outputs,
)
self._is_text_transformer = False
else:
transformer = HuggingFaceSequenceClassificationTransformer(
pretrained_model_name_or_path=pretrained_model_name_or_path,
prediction_strategy=prediction_strategy,
add_icarl_class_means=add_icarl_class_means,
num_outputs=num_outputs,
)

self._is_text_transformer = True
transformer._tasks_params.clear()
transformer = PromptedTransformer(
pretrained_model_name_or_path=pretrained_model_name_or_path,
image_size=image_size,
patch_size=patch_size,
num_layers=num_layers,
num_heads=num_heads,
hidden_dim=hidden_dim,
mlp_dim=mlp_dim,
dropout=dropout,
attention_dropout=attention_dropout,
num_outputs=num_outputs,
add_icarl_class_means=add_icarl_class_means,
prediction_strategy=prediction_strategy,
)
prompter = PromptPool(
embedding_dim=transformer._embedding_size,
embedding_dim=transformer.transformer._embedding_size,
pool_size=pool_size,
pool_selection_size=pool_selection_size,
prompt_size=prompt_size,
Expand All @@ -204,10 +309,10 @@ def __init__(
)

super().__init__(
embedding_size=transformer._embedding_size,
embedding_size=transformer.transformer._embedding_size,
num_outputs=num_outputs,
constructor_arguments=dict(
**transformer._constructor_arguments,
**transformer.transformer._constructor_arguments,
pool_size=pool_size,
pool_selection_size=pool_selection_size,
prompt_size=prompt_size,
Expand All @@ -221,6 +326,7 @@ def __init__(
)

self._backbone = nn.ModuleDict({"transformer": transformer, "prompter": prompter})
self._is_text_transformer = transformer.is_text_transformer
self.prompt_embedding_features = prompt_embedding_features
self.patch_pooler = patch_pooler
self.similarity_score: Optional[torch.Tensor] = None
Expand All @@ -236,56 +342,32 @@ def __init__(
"prompt_mean",
], f"Invalid method to extract prompt embedding features. Got {patch_pooler}"

for n, p in self._backbone["transformer"].named_parameters():
p.requires_grad = False
self._backbone["transformer"].eval()
for p in self._backbone["prompter"].parameters():
p.requires_grad = True

if self._is_text_transformer:
## This is to find the Embedding layer.
for named_param, value in self._backbone["transformer"].named_parameters():
if value.shape[0] == self._backbone["transformer"]._backbone.config.vocab_size:
self.word_embeddings = self._backbone["transformer"].get_submodule(
named_param.replace(".weight", "")
)
break
# The backbone's forward is monkey-patched to allow the parent class' forward to work
# without any manual management.
self._backbone.forward = self.forward_for_monkey_patching

def forward_for_monkey_patching(
self, x: torch.Tensor, task_id: str = defaults.TASK_ID
) -> torch.Tensor:
with torch.no_grad():
prompt_pool_input = self._backbone["transformer"](x, cls_feat=False)
if not self._is_text_transformer:
# The vision transformer code is manual strapping in.
with torch.no_grad():
prompt_pool_input = self._backbone["transformer"].get_features(x, cls_feat=False)
if self.prompt_embedding_features == "cls":
# retrieve cls token features. This is used in L2P paper.
prompt_pool_input = prompt_pool_input[:, 0, :]
elif self.prompt_embedding_features == "mean":
# compute mean patch features.
prompt_pool_input = prompt_pool_input[:, 1:, :].mean(1)
# Compute the prompts to be stacked
prompts, prompt_similarity = self._backbone["prompter"](prompt_pool_input)
# compute patch embeddings
patch_embeddings = self._backbone["transformer"].get_submodule("_backbone.embeddings")(
x
)
# concatenate both.
input_concat_prompt = torch.cat([patch_embeddings, prompts], dim=1)
## rest of processing. this code is part of the ViTModel class in HF Transformers.
encoded_features = self._backbone["transformer"].get_submodule("_backbone.encoder")(
input_concat_prompt, return_dict=False
)[0]
encoded_features = self._backbone["transformer"].get_submodule("_backbone.layernorm")(
encoded_features
)

## Save similarity
self.similarity_score = prompt_similarity

prompts, prompt_similarity = self._backbone["prompter"](prompt_pool_input)
self.similarity_score = prompt_similarity
encoded_features = self._backbone["transformer"](x, prompts, cls_feat=False)
if self._is_text_transformer:
return encoded_features
else:
if self.patch_pooler == "cls":
seq_cls_token = encoded_features[:, 0, :]
elif self.patch_pooler == "mean":
Expand All @@ -294,19 +376,3 @@ def forward_for_monkey_patching(
num_prompts = prompts.size(1)
seq_cls_token = encoded_features[:, -num_prompts:, :].mean(1)
return seq_cls_token

else:
## The implicit assumption here is that x for text transformers is the input_ids.
# This simplified forward pass has 4 steps:
# 1. Get prompts
# 2. Get embeddings from inputs.
# 3. Concat prompt and inputs
# 4. Forward prop inputs_embeds to get the features. The forward of the RenateBM applies
# the classifier and gets logits.
with torch.no_grad():
prompt_pool_input = self._backbone["transformer"].get_features(x)
prompts, prompt_similarity = self._backbone["prompter"](prompt_pool_input) # 1
self.similarity_score = prompt_similarity
inputs_embeds = self.word_embeddings(x["input_ids"]) # 2
inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1) # 3
return self._backbone["transformer"].get_features({"inputs_embeds": inputs_embeds}) # 4
2 changes: 1 addition & 1 deletion src/renate/updaters/experimental/l2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
**kwargs,
)
self.prompt_sim_loss_weight = prompt_sim_loss_weight
self._loss_collections["train_losses"].update({"key_sim_loss": torchmetrics.MeanMetric()})
self._loss_collections["train_losses"].update({"key_sim": torchmetrics.MeanMetric()})

def training_step(
self, batch: Tuple[NestedTensors, torch.Tensor], batch_idx: int
Expand Down
26 changes: 25 additions & 1 deletion test/renate/benchmark/models/test_l2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import torch

from renate.benchmark.models.l2p import LearningToPromptTransformer, PromptPool
from renate.benchmark.models.l2p import LearningToPromptTransformer, PromptPool, PromptedTransformer


def test_prompt_pool():
Expand Down Expand Up @@ -62,3 +62,27 @@ def test_prompt_vision_transformer_trainable_parameters(backbone, num_trainable_
model = LearningToPromptTransformer(pretrained_model_name_or_path=backbone)
n = sum(1 for x in model.parameters() if x.requires_grad)
assert n == num_trainable_params


@pytest.mark.parametrize("backbone", ["google/vit-base-patch16-224", "bert-base-uncased"])
@pytest.mark.parametrize("prompt", [None, torch.rand(3, 10, 768)])
@pytest.mark.parametrize("cls_feat", [True, False])
def test_prompted_transformer(backbone, prompt, cls_feat):
model = PromptedTransformer(
pretrained_model_name_or_path=backbone,
num_outputs=10,
prediction_strategy=None,
add_icarl_class_means=False,
)

B, P_len, _ = prompt.shape if prompt is not None else (5, 0, 0)
if "vit" in backbone:
# we are doing ViT.
inputs = torch.rand(B, 3, 224, 224)
expected_output_size = (B, 197 + P_len, 768) if not cls_feat else (B, 768)
else:
inputs = {"input_ids": torch.randint(0, 10000, (B, 128))}
expected_output_size = (B, 768)

out = model(inputs, prompt, cls_feat=cls_feat)
assert out.shape == expected_output_size

0 comments on commit 864c1a4

Please sign in to comment.