In [1]:
import torch
import transformers
from transformers import T5Tokenizer, T5EncoderModel, T5Config

In [2]:
# config

MAX_LENGTH = 256

DEFAULT_T5_NAME = 'google/t5-v1_1-base'

T5_CONFIGS = {}

In [3]:
# singleton globals

def get_tokenizer(name):
    tokenizer = T5Tokenizer.from_pretrained(name)
    return tokenizer

def get_model(name):
    model = T5EncoderModel.from_pretrained(name)
    return model

def get_model_and_tokenizer(name):
    global T5_CONFIGS

    if name not in T5_CONFIGS:
        T5_CONFIGS[name] = dict()
    if "model" not in T5_CONFIGS[name]:
        T5_CONFIGS[name]["model"] = get_model(name)
    if "tokenizer" not in T5_CONFIGS[name]:
        T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name)

    return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer']

def get_encoded_dim(name):
    if name not in T5_CONFIGS:
        # avoids loading the model if we only want to get the dim
        config = T5Config.from_pretrained(name)
        T5_CONFIGS[name] = dict(config=config)
    elif "config" in T5_CONFIGS[name]:
        config = T5_CONFIGS[name]["config"]
    elif "model" in T5_CONFIGS[name]:
        config = T5_CONFIGS[name]["model"].config
    else:
        assert False
    return config.d_model

In [17]:
# encoding text

def t5_encode_text(texts, name = DEFAULT_T5_NAME):
    t5, tokenizer = get_model_and_tokenizer(name)

    if torch.cuda.is_available():
        t5 = t5.cuda()

    device = next(t5.parameters()).device
    print(device)

    encoded = tokenizer.batch_encode_plus(
        texts,
        return_tensors = "pt",
        padding = 'longest',
        max_length = MAX_LENGTH,
        truncation = True
    )

    input_ids = encoded.input_ids.to(device)
    attn_mask = encoded.attention_mask.to(device)

    t5.eval()

    with torch.no_grad():
        output = t5(input_ids = input_ids, attention_mask = attn_mask)
        encoded_text = output.last_hidden_state.detach()

    return encoded_text, attn_mask.bool()

In [30]:
class TextEncoderT5Based():
    
    def __init__(self, name = 'google/t5-v1_1-small', device='cpu'):
        
        self.device    = device
        self.model     = T5EncoderModel.from_pretrained(name).to(device)
        self.tokenizer = T5Tokenizer.from_pretrained(name)
        
    def textEncoder(self, texts):
        
        text_encoded = self.tokenizer.batch_encode_plus(texts,return_tensors = "pt", padding = 'longest',
                                                        max_length = MAX_LENGTH, truncation = True)
        
        text_ids = text_encoded.input_ids.to(self.device)
        mask     = text_encoded.attention_mask.to(self.device)
        
        self.model.eval()
        
        with torch.no_grad(): encoded_text = self.model(text_ids, mask).last_hidden_state.detach()
                
        return encoded_text, mask.bool()        

In [31]:
T5 = TextEncoderT5Based()

Some weights of the model checkpoint at google/t5-v1_1-small were not used when initializing T5EncoderModel: ['decoder.embed_tokens.weight', 'decoder.block.0.layer.0.SelfAttention.q.weight', 'decoder.block.0.layer.0.SelfAttention.k.weight', 'decoder.block.0.layer.0.SelfAttention.v.weight', 'decoder.block.0.layer.0.SelfAttention.o.weight', 'decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight', 'decoder.block.0.layer.0.layer_norm.weight', 'decoder.block.0.layer.1.EncDecAttention.q.weight', 'decoder.block.0.layer.1.EncDecAttention.k.weight', 'decoder.block.0.layer.1.EncDecAttention.v.weight', 'decoder.block.0.layer.1.EncDecAttention.o.weight', 'decoder.block.0.layer.1.layer_norm.weight', 'decoder.block.0.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.0.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.0.layer.2.DenseReluDense.wo.weight', 'decoder.block.0.layer.2.layer_norm.weight', 'decoder.block.1.layer.0.SelfAttention.q.weight', 'decoder.block.1.layer.0.SelfAtten

In [32]:
t5 = T5.textEncoder(['I', 'you', 'yes my'])

In [33]:
t5

(tensor([[[ 7.7179e-02, -9.5068e-03,  4.9421e-04,  ..., -1.7528e-01,
           -2.7930e-03,  7.0717e-01],
          [-3.6950e-01,  8.6838e-02, -1.3701e-01,  ..., -3.6052e-01,
           -1.9034e-01,  8.7760e-02],
          [ 8.3804e-02, -2.0349e-02,  5.8009e-03,  ..., -1.3579e-01,
           -1.9433e-02,  5.5632e-01]],
 
         [[ 7.4668e-02, -2.6701e-02, -3.7860e-03,  ..., -1.2131e-01,
           -2.6960e-02,  5.9868e-01],
          [-3.0119e-01, -4.8806e-01,  2.6033e-01,  ...,  1.0381e-01,
           -7.0362e-01,  9.3447e-01],
          [ 2.4214e-03, -1.6723e-01,  1.2286e-01,  ..., -3.0876e-01,
           -6.1628e-01,  6.4991e-01]],
 
         [[ 8.1621e-02, -2.6020e-02, -4.7646e-03,  ..., -1.2855e-01,
           -2.7603e-02,  6.2340e-01],
          [ 4.4333e-01,  2.6897e-01, -2.7354e-01,  ..., -3.6821e-01,
           -9.6742e-02, -1.9053e-01],
          [-1.2980e-01, -4.0838e-01,  1.0294e-01,  ..., -1.4140e-01,
           -4.8038e-01, -3.9026e-01]]]),
 tensor([[ True,  True, Fals

In [24]:
t = t5_encode_text(['I', 'you', 'yes my'], name = 'google/t5-v1_1-small')
t

cpu


(tensor([[[ 7.7179e-02, -9.5068e-03,  4.9421e-04,  ..., -1.7528e-01,
           -2.7930e-03,  7.0717e-01],
          [-3.6950e-01,  8.6838e-02, -1.3701e-01,  ..., -3.6052e-01,
           -1.9034e-01,  8.7760e-02],
          [ 8.3804e-02, -2.0349e-02,  5.8009e-03,  ..., -1.3579e-01,
           -1.9433e-02,  5.5632e-01]],
 
         [[ 7.4668e-02, -2.6701e-02, -3.7860e-03,  ..., -1.2131e-01,
           -2.6960e-02,  5.9868e-01],
          [-3.0119e-01, -4.8806e-01,  2.6033e-01,  ...,  1.0381e-01,
           -7.0362e-01,  9.3447e-01],
          [ 2.4214e-03, -1.6723e-01,  1.2286e-01,  ..., -3.0876e-01,
           -6.1628e-01,  6.4991e-01]],
 
         [[ 8.1621e-02, -2.6020e-02, -4.7646e-03,  ..., -1.2855e-01,
           -2.7603e-02,  6.2340e-01],
          [ 4.4333e-01,  2.6897e-01, -2.7354e-01,  ..., -3.6821e-01,
           -9.6742e-02, -1.9053e-01],
          [-1.2980e-01, -4.0838e-01,  1.0294e-01,  ..., -1.4140e-01,
           -4.8038e-01, -3.9026e-01]]]),
 tensor([[ True,  True, Fals

In [23]:
t[0].shape

torch.Size([2, 3, 512])