In [2]:
import torch
import clip
from PIL import Image
import numpy as np
import torchvision

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
model_ViT_B_32, preprocess_ViT_B_32 = clip.load("ViT-B/32", device=device)

In [4]:
image = Image.open("test_car.png")
print(np.array(image).shape)
image = preprocess_ViT_B_32(image).unsqueeze(0).to(device)
image.shape

(900, 1600, 3)


torch.Size([1, 3, 224, 224])

In [5]:
with torch.no_grad():
    image_features = model_ViT_B_32.encode_image(image)
image_features.shape

torch.Size([1, 512])

In [6]:
model_ViT_L_14, preprocess_ViT_L_14 = clip.load("ViT-L/14@336px", device=device)

In [7]:
image = Image.open("test_car.png")
np.array(image).shape

(900, 1600, 3)

In [8]:
image = preprocess_ViT_L_14(image).unsqueeze(0).to(device)
image.shape

torch.Size([1, 3, 336, 336])

In [9]:
with torch.no_grad():
    image_features = model_ViT_L_14.encode_image(image)

In [10]:
image_features.shape

torch.Size([1, 768])

In [11]:
image_features.squeeze().shape

torch.Size([768])

In [12]:
image_features

tensor([[ 3.2666e-01,  9.3115e-01,  1.8420e-01,  6.3281e-01, -1.4626e-02,
          6.6260e-01, -1.6492e-01, -2.3376e-01,  5.4688e-01, -3.8428e-01,
         -3.5498e-01, -1.2783e+00,  1.3281e-01,  1.9934e-01,  2.9053e-01,
         -8.6121e-02,  1.2524e-01, -4.2358e-01,  1.9934e-01, -4.0771e-01,
          1.6309e-01, -8.0200e-02, -8.9893e-01, -3.1177e-01, -5.9131e-01,
         -1.9482e-01,  1.6553e-01, -5.3516e-01, -8.1482e-02, -2.3340e-01,
          2.0190e-01,  1.1124e-02, -1.1581e-02, -3.4595e-01,  9.6875e-01,
          5.7227e-01,  3.0762e-01,  6.3354e-02,  7.9651e-02, -5.5859e-01,
          9.1846e-01, -2.6904e-01, -5.1709e-01, -2.9419e-01,  2.8979e-01,
          1.0742e+00, -4.2773e-01,  4.4800e-01, -2.0593e-01, -3.2178e-01,
          5.6982e-01,  3.4088e-02,  3.2690e-01,  3.5425e-01, -1.8738e-01,
          2.2446e-02,  6.2207e-01, -5.7129e-01, -6.0974e-02, -7.4658e-01,
         -6.6553e-01, -4.3262e-01, -6.6992e-01,  8.6609e-02,  1.5063e-01,
          1.7542e-01, -1.4679e-02,  6.

In [13]:
import torch
import torch.nn as nn

class mapper(nn.Module):
    def __init__(self,
                 input_dim = 512,
                 hid_layers = [128, 256, 128],
                 activation='Relu',
                 ):
        super().__init__()
        self.activation_dict = {
            'relu': nn.ReLU,
            'silu': nn.SiLU,
            'softplus': nn.Softplus,
            }

        self.activation = self.activation_dict[activation.lower()]

        mapper_net = []
        mapper_net.append(nn.Linear(input_dim, hid_layers[0]))
        for i in range(len(hid_layers)-1):
            mapper_net.append(nn.Linear(hid_layers[i], hid_layers[i + 1]))
            mapper_net.append(self.activation())
        self.mapper_net = nn.Sequential(*mapper_net)

    def forward(self, clip_feature):
        if len(clip_feature.shape) == 2:
            clip_feature = clip_feature.squeeze()
        latent_code = self.mapper_net(clip_feature)

        return latent_code

In [14]:
shape_mapper = mapper(input_dim = 768)
shape_mapper.to(device)

mapper(
  (mapper_net): Sequential(
    (0): Linear(in_features=768, out_features=128, bias=True)
    (1): Linear(in_features=128, out_features=256, bias=True)
    (2): ReLU()
    (3): Linear(in_features=256, out_features=128, bias=True)
    (4): ReLU()
  )
)

In [15]:
output = shape_mapper(image_features.to(torch.float32))
print(output.shape)
output

torch.Size([128])


tensor([0.0000, 0.0311, 0.0000, 0.0322, 0.0567, 0.0512, 0.0000, 0.0579, 0.1202,
        0.0689, 0.0000, 0.0329, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0348, 0.0000, 0.1086, 0.0516, 0.0282, 0.1362, 0.0000, 0.0000, 0.1037,
        0.1329, 0.1313, 0.0000, 0.0000, 0.1590, 0.0000, 0.0000, 0.0588, 0.0000,
        0.0000, 0.1345, 0.0124, 0.0000, 0.0000, 0.0000, 0.0749, 0.0493, 0.0000,
        0.0000, 0.0944, 0.0000, 0.0231, 0.0194, 0.2232, 0.0559, 0.0487, 0.1052,
        0.0330, 0.1222, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0290, 0.0000,
        0.0867, 0.0000, 0.0000, 0.0649, 0.0000, 0.0000, 0.0000, 0.1980, 0.0414,
        0.0954, 0.1463, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0673, 0.0000, 0.0000, 0.1515, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.1243, 0.0135, 0.0539, 0.0501, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0559, 0.0128, 0.0611, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0900, 0.0828, 0.0000, 

In [16]:
class deformation(nn.Module):
    def __init__(self,
                 input_dim = 128,
                 hid_layers = [256, 256, 256, 256],
                 output_dim = 128,
                 activation='Relu',
                 output_activation = 'Tanh',
                 ):
        super().__init__()
        self.activation_dict = {
            'relu': nn.ReLU,
            'silu': nn.SiLU,
            'softplus': nn.Softplus,
            'tanh' : nn.Tanh,
            }

        self.activation = self.activation_dict[activation.lower()]
        self.output_activation = self.activation_dict[output_activation.lower()]

        deformation_net = []
        deformation_net.append(nn.Linear(input_dim, hid_layers[0]))
        for i in range(len(hid_layers)-1):
            deformation_net.append(nn.Linear(hid_layers[i], hid_layers[i + 1]))
            deformation_net.append(self.activation())

        deformation_net.append(nn.Linear(hid_layers[-1], output_dim))
        deformation_net.append(self.output_activation())
        self.deformation_net = nn.Sequential(*deformation_net)

    def forward(self, latent_code):
        if len(latent_code.shape) == 2:
            latent_code = latent_code.squeeze()
        output = self.deformation_net(latent_code)

        return output

In [17]:
deformation_network = deformation()
deformation_network.to(device)

deformation(
  (deformation_net): Sequential(
    (0): Linear(in_features=128, out_features=256, bias=True)
    (1): Linear(in_features=256, out_features=256, bias=True)
    (2): ReLU()
    (3): Linear(in_features=256, out_features=256, bias=True)
    (4): ReLU()
    (5): Linear(in_features=256, out_features=256, bias=True)
    (6): ReLU()
    (7): Linear(in_features=256, out_features=128, bias=True)
    (8): Tanh()
  )
)

In [18]:
output_deformation = deformation_network(output)
output_deformation

tensor([ 4.6048e-02, -1.4516e-02, -3.8180e-02,  6.6768e-02, -3.9768e-02,
         7.1250e-02, -7.7884e-02, -6.6190e-02, -4.9468e-02, -2.3871e-02,
         3.6018e-04,  4.9383e-02,  1.6978e-02, -1.0703e-02,  3.1735e-02,
        -3.6205e-02,  9.4390e-03,  3.4777e-02, -4.5799e-02,  8.6324e-03,
        -6.9536e-03,  5.2328e-02,  1.0397e-02, -1.4017e-02, -3.9382e-02,
        -5.1413e-02, -8.3394e-04,  2.3591e-02,  6.0519e-02, -5.6688e-02,
        -4.8738e-02, -9.2779e-03, -4.4122e-02, -2.6135e-02,  1.9319e-03,
         4.5568e-02,  4.6879e-02,  7.4487e-02, -4.6337e-02,  2.9474e-02,
         5.0590e-02, -1.3055e-02,  6.8743e-04,  6.9240e-02,  4.4935e-02,
         4.8148e-02, -1.4407e-02, -1.3300e-02,  1.0922e-02,  3.6657e-02,
        -6.4424e-02,  1.3925e-03, -6.1156e-02,  1.1899e-02, -3.6294e-02,
         7.0185e-02, -5.7816e-02,  2.2234e-02,  5.1034e-02, -5.7347e-02,
         1.0293e-02, -1.2130e-02,  5.0040e-03,  2.4692e-02,  4.6353e-02,
         8.3159e-04,  2.3148e-02,  1.3143e-02,  4.4

In [49]:

class mapper_attention(nn.Module):
    def __init__(self, dim_input = 512, dim_embed = None, input_embedding = False, num_head = 1, dropout=0.1):
        super().__init__()
        self.dim_input = dim_input
        self.num_head = num_head
        if dim_embed == None:
            self.dim_embed = self.dim_input
        else:
            self.dim_embed = dim_embed
        self.dim_head = self.dim_embed // self.num_head

        if input_embedding is True:
            self.embed_fc = nn.Linear(self.dim_input, self.dim_embed)

        self.q_fc = nn.Linear(self.dim_input, self.dim_embed, bias=False)
        self.k_fc = nn.Linear(self.dim_input, self.dim_embed, bias=False)
        self.v_fc = nn.Linear(self.dim_input, self.dim_embed, bias=False)

        self.layer_norm = nn.LayerNorm(self.dim_input)

        self.out_fc = nn.Linear(self.dim_embed, self.dim_input, bias=False)

        self.dropout = nn.Dropout(dropout)

        self.softmax = nn.Softmax(dim=2)

    def forward(self, input, mask = None):

        if len(input.shape) == 2:
            input = input.unsqueeze(0)
        residual = input
        input = self.layer_norm(input)

        '每张图片经过clip的image_encoder之后输出的feature的特征为(1,512)或(1,768), 所以这里input的shape为(batch_size, 1, 512)或(batch_size, 1, 768)'
        q = self.q_fc(input)
        q = q.view(input.shape[0], input.shape[1], self.num_head, self.dim_head)
        # q = q.permute(2, 0, 1, 3)  # (self.num_head * batch_input) x len_input x self.dim_head
        q = q.permute(0, 2, 1, 3)  # (batch_input * self.num_head) x len_input x self.dim_head

        k = self.k_fc(input)
        k = k.view(input.shape[0], input.shape[1], self.num_head, self.dim_head)
        # k = k.permute(2, 0, 1, 3)  # (self.num_head * batch_input) x len_input x self.dim_head
        k = k.permute(0, 2, 1, 3)  # (batch_input * self.num_head) x len_input x self.dim_head

        v = self.v_fc(input)
        v = v.view(input.shape[0], input.shape[1], self.num_head, self.dim_head)
        # v = v.permute(2, 0, 1, 3)  # (self.num_head * batch_input) x len_input x self.dim_head
        k = k.permute(0, 2, 1, 3)  # (batch_input * self.num_head) x len_input x self.dim_head
        
        # attn = torch.bmm(q, k.transpose(1, 2)) / np.power(self.dim_head, 0.5)
        attn = torch.matmul(q, k.transpose(-2, -1)) / np.power(self.dim_head, 0.5)
        
        if mask is not None:
            mask = mask.repeat(self.num_head, 1, 1)  # (self.num_head * batch_input) x .. x ..
            attn = attn.masked_fill(mask, -np.inf)
        attn = self.softmax(attn)
        attn = self.dropout(attn)

        # output = torch.bmm(attn, v)
        output = torch.matmul(attn, v)
        
        output = output.view(self.num_head, input.shape[0], input.shape[1], self.dim_head)
        # output = output.permute(1, 2, 0, 3)
        output = output.permute(0, 2, 1, 3)

        
        output = self.dropout(self.out_fc(output))
        output = self.layer_norm(output + residual)

        return output

In [50]:
shape_mapper_transformer = mapper_attention(dim_input = image_features.shape[1])
shape_mapper_transformer.to(device)

mapper_attention(
  (q_fc): Linear(in_features=768, out_features=768, bias=False)
  (k_fc): Linear(in_features=768, out_features=768, bias=False)
  (v_fc): Linear(in_features=768, out_features=768, bias=False)
  (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (out_fc): Linear(in_features=768, out_features=768, bias=False)
  (dropout): Dropout(p=0.1, inplace=False)
  (softmax): Softmax(dim=2)
)

In [52]:
output_transformer = shape_mapper_transformer(image_features.to(torch.float32))
print(output_transformer.shape)
print(output_transformer.squeeze().shape)
output_transformer

torch.Size([1, 1, 1, 768])
torch.Size([768])


tensor([[[[ 3.3413e-01,  8.8698e-01,  7.7593e-01,  1.0596e+00, -1.5391e+00,
            2.4874e-02,  4.3457e-01,  6.2107e-02,  1.0272e+00,  1.0651e-01,
           -8.1953e-01, -1.4415e+00,  1.2380e+00,  3.8789e-01,  2.4002e-01,
            5.7627e-01,  1.7972e-01, -5.0310e-01, -7.8736e-01, -4.9860e-01,
           -4.6066e-01, -7.5757e-01, -9.0517e-01, -3.6410e-01, -1.3448e+00,
           -2.5748e-01,  2.2928e-01, -1.4861e+00,  3.2395e-01, -2.6667e-01,
           -8.8642e-02,  3.9246e-01,  7.2979e-01, -8.4601e-01,  1.6857e+00,
            1.7084e+00,  4.6110e-01, -6.6204e-01, -4.4706e-01,  2.9738e-01,
            1.2236e+00, -6.1285e-01, -1.0494e+00,  5.5557e-02,  1.0206e+00,
            1.6448e+00, -1.2919e+00,  7.6506e-01, -3.9065e-01, -1.0837e+00,
            3.3674e-01, -5.8128e-01,  6.9055e-01,  8.2792e-01, -2.0414e-02,
           -6.8276e-01,  9.2447e-01, -3.1239e-01, -3.2025e-02, -1.0134e+00,
           -2.6134e-01, -5.1886e-02, -8.0935e-01,  8.2384e-01, -7.0373e-01,
            

In [30]:
image_features.shape[1]

768

In [34]:
image_features

tensor([[ 3.2666e-01,  9.3115e-01,  1.8420e-01,  6.3281e-01, -1.4626e-02,
          6.6260e-01, -1.6492e-01, -2.3376e-01,  5.4688e-01, -3.8428e-01,
         -3.5498e-01, -1.2783e+00,  1.3281e-01,  1.9934e-01,  2.9053e-01,
         -8.6121e-02,  1.2524e-01, -4.2358e-01,  1.9934e-01, -4.0771e-01,
          1.6309e-01, -8.0200e-02, -8.9893e-01, -3.1177e-01, -5.9131e-01,
         -1.9482e-01,  1.6553e-01, -5.3516e-01, -8.1482e-02, -2.3340e-01,
          2.0190e-01,  1.1124e-02, -1.1581e-02, -3.4595e-01,  9.6875e-01,
          5.7227e-01,  3.0762e-01,  6.3354e-02,  7.9651e-02, -5.5859e-01,
          9.1846e-01, -2.6904e-01, -5.1709e-01, -2.9419e-01,  2.8979e-01,
          1.0742e+00, -4.2773e-01,  4.4800e-01, -2.0593e-01, -3.2178e-01,
          5.6982e-01,  3.4088e-02,  3.2690e-01,  3.5425e-01, -1.8738e-01,
          2.2446e-02,  6.2207e-01, -5.7129e-01, -6.0974e-02, -7.4658e-01,
         -6.6553e-01, -4.3262e-01, -6.6992e-01,  8.6609e-02,  1.5063e-01,
          1.7542e-01, -1.4679e-02,  6.

In [27]:
image_features.unsqueeze(0).shape

torch.Size([1, 1, 768])

In [56]:
multihead_attention_test = nn.MultiheadAttention(768, 1)
multihead_attention_test.to(device)

MultiheadAttention(
  (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)

In [61]:
x = image_features.to(torch.float32).unsqueeze(0)
output_test = multihead_attention_test(x,x,x)
output_test[0].shape

torch.Size([1, 1, 768])

In [62]:
output_test

(tensor([[[ 5.7019e-01,  1.2763e-01,  3.3382e-02, -1.8587e-01,  6.6203e-02,
           -1.2252e-01,  3.9429e-03, -1.8167e-02,  5.2320e-01, -3.1655e-01,
           -6.1578e-01,  2.1887e-01, -1.4394e-03, -5.7475e-02,  6.2321e-02,
           -2.5742e-02,  1.0146e-01,  1.3037e-01,  1.8281e-01,  3.3499e-01,
            2.1980e-03, -1.7023e-01,  1.1006e-01, -6.9741e-02, -6.2091e-02,
            4.9941e-02,  1.0687e-01, -1.9786e-02,  2.5288e-02,  3.8842e-01,
            2.3025e-01, -3.5966e-01, -2.2405e-01, -1.9422e-01,  4.8852e-02,
            5.8021e-01, -1.4068e-02, -1.0356e-01,  4.6665e-01,  2.4847e-02,
            2.6535e-01, -7.1577e-01, -1.3875e-02,  2.1711e-01,  9.2100e-02,
            2.4983e-01, -2.1090e-01,  1.5366e-01,  3.7736e-01,  4.2582e-01,
            1.5394e-01,  2.1856e-02,  9.6780e-02,  2.4154e-01, -9.4067e-02,
            1.1988e-01,  2.5306e-01,  1.0430e-01, -2.8797e-01, -3.2132e-02,
           -2.8696e-01,  2.6184e-01, -3.5610e-02,  7.3354e-02, -6.7447e-02,
           -