# Image Encoding with Custom Transformer Decoder - Training Prototype


## Imports

In [16]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPImageProcessor, CLIPVisionModelWithProjection
from datasets import load_dataset, DatasetDict
import torch.optim as optim

## Load and Prepare Data


In [2]:
# Load dataset
ds = load_dataset("nlphuji/flickr30k", split='test').select(range(10000))

In [3]:
# Load CLIP models and processors
tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
vision_model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")

# Move models to GPU
text_model.to('cuda')
vision_model.to('cuda')

CLIPVisionModelWithProjection(
  (vision_model): CLIPVisionTransformer(
    (embeddings): CLIPVisionEmbeddings(
      (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
      (position_embedding): Embedding(257, 1024)
    )
    (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-23): 24 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=1024, out_featur

## Preprocessing Functions

In [4]:
def process_images(examples):
    images = examples["image"]
    processed_images = image_processor(images, return_tensors="pt")
    processed_images = {k: v.to('cuda') for k, v in processed_images.items()}
    
    with torch.no_grad():
        image_outputs = vision_model(**processed_images)
    
    return {"image_embed": image_outputs.last_hidden_state.cpu().numpy()}

def process_text(examples):
    # Flatten the list of lists of captions
    all_captions = [caption for image_captions in examples["caption"] for caption in image_captions]
    
    # Tokenize all captions at once
    tokenized_text = tokenizer(all_captions, padding=True, truncation=True, return_tensors="pt")
    tokenized_text = {k: v.to('cuda') for k, v in tokenized_text.items()}
    
    with torch.no_grad():
        text_outputs = text_model(**tokenized_text)
    
    # Get the text embeddings
    text_embeds = text_outputs.text_embeds.cpu().numpy()
    
    # Reshape the embeddings to match the original structure (5 captions per image)
    num_images = len(examples["caption"])
    reshaped_embeds = text_embeds.reshape(num_images, 5, -1)
    
    return {"text_embed": reshaped_embeds}


## Preprocess Data

In [5]:
# Process text
ds = ds.map(process_text, batched=True, batch_size=32, remove_columns=["caption"])

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [6]:
# Process images
ds = ds.map(process_images, batched=True, batch_size=32, remove_columns=["image"])

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [7]:
# Set the format for PyTorch
ds.set_format(type="torch")

# Split the dataset
dataset = ds.train_test_split(test_size=0.2, seed=42)
train_val_dataset = dataset['train'].train_test_split(test_size=0.2, seed=42)

final_dataset = DatasetDict({
    'train': train_val_dataset['train'],
    'validation': train_val_dataset['test'],
    'test': dataset['test']
})

# Create DataLoaders
train_loader = DataLoader(final_dataset['train'], batch_size=64, shuffle=True)
val_loader = DataLoader(final_dataset['validation'], batch_size=64, shuffle=False)
test_loader = DataLoader(final_dataset['test'], batch_size=64, shuffle=False)


## Model Definition

In [17]:
d_model = 128
vision_dim = 1024
clip_dim = 768
max_length = 6  # Maximum number of tokens to generate

# Define the custom transformer decoder
model = CustomImageEncoder(d_model, vision_dim, 8, 512, 0.05, 6, None, 8)

# Define the projector
projector = nn.Sequential(
    nn.Linear(d_model, clip_dim),
    nn.LayerNorm(clip_dim),
    nn.ReLU(),
    nn.Linear(clip_dim, clip_dim)
)

# Move models to GPU
model.to('cuda')
projector.to('cuda')

Sequential(
  (0): Linear(in_features=128, out_features=768, bias=True)
  (1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (2): ReLU()
  (3): Linear(in_features=768, out_features=768, bias=True)
)

In [18]:
print(model)

CustomImageEncoder(
  (decoder): CustomTransformerDecoder(
    (layers): ModuleList(
      (0-5): 6 x CustomTransformerDecoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (multihead_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=512, bias=True)
        (dropout): Dropout(p=0.05, inplace=False)
        (linear2): Linear(in_features=512, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm3): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.05, inplace=False)
        (dropout2): Dropout(p=0.05, inplace=False)
        (dropout3): Dropout(p=0.05, inplace=False)


## Loss Function and Optimizer

In [20]:
from transformers import get_cosine_schedule_with_warmup
# Define hyperparameters
num_epochs = 10
warmup_steps = 1000  # Adjust as needed
total_steps = num_epochs * len(train_loader)

criterion = nn.CosineEmbeddingLoss()
optimizer = optim.AdamW(list(model.parameters()) + list(projector.parameters()), lr=5e-4)

# Create the learning rate scheduler
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

## Training Loop

In [21]:
import tqdm

In [22]:
print(CustomImageEncoder)

<class '__main__.CustomImageEncoder'>


In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.transformer import TransformerDecoderLayer

class CustomTransformerDecoderLayer(TransformerDecoderLayer):
    def __init__(self, d_model, encoded_img_dim, nhead, dim_feedforward=2048,
                 dropout=0.1, norm_first=False, batch_first=True):
        super(CustomTransformerDecoderLayer, self).__init__(d_model, nhead, dim_feedforward,
                                                            dropout, batch_first=batch_first,
                                                            norm_first=norm_first)
        # Override the multihead_attn layer to attend to encoded image with different dim
        self.multihead_attn  = nn.MultiheadAttention(d_model, nhead, dropout, kdim=encoded_img_dim,
                                                     vdim=encoded_img_dim, batch_first=batch_first)

    def forward(self, tgt: torch.Tensor, memory: torch.Tensor, tgt_mask: torch.Tensor = None,
                memory_mask: torch.Tensor = None, tgt_key_padding_mask: torch.Tensor = None, 
                memory_key_padding_mask: torch.Tensor = None, tgt_is_causal: bool = False, 
                memory_is_causal: bool = False) -> torch.Tensor:
        x = tgt
        if self.norm_first:
            x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal)
            x = x + self._mha_block(self.norm2(x), memory, memory_mask,
                                    memory_key_padding_mask, memory_is_causal)
            x = x + self._ff_block(self.norm3(x))
        else:
            x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal))
            x = self.norm2(x + self._mha_block(x, memory, memory_mask,
                                               memory_key_padding_mask, memory_is_causal))
            x = self.norm3(x + self._ff_block(x))

        return x

    def _mha_block(self, x: torch.Tensor, mem: torch.Tensor,
                   attn_mask: torch.Tensor, key_padding_mask: torch.Tensor,
                   is_causal: bool = False) -> torch.Tensor:
        x = self.multihead_attn(x, mem, mem,
                                attn_mask=attn_mask,
                                key_padding_mask=key_padding_mask,
                                is_causal=is_causal,
                                need_weights=False)[0]
        return self.dropout2(x)

class CustomTransformerDecoder(nn.TransformerDecoder):
    def __init__(self, d_model, d_encoding, nhead, dim_feedforward, dropout, num_layers, norm=None):
        decoder_layer = CustomTransformerDecoderLayer(d_model, d_encoding, nhead, dim_feedforward, dropout)
        super(CustomTransformerDecoder, self).__init__(decoder_layer, num_layers, norm)

    def forward(self, tgt: torch.Tensor, memory: torch.Tensor, tgt_mask: torch.Tensor = None,
                memory_mask: torch.Tensor = None, tgt_key_padding_mask: torch.Tensor = None,
                memory_key_padding_mask: torch.Tensor = None, tgt_is_causal: bool = False,
                memory_is_causal: bool = False) -> torch.Tensor:
        output = tgt

        for mod in self.layers:
            output = mod(output, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
                         tgt_key_padding_mask=tgt_key_padding_mask,
                         memory_key_padding_mask=memory_key_padding_mask,
                         tgt_is_causal=tgt_is_causal, memory_is_causal=memory_is_causal)

        if self.norm is not None:
            output = self.norm(output)

        return output

In [24]:
class CustomImageEncoder(nn.Module):
    def __init__(self, d_model, d_encoding, nhead, dim_feedforward, dropout,
                 num_layers, output_dim, max_length, norm=None):
        super().__init__()
        self.decoder = CustomTransformerDecoder(d_model, d_encoding, nhead, dim_feedforward,
                                                dropout, num_layers, norm)
        self.max_length = max_length
        self.d_model = d_model

    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None,
                memory_key_padding_mask=None):
        output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
                              tgt_key_padding_mask=tgt_key_padding_mask,
                              memory_key_padding_mask=memory_key_padding_mask)
        return output

    def generate(self, memory, max_length):
        device = next(self.parameters()).device
        batch_size = memory.size(0)

        # Initialize the input sequence with start tokens
        input_seq = torch.zeros(batch_size, 1, self.d_model, device=device)

        for _ in range(max_length - 1):  # -1 because we already have one token
            # Create causal mask
            tgt_mask = self.generate_square_subsequent_mask(input_seq.size(1)).to(device)

            # Generate the next token
            output = self.forward(input_seq, memory, tgt_mask=tgt_mask)
            next_token = output[:, -1, :].unsqueeze(1)

            # Append the next token to the input sequence
            input_seq = torch.cat([input_seq, next_token], dim=1)

        return input_seq

    @staticmethod
    def generate_square_subsequent_mask(sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

In [25]:
import wandb
import torch
import torch.nn.functional as F
from tqdm import tqdm

# Initialize wandb
wandb.init(project="image-encoding-project", name="experiment-1")

# Log hyperparameters
wandb.config.update({
    "epochs": num_epochs,
    "batch_size": train_loader.batch_size,
    "initial_learning_rate": optimizer.param_groups[0]['lr'],
    "warmup_steps": warmup_steps,
    "model": model.__class__.__name__,
    "max_length": max_length,
    "d_model": d_model
})

# Training loop
best_val_loss = float('inf')
for epoch in range(num_epochs):
    # Training phase
    model.train()
    projector.train()
    total_train_loss = 0
    
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training"):
        image_embed = batch["image_embed"].to('cuda')
        text_embeds = batch["text_embed"].to('cuda')  # Assume this contains 5 text embeddings per image
        
        # Forward pass
        generated_sequence = model.generate(image_embed, max_length)
        
        # Project each token and compute loss
        loss = 0
        for i in range(1, generated_sequence.size(1)):
            projected_token = projector(generated_sequence[:, i, :])
            
            # Compute similarity between projected token and each of the 5 captions
            for j in range(5):
                similarity = F.cosine_similarity(projected_token, text_embeds[:, j, :], dim=1)
                # We want to maximize similarity, so we minimize negative similarity
                loss -= similarity.mean() * i / (max_length - 1)
        
        loss /= 5  # Average loss over the 5 captions
        
        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        scheduler.step()
        
        total_train_loss += loss.item()
    
    avg_train_loss = total_train_loss / len(train_loader)
    
    # Validation phase
    model.eval()
    projector.eval()
    total_val_loss = 0
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation"):
            image_embed = batch["image_embed"].to('cuda')
            text_embeds = batch["text_embed"].to('cuda')
            
            # Forward pass
            generated_sequence = model.generate(image_embed, max_length)
            
            # Project each token and compute loss
            loss = 0
            for i in range(1, generated_sequence.size(1)):
                projected_token = projector(generated_sequence[:, i, :])
                
                # Compute similarity between projected token and each of the 5 captions
                for j in range(5):
                    similarity = F.cosine_similarity(projected_token, text_embeds[:, j, :], dim=1)
                    loss -= similarity.mean() * i / (max_length - 1)
            
            loss /= 5  # Average loss over the 5 captions
            total_val_loss += loss.item()
    
    avg_val_loss = total_val_loss / len(val_loader)
    
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
    
    # Log metrics to wandb
    wandb.log({
        "epoch": epoch + 1,
        "train_loss": avg_train_loss,
        "val_loss": avg_val_loss,
        "learning_rate": optimizer.param_groups[0]['lr']
    })
    
    # Save the best model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'projector_state_dict': projector.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': avg_train_loss,
            'val_loss': avg_val_loss,
        }, 'best_model.pth')
        print(f"New best model saved with validation loss: {best_val_loss:.4f}")

