In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet50
from transformers import BertModel, BertTokenizer
from torchvision import transforms
from PIL import Image
import random
import json
import numpy as np


In [2]:
class TextEncoder(nn.Module):
    def __init__(self, projection_dim=512):
        super(TextEncoder, self).__init__()

        # Load pre-trained BERT model
        self.bert = BertModel.from_pretrained('bert-base-uncased')

        # Freeze all BERT layers (optional: unfreeze some for fine-tuning)
        for param in self.bert.parameters():
            param.requires_grad = False

        # Projection head to align with ImgEncoder_CNN
        self.projection_head = nn.Sequential(
            nn.Linear(768, projection_dim),  # BERT output dim (768) → projection_dim (512)
            nn.LayerNorm(projection_dim)
        )

    def forward(self, input_ids, attention_mask):
        # Get BERT outputs
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        hidden_states = outputs.last_hidden_state  # Shape: (batch_size, seq_length, 768)

        # Mask padding tokens for proper averaging
        mask = attention_mask.unsqueeze(-1).expand_as(hidden_states)  # Shape: (batch_size, seq_length, 768)
        masked_hidden_states = hidden_states * mask  # Zero out padding tokens
        token_counts = mask.sum(dim=1, keepdim=True)  # Count non-padding tokens
        avg_embedding = masked_hidden_states.sum(dim=1) / (token_counts + 1e-6)  # Avoid division by zero

        # Project embeddings to match image encoder output size
        z = self.projection_head(avg_embedding)  # Shape: (batch_size, 512)

        # Normalize the output embeddings
        return nn.functional.normalize(z, dim=1)

In [3]:
from torchinfo import summary


text_encoder = TextEncoder()

batch_size = 32
seq_length = 64

sample_cap = torch.randint(0, 30522, (batch_size, seq_length)) 
sample_mask = torch.ones(batch_size, seq_length) 

# Use torchinfo to summarize the model
summary(text_encoder, input_data=(sample_cap, sample_mask))

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Layer (type:depth-idx)                                       Output Shape              Param #
TextEncoder                                                  [32, 32, 512]             --
├─BertModel: 1-1                                             [32, 768]                 --
│    └─BertEmbeddings: 2-1                                   [32, 64, 768]             --
│    │    └─Embedding: 3-1                                   [32, 64, 768]             (23,440,896)
│    │    └─Embedding: 3-2                                   [32, 64, 768]             (1,536)
│    │    └─Embedding: 3-3                                   [1, 64, 768]              (393,216)
│    │    └─LayerNorm: 3-4                                   [32, 64, 768]             (1,536)
│    │    └─Dropout: 3-5                                     [32, 64, 768]             --
│    └─BertEncoder: 2-2                                      [32, 64, 768]             --
│    │    └─ModuleList: 3-6                                  --     