DeepSeek提供的原始的代码

In [None]:
import torch
from transformers import FlavaModel

def capture_flava_activations(model):
    # 存储钩子的引用
    hooks = []
    
    # 存储各层的数据
    attention_data = {}  # 保存每层的Q和K
    ffn_outputs = {}     # 保存每层FFN的输出
    ffn_weights = {}     # 保存每层FFN的线性层权重

    # 遍历模型的encoder层（以image_model为例）
    encoder = model.image_model.encoder
    for layer_idx, layer in enumerate(encoder.layers):
        # 获取当前层的SelfAttention模块
        # self_attention = layer.attention.attention
        self_attention = layer.self_attn
        
        # 注册钩子捕获Q和K
        def q_hook(module, input, output, idx=layer_idx):
            attention_data.setdefault(idx, {})['Q'] = output.detach()
        hook_q = self_attention.query.register_forward_hook(q_hook)
        hooks.append(hook_q)
        
        def k_hook(module, input, output, idx=layer_idx):
            attention_data.setdefault(idx, {})['K'] = output.detach()
        hook_k = self_attention.key.register_forward_hook(k_hook)
        hooks.append(hook_k)
        
        # 保存FFN的权重（中间层和输出层）
        intermediate_weights = layer.intermediate.dense.weight.data.clone()
        output_weights = layer.output.dense.weight.data.clone()
        ffn_weights[layer_idx] = {
            'intermediate': intermediate_weights,
            'output': output_weights
        }
        
        # 注册钩子捕获FFN的输出（FlavaOutput的输出）
        def ffn_hook(module, input, output, idx=layer_idx):
            ffn_outputs[idx] = output.detach()
        hook_ffn = layer.output.register_forward_hook(ffn_hook)
        hooks.append(hook_ffn)
    
    return hooks, attention_data, ffn_outputs, ffn_weights

# 使用示例
model = FlavaModel.from_pretrained("facebook/flava-full")  # 加载模型

# 注册钩子
hooks, attention_data, ffn_outputs, ffn_weights = capture_flava_activations(model)

# 准备输入数据（示例）
inputs = {
    "pixel_values": torch.randn(1, 3, 224, 224),  # 示例图像输入
    "input_ids": torch.randint(0, 30522, (1, 77)), # 示例文本输入
}

# 前向传播，触发钩子
outputs = model(**inputs)

# 移除钩子
for hook in hooks:
    hook.remove()

# 打印结果示例
print("Q values for layer 0:", attention_data[0]['Q'].shape)
print("FFN output for layer 0:", ffn_outputs[0].shape)
print("FFN intermediate weights shape:", ffn_weights[0]['intermediate'].shape)

CLIP-Text (Q, K, FFN)

In [None]:
# 输出结果维度（层数，1，样本数，tokens数，hidden_size）
import torch
import numpy as np
from transformers import CLIPModel, CLIPProcessor


# 0. ==========定义钩子==========
def capture_flava_activations(model):
    # 存储钩子的引用
    hooks = []

    # 遍历模型的encoder层（以image_model为例）
    encoder = model.text_model.encoder
    
    # 存储各层的数据
    q_list = [[] for _ in range(len(encoder.layers))]  # 保存每层的Q和K
    k_list = [[] for _ in range(len(encoder.layer))]  # 保存每层的Q和K
    ffn_outputs = [[] for _ in range(len(encoder.layer))]    # 保存每层FFN的输出
    ffn_weights = {}     # 保存每层FFN的线性层权重


    for layer_idx, layer in enumerate(encoder.layer):
        # 获取当前层的SelfAttention模块
        self_attention = layer.attention.attention
        
        # 注册钩子捕获Q和K
        def q_hook(module, input, output, idx=layer_idx):
            q_list[idx].append(output.detach().cpu().numpy())
        hook_q = self_attention.query.register_forward_hook(q_hook)
        hooks.append(hook_q)
        
        def k_hook(module, input, output, idx=layer_idx):
            k_list[idx] = output.detach()
        hook_k = self_attention.key.register_forward_hook(k_hook)
        hooks.append(hook_k)
        
        # 保存FFN的权重（中间层和输出层）
        intermediate_weights = layer.intermediate.dense.weight.data.clone()
        output_weights = layer.output.dense.weight.data.clone()
        ffn_weights[layer_idx] = {
            'intermediate': intermediate_weights,
            'output': output_weights
        }
        
        # 注册钩子捕获FFN的输出（FlavaOutput的输出）
        def ffn_hook(module, input, output, idx=layer_idx):
            ffn_outputs[idx] = output.detach().cpu().numpy()
        hook_ffn = layer.output.register_forward_hook(ffn_hook)
        hooks.append(hook_ffn)
    
    return hooks, q_list, k_list, ffn_outputs, ffn_weights