# Close wandb run
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
Epoch 1/10 - Training: 100%|██████████| 100/100 [06:31<00:00,  3.91s/it]
Epoch 1/10 - Validation: 100%|██████████| 25/25 [00:55<00:00,  2.22s/it]


Epoch 1/10, Train Loss: -0.6688, Val Loss: -1.5888
New best model saved with validation loss: -1.5888


Epoch 2/10 - Training: 100%|██████████| 100/100 [06:39<00:00,  4.00s/it]
Epoch 2/10 - Validation: 100%|██████████| 25/25 [00:53<00:00,  2.14s/it]


Epoch 2/10, Train Loss: -1.6566, Val Loss: -1.7085
New best model saved with validation loss: -1.7085


Epoch 3/10 - Training: 100%|██████████| 100/100 [06:00<00:00,  3.60s/it]
Epoch 3/10 - Validation: 100%|██████████| 25/25 [00:48<00:00,  1.95s/it]


Epoch 3/10, Train Loss: -1.6966, Val Loss: -1.7137
New best model saved with validation loss: -1.7137


Epoch 4/10 - Training: 100%|██████████| 100/100 [05:48<00:00,  3.48s/it]
Epoch 4/10 - Validation: 100%|██████████| 25/25 [00:45<00:00,  1.80s/it]


