### Transfer Learning on A Pre-Trained Text Embedding Transformer

In [1]:
from transformers import AutoTokenizer, AutoModel
import torch
from torch.utils.data import Dataset, DataLoader
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image

In [2]:
MODEL_NAME = "thenlper/gte-base" # This model is fairly small and is #22 on MTEB leaderboard
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)

In [3]:
class CustomTextEmbeddingModel(torch.nn.Module):
    def __init__(self, original_model, output_dim):
        super(CustomTextEmbeddingModel, self).__init__()
        self.original_model = original_model
        # 768 is the embedding dims for the original gte-base model. adding another layer on the end to project to the output dim
        self.projection = torch.nn.Linear(768, output_dim)

    def forward(self, input_ids, attention_mask=None):
        # Original model output
        outputs = self.original_model(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = self._average_pool(outputs.last_hidden_state, attention_mask)
        # Project to new output dim
        projected_output = self.projection(pooled_output)
        return projected_output
    
    # This function is from https://huggingface.co/thenlper/gte-base
    def _average_pool(self, last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
        last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
        return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

In [4]:
import json
import os
import torch
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm  # for progress bars
import concurrent.futures

MODEL_NAME = "thenlper/gte-base"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
model.eval()  # Ensure the model is in evaluation mode
model.to('cpu')  # Ensure the model is on CPU

class CustomTextEmbeddingModel(torch.nn.Module):
    def __init__(self, original_model, output_dim=700):
        super(CustomTextEmbeddingModel, self).__init__()
        self.original_model = original_model
        self.projection = torch.nn.Linear(768, output_dim)

    def forward(self, input_ids, attention_mask=None):
        outputs = self.original_model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_states = outputs.last_hidden_state
        pooled_output = torch.mean(last_hidden_states, dim=1)  # Average pooling
        projected_output = self.projection(pooled_output)
        return projected_output

# Initialize your custom text embedding model
test_model = CustomTextEmbeddingModel(model, 700).to('cpu')

def process_caption(file_path):
    """Process the first caption in a single file and return its embedding."""
    file_name = os.path.basename(file_path)
    image_id = file_name.split('_')[0]  # Assuming file format is "<image_id>_captions.txt"
    with open(file_path, 'r') as file:
        first_caption = file.readline().strip()  # Read only the first caption

    inputs = tokenizer(first_caption, return_tensors="pt", padding='max_length', truncation=True, max_length=512)

    with torch.no_grad():  # Do not compute gradients
        embeddings = test_model(inputs.input_ids.to('cpu'), inputs.attention_mask.to('cpu')).detach().numpy()  # Generate embedding

    return image_id, embeddings.tolist()

def generate_text_embeddings(base_dir):
    embeddings = {}
    caption_paths = [os.path.join(base_dir, name) for name in os.listdir(base_dir) if name.endswith('.txt')]
    
    # Use ThreadPoolExecutor to process caption files in parallel
    with concurrent.futures.ThreadPoolExecutor() as executor:
        # Use tqdm for progress bar with futures
        futures = [executor.submit(process_caption, caption_path) for caption_path in caption_paths]
        for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
            image_id, embedding = future.result()
            embeddings[image_id] = embedding

    return embeddings

# Specify your base directory containing caption files
base_dir = "../data/flickr30k_images/"  # Adjust this path as necessary

# Generate embeddings
embeddings = generate_text_embeddings(base_dir)

# Save embeddings to a JSON file
embeddings_file = "text_embeddings.json"
with open(embeddings_file, 'w') as f:
    json.dump(embeddings, f)

print(f"Embeddings saved to {embeddings_file}.")


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
  0%|          | 83/30000 [00:10<44:30, 11.20it/s]  

In [9]:
# A test to verify that the model is working. Outputs should just be meaningless tensors of size 10
test_model = CustomTextEmbeddingModel(model, 10)

# Tokenize test inputs
test_inputs = ["This is test sentence 1", "This is test sentence 2"]
batch_dict = tokenizer(test_inputs, max_length=512, padding=True, truncation=True, return_tensors='pt')

# Pass the tokenized inputs through your custom model
with torch.no_grad():
    embeddings = test_model(batch_dict['input_ids'], batch_dict['attention_mask'])

#print(batch_dict)
print(embeddings)


tensor([[-0.2934,  0.1638,  0.3081,  0.6756,  0.0679,  0.6542, -0.0653,  0.1022,
         -0.3658, -0.3653],
        [-0.1512,  0.1145,  0.2403,  0.6169, -0.0582,  0.5726, -0.0173,  0.0687,
         -0.3651, -0.3757]])


#### Freeze Pre-Trained Parameters

In [5]:
def freeze_pretrained_weights(model: nn.Module):
    '''
    Freezes the pretrained weights for an instance of the CustomTextEmbeddingModel class
    '''
    for param in model.original_model.parameters():
        param.requires_grad = False

#### Setup a Dataset Class

In [6]:
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        encoded_text = self.tokenizer(self.texts[idx], return_tensors='pt', padding='max_length', truncation=True, max_length=512)
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return encoded_text, label


In [7]:
# Example usage of the TextDataset class

# Example text, labels, and tokenizer
    # Image captions
texts = ["Caption for image 1", "Caption for image 2", "Caption for image 3"]
    # The labels are the image embeddings from the image embedding model.
labels = [Tensor([1,2,3,4,5,6,7,8,9,0]), Tensor([1,2,3,4,5,6,7,8,9,0]), Tensor([1,2,3,4,5,6,7,8,9,0])]
    # This is the tokenizer for our base pretrained text embedding model: gte-base
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)


dataset = TextDataset(texts, labels, tokenizer)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

## Extracting Embeddings from Pre-Trained ViT

In [9]:
# Possible models
# oschamp/vit-artworkclassifier        probably not this one... its finetuning seems to be somewhat poor
# google/vit-base-patch16-224

In [22]:
# Load the google ViT model
from transformers import ViTImageProcessor, ViTForImageClassification

processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')

test_image = Image.open("../data/test_art.jpg")

inputs = processor(images=test_image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
# model predicts one of the 1000 ImageNet classes
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])

Predicted class: book jacket, dust cover, dust jacket, dust wrapper


In [25]:
# Extract the embeddings from the model by removing the final classification layer
model.classifier = nn.Identity()

embedding = model(**inputs).logits

print(embedding.shape)

torch.Size([3, 768])


In [26]:
# Example of how to create embedding labels for our image dataset
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
model.classifier = nn.Identity()
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')


images = [test_image, test_image, test_image]
inputs = processor(images=images, return_tensors="pt")
embeddings = model(**inputs).logits
print(embeddings.shape)

torch.Size([3, 768])