if __name__ == "__main__":
    model_name = "D:/dev/code/HuggingFace/pretrainedModel/clip-vit-base-patch32" # 模型名称

    # 设置设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 1. ==========加载模型和分词器==========
    model = CLIPModel.from_pretrained(model_name)  # 加载模型
    tokenizer = CLIPProcessor.from_pretrained(model_name)
    model.text_model.encoder.layers.self_attn

    model.eval()
    model.to(device)
    print(f"模型已加载至 {device}")


    # 2. ==========准备输入文本==========
    texts = [
        "一只函数的返回大幅改进企鹅瑞华企鹅舞i意见猫", 
        "一只猫和一啊但是发射点发射点只狗", 
        "as阿凡达发hpoerujhiopertfasfa", 
        "放噶撒旦发射覅殴打事件回顾i哦速度返回结果点"
    ]
    inputs = tokenizer(
        text=texts, 
        return_tensors="pt", 
        padding=True, 
        truncation=True
    ).to(device)


    # 3. ==========注册钩子==========
    hooks, q_list, k_list, ffn_outputs, ffn_weights = capture_flava_activations(model)


    # 4. ==========前向传播，触发钩子==========
    outputs = model(**inputs)


    # 5. ==========移除钩子==========
    for hook in hooks:
        hook.remove()

    # # 打印结果示例
    # print("Q values for layer 0:", attention_data[0]['Q'].shape)
    # print("FFN output for layer 0:", ffn_outputs[0].shape)
    # print("FFN intermediate weights shape:", ffn_weights[0]['intermediate'].shape)

    # # 6. ==========保存所有数据为NumPy数组==========
    # output_dir = "./flava_full_outputs"
    # import os
    # if not os.path.exists(output_dir):
    #     os.makedirs(output_dir)

    # # 保存Q、K和FFN输出
    # for i in range(len(model.text_model.encoder.layer)):
    #     # 保存Q
    #     np.save(f'{output_dir}/q_layer_{i+1}.npy', np.array(q_list[i]))

    # ffn_outputs_np = np.array(ffn_outputs)
    # print(ffn_outputs_np.shape)


    # import gc   # 清理内存
    # gc.collect()    # 清理内存
    # del q_list, k_list, ffn_outputs, ffn_weights
    # torch.cuda.empty_cache()

DeepSeek又一版

In [None]:
import torch
from collections import defaultdict
from transformers import FlavaModel, FlavaProcessor

# 存储权重和激活的字典
activations = defaultdict(dict)
weights = defaultdict(dict)

def get_hooks(layer_idx):
    """为指定层定义前向钩子"""
    
    # 捕获 SelfAttention 的 key 和 query
    def attention_hook(module, input, output):
        # output 是 (context_layer,) 或 (context_layer, attention_probs)
        with torch.no_grad():
            query = module.query(input[0])
            key = module.key(input[0])
            
            # 转置为 (batch, heads, seq_len, dim_per_head)
            query_layer = module.transpose_for_scores(query)
            key_layer = module.transpose_for_scores(key)
            
        activations[layer_idx]['query'] = query_layer.detach()
        activations[layer_idx]['key'] = key_layer.detach()

    # 捕获 FFN 中间层权重和输出
    def ffn_intermediate_hook(module, input, output):
        # 中间层的 dense 权重
        weights[layer_idx]['ffn_intermediate_weight'] = module.dense.weight.detach().clone()
        # 中间层的输出激活
        activations[layer_idx]['ffn_intermediate_output'] = output.detach()

    # 捕获 FFN 输出层权重和输出
    def ffn_output_hook(module, input, output):
        # 输出层的 dense 权重
        weights[layer_idx]['ffn_output_weight'] = module.dense.weight.detach().clone()
        # 输出层的输出激活
        activations[layer_idx]['ffn_output_output'] = output.detach()

    return attention_hook, ffn_intermediate_hook, ffn_output_hook

