### Transfer Learning on A Pre-Trained Text Embedding Transformer

In [18]:
from transformers import AutoTokenizer, AutoModel
import torch
from torch.utils.data import Dataset, DataLoader
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F

In [14]:
MODEL_NAME = "thenlper/gte-base" # This model is fairly small and is #22 on MTEB leaderboard
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)

In [15]:
class CustomTextEmbeddingModel(torch.nn.Module):
    def __init__(self, original_model, output_dim):
        super(CustomTextEmbeddingModel, self).__init__()
        self.original_model = original_model
        # 768 is the embedding dims for the original gte-base model. adding another layer on the end to project to the output dim
        self.projection = torch.nn.Linear(768, output_dim)

    def forward(self, input_ids, attention_mask=None):
        # Original model output
        outputs = self.original_model(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = self._average_pool(outputs.last_hidden_state, attention_mask)
        # Project to new output dim
        projected_output = self.projection(pooled_output)
        return projected_output
    
    # This function is from https://huggingface.co/thenlper/gte-base
    def _average_pool(self, last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
        last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
        return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

In [16]:
# A test to verify that the model is working. Outputs should just be meaningless tensors of size 10
test_model = CustomTextEmbeddingModel(model, 10)

# Tokenize test inputs
test_inputs = ["This is test sentence 1", "This is test sentence 2"]
batch_dict = tokenizer(test_inputs, max_length=512, padding=True, truncation=True, return_tensors='pt')

# Pass the tokenized inputs through your custom model
with torch.no_grad():
    embeddings = test_model(batch_dict['input_ids'], batch_dict['attention_mask'])

#print(batch_dict)
print(embeddings)


tensor([[-0.2958, -0.4906,  0.2907, -0.6942, -0.1876,  0.2642, -0.3741, -0.1587,
          0.1651,  0.3889],
        [-0.2686, -0.2305,  0.3190, -0.6691, -0.1300,  0.2702, -0.2727, -0.2254,
          0.2705,  0.3965]])


#### Freeze Pre-Trained Parameters

In [17]:
def freeze_pretrained_weights(model: nn.Module):
    '''
    Freezes the pretrained weights for an instance of the CustomTextEmbeddingModel class
    '''
    for param in model.original_model.parameters():
        param.requires_grad = False

#### Setup a Dataset Class

In [19]:
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        encoded_text = self.tokenizer(self.texts[idx], return_tensors='pt', padding='max_length', truncation=True, max_length=512)
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return encoded_text, label


In [20]:
# Example usage of the TextDataset class

# Example text, labels, and tokenizer
    # Image captions
texts = ["Caption for image 1", "Caption for image 2", "Caption for image 3"]
    # The labels are the image embeddings from the image embedding model.
labels = [Tensor([1,2,3,4,5,6,7,8,9,0]), Tensor([1,2,3,4,5,6,7,8,9,0]), Tensor([1,2,3,4,5,6,7,8,9,0])]
    # This is the tokenizer for our base pretrained text embedding model: gte-base
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)


dataset = TextDataset(texts, labels, tokenizer)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)