Epoch 4/10, Train Loss: -1.7227, Val Loss: -1.7677
New best model saved with validation loss: -1.7677


Epoch 5/10 - Training: 100%|██████████| 100/100 [05:47<00:00,  3.48s/it]
Epoch 5/10 - Validation: 100%|██████████| 25/25 [00:44<00:00,  1.78s/it]


Epoch 5/10, Train Loss: -1.7950, Val Loss: -1.8530
New best model saved with validation loss: -1.8530


Epoch 6/10 - Training: 100%|██████████| 100/100 [05:48<00:00,  3.49s/it]
Epoch 6/10 - Validation: 100%|██████████| 25/25 [00:43<00:00,  1.74s/it]


Epoch 6/10, Train Loss: -1.8927, Val Loss: -1.9432
New best model saved with validation loss: -1.9432


Epoch 7/10 - Training: 100%|██████████| 100/100 [05:48<00:00,  3.49s/it]
Epoch 7/10 - Validation: 100%|██████████| 25/25 [00:43<00:00,  1.75s/it]


Epoch 7/10, Train Loss: -1.9667, Val Loss: -2.0012
New best model saved with validation loss: -2.0012


Epoch 8/10 - Training: 100%|██████████| 100/100 [05:48<00:00,  3.49s/it]
Epoch 8/10 - Validation: 100%|██████████| 25/25 [00:44<00:00,  1.78s/it]


