In [1]:
import torch
import clip
from PIL import Image
import numpy as np
import torchvision

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
model_ViT_B_32, preprocess_ViT_B_32 = clip.load("ViT-B/32", device=device)

In [4]:
image = Image.open("/media/lingaoyuan/SATA/test-image.jpg")
print(np.array(image).shape)
image = preprocess_ViT_B_32(image).unsqueeze(0).to(device)
image.shape

(900, 1600, 3)


torch.Size([1, 3, 224, 224])

In [5]:
with torch.no_grad():
    image_features = model_ViT_B_32.encode_image(image)
image_features.shape

torch.Size([1, 512])

In [6]:
model_ViT_L_14, preprocess_ViT_L_14 = clip.load("ViT-L/14@336px", device=device)

In [7]:
image = Image.open("/media/lingaoyuan/SATA/test-image.jpg")
np.array(image).shape

(900, 1600, 3)

In [8]:
image = preprocess_ViT_L_14(image).unsqueeze(0).to(device)
image.shape

torch.Size([1, 3, 336, 336])

In [9]:
with torch.no_grad():
    image_features = model_ViT_L_14.encode_image(image)

In [10]:
image_features.shape

torch.Size([1, 768])

In [11]:
image_features.squeeze().shape

torch.Size([768])

In [12]:
image_features