def register_hooks(model):
    """遍历模型并为每一层注册钩子"""
    
    # 假设处理图像编码器
    encoder = model.image_model.encoder
    
    for layer_idx, layer in enumerate(encoder.layer):
        # 注册 SelfAttention 钩子
        attention = layer.attention.attention
        attn_hook = attention.register_forward_hook(get_hooks(layer_idx)[0])
        
        # 注册 FFN 中间层和输出层钩子
        intermediate = layer.intermediate
        ffn_inter_hook = intermediate.register_forward_hook(get_hooks(layer_idx)[1])
        
        output = layer.output
        ffn_out_hook = output.register_forward_hook(get_hooks(layer_idx)[2])


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 使用示例
model_name = "C:/Users/xinlong/Desktop/code/python/flava_use/model/facebook/flava-full"

model = FlavaModel.from_pretrained(model_name)
tokenizer = FlavaProcessor.from_pretrained(model_name)

model.eval()
model.to(device)
print(f"模型已加载至 {device}")

# 注册钩子到图像编码器
register_hooks(model)

# 假设输入数据
pixel_values = torch.randn(1, 3, 224, 224).to(device)  # 示例图像输入
input_ids = torch.randint(0, 30522, (1, 77)).to(device) # 示例文本输入

# 运行前向传播
with torch.no_grad():
    outputs = model(pixel_values=pixel_values, input_ids=input_ids)

# 查看捕获的权重
for layer_idx in weights.keys():
    print(f"Layer {layer_idx} FFN Intermediate Weight Shape:", weights[layer_idx]['ffn_intermediate_weight'].shape)
    print(f"Layer {layer_idx} FFN Output Weight Shape:", weights[layer_idx]['ffn_output_weight'].shape)
print(weights[0]['ffn_intermediate_weight'])
print(weights[0]['ffn_output_weight'])
# 示例输出：
# Layer 0 FFN Intermediate Weight Shape: torch.Size([3072, 768])  # 假设 hidden_size=768, intermediate_size=3072
# Layer 0 FFN Output Weight Shape: torch.Size([768, 3072])

In [None]:
import torch
from transformers import CLIPModel, CLIPProcessor, CLIPTokenizer


# 存储文本和视觉编码器的中间结果
text_attention_outputs = []  # 各层自注意力的key和query
vision_attention_outputs = []
text_ffn_outputs = []        # 各层FFN的输出
vision_ffn_outputs = []
text_ffn_weights = []        # 各层FFN的权重
vision_ffn_weights = []
hooks = []                    # 用于保存钩子以便移除

# 定义钩子函数
def register_hooks(encoder, is_text=True):
    encoder_hooks = []
    for layer_idx, layer in enumerate(encoder.layers):
        # 注册 SelfAttention 钩子
        attn = layer.self_attn
        # 设置父模块以获取参数
        attn.k_proj.parent = attn
        attn.q_proj.parent = attn

        # 注册 k_proj的钩子
        def hook_k_proj(module, input, output, layer_idx=layer_idx, is_text=is_text):
            parent = module.parent
            bsz, seq_len, _ = output.shape
            num_heads = parent.num_heads
            head_dim = parent.head_dim
            key_states = output.view(bsz, seq_len, num_heads, head_dim).transpose(1, 2).contiguous()
            target = text_attention_outputs if is_text else vision_attention_outputs
            while len(target) <= layer_idx:
                target.append({'key_layer': [], 'query_layer': []})
            target[layer_idx]['key_layer'].append(key_states)
        
        encoder_hooks.append(attn.k_proj.register_forward_hook(hook_k_proj))

        # 注册q_proj的钩子
        def hook_q_proj(module, input, output, layer_idx=layer_idx, is_text=is_text):
            parent = module.parent
            bsz, seq_len, _ = output.shape
            num_heads = parent.num_heads
            head_dim = parent.head_dim
            query_states = output.view(bsz, seq_len, num_heads, head_dim).transpose(1, 2).contiguous()
            target = text_attention_outputs if is_text else vision_attention_outputs
            while len(target) <= layer_idx:
                target.append({'key_layer': [], 'query_layer': []})
            target[layer_idx]['query_layer'].append(query_states)
        
        encoder_hooks.append(attn.q_proj.register_forward_hook(hook_q_proj))

        # 注册 FFN 中间层和输出层钩子
        mlp = layer.mlp
        
        def hook_mlp(module, input, output, layer_idx=layer_idx, is_text=is_text):
            target_outputs = text_ffn_outputs if is_text else vision_ffn_outputs
            target_weights = text_ffn_weights if is_text else vision_ffn_weights
            target_outputs.append(output)
            target_weights.append({
                'fc1': module.fc1.weight.detach().clone(),
                'fc2': module.fc2.weight.detach().clone(),
            })
        
        encoder_hooks.append(mlp.register_forward_hook(hook_mlp))
    return encoder_hooks



