<a href="https://colab.research.google.com/github/DerekLiu35/CS485-project/blob/main/interpretable_prompts.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Uses https://github.com/kipgparker/soft-prompt-tuning/tree/main as starter code

In [None]:
import torch
import torch.nn as nn

class InterptEmbedding(nn.Module):
    def __init__(self,
                wte: nn.Embedding,
                n_tokens: int = 10,
                random_range: float = 0.5,
                initialize_from_vocab: bool = True):
        """appends learned embedding to

        Args:
            wte (nn.Embedding): original transformer word embedding
            n_tokens (int, optional): number of tokens for task. Defaults to 10.
            random_range (float, optional): range to init embedding (if not initialize from vocab). Defaults to 0.5.
            initialize_from_vocab (bool, optional): initalizes from default vocab. Defaults to True.
        """
        super(InterptEmbedding, self).__init__()
        self.wte = wte
        self.n_tokens = n_tokens
        self.learned_embedding = nn.parameter.Parameter(self.initialize_embedding(wte,
                                                                               n_tokens,
                                                                               random_range,
                                                                               initialize_from_vocab))

    def initialize_embedding(self,
                             wte: nn.Embedding,
                             n_tokens: int = 10,
                             random_range: float = 0.5,
                             initialize_from_vocab: bool = True):
        """initializes learned embedding

        Args:
            same as __init__

        Returns:
            torch.float: initialized using original schemes
        """
        if initialize_from_vocab:
            return self.wte.weight[:n_tokens].clone().detach()
        return torch.FloatTensor(n_tokens, wte.weight.size(1)).uniform_(-random_range, random_range)

    def forward(self, tokens):
        """run forward pass

        Args:
            tokens (torch.long): input tokens before encoding

        Returns:
            torch.float: encoding of text concatenated with learned task specifc embedding
        """
        # input_embedding = self.wte(tokens[:, self.n_tokens:])
        input_embedding = self.wte(tokens)
        learned_embedding = self.learned_embedding.repeat(input_embedding.size(0), 1, 1)
        return torch.cat([learned_embedding, input_embedding], 1)

## Test

In [1]:
from transformers import GPT2LMHeadModel, GPT2TokenizerFast

In [None]:
n_tokens = 10

In [None]:
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained('gpt2')

s_wte = InterptEmbedding(model.get_input_embeddings(),
                      n_tokens=n_tokens)

In [None]:
model.set_input_embeddings(s_wte)

In [None]:
inputs = tokenizer("May the force be", return_tensors="pt")

# need to pad attention_mask and input_ids to be full seq_len + n_learned_tokens
# even though it does not matter what you pad input_ids with, it's just to make HF happy
inputs['input_ids'] = torch.cat([torch.full((1,n_tokens), 50256), inputs['input_ids']], 1)
inputs['attention_mask'] = torch.cat([torch.full((1,n_tokens), 1), inputs['attention_mask']], 1)

outputs = model(**inputs)

In [None]:
outputs

CausalLMOutputWithCrossAttentions(loss=None, logits=tensor([[[-37.7550, -38.2652, -40.4861,  ..., -47.3888, -46.8042, -36.8606],
         [-69.7059, -68.8304, -68.4345,  ..., -75.0572, -74.2558, -64.0604],
         [-50.3636, -47.9540, -51.2161,  ..., -63.0637, -60.9918, -52.0684],
         ...,
         [-69.5300, -69.3791, -71.4374,  ..., -75.3412, -76.0725, -70.3825],
         [-50.3486, -51.7672, -56.7453,  ..., -64.1998, -60.3861, -55.5473],
         [-83.9399, -84.7640, -89.9478,  ..., -99.2800, -89.1253, -89.7127]]],
       grad_fn=<UnsafeViewBackward0>), past_key_values=((tensor([[[[-1.2244e+00,  2.2710e+00,  7.1464e-01,  ..., -1.3720e+00,
           -7.8290e-01,  1.6593e+00],
          [-1.7840e+00,  2.5542e+00,  2.1433e+00,  ..., -1.2502e+00,
           -1.3120e+00,  2.1512e+00],
          [-1.8715e+00,  2.1623e+00,  2.7442e+00,  ..., -2.1212e-01,
           -1.7059e+00,  1.5354e+00],
          ...,
          [-2.5883e+00,  2.3094e+00,  2.2216e+00,  ..., -6.5281e-01,
        