Epoch 8/10, Train Loss: -2.0197, Val Loss: -2.0392
New best model saved with validation loss: -2.0392


Epoch 9/10 - Training: 100%|██████████| 100/100 [05:44<00:00,  3.44s/it]
Epoch 9/10 - Validation: 100%|██████████| 25/25 [00:43<00:00,  1.72s/it]


Epoch 9/10, Train Loss: -2.0578, Val Loss: -2.0698
New best model saved with validation loss: -2.0698


Epoch 10/10 - Training: 100%|██████████| 100/100 [05:41<00:00,  3.42s/it]
Epoch 10/10 - Validation: 100%|██████████| 25/25 [00:43<00:00,  1.73s/it]


Epoch 10/10, Train Loss: -2.0890, Val Loss: -2.0900
New best model saved with validation loss: -2.0900


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▂▃▃▄▅▆▆▇█
learning_rate,▁▂▃▃▄▅▆▆▇█
train_loss,█▃▃▃▂▂▂▁▁▁
val_loss,█▆▆▆▄▃▂▂▁▁

0,1
epoch,10.0
learning_rate,0.0005
train_loss,-2.089
val_loss,-2.08999


## Save the Model

In [26]:
torch.save({
    'model_state_dict': model.state_dict(),
    'projector_state_dict': projector.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}, 'custom_transformer_decoder_model.pth')

print("Model saved successfully.")

Model saved successfully.


In [30]:
import requests
from PIL import Image
url = "http://images.cocodataset.org/val2017/000000039769.jpg"