if __name__ == "__main__":
    # 设置设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # 模型名称
    model_name = "C:\\Users\\xinlong\\Desktop\\code\\python\\HuggingFace\\pretrainedModel\\clip-vit-base-patch32"
    # 加载模型
    model = CLIPModel.from_pretrained(model_name)
    # 分词器
    tokenizer = CLIPTokenizer.from_pretrained(model_name)
    processor = CLIPProcessor.from_pretrained(model_name)
    # 模型加载至设备
    model.to(device)
    print(f"模型已加载至 {device}")


    # 注册文本编码器钩子
    text_encoder = model.text_model.encoder
    hooks += register_hooks(text_encoder, is_text=True)

    # 注册视觉编码器钩子
    vision_encoder = model.vision_model.encoder
    hooks += register_hooks(vision_encoder, is_text=False)


    # 准备输入数据
    # 文本输入
    text_inputs = tokenizer(
        ["a photo of a cat", "a photo of a dog", "asdaasdasggg", "的噶啥分割后i的后果偶i"], 
        padding=True, 
        return_tensors="pt"
    ).to(device)
    # 图像输入（示例）
    from PIL import Image
    import requests
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    image = Image.open(requests.get(url, stream=True).raw)
    pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(device)

    # 清空之前的存储
    text_attention_outputs.clear()
    vision_attention_outputs.clear()
    text_ffn_outputs.clear()
    vision_ffn_outputs.clear()
    text_ffn_weights.clear()
    vision_ffn_weights.clear()


    # 前向传播并捕获数据
    # outputs = model(input_ids=text_inputs.input_ids, pixel_values=pixel_values)
    outputs = model(
        input_ids=text_inputs.input_ids,
        pixel_values=pixel_values
    )

    # 移除所有钩子
    for hook in hooks:
        hook.remove()

In [None]:
# 示例：访问文本编码器第一层的key和query
text_layer0_key = text_attention_outputs[0]['key_layer'][0]
text_layer0_query = text_attention_outputs[0]['query_layer'][0]

# 示例：访问文本编码器第一层的key和query
vision_layer0_key = vision_attention_outputs[0]['key_layer'][0]
vision_layer0_query = vision_attention_outputs[0]['query_layer'][0]

# 示例：访问视觉编码器第一层的FFN输出和权重
vision_ffn_output = vision_ffn_outputs[0]
vision_ffn_fc1_weight = vision_ffn_weights[0]['fc1']

In [None]:
outputs.vision_model_output.last_hidden_state

In [None]:
vision_ffn_12_output = vision_ffn_outputs[11]
vision_ffn_12_output

使用hook获取FFN输出并实现交叉注意力机制

In [None]:
import torch
import torch.nn as nn
from transformers import CLIPModel, CLIPProcessor, CLIPTokenizer
from PIL import Image
import requests


# 全局变量，用于存储ffn输出
text_ffn_outputs = None
vision_ffn_outputs = None

# 定义钩子函数
def text_ffn_hook(module, input, output):
    global text_ffn_outputs
    text_ffn_outputs = output

def vision_ffn_hook(module, input, output):
    global vision_ffn_outputs
    vision_ffn_outputs = output

# 定义交叉注意里模块
class CrossAttention(nn.module):
    def __init__(self, query_dim, key_dim, value_dim, output_dim, num_head):
        super().__init__()
        self.num_head = num_head
        self.head_dim = output_dim // num_head
        self.scale = self.head_dim ** -0.5

        self.q_proj = nn.Linear(query_dim, output_dim)
        self.k_proj = nn.Linear(key_dim, output_dim)
        self.v_proj = nn.Linear(value_dim, output_dim)
        self.out_proj = nn.Linear(output_dim, output_dim)

    def forward(self, query, key, value):
        batch_size = query.size(0)

        q = self.q_proj(query)
        k = self.k_proj(key)
        v = self.v_proj(value)

        q = q.view(batch_size, -1, self.num_head, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, -1, self.num_head, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, -1, self.num_head, self.head_dim).transpose(1, 2)

        attn_weights = torch.matmul(q, k.transpose(-2, -1))
    