tensor([[ 1.2705e+00,  9.8096e-01, -1.7273e-01,  9.5410e-01,  1.1683e-03,
         -5.1660e-01,  6.2549e-01, -4.0601e-01,  6.5796e-02, -6.9678e-01,
         -2.4231e-01, -2.2302e-01,  5.1575e-02,  3.3838e-01, -2.2327e-01,
          1.9739e-01,  4.9243e-01, -5.0293e-01, -1.3525e-01, -1.1462e-01,
         -4.3530e-01,  1.6223e-01,  4.1943e-01, -1.5857e-01, -6.4746e-01,
         -2.3145e-01, -1.7957e-01,  5.1074e-01, -4.3237e-01,  2.2852e-01,
          3.9612e-02,  8.3008e-02,  1.1157e-01, -1.0150e-01, -5.2002e-01,
          6.8066e-01,  8.5876e-02, -3.6206e-01, -4.0918e-01,  7.3425e-02,
          1.4636e-01,  7.0361e-01,  4.6167e-01, -1.4990e-01,  6.3477e-01,
          3.0664e-01,  1.4258e-01,  1.6626e-01,  3.4766e-01, -1.3416e-01,
         -1.6614e-01, -5.3320e-01, -3.5065e-02,  5.2686e-01, -2.7026e-01,
          5.5957e-01, -2.7734e-01,  1.0199e-01,  1.2262e-01, -7.7734e-01,
         -2.0911e-01,  2.4670e-01,  4.0588e-02,  1.2390e-01,  6.4990e-01,
          7.3242e-01, -2.5122e-01,  1.

In [13]:
import torch
import torch.nn as nn

class mapper(nn.Module):
    def __init__(self,
                 input_dim = 512,
                 hid_layers = [128, 256, 128],
                 activation='Relu',
                 ):
        super().__init__()
        self.activation_dict = {
            'relu': nn.ReLU,
            'silu': nn.SiLU,
            'softplus': nn.Softplus,
            }

        self.activation = self.activation_dict[activation.lower()]

        mapper_net = []
        mapper_net.append(nn.Linear(input_dim, hid_layers[0]))
        for i in range(len(hid_layers)-1):
            mapper_net.append(nn.Linear(hid_layers[i], hid_layers[i + 1]))
            mapper_net.append(self.activation())
        self.mapper_net = nn.Sequential(*mapper_net)

    def forward(self, clip_feature):
        if len(clip_feature.shape) == 2:
            clip_feature = clip_feature.squeeze()
        latent_code = self.mapper_net(clip_feature)

        return latent_code

In [14]:
shape_mapper = mapper(input_dim = 768)
shape_mapper.to(device)

mapper(
  (mapper_net): Sequential(
    (0): Linear(in_features=768, out_features=128, bias=True)
    (1): Linear(in_features=128, out_features=256, bias=True)
    (2): ReLU()
    (3): Linear(in_features=256, out_features=128, bias=True)
    (4): ReLU()
  )
)

In [15]:
output = shape_mapper(image_features.to(torch.float32))
print(output.shape)
output

torch.Size([128])


tensor([0.0108, 0.2289, 0.0000, 0.0000, 0.0000, 0.0000, 0.0076, 0.1661, 0.0000,
        0.1132, 0.0179, 0.0000, 0.0000, 0.0551, 0.0000, 0.0841, 0.0210, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0413, 0.0810, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.1825, 0.0000, 0.0840, 0.0000, 0.1518, 0.0383,
        0.0000, 0.0000, 0.0000, 0.0000, 0.1815, 0.0000, 0.1182, 0.0567, 0.0185,
        0.0035, 0.1285, 0.0000, 0.0000, 0.0000, 0.1876, 0.0920, 0.0000, 0.0000,
        0.0208, 0.0000, 0.2494, 0.0697, 0.0408, 0.0914, 0.0631, 0.1497, 0.0361,
        0.0490, 0.0000, 0.0419, 0.0000, 0.0698, 0.0000, 0.0825, 0.0000, 0.0176,
        0.0149, 0.0258, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.1017, 0.0100, 0.0390, 0.0529, 0.0078, 0.0566, 0.0000, 0.0000,
        0.1117, 0.0000, 0.0000, 0.0305, 0.0000, 0.0000, 0.0946, 0.1997, 0.0173,
        0.0799, 0.0000, 0.0000, 0.0677, 0.0054, 0.0310, 0.1111, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0808, 0.1024, 

In [16]:
class deformation(nn.Module):
    def __init__(self,
                 input_dim = 128,
                 hid_layers = [256, 256, 256, 256],
                 output_dim = 128,
                 activation='Relu',
                 output_activation = 'Tanh',
                 ):
        super().__init__()
        self.activation_dict = {
            'relu': nn.ReLU,
            'silu': nn.SiLU,
            'softplus': nn.Softplus,
            'tanh' : nn.Tanh,
            }

        self.activation = self.activation_dict[activation.lower()]
        self.output_activation = self.activation_dict[output_activation.lower()]

        deformation_net = []
        deformation_net.append(nn.Linear(input_dim, hid_layers[0]))
        for i in range(len(hid_layers)-1):
            deformation_net.append(nn.Linear(hid_layers[i], hid_layers[i + 1]))
            deformation_net.append(self.activation())

        deformation_net.append(nn.Linear(hid_layers[-1], output_dim))
        deformation_net.append(self.output_activation())
        self.deformation_net = nn.Sequential(*deformation_net)

    def forward(self, latent_code):
        if len(latent_code.shape) == 2:
            latent_code = latent_code.squeeze()
        output = self.deformation_net(latent_code)

        return output

In [17]:
deformation_network = deformation()
deformation_network.to(device)

deformation(
  (deformation_net): Sequential(
    (0): Linear(in_features=128, out_features=256, bias=True)
    (1): Linear(in_features=256, out_features=256, bias=True)
    (2): ReLU()
    (3): Linear(in_features=256, out_features=256, bias=True)
    (4): ReLU()
    (5): Linear(in_features=256, out_features=256, bias=True)
    (6): ReLU()
    (7): Linear(in_features=256, out_features=128, bias=True)
    (8): Tanh()
  )
)

In [18]:
output_deformation = deformation_network(output)
output_deformation

tensor([ 0.0283, -0.0365,  0.0250, -0.0368,  0.0027, -0.0009,  0.0228,  0.0508,
        -0.0129, -0.0164, -0.0502,  0.0285,  0.0417, -0.0256,  0.0231, -0.0117,
        -0.0558, -0.0499,  0.0050, -0.0275, -0.0596, -0.0211, -0.0227, -0.0257,
        -0.0586,  0.0084, -0.0027,  0.0443,  0.0100,  0.0266,  0.0461, -0.0069,
        -0.0004, -0.0035, -0.0524, -0.0364,  0.0395,  0.0332,  0.0479,  0.0414,
         0.0657, -0.0392,  0.0461,  0.0698,  0.0598, -0.0751,  0.0441,  0.0392,
         0.0080,  0.0249, -0.0393, -0.0741, -0.0428,  0.0406,  0.0383,  0.0016,
         0.0155,  0.0422, -0.0449, -0.0756,  0.0221, -0.0288,  0.0422,  0.0178,
         0.0167,  0.0595, -0.0371,  0.0108, -0.0096,  0.0701, -0.0331, -0.0076,
        -0.0233,  0.0282,  0.0208,  0.0191, -0.0227, -0.0210,  0.0696, -0.0345,
        -0.0086,  0.0337,  0.0684, -0.0017, -0.0382,  0.0026, -0.0328, -0.0279,
         0.0106, -0.0891, -0.0353,  0.0575,  0.0434, -0.0307, -0.0493, -0.0065,
         0.0345, -0.0299, -0.0284, -0.01

In [19]:

class mapper_attention(nn.Module):
    def __init__(self, dim_input = 512, dim_embed = None, input_embedding = False, num_head = 1, dropout=0.1):
        super().__init__()
        self.dim_input = dim_input
        self.num_head = num_head
        if dim_embed == None:
            self.dim_embed = self.dim_input
        else:
            self.dim_embed = dim_embed
        self.dim_head = self.dim_embed // self.num_head

        if input_embedding is True:
            self.embed_fc = nn.Linear(self.dim_input, self.dim_embed)

        self.q_fc = nn.Linear(self.dim_input, self.dim_embed, bias=False)
        self.k_fc = nn.Linear(self.dim_input, self.dim_embed, bias=False)
        self.v_fc = nn.Linear(self.dim_input, self.dim_embed, bias=False)

        self.layer_norm = nn.LayerNorm(self.dim_input)

        self.out_fc = nn.Linear(self.dim_embed, self.dim_input, bias=False)

        self.dropout = nn.Dropout(dropout)

        self.softmax = nn.Softmax(dim=2)

    def forward(self, input, mask = None):

        if len(input.shape) == 2:
            input = input.unsqueeze(0)
        residual = input
        input = self.layer_norm(input)

        '每张图片经过clip的image_encoder之后输出的feature的特征为(1,512)或(1,768), 所以这里input的shape为(batch_size, 1, 512)或(batch_size, 1, 768)'
        q = self.q_fc(input)
        q = q.view(input.shape[0], input.shape[1], self.num_head, self.dim_head)
        # q = q.permute(2, 0, 1, 3)  # (self.num_head * batch_input) x len_input x self.dim_head
        q = q.permute(0, 2, 1, 3)  # (batch_input * self.num_head) x len_input x self.dim_head

        k = self.k_fc(input)
        k = k.view(input.shape[0], input.shape[1], self.num_head, self.dim_head)
        # k = k.permute(2, 0, 1, 3)  # (self.num_head * batch_input) x len_input x self.dim_head
        k = k.permute(0, 2, 1, 3)  # (batch_input * self.num_head) x len_input x self.dim_head

        v = self.v_fc(input)
        v = v.view(input.shape[0], input.shape[1], self.num_head, self.dim_head)
        # v = v.permute(2, 0, 1, 3)  # (self.num_head * batch_input) x len_input x self.dim_head
        k = k.permute(0, 2, 1, 3)  # (batch_input * self.num_head) x len_input x self.dim_head
        
        # attn = torch.bmm(q, k.transpose(1, 2)) / np.power(self.dim_head, 0.5)
        attn = torch.matmul(q, k.transpose(-2, -1)) / np.power(self.dim_head, 0.5)
        
        if mask is not None:
            mask = mask.repeat(self.num_head, 1, 1)  # (self.num_head * batch_input) x .. x ..
            attn = attn.masked_fill(mask, -np.inf)
        attn = self.softmax(attn)
        attn = self.dropout(attn)

        # output = torch.bmm(attn, v)
        output = torch.matmul(attn, v)
        
        output = output.view(self.num_head, input.shape[0], input.shape[1], self.dim_head)
        # output = output.permute(1, 2, 0, 3)
        output = output.permute(0, 2, 1, 3)

        
        output = self.dropout(self.out_fc(output))
        output = self.layer_norm(output + residual)

        return output

In [20]:
shape_mapper_transformer = mapper_attention(dim_input = image_features.shape[1])
shape_mapper_transformer.to(device)

mapper_attention(
  (q_fc): Linear(in_features=768, out_features=768, bias=False)
  (k_fc): Linear(in_features=768, out_features=768, bias=False)
  (v_fc): Linear(in_features=768, out_features=768, bias=False)
  (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (out_fc): Linear(in_features=768, out_features=768, bias=False)
  (dropout): Dropout(p=0.1, inplace=False)
  (softmax): Softmax(dim=2)
)

In [21]:
output_transformer = shape_mapper_transformer(image_features.to(torch.float32))
print(output_transformer.shape)
print(output_transformer.squeeze().shape)
output_transformer

torch.Size([1, 1, 1, 768])
torch.Size([768])


tensor([[[[ 2.3863e+00,  1.8939e+00, -2.7535e-01,  1.4526e+00, -5.1584e-02,
           -6.8054e-01,  8.2854e-01, -5.9795e-01, -1.0896e-01, -6.4381e-01,
           -4.6264e-02,  1.0272e+00,  2.8933e-01,  1.1122e+00, -8.8840e-01,
            2.9629e-02,  1.0457e+00, -1.6589e+00,  3.7985e-01,  4.2173e-01,
           -6.7273e-01,  1.1385e-01,  8.0799e-02,  2.2171e-01, -4.4727e-01,
           -3.5655e-01, -2.0901e-01,  1.0120e+00, -6.3442e-01, -1.6664e-01,
            1.5671e+00,  7.8321e-02, -1.8155e-01, -5.1609e-01, -7.5563e-01,
            2.4758e-01,  6.5301e-01,  1.6820e-01, -4.2333e-01, -2.1068e-01,
            8.6223e-02,  9.3658e-01,  7.3140e-01, -2.4378e-01,  1.2127e+00,
            4.7243e-01, -3.0084e-01,  7.5521e-02,  1.0269e+00,  6.0849e-01,
           -2.4779e-01, -2.0495e-01, -2.5628e-01,  6.2257e-01, -3.5658e-01,
            1.5190e+00, -2.2870e-01, -4.5595e-01,  1.3303e-01, -5.9026e-01,
           -3.2565e-01,  4.3671e-01, -1.6875e-01,  7.4379e-01,  1.3026e+00,
            

In [22]:
image_features.shape[1]

768

In [23]:
image_features

tensor([[ 1.2705e+00,  9.8096e-01, -1.7273e-01,  9.5410e-01,  1.1683e-03,
         -5.1660e-01,  6.2549e-01, -4.0601e-01,  6.5796e-02, -6.9678e-01,
         -2.4231e-01, -2.2302e-01,  5.1575e-02,  3.3838e-01, -2.2327e-01,
          1.9739e-01,  4.9243e-01, -5.0293e-01, -1.3525e-01, -1.1462e-01,
         -4.3530e-01,  1.6223e-01,  4.1943e-01, -1.5857e-01, -6.4746e-01,
         -2.3145e-01, -1.7957e-01,  5.1074e-01, -4.3237e-01,  2.2852e-01,
          3.9612e-02,  8.3008e-02,  1.1157e-01, -1.0150e-01, -5.2002e-01,
          6.8066e-01,  8.5876e-02, -3.6206e-01, -4.0918e-01,  7.3425e-02,
          1.4636e-01,  7.0361e-01,  4.6167e-01, -1.4990e-01,  6.3477e-01,
          3.0664e-01,  1.4258e-01,  1.6626e-01,  3.4766e-01, -1.3416e-01,
         -1.6614e-01, -5.3320e-01, -3.5065e-02,  5.2686e-01, -2.7026e-01,
          5.5957e-01, -2.7734e-01,  1.0199e-01,  1.2262e-01, -7.7734e-01,
         -2.0911e-01,  2.4670e-01,  4.0588e-02,  1.2390e-01,  6.4990e-01,
          7.3242e-01, -2.5122e-01,  1.

In [24]:
image_features.unsqueeze(0).shape

torch.Size([1, 1, 768])

In [25]:
multihead_attention_test = nn.MultiheadAttention(768, 1)
multihead_attention_test.to(device)

MultiheadAttention(
  (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)

In [26]:
x = image_features.to(torch.float32).unsqueeze(0)
output_test = multihead_attention_test(x,x,x)
output_test[0].shape

torch.Size([1, 1, 768])

In [27]:
output_test

(tensor([[[ 1.7690e-01,  6.9132e-02, -2.9283e-01,  3.3537e-02,  2.0485e-01,
           -3.6781e-01,  3.0815e-01, -4.1029e-01,  2.7129e-02, -1.0344e-01,
           -3.2819e-01,  3.8913e-01,  1.6188e-01, -1.8534e-01, -4.2892e-01,
           -3.0699e-02,  5.5403e-01, -2.9108e-01,  1.2456e-01, -6.8365e-02,
           -1.0143e+00, -3.4592e-01,  1.6917e-01, -1.6696e-01,  6.8395e-02,
           -7.6003e-02,  2.9042e-02,  2.7340e-01,  3.1190e-01,  5.4855e-01,
           -2.0090e-01,  7.8710e-02, -6.2051e-01,  2.6042e-01, -5.0313e-01,
           -1.9233e-01, -4.4314e-01, -1.0226e-01, -1.8998e-03, -2.5906e-01,
           -1.0640e-01, -1.2601e-02,  2.0295e-01,  2.8617e-01,  4.4771e-02,
           -1.7724e-01,  9.0346e-03, -3.1555e-02, -1.5884e-01, -1.1444e-01,
           -1.6712e-01,  1.5985e-01,  2.6544e-03,  1.3330e-01,  3.6648e-01,
            3.3451e-02, -2.9068e-01,  7.4172e-02,  3.0254e-01, -1.0031e-01,
            7.8646e-03,  2.9873e-01,  1.4951e-01,  3.1604e-01,  8.4016e-02,
           -

In [2]:
import torch

In [3]:
k = torch.ones((1,4,90,160,3))
q = torch.ones((1,4,90,160))

In [6]:
k.shape[1]

4

In [4]:
k[:,0,:,:,:].shape

torch.Size([1, 90, 160, 3])

In [7]:
k.transpose(-2, -1).shape

torch.Size([1, 1536, 128, 2])

In [8]:
torch.matmul(q, k.transpose(-2, -1))

RuntimeError: The size of tensor a (2) must match the size of tensor b (1536) at non-singleton dimension 1

In [9]:
k = torch.ones((1536, 128, 2))
q = torch.ones((2,1536,128))

In [10]:
torch.matmul(q, k)

RuntimeError: The size of tensor a (2) must match the size of tensor b (1536) at non-singleton dimension 0