image = Image.open(requests.get(url, stream=True).raw)
image.show()

In [31]:
image = image_processor(image, return_tensors="pt")

In [32]:
image

{'pixel_values': tensor([[[[ 0.5873,  0.5873,  0.6165,  ...,  0.0617,  0.0471, -0.0259],
          [ 0.5727,  0.5727,  0.6603,  ...,  0.1201,  0.0763,  0.0909],
          [ 0.5873,  0.5435,  0.6165,  ...,  0.0325,  0.1201,  0.0617],
          ...,
          [ 1.8719,  1.8573,  1.8719,  ...,  1.3902,  1.4340,  1.4194],
          [ 1.8281,  1.8719,  1.8427,  ...,  1.4486,  1.4340,  1.5070],
          [ 1.8573,  1.9011,  1.8281,  ...,  1.3756,  1.3610,  1.4486]],

         [[-1.3169, -1.3019, -1.3169,  ..., -1.4970, -1.4369, -1.4820],
          [-1.2418, -1.2718, -1.2268,  ..., -1.4369, -1.4669, -1.4519],
          [-1.2568, -1.3169, -1.2268,  ..., -1.4669, -1.4069, -1.4519],
          ...,
          [ 0.1239,  0.1089,  0.1239,  ..., -0.7016, -0.6865, -0.6865],
          [ 0.0789,  0.0939,  0.0488,  ..., -0.6565, -0.6865, -0.6115],
          [ 0.0939,  0.1089,  0.0038,  ..., -0.7766, -0.7316, -0.6115]],

         [[-0.4848, -0.4137, -0.3853,  ..., -0.9541, -0.8545, -0.8545],
          [-0

In [40]:
image = image['pixel_values'].to('cuda')

In [41]:
input = vision_model(image)

In [45]:
output = model.generate(input.last_hidden_state, 8)

In [46]:
output.shape

torch.Size([1, 8, 128])

In [47]:
projected_output = projector(output)

In [48]:
projected_output.shape

torch.Size([1, 8, 768])

In [49]:
captions = ["two cats on a red background", "cat", "red", "remotes"]
text_input = tokenizer(captions, padding=True, truncation=True, return_tensors="pt")
text_input = {k: v.to('cuda') for k, v in text_input.items()}
text_output = text_model(**text_input)

In [51]:
text_output.text_embeds.shape

torch.Size([4, 768])

In [53]:
# compute cosine similarity for each token with each caption
# text_out: torch.Size([4, 768])
# image_out: torch.Size([1, 8, 768])

similarity = F.cosine_similarity(projected_output, text_output.text_embeds.unsqueeze(1), dim=2)


In [54]:
print(similarity)

tensor([[0.3084, 0.5018, 0.5014, 0.5012, 0.5012, 0.5011, 0.5011, 0.5011],
        [0.3872, 0.6365, 0.6353, 0.6351, 0.6350, 0.6350, 0.6349, 0.6349],
        [0.3236, 0.5169, 0.5157, 0.5154, 0.5153, 0.5152, 0.5152, 0.5151],
        [0.2977, 0.5974, 0.5964, 0.5963, 0.5962, 0.5962, 0.5961, 0.5961]],
       device='cuda:0', grad_fn=<SumBackward1>)


## Test the Model

In [27]:
model.eval()
test_loss = 0
with torch.no_grad():
    for batch in test_loader:
        input_ids = batch["input_ids"].to('cuda')
        attention_mask = batch["attention_mask"].to('cuda')
        image_embed = batch["image_embed"].to('cuda')
        
        start_token = torch.zeros(image_embed.shape[0], 1, d_model).to('cuda')
        output = model(start_token, image_embed, tgt_mask=None, memory_mask=None)
        output = projector(output[:,1:,:])
        
        target = torch.ones(output.shape[0]).to('cuda')
        loss = criterion(output.view(-1, clip_dim), image_embed.view(-1, clip_dim), target)
        test_loss += loss.item()

avg_test_loss = test_loss / len(test_loader)
print(f"Test Loss: {avg_test_loss:.4f}")

KeyError: 'input_ids'