正在完善的代码

In [None]:
import torch
import torch.nn as nn
from transformers import CLIPPreTrainedModel, CLIPModel, CLIPProcessor, CLIPTokenizer, CLIPConfig, CLIPTextConfig, CLIPVisionConfig
from PIL import Image
from typing import Optional, Tuple, Union, Any
import requests



class CLIPWithProjectedCrossAttention(CLIPPreTrainedModel):
    def __init__(self, config: CLIPConfig):
        super().__init__(config)

        # 加载CLIP模型并冻结参数
        self.clip = CLIPModel(config)
        for param in self.clip.parameters():
            param.requires_grad = False
        
        # 获取编码器的维度信息
        text_hidden_size = config.text_config.hidden_size
        vision_hidden_size = config.vision_config.hidden_size
        common_dim = 768    # 统一投影维度

        # # 维度对齐投影层（使用线性层）
        # self.text_proj = nn.Linear(text_hidden_size, common_dim)
        # self.vision_proj = nn.Linear(vision_hidden_size, common_dim)

        # 增强的投影层（使用MLP）
        self.text_proj = nn.Sequential(
            nn.Linear(text_hidden_size, common_dim),
            nn.ReLU(),
            nn.Linear(common_dim, common_dim)
        )
        self.vision_proj = nn.Sequential(
            nn.Linear(vision_hidden_size, common_dim),
            nn.ReLU(),
            nn.Linear(common_dim, common_dim)
        )

        # 交叉注意力层
        self.cross_attn = nn.MultiheadAttention(
            embed_dim=common_dim,
            num_heads=12, # 根据需要调整
            dropout=0.1,
            batch_first=True,
        )

        # # 初始化投影层参数
        # nn.init.xavier_uniform_(self.text_proj.weight)
        # nn.init.xavier_uniform_(self.vision_proj.weight)
        # self.post_init()

        # 参数初始化
        nn.init.xavier_uniform_(self.text_proj[0].weight)
        nn.init.xavier_uniform_(self.text_proj[2].weight)
        nn.init.xavier_uniform_(self.vision_proj[0].weight)
        nn.init.xavier_uniform_(self.vision_proj[2].weight)
        self.post_init()

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        pixel_values: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, torch.Tensor]:
        
        # 获取编码器原始输出
        text_outputs = self.clip.text_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            return_dict=return_dict
        )

        vision_outputs = self.clip.vision_model(
            pixel_values=pixel_values,
            return_dict=return_dict
        )

        # 获取最后一层隐藏状态（未投影）
        text_features = text_outputs.last_hidden_state
        vision_features = vision_outputs.last_hidden_state

        # 维度对齐投影
        projected_text = self.text_proj(text_features)          # (batch, seq_len. common_dim)
        projected_vision = self.vision_proj(vision_features)    # (batch, seq_len, common_dim)
        print(f"projected_text.shape: {projected_text.shape}, projected_vision.shape: {projected_vision.shape}")

        # 归一化
        projected_text = projected_text / projected_text.norm(dim=-1, keepdim=True)
        projected_vision = projected_vision / projected_vision.norm(dim=-1, keepdim=True)
        print(f"projected_text.shape: {projected_text.shape}, projected_vision.shape: {projected_vision.shape}")

        # 交叉注意力计算（文本作为Query，视觉作为Key/Value）
        text_to_vision, _ = self.cross_attn(
            query=projected_text,
            key=projected_vision,
            value=projected_vision,
            key_padding_mask=(attention_mask == 0) if attention_mask is not None else None
        )
        vision_to_text, _ = self.cross_attn(
            query=projected_vision,
            key=projected_text,
            value=projected_text,
            key_padding_mask=(attention_mask == 0) if attention_mask is not None else None
        )

        return text_to_vision, vision_to_text

