In [66]:
from transformers import AutoTokenizer, T5ForConditionalGeneration
# from soft_prompt_tuning import SoftEmbedding
import torch
import torch.nn as nn

In [81]:
tokenizer = AutoTokenizer.from_pretrained("checkpoint/t5-base")
model = T5ForConditionalGeneration.from_pretrained("checkpoint/t5-base")

In [82]:
class SoftEmbedding(nn.Module):
    def __init__(self, wte=nn.Embedding, n_tokens=10, random_range=0.5, initialize_from_vocab=True):
        """appends learned embedding to
        Args:
            wte (nn.Embedding): original transformer word embedding
            n_tokens (int, optional): number of tokens for task. Defaults to 10.
            random_range (float, optional): range to init embedding (if not initialize from vocab). Defaults to 0.5.
            initialize_from_vocab (bool, optional): initialize from default vocab. Defaults to True.
        """
        super(SoftEmbedding, self).__init__()
        self.wte = wte
        self.n_tokens = n_tokens
        self.random_range = random_range
        self.initialize_from_vocab = initialize_from_vocab
        parameters = self.initialize_embedding()
        self.learned_embedding = nn.parameter.Parameter(parameters)

    def initialize_embedding(self):
        """initializes learned embedding
        Args:
            same as __init__
        Returns:
            torch.float: initialized using original schemes
        Description:
            # if initialize from the base model vocab, return the num:self.n_tokens weights from the base model embedding.
            # if not, random generate the soft embedding with the same size as above
        """
        if self.initialize_from_vocab:
            return self.wte.weight[:self.n_tokens].clone().detach() #.clone().detach() refers to create a new memory, and remove the tensor from the computational graph
        return torch.FloatTensor(self.n_tokens, self.wte.weight.size(1)).uniform_(-self.random_range, self.random_range) # torch.FloatTensor(a,b), create a random tensor according to the shape of a and b;

    def forward(self, tokens):
        """run forward pass
        Args:
            tokens (torch.long): input tokens before encoding
        Returns:
            torch.float: encoding of text concatenated with learned task specific embedding
        """
        input_embedding = self.wte(tokens[:, self.n_tokens:])
        learned_embedding = self.learned_embedding.repeat(input_embedding.size(0), 1, 1)
        return torch.cat([learned_embedding, input_embedding], 1)

In [83]:
n_tokens = 5
s_wte = SoftEmbedding(wte=model.get_input_embeddings(),
                      n_tokens=n_tokens,
                      random_range=0.5,
                      initialize_from_vocab=True)
model.set_input_embeddings(s_wte)

In [94]:
"""
1.need to pad attention_mask and input_ids to be full seq_len + n_learned_tokens
2.even though it does not matter what you pad input_ids with, it's just to make HF happy, here, we use -1 to label these stuffs
"""
inputs = ["what is the name of the Movie1s distributed by Company1?", "what is the name of the company distributed the Movie1?"]
labels = ["distributor_to_movie", "movie_to_distributor"]
task_prefix = "summarize: " # add the task prefix summarize: <s1> to improve the model performance
max_source_length = 50
max_target_length = 15

# encode the inputs
encoding = tokenizer([task_prefix + sequence for sequence in inputs],
                     padding='longest',
                     max_length=max_source_length,
                     truncation=True,
                     return_tensors="pt")
input_ids, attention_mask = encoding.input_ids, encoding.attention_mask

# encode the targets
target_encoding = tokenizer(labels,
                            padding='longest',
                            max_length=max_target_length,
                            truncation=True,
                            return_tensors='pt')
decoder_attention_mask = target_encoding.attention_mask
labels = target_encoding.input_ids


# torch.full(input_size, fill_value), fill the input_size with the fill_value
inputs['input_ids'] = torch.cat([torch.full((1,n_tokens), -1), inputs['input_ids']], 1)
inputs['attention_mask'] = torch.cat([torch.full((1,n_tokens), -1), inputs['attention_mask']], 1)

In [96]:
loss = model(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, labels=labels.input_ids).loss

tensor(6.0733, grad_fn=<NllLossBackward0>)

In [None]:
# input_query = "which are the directors of the films written by the writer of [The Green Mile]"
input_query = "which are the directors"

inputs = tokenizer(input_query, return_tensors="pt", truncation=True)
inputs['input_ids'] = torch.cat([torch.full((1,n_tokens), 50256), inputs['input_ids']], 1)
inputs['attention_mask'] = torch.cat([torch.full((1,n_tokens), 1), inputs['attention_mask']], 1)

print(inputs['input_ids'].shape)
tokens_to_generate = 10

# outputs = model.generate(**inputs, max_length=inputs['input_ids'].size(1)+tokens_to_generate, use_cache=False)
# outputs = model.generate(**inputs, max_length=inputs['input_ids'].size(1)+tokens_to_generate, use_cache=False)
outputs = model.generate(**inputs, max_length=inputs['input_ids'].size(1)+tokens_to_generate, use_cache=False)

In [None]:
outputs = model.generate(
    inputs["input_ids"], max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True
)

In [None]:
inputs = tokenizer("may the force", return_tensors="pt")

# need to pad attention_mask and input_ids to be full seq_len + n_learned_tokens
# even though it does not matter what you pad input_ids with, it's just to make HF happy
inputs['input_ids'] = torch.cat([torch.full((1,n_tokens), 50256), inputs['input_ids']], 1)
inputs['attention_mask'] = torch.cat([torch.full((1,n_tokens), 1), inputs['attention_mask']], 1)

tokens_to_generate = 10

outputs = model.generate(**inputs, max_length=inputs['input_ids'].size(1)+tokens_to_generate, use_cache=False)

In [None]:
print(outputs)
print(tokenizer.decode(outputs[0]))