Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Abstracting prompting transformer for use in L2P and S-Prompt #420

Merged
merged 3 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 143 additions & 77 deletions src/renate/benchmark/models/l2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,123 @@ def forward(self, x: torch.Tensor, manual_prompt_indices: Optional[torch.LongTen
return selected_prompts, loss_value


class PromptedTransformer(nn.Module):
"""This generic module is the basic prompted transformer. It takes in a model string and creates
the appropriate transformer (ViT or Text transformer). If no prompts are provided in the forward
call, image/text features are returned. If a prompt is provided, it is concatenated to the
embedding layer output and the resultant features are returned.

Args:
pretrained_model_name_or_path: A string that denotes which pretrained model from the HF hub
to use.
num_outputs: Size of the output.
prediction_strategy: Continual learning strategies may alter the prediction at train or test
time.
add_icarl_class_means: If ``True``, additional parameters used only by the
``ICaRLModelUpdater`` are added. Only required when using that updater.
"""

def __init__(
self,
pretrained_model_name_or_path="google/vit-base-patch16-224",
image_size: int = 32,
patch_size: int = 4,
num_layers: int = 12,
num_heads: int = 12,
hidden_dim: int = 768,
mlp_dim: int = 3072,
dropout: float = 0.1,
attention_dropout: float = 0.1,
num_outputs: int = 10,
prediction_strategy: Optional[PredictionStrategy] = None,
add_icarl_class_means: bool = True,
) -> None:
super().__init__()
if "vit" in pretrained_model_name_or_path:
self.transformer = VisionTransformer(
pretrained_model_name_or_path=pretrained_model_name_or_path,
image_size=image_size,
patch_size=patch_size,
num_layers=num_layers,
num_heads=num_heads,
hidden_dim=hidden_dim,
mlp_dim=mlp_dim,
dropout=dropout,
attention_dropout=attention_dropout,
num_outputs=num_outputs,
prediction_strategy=prediction_strategy,
add_icarl_class_means=add_icarl_class_means,
)
self.is_text_transformer = False
else:
self.transformer = HuggingFaceSequenceClassificationTransformer(
pretrained_model_name_or_path=pretrained_model_name_or_path,
num_outputs=num_outputs,
prediction_strategy=prediction_strategy,
add_icarl_class_means=add_icarl_class_means,
)
for named_param, value in self.transformer.named_parameters():
if value.shape[0] == self.transformer._backbone.config.vocab_size:
self.word_embeddings = self.transformer.get_submodule(
named_param.replace(".weight", "")
)
break

self.is_text_transformer = True

self.transformer._tasks_params.clear()
self.transformer.eval()
for p in self.transformer.parameters():
p.requires_grad_(False)

def forward(
self, x: torch.Tensor, prompt: Optional[torch.Tensor] = None, cls_feat: bool = True
) -> torch.Tensor:
"""
Args:
x: Input torch tensor.
prompt: Prompt tensor. Defaults to None.
cls_feat: Whether to extract [CLS] token or to return full feature tensor.
Ignored for text transformer. Defaults to True.
"""
if prompt is None:
return (
self.transformer.get_features(x)
if self.is_text_transformer
else self.transformer.get_features(x, cls_feat=cls_feat)
)
# text transformers dont support cls_feat.
elif self.is_text_transformer:
# The implicit assumption here is that x for text transformers is the input_ids.
# This simplified forward pass has 4 steps:
# 1. Get prompts
# 2. Get embeddings from inputs.
# 3. Concat prompt and inputs
# 4. Forward prop inputs_embeds to get the features.
inputs_embeds = self.word_embeddings(x["input_ids"])
if prompt.size(0) != inputs_embeds.size(0):
prompt = prompt.unsqueeze(0).expand(
inputs_embeds.size(0), -1, -1
) # Expand one prompt to batch size
inputs_embeds = torch.cat((prompt, inputs_embeds), dim=1)
return self.transformer.get_features({"inputs_embeds": inputs_embeds})
else:
patch_embeddings = self.transformer.get_submodule("_backbone.embeddings")(x)
if prompt.size(0) != x.size(0):
prompt = prompt.unsqueeze(0).expand(
x.size(0), -1, -1
) # Expand one prompt to batch size# Expand one prompt to batch size
input_concat_prompt = torch.cat([patch_embeddings, prompt], dim=1)

encoded_features = self.transformer.get_submodule("_backbone.encoder")(
input_concat_prompt, return_dict=False
)[0]
encoded_features = self.transformer.get_submodule("_backbone.layernorm")(
encoded_features
)
return encoded_features[:, 0, :] if cls_feat else encoded_features


class LearningToPromptTransformer(RenateBenchmarkingModule):
"""
Implements the vision transformer with prompt pool described in
Expand Down Expand Up @@ -166,34 +283,22 @@ def __init__(
prompt_embedding_features: str = "cls",
patch_pooler: str = "prompt_mean",
) -> None:
if "vit" in pretrained_model_name_or_path:
transformer = VisionTransformer(
pretrained_model_name_or_path=pretrained_model_name_or_path,
image_size=image_size,
patch_size=patch_size,
num_layers=num_layers,
num_heads=num_heads,
hidden_dim=hidden_dim,
mlp_dim=mlp_dim,
dropout=dropout,
attention_dropout=attention_dropout,
prediction_strategy=prediction_strategy,
add_icarl_class_means=add_icarl_class_means,
num_outputs=num_outputs,
)
self._is_text_transformer = False
else:
transformer = HuggingFaceSequenceClassificationTransformer(
pretrained_model_name_or_path=pretrained_model_name_or_path,
prediction_strategy=prediction_strategy,
add_icarl_class_means=add_icarl_class_means,
num_outputs=num_outputs,
)

self._is_text_transformer = True
transformer._tasks_params.clear()
transformer = PromptedTransformer(
pretrained_model_name_or_path=pretrained_model_name_or_path,
image_size=image_size,
patch_size=patch_size,
num_layers=num_layers,
num_heads=num_heads,
hidden_dim=hidden_dim,
mlp_dim=mlp_dim,
dropout=dropout,
attention_dropout=attention_dropout,
num_outputs=num_outputs,
add_icarl_class_means=add_icarl_class_means,
prediction_strategy=prediction_strategy,
)
prompter = PromptPool(
embedding_dim=transformer._embedding_size,
embedding_dim=transformer.transformer._embedding_size,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the following, we access a lot of protected attributes of the transformer. do we want to keep it that way or rather make them public?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 3896135

pool_size=pool_size,
pool_selection_size=pool_selection_size,
prompt_size=prompt_size,
Expand All @@ -204,10 +309,10 @@ def __init__(
)

super().__init__(
embedding_size=transformer._embedding_size,
embedding_size=transformer.transformer._embedding_size,
num_outputs=num_outputs,
constructor_arguments=dict(
**transformer._constructor_arguments,
**transformer.transformer._constructor_arguments,
pool_size=pool_size,
pool_selection_size=pool_selection_size,
prompt_size=prompt_size,
Expand All @@ -221,6 +326,7 @@ def __init__(
)

self._backbone = nn.ModuleDict({"transformer": transformer, "prompter": prompter})
self._is_text_transformer = transformer.is_text_transformer
self.prompt_embedding_features = prompt_embedding_features
self.patch_pooler = patch_pooler
self.similarity_score: Optional[torch.Tensor] = None
Expand All @@ -236,56 +342,32 @@ def __init__(
"prompt_mean",
], f"Invalid method to extract prompt embedding features. Got {patch_pooler}"

for n, p in self._backbone["transformer"].named_parameters():
p.requires_grad = False
self._backbone["transformer"].eval()
for p in self._backbone["prompter"].parameters():
p.requires_grad = True

if self._is_text_transformer:
## This is to find the Embedding layer.
for named_param, value in self._backbone["transformer"].named_parameters():
if value.shape[0] == self._backbone["transformer"]._backbone.config.vocab_size:
self.word_embeddings = self._backbone["transformer"].get_submodule(
named_param.replace(".weight", "")
)
break
# The backbone's forward is monkey-patched to allow the parent class' forward to work
# without any manual management.
self._backbone.forward = self.forward_for_monkey_patching

def forward_for_monkey_patching(
self, x: torch.Tensor, task_id: str = defaults.TASK_ID
) -> torch.Tensor:
with torch.no_grad():
prompt_pool_input = self._backbone["transformer"](x, cls_feat=False)
if not self._is_text_transformer:
# The vision transformer code is manual strapping in.
with torch.no_grad():
prompt_pool_input = self._backbone["transformer"].get_features(x, cls_feat=False)
if self.prompt_embedding_features == "cls":
# retrieve cls token features. This is used in L2P paper.
prompt_pool_input = prompt_pool_input[:, 0, :]
elif self.prompt_embedding_features == "mean":
# compute mean patch features.
prompt_pool_input = prompt_pool_input[:, 1:, :].mean(1)
# Compute the prompts to be stacked
prompts, prompt_similarity = self._backbone["prompter"](prompt_pool_input)
# compute patch embeddings
patch_embeddings = self._backbone["transformer"].get_submodule("_backbone.embeddings")(
x
)
# concatenate both.
input_concat_prompt = torch.cat([patch_embeddings, prompts], dim=1)
## rest of processing. this code is part of the ViTModel class in HF Transformers.
encoded_features = self._backbone["transformer"].get_submodule("_backbone.encoder")(
input_concat_prompt, return_dict=False
)[0]
encoded_features = self._backbone["transformer"].get_submodule("_backbone.layernorm")(
encoded_features
)

## Save similarity
self.similarity_score = prompt_similarity

prompts, prompt_similarity = self._backbone["prompter"](prompt_pool_input)
self.similarity_score = prompt_similarity
encoded_features = self._backbone["transformer"](x, prompts, cls_feat=False)
if self._is_text_transformer:
return encoded_features
else:
if self.patch_pooler == "cls":
seq_cls_token = encoded_features[:, 0, :]
elif self.patch_pooler == "mean":
Expand All @@ -294,19 +376,3 @@ def forward_for_monkey_patching(
num_prompts = prompts.size(1)
seq_cls_token = encoded_features[:, -num_prompts:, :].mean(1)
return seq_cls_token

else:
## The implicit assumption here is that x for text transformers is the input_ids.
# This simplified forward pass has 4 steps:
# 1. Get prompts
# 2. Get embeddings from inputs.
# 3. Concat prompt and inputs
# 4. Forward prop inputs_embeds to get the features. The forward of the RenateBM applies
# the classifier and gets logits.
with torch.no_grad():
prompt_pool_input = self._backbone["transformer"].get_features(x)
prompts, prompt_similarity = self._backbone["prompter"](prompt_pool_input) # 1
self.similarity_score = prompt_similarity
inputs_embeds = self.word_embeddings(x["input_ids"]) # 2
inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1) # 3
return self._backbone["transformer"].get_features({"inputs_embeds": inputs_embeds}) # 4
2 changes: 1 addition & 1 deletion src/renate/updaters/experimental/l2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
**kwargs,
)
self.prompt_sim_loss_weight = prompt_sim_loss_weight
self._loss_collections["train_losses"].update({"key_sim_loss": torchmetrics.MeanMetric()})
self._loss_collections["train_losses"].update({"key_sim": torchmetrics.MeanMetric()})

def training_step(
self, batch: Tuple[NestedTensors, torch.Tensor], batch_idx: int
Expand Down
26 changes: 25 additions & 1 deletion test/renate/benchmark/models/test_l2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import torch

from renate.benchmark.models.l2p import LearningToPromptTransformer, PromptPool
from renate.benchmark.models.l2p import LearningToPromptTransformer, PromptPool, PromptedTransformer


def test_prompt_pool():
Expand Down Expand Up @@ -62,3 +62,27 @@ def test_prompt_vision_transformer_trainable_parameters(backbone, num_trainable_
model = LearningToPromptTransformer(pretrained_model_name_or_path=backbone)
n = sum(1 for x in model.parameters() if x.requires_grad)
assert n == num_trainable_params


@pytest.mark.parametrize("backbone", ["google/vit-base-patch16-224", "bert-base-uncased"])
@pytest.mark.parametrize("prompt", [None, torch.rand(3, 10, 768)])
@pytest.mark.parametrize("cls_feat", [True, False])
def test_prompted_transformer(backbone, prompt, cls_feat):
model = PromptedTransformer(
pretrained_model_name_or_path=backbone,
num_outputs=10,
prediction_strategy=None,
add_icarl_class_means=False,
)

B, P_len, _ = prompt.shape if prompt is not None else (5, 0, 0)
if "vit" in backbone:
# we are doing ViT.
inputs = torch.rand(B, 3, 224, 224)
expected_output_size = (B, 197 + P_len, 768) if not cls_feat else (B, 768)
else:
inputs = {"input_ids": torch.randint(0, 10000, (B, 128))}
expected_output_size = (B, 768)

out = model(inputs, prompt, cls_feat=cls_feat)
assert out.shape == expected_output_size
Loading