# 对比学习损失函数
class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
        self.cross_entroy = nn.CrossEntropyLoss
    
    def forward(self, text_emb, vision_emb):
        # 计算相似度矩阵
        logits_per_text = text_emb @ vision_emb.t() / self.temperature
        logits_per_vision = vision_emb @ text_emb.t() / self.temperature
        labels = torch.arange(logits_per_text.size(0), device=text_emb.device)
        loss_t = self.cross_entroy(logits_per_text, labels)
        loss_v = self.cross_entroy(logits_per_vision, labels)
        return (loss_t + loss_v) / 2

# 使用示例
if __name__ == "__main__":
    # 初始化配置（示例使用不同维度）
    text_config = CLIPTextConfig(hidden_size=512)
    vision_config = CLIPVisionConfig(hidden_size=768)

    text_config_dict = text_config.to_dict()
    vision_config_dict = vision_config.to_dict()
    config = CLIPConfig(
        text_config=text_config_dict, 
        vision_config=vision_config_dict
    )
    
    model = CLIPWithProjectedCrossAttention(config)
    

    # ====================训练模型====================
    # 打印可训练参数
    print(f"{'='*10}Trainable parameters:{'='*10}")
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"{'-'*10}{name}")

    # 仅选择需要训练的参数
    trainable_params = [
        {'params': model.text_proj.parameters()},
        {'params': model.vision_proj.parameters()},
        {'params': model.cross_attn.parameters()},
    ]

    # 定义优化器
    optimizer = torch.optim.Adam(trainable_params, lr=1e-4)

    # # 定义任务相关的损失函数(对比学习损失)
    # criterion = nn.CosineEmbeddingLoss()

    # 
    
    # 模拟输入
    text_inputs = torch.randint(0, 49408, (2, 77))  # 文本输入
    image_inputs = torch.randn(2, 3, 224, 224)     # 图像输入
    
    # 前向传播
    text_to_vision, vision_to_text = model(input_ids=text_inputs, pixel_values=image_inputs)
    # print("Output shape:", output.shape)  # 预期输出形状 [2, 768]

In [None]:
print(text_to_vision.shape)
# 取[CLS] token作为最终表示
cls_output = output[:, 0, :]  # [batch, common_dim]
print(cls_output.shape)

In [None]:
import numpy as np

n1 = np.random.randn(2,3)
n2 = np.random.randn(6,3)
n3 = n1 @ n2.T


In [None]:
print(n1.shape)
print(n2.shape)
print(n3.shape)

In [None]:
import torch
labels = torch.arange(4)
labels

In [None]:
import torch
import torch.nn as nn
from transformers import CLIPPreTrainedModel, CLIPModel, CLIPProcessor, CLIPTokenizer, CLIPConfig, CLIPTextConfig, CLIPVisionConfig
from PIL import Image
from typing import Optional, Tuple, Union, Any
import requests

class DecoderBlock(nn.Module):
    def __init__(self, hidden_dim, num_heads, ff_dim, dropout=0.1):
        """
        A single decoder block with cross-attention and feed forward network.
        """
        super(DecoderBlock, self).__init__()
        self.cross_attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads, dropout=dropout)
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.ff = nn.Sequential(
            nn.Linear(hidden_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, hidden_dim)
        )
        self.norm2 = nn.LayerNorm(hidden_dim)

