In [1]:
import torch
import clip
from PIL import Image

In [2]:
model, preprocess = clip.load("ViT-B/16", device="cpu", download_root="/vinai/thanhlv19/workspace/clip/")

In [5]:
image = preprocess(Image.open("../optimization/jk.jpg")).unsqueeze(0).to("cpu")

In [12]:
model.visual.ln_post

LayerNorm((768,), eps=1e-05, elementwise_affine=True)

In [13]:
class HookInputOutput:
    def __init__(self):
        self.inputs = []
        self.outputs = []
        
    def __call__(self, module, module_in, module_out):
        self.inputs.append(module_in)
        self.outputs.append(module_out)
        
    def clear(self):
        self.inputs = []
        self.outputs = []

In [34]:
hook = HookInputOutput()

In [35]:
hook.clear()

In [36]:
handle = model.visual.transformer.register_forward_hook(hook)

In [37]:
out = model.visual(image)

In [41]:
hook.inputs[0][0].shape

torch.Size([197, 1, 768])

In [43]:
transformers_output = hook.outputs[0] # L, N, D

In [44]:
out.shape

torch.Size([1, 512])

In [47]:
out2 = model.visual.ln_post(transformers_output.permute(1, 0, 2)) @ model.visual.proj

In [48]:
out2.shape

torch.Size([1, 197, 512])

In [49]:
out.shape

torch.Size([1, 512])

In [87]:
torch.cosine_similarity(out, out2[:, 0, :])

RuntimeError: The size of tensor a (768) must match the size of tensor b (512) at non-singleton dimension 1

In [54]:
hook.clear()
handle.remove()

---

# For resnet

In [55]:
model, preprocess = clip.load("RN50x16", device="cpu", download_root="/vinai/thanhlv19/workspace/clip/")

In [56]:
handle = model.visual.attnpool.register_forward_hook(hook)

In [59]:
image = preprocess(Image.open("../optimization/jk.jpg")).unsqueeze(0).to("cpu")

In [60]:
out = model.visual(image)

In [63]:
hook.inputs[0][0].shape

torch.Size([1, 3072, 12, 12])

In [64]:
hook.outputs[0].shape

torch.Size([1, 768])

In [70]:
inp = hook.inputs[0][0]

In [71]:
inp.shape

torch.Size([1, 3072, 12, 12])

In [72]:
# N C H W => (HW) N C
x = inp.reshape(inp.shape[0], inp.shape[1], inp.shape[2] * inp.shape[3]).permute(2, 0, 1)

In [73]:
x.shape

torch.Size([144, 1, 3072])

In [74]:
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC

In [75]:
x.shape

torch.Size([145, 1, 3072])

In [76]:
x = x + model.visual.attnpool.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC


In [77]:
x.shape

torch.Size([145, 1, 3072])

In [78]:
import torch.nn.functional as F

In [80]:
x, _ = F.multi_head_attention_forward(
            query=x, key=x, value=x,
            embed_dim_to_check=x.shape[-1],
            num_heads=model.visual.attnpool.num_heads,
            q_proj_weight=model.visual.attnpool.q_proj.weight,
            k_proj_weight=model.visual.attnpool.k_proj.weight,
            v_proj_weight=model.visual.attnpool.v_proj.weight,
            in_proj_weight=None,
            in_proj_bias=torch.cat([model.visual.attnpool.q_proj.bias, model.visual.attnpool.k_proj.bias, model.visual.attnpool.v_proj.bias]),
            bias_k=None,
            bias_v=None,
            add_zero_attn=False,
            dropout_p=0,
            out_proj_weight=model.visual.attnpool.c_proj.weight,
            out_proj_bias=model.visual.attnpool.c_proj.bias,
            use_separate_proj_weight=True,
            training=model.visual.attnpool.training,
            need_weights=False
        )


In [81]:
x.shape

torch.Size([145, 1, 768])

In [82]:
out.shape

torch.Size([1, 768])

In [84]:
x[0].shape

torch.Size([1, 768])

In [93]:
import numpy as np

In [95]:
np.allclose(torch.cosine_similarity(x[0], out).item(), 1)

True