In [1]:
from transformers import BertModel

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
import torch
from torch import nn
from transformers import BertModel

class MultiModalTransformer(nn.Module):
    def __init__(self, modality_dim, hidden_dim, num_classes):
        super(MultiModalTransformer, self).__init__()
        self.modality_dim = modality_dim
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes
        
        # 文本编码器，使用BERT模型，其输出维度为(hidden_dim, sequence_length)
        self.text_encoder = BertModel.from_pretrained('bert-base-uncased')
        self.image_projection = nn.Linear(modality_dim, hidden_dim)
        
        # 位置编码，维度为(max_position_embeddings, hidden_dim)
        self.position_embedding = nn.Embedding(512, hidden_dim)
        
        # Transformer层，输入维度为(sequence_length, batch_size, hidden_dim)
        self.transformer_layers = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=8,
            dim_feedforward=hidden_dim * 4,
            dropout=0.1,
        )
        
        # 输出层，将Transformer的输出映射到分类数量
        self.output_layer = nn.Linear(hidden_dim, num_classes)
        
    def forward(self, text_inputs, image_features):
        # 文本输入通过BERT模型，输出维度为(batch_size, sequence_length, hidden_dim)
        outputs = self.text_encoder(text_inputs, return_dict=True) # 
        sequence_output = outputs.last_hidden_state  # .shape=(2,128,768)
        
        # 图像特征通过线性层投影到与文本相同的维度，输出维度为(batch_size, hidden_dim)
        image_features = self.image_projection(image_features) # .shape=(2,768)
        
        # 合并文本和图像特征，假设图像特征被复制以匹配文本序列的长度
        # 输出维度为(batch_size, sequence_length + 1, hidden_dim)  经过下一行的代码， cat中的2个元素的维度，分别都为： .shape =(2, 128, 768)
        combined_features = torch.cat((sequence_output, image_features.unsqueeze(1).repeat(1, sequence_output.size(1), 1)), dim=1)  # 在长度上做拼接 .shape=(2, 256, 768)
        
        # 添加位置编码，位置编码的维度为(1, sequence_length + 1, hidden_dim)
        position_ids = torch.arange(combined_features.size(1), device=combined_features.device)  # .shape=(256,)
        position_embeddings = self.position_embedding(position_ids) # .shape=(256,786)
        combined_features += position_embeddings                    # .shape=(2, 256, 768)
        
        # 通过Transformer层，输入维度为(sequence_length + 1, batch_size, hidden_dim)
        transformer_output = self.transformer_layers(combined_features.transpose(0, 1)).transpose(0, 1) # .shape=(2, 256, 768)
        
        # 特征融合，取Transformer输出的第一个token（分类token）作为序列的表示
        # 输出维度为(batch_size, hidden_dim)
        pooled_output = transformer_output[:, 0, :]  # .shape=(2,768)
        
        # 通过输出层得到最终分类结果，输出维度为(batch_size, num_classes)
        logits = self.output_layer(pooled_output)
        
        return logits

# 模型初始化
modality_dim = 512  # 图像特征维度
hidden_dim = 768   # 隐藏层维度，与BERT的维度相同
num_classes = 8    # 分类类别数
model = MultiModalTransformer(modality_dim, hidden_dim, num_classes)

# 假设文本输入和图像特征
text_inputs = torch.randint(0, 100, (2, 128))  # 假设的文本输入，维度为(batch_size, sequence_length)   .shape=(2,128)
image_features = torch.rand(2, modality_dim)  # 假设的图像特征，维度为(batch_size, modality_dim)       .shape=(2,512)

# 前向传播
outputs = model(text_inputs, image_features)
print(outputs)  # 输出维度为(batch_size, num_classes)

tensor([[-0.5067, -0.7265,  0.2073, -0.2031, -0.7495,  0.6879, -0.4810,  0.3496],
        [-0.4326, -0.6922,  0.2237, -0.2000, -0.6833,  0.6560, -0.4789,  0.1451]],
       grad_fn=<AddmmBackward0>)