class CLIPWithProjectedCrossAttention(CLIPPreTrainedModel):
    def __init__(self, config: CLIPConfig):
        super().__init__(config)

        # 加载CLIP模型并冻结参数
        self.clip = CLIPModel(config)
        for param in self.clip.parameters():
            param.requires_grad = False
        
        # 获取编码器的维度信息
        text_hidden_size = config.text_config.hidden_size
        vision_hidden_size = config.vision_config.hidden_size
        common_dim = 768    # 统一投影维度

        # # 维度对齐投影层（使用线性层）
        # self.text_proj = nn.Linear(text_hidden_size, common_dim)
        # self.vision_proj = nn.Linear(vision_hidden_size, common_dim)

        # 增强的投影层（使用MLP）
        self.text_proj = nn.Sequential(
            nn.Linear(text_hidden_size, common_dim),
            nn.ReLU(),
            nn.Linear(common_dim, common_dim)
        )
        self.vision_proj = nn.Sequential(
            nn.Linear(vision_hidden_size, common_dim),
            nn.ReLU(),
            nn.Linear(common_dim, common_dim)
        )

        # 交叉注意力层
        self.cross_attn = nn.MultiheadAttention(
            embed_dim=common_dim,
            num_heads=12, # 根据需要调整
            dropout=0.1,
            batch_first=True,
        )

        # # 初始化投影层参数
        # nn.init.xavier_uniform_(self.text_proj.weight)
        # nn.init.xavier_uniform_(self.vision_proj.weight)
        # self.post_init()

        # 参数初始化
        nn.init.xavier_uniform_(self.text_proj[0].weight)
        nn.init.xavier_uniform_(self.text_proj[2].weight)
        nn.init.xavier_uniform_(self.vision_proj[0].weight)
        nn.init.xavier_uniform_(self.vision_proj[2].weight)
        self.post_init()

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        pixel_values: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, torch.Tensor]:
        
        # 获取编码器原始输出
        text_outputs = self.clip.text_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            return_dict=return_dict
        )

        vision_outputs = self.clip.vision_model(
            pixel_values=pixel_values,
            return_dict=return_dict
        )

        # 获取最后一层隐藏状态（未投影）
        text_features = text_outputs.last_hidden_state
        vision_features = vision_outputs.last_hidden_state

        # 维度对齐投影
        projected_text = self.text_proj(text_features)          # (batch, seq_len. common_dim)
        projected_vision = self.vision_proj(vision_features)    # (batch, seq_len, common_dim)
        print(f"projected_text.shape: {projected_text.shape}, projected_vision.shape: {projected_vision.shape}")

        # 归一化
        projected_text = projected_text / projected_text.norm(dim=-1, keepdim=True)
        projected_vision = projected_vision / projected_vision.norm(dim=-1, keepdim=True)
        print(f"projected_text.shape: {projected_text.shape}, projected_vision.shape: {projected_vision.shape}")

        # 交叉注意力计算（文本作为Query，视觉作为Key/Value）
        text_to_vision, _ = self.cross_attn(
            query=projected_text,
            key=projected_vision,
            value=projected_vision,
            key_padding_mask=(attention_mask == 0) if attention_mask is not None else None
        )
        vision_to_text, _ = self.cross_attn(
            query=projected_vision,
            key=projected_text,
            value=projected_text,
            key_padding_mask=(attention_mask == 0) if attention_mask is not None else None
        )

        return text_to_vision, vision_to_text

# 对比学习损失函数
class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
        self.cross_entroy = nn.CrossEntropyLoss
    
    def forward(self, text_emb, vision_emb):
        # 计算相似度矩阵
        logits_per_text = text_emb @ vision_emb.t() / self.temperature
        logits_per_vision = vision_emb @ text_emb.t() / self.temperature
        labels = torch.arange(logits_per_text.size(0), device=text_emb.device)
        loss_t = self.cross_entroy(logits_per_text, labels)
        loss_v = self.cross_entroy(logits_per_vision, labels)
        return (loss_t + loss_v) / 2

# 使用示例
if __name__ == "__main__":
    # 初始化配置（示例使用不同维度）
    text_config = CLIPTextConfig(hidden_size=512)
    vision_config = CLIPVisionConfig(hidden_size=768)

    text_config_dict = text_config.to_dict()
    vision_config_dict = vision_config.to_dict()
    config = CLIPConfig(
        text_config=text_config_dict, 
        vision_config=vision_config_dict
    )
    
    model = CLIPWithProjectedCrossAttention(config)
    

    # ====================训练模型====================
    # 打印可训练参数
    print(f"{'='*10}Trainable parameters:{'='*10}")
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"{'-'*10}{name}")

    # 仅选择需要训练的参数
    trainable_params = [
        {'params': model.text_proj.parameters()},
        {'params': model.vision_proj.parameters()},
        {'params': model.cross_attn.parameters()},
    ]

    # 定义优化器
    optimizer = torch.optim.Adam(trainable_params, lr=1e-4)

    # # 定义任务相关的损失函数(对比学习损失)
    # criterion = nn.CosineEmbeddingLoss()

    # 
    
    # 模拟输入
    text_inputs = torch.randint(0, 49408, (2, 77))  # 文本输入
    image_inputs = torch.randn(2, 3, 224, 224)     # 图像输入
    
    # 前向传播
    text_to_vision, vision_to_text = model(input_ids=text_inputs, pixel_values=image_inputs)
    # print("Output shape:", output.shape)  # 预期输出形状 [2, 768]