In [40]:
import os
import random
import json

import sys 
sys.path.append(".")
sys.path.append("..")

from PIL import Image
from tqdm import tqdm

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F

from transformers import CLIPProcessor, CLIPTextModel

%load_ext autoreload 
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [192]:
model_name = "openai/clip-vit-large-patch14"  
processor = CLIPProcessor.from_pretrained(model_name)
text_encoder = CLIPTextModel.from_pretrained(model_name)



In [185]:
for param in text_encoder.parameters():
    param.requires_grad = True

for block in text_encoder.text_model.encoder.layers:
    block.self_attn.q_proj.requires_grad_(True)
    block.self_attn.k_proj.requires_grad_(True)

In [193]:
prompts = [' rides on a bicycle',
 ' rides on a bike',
 ' rides on a motorcycle',
 ' rides on a horse']

refer_subject = 'a man'
infer_subject = 'a cat'

refer_prompt = [refer_subject + prompt for prompt in prompts]
infer_prompt = [infer_subject + prompt for prompt in prompts]

In [187]:
def get_prediction(prompt, token_indice):
    inputs = processor(
        prompt,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )

    embeds = text_encoder(
        inputs.input_ids, output_attentions=True
    )

    attn = embeds.attentions
    last_layer_attn = torch.stack(attn)[-1].mean(dim=1)[:, token_indice, :token_indice+1]
    
    return last_layer_attn

def compute_loss(value1, value2):
    loss_fn = torch.nn.MSELoss()
    return loss_fn(value1, value2)

In [188]:
gt_attn = get_prediction(refer_prompt, 8)

In [189]:
optimizer = torch.optim.AdamW(text_encoder.text_model.embeddings.parameters(), lr=1e-5)

num_epochs = 5  # 设置更新次数
for epoch in range(num_epochs):
    # 获取两个 prompt 的 attention values    
    pre_attn = get_prediction(infer_prompt, 8)

    # 计算损失
    loss = compute_loss(pre_attn, gt_attn)

    # 反向传播并更新嵌入层
    optimizer.zero_grad()
    loss.backward(retain_graph=True)
    optimizer.step()

    # 打印每次更新后的损失
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}")

Epoch 1/5, Loss: 0.0034488101955503225
Epoch 2/5, Loss: 0.0034126455429941416
Epoch 3/5, Loss: 0.003376640845090151
Epoch 4/5, Loss: 0.003340776078402996
Epoch 5/5, Loss: 0.003305052174255252


In [194]:
inputs = processor(
    infer_prompt,
    padding="max_length",
    truncation=True,
    return_tensors="pt",
)

embeds = text_encoder(
    inputs.input_ids, output_attentions=True
)

In [191]:
torch.save(embeds.last_hidden_state, 'new_embeds_all.pt')

In [109]:
torch.stack(embeds.attentions).mean(dim=0).mean(dim=1)[:,8,:9]

tensor([[0.6926, 0.0215, 0.0319, 0.0274, 0.0333, 0.0320, 0.0206, 0.0767, 0.0640]],
       grad_fn=<SliceBackward0>)

In [112]:
torch.stack(embeds.attentions).mean(dim=0).mean(dim=1)[:,8,:9]

tensor([[0.6700, 0.0226, 0.0474, 0.0256, 0.0325, 0.0308, 0.0217, 0.0872, 0.0623]],
       grad_fn=<SliceBackward0>)

In [195]:
embeds.last_hidden_state[:, :6, :7]

tensor([[[-0.3884,  0.0229, -0.0522, -0.1841, -0.0273, -0.3355, -0.0176],
         [ 0.0290, -1.3258,  0.3085, -0.0615,  0.0398, -0.7107, -0.9693],
         [ 0.1177,  0.9112,  0.6360, -0.3315,  0.8024,  0.2287, -0.9680],
         [-1.3853, -0.8009,  0.0932,  0.7572, -1.0690, -0.0540, -1.2084],
         [-1.7795, -0.1171,  0.4707,  1.1404, -0.4600, -0.3640, -1.0923],
         [-1.8796,  0.2238, -0.0867,  0.4162,  0.3026,  0.5182, -1.9219]],

        [[-0.3884,  0.0229, -0.0522, -0.1841, -0.0273, -0.3355, -0.0176],
         [ 0.0290, -1.3258,  0.3085, -0.0615,  0.0398, -0.7107, -0.9693],
         [ 0.1177,  0.9112,  0.6360, -0.3315,  0.8024,  0.2287, -0.9680],
         [-1.3853, -0.8009,  0.0932,  0.7572, -1.0690, -0.0540, -1.2084],
         [-1.7795, -0.1171,  0.4707,  1.1404, -0.4600, -0.3640, -1.0923],
         [-1.8796,  0.2238, -0.0867,  0.4162,  0.3026,  0.5182, -1.9219]],

        [[-0.3884,  0.0229, -0.0522, -0.1841, -0.0273, -0.3355, -0.0176],
         [ 0.0290, -1.3258,  0.308