### Transfer Learning on A Pre-Trained Text Embedding Transformer

In [15]:
from transformers import AutoTokenizer, AutoModel
import torch
from torch.utils.data import Dataset, DataLoader
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [58]:
MODEL_NAME = "thenlper/gte-base" # This model is fairly small and is #22 on MTEB leaderboard
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
base_text_model = AutoModel.from_pretrained(MODEL_NAME)

In [59]:
class CustomTextEmbeddingModel(torch.nn.Module):
    def __init__(self, original_model, output_dim):
        super(CustomTextEmbeddingModel, self).__init__()
        self.original_model = original_model
        # 768 is the embedding dims for the original gte-base model. adding another layer on the end to project to the output dim
        if output_dim == 768:
            self.projection = torch.nn.Identity()
            for param in self.projection.parameters():
                param.requires_grad = False
        else:
            self.projection = torch.nn.Linear(768, output_dim)

    def forward(self, input_ids, attention_mask=None):
        # Original model output
        outputs = self.original_model(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = self._average_pool(outputs.last_hidden_state, attention_mask)
        # Project to new output dim
        projected_output = self.projection(pooled_output)
        return projected_output
    
    # This function is from https://huggingface.co/thenlper/gte-base
    def _average_pool(self, last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
        last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
        return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

In [60]:
# A test to verify that the text model is working
text_model = CustomTextEmbeddingModel(base_text_model, 768)

# Tokenize test inputs
test_inputs = ["This is test sentence 1", "This is test sentence 2"]
batch_dict = tokenizer(test_inputs, max_length=512, padding=True, truncation=True, return_tensors='pt')

# Pass the tokenized inputs through your custom model
with torch.no_grad():
    embeddings = text_model(batch_dict['input_ids'], batch_dict['attention_mask'])

#print(batch_dict)
print(embeddings)
print(embeddings.shape)

tensor([[ 0.1915, -0.0581,  0.1463,  ...,  0.2232,  0.2082, -0.5509],
        [ 0.0026, -0.1361,  0.1179,  ..., -0.0378,  0.2154, -0.3237]])
torch.Size([2, 768])


#### Freeze Pre-Trained Parameters

In [65]:
def freeze_pretrained_weights(model: nn.Module):
    '''
    Freezes the pretrained weights for an instance of the CustomTextEmbeddingModel class
    '''
    for param in model.original_model.parameters():
        param.requires_grad = False
    # Turn the gradients back on for the final 2 linear layers
    model.original_model.encoder.layer[10].output.requires_grad = True
    #print(model.original_model.encoder.layer[11].output.dense)
    model.original_model.encoder.layer[11].output.requires_grad = True
    #print(model.original_model.pooler.dense)
    model.original_model.pooler.requires_grad = True


#### Setup a Dataset Class

In [29]:
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        encoded_text = self.tokenizer(self.texts[idx], return_tensors='pt', padding='max_length', truncation=True, max_length=512)
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return encoded_text, label


In [7]:
# Example text, labels, and tokenizer
    # Image captions
texts = ["Caption for image 1", "Caption for image 2", "Caption for image 3"]
    # The labels are the image embeddings from the image embedding model.
labels = [Tensor([1,2,3,4,5,6,7,8,9,0]), Tensor([1,2,3,4,5,6,7,8,9,0]), Tensor([1,2,3,4,5,6,7,8,9,0])]
    # This is the tokenizer for our base pretrained text embedding model: gte-base
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)


dataset = TextDataset(texts, labels, tokenizer)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

## Extracting Embeddings from Pre-Trained ViT

In [8]:
# Possible models
# oschamp/vit-artworkclassifier        probably not this one... its finetuning seems to be somewhat poor
# google/vit-base-patch16-224

In [30]:
# Load the google ViT model
from transformers import ViTImageProcessor, ViTForImageClassification

processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
image_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')

test_image = Image.open("../data/test_art.jpg")

inputs = processor(images=test_image, return_tensors="pt")
outputs = image_model(**inputs)
logits = outputs.logits
# model predicts one of the 1000 ImageNet classes
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", image_model.config.id2label[predicted_class_idx])

Predicted class: book jacket, dust cover, dust jacket, dust wrapper


In [31]:
# Extract the embeddings from the model by removing the final classification layer
image_model.classifier = nn.Identity()

embedding = image_model(**inputs).logits

print(embedding.shape)

torch.Size([1, 768])


In [35]:
# Example of how to create embedding labels for our image dataset
image_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
image_model.classifier = nn.Identity()
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')


images = [test_image]
inputs = processor(images=images, return_tensors="pt")
embeddings = image_model(**inputs).logits
print(embeddings.shape)

torch.Size([1, 768])


In [40]:
old_man_caption = "A older man with a white shirt, golf cap and a cane walks away from an outdoor flower shop."
old_man_image = Image.open("../data/old_man.jpg")
# Resize the image to 224x224
old_man_image = old_man_image.resize((224, 224))

images = [old_man_image]
inputs = processor(images=images, return_tensors="pt")
old_man_embeddings = image_model(**inputs).logits
print(old_man_embeddings.shape)

# Pass caption through text model
text_model.load_state_dict(torch.load("C:/Users/nickj/Downloads/finetuned_text_model.pth"))

# Tokenize test inputs
batch_dict = tokenizer(old_man_caption, max_length=512, padding=True, truncation=True, return_tensors='pt')

# Pass the tokenized inputs through your custom model
with torch.no_grad():
    text_old_man_embedding = text_model(batch_dict['input_ids'], batch_dict['attention_mask'])

print(text_old_man_embedding.shape)

torch.Size([1, 768])
torch.Size([1, 768])


In [44]:
# FWrite function to calc cosine similarity
def cosine_similarity(embedding1, embedding2):
    '''
    Calculate the cosine similarity between two embeddings
    '''
    return F.cosine_similarity(embedding1, embedding2, dim=1)

# Calculate the cosine similarity between two embeddings
similarity = cosine_similarity(old_man_embeddings, text_old_man_embedding)

print(similarity.item())



0.04225935786962509


### Finetune Final Layers in Text Embedding Model

In [12]:
from datasets import load_dataset
from torchvision import transforms

In [67]:
class Flickr30kDataset(torch.utils.data.Dataset):
    def __init__(self):
        self.dataset = load_dataset(path="nlphuji/flickr30k", cache_dir="./huggingface_data")
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
        ])
        self.cap_per_image = 2

    def __len__(self):
        return self.dataset.num_rows["test"] * self.cap_per_image

    def __getitem__(self, idx):
        original_idx = idx // self.cap_per_image
        # image_path = self.dataset[idx]["image_path"]
        image = self.dataset["test"][original_idx]["image"].convert("RGB")
        image = self.transform(image)
        caption = self.dataset["test"][original_idx]["caption"][idx % self.cap_per_image]
        return {"image": image, "caption": caption}

# Create an instance of the custom dataset
flickr30k_custom_dataset = Flickr30kDataset()

In [68]:
# Setup training loop to finetune the final three layers of the text embedding model
text_model = text_model.to(device)
freeze_pretrained_weights(text_model)
image_model = image_model.to(device)

# Define the optimizer, loss, and data loader
optimizer = torch.optim.Adam(text_model.parameters(), lr=1e-3, weight_decay=1e-5)
loss_fn = nn.CosineEmbeddingLoss()
dataloader = DataLoader(flickr30k_custom_dataset, batch_size=4, shuffle=True)

training_losses = []

# Training loop
c = 0
for epoch in range(1):
    for batch in dataloader:
        # Get the image and caption from the batch
        images = batch["image"].to(device)
        captions = batch["caption"]

        # Get the image embeddings
        image_inputs = processor(images=images, return_tensors="pt").to(device)
        image_embeddings = image_model(**image_inputs).logits

        # Get the text embeddings
        text_inputs = tokenizer(captions, return_tensors='pt', padding='max_length', truncation=True, max_length=512).to(device)
        input_ids = text_inputs['input_ids'].to(device)
        attention_mask = text_inputs['attention_mask'].to(device)
        text_embeddings = text_model(input_ids, attention_mask)

        # Calculate the loss
        target = torch.ones(1).to(device)
        loss = loss_fn(text_embeddings, image_embeddings, target)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        training_losses.append(loss.item())
        
        c += 1
        if c%100 == 0:
            print(f"Batch {c}, Loss: {loss.item()}")
        if c == 600:
            torch.save(text_model.state_dict(), "finetuned_text_model.pth")

    print(f"Epoch: {epoch}, Loss: {loss.item()}")

       


In [66]:
for name, module in text_model.named_modules():
    print(f"{name}: {type(module)}")

: <class '__main__.CustomTextEmbeddingModel'>
original_model: <class 'transformers.models.bert.modeling_bert.BertModel'>
original_model.embeddings: <class 'transformers.models.bert.modeling_bert.BertEmbeddings'>
original_model.embeddings.word_embeddings: <class 'torch.nn.modules.sparse.Embedding'>
original_model.embeddings.position_embeddings: <class 'torch.nn.modules.sparse.Embedding'>
original_model.embeddings.token_type_embeddings: <class 'torch.nn.modules.sparse.Embedding'>
original_model.embeddings.LayerNorm: <class 'torch.nn.modules.normalization.LayerNorm'>
original_model.embeddings.dropout: <class 'torch.nn.modules.dropout.Dropout'>
original_model.encoder: <class 'transformers.models.bert.modeling_bert.BertEncoder'>
original_model.encoder.layer: <class 'torch.nn.modules.container.ModuleList'>
original_model.encoder.layer.0: <class 'transformers.models.bert.modeling_bert.BertLayer'>
original_model.encoder.layer.0.attention: <class 'transformers.models.bert.modeling_bert.BertAtte