In [46]:
import torch
import clip
import coremltools as ct
import numpy as np
from PIL import Image

# 1. Load ViT-B/32 CLIP model

In [47]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

image = preprocess(Image.open("IMG_3628.jpg")).unsqueeze(0).to(device)
#text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)
text = clip.tokenize("a diagram").to(device)


with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    
    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

traced = torch.jit.trace(model, (image, text))
print("Label probs:", probs)
#print("traced:", traced)

Label probs: [[1.]]


# 2. Export TextEncoder

In [48]:
import torch.nn as nn
from collections import OrderedDict

class ResidualAttentionBlock(nn.Module):
    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
        super().__init__()

        self.attn = nn.MultiheadAttention(d_model, n_head)
        self.ln_1 = LayerNorm(d_model)
        self.mlp = nn.Sequential(OrderedDict([
            ("c_fc", nn.Linear(d_model, d_model * 4)),
            ("gelu", QuickGELU()),
            ("c_proj", nn.Linear(d_model * 4, d_model))
        ]))
        self.ln_2 = LayerNorm(d_model)
        self.attn_mask = attn_mask

    def attention(self, x: torch.Tensor):
        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]

    def forward(self, x: torch.Tensor):
        x = x + self.attention(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x
    
class Transformer(nn.Module):
    def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
        super().__init__()
        self.width = width
        self.layers = layers
        self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])

    def forward(self, x: torch.Tensor):
        return self.resblocks(x)

class LayerNorm(nn.LayerNorm):
    """Subclass torch's LayerNorm to handle fp16."""

    def forward(self, x: torch.Tensor):
        orig_type = x.dtype
        ret = super().forward(x.type(torch.float32))
        return ret.type(orig_type)

class QuickGELU(nn.Module):
    def forward(self, x: torch.Tensor):
        return x * torch.sigmoid(1.702 * x)

In [49]:
import torch.nn as nn

class TextEncoder(nn.Module):
    def __init__(self,
                 embed_dim: int,
                 # text
                 context_length: int,
                 vocab_size: int,
                 transformer_width: int,
                 transformer_heads: int,
                 transformer_layers: int
                 ):
        super().__init__()

        self.context_length = context_length

        self.transformer = Transformer(
                width=transformer_width,
                layers=transformer_layers,
                heads=transformer_heads,
                attn_mask=self.build_attention_mask()
        )

        self.vocab_size = vocab_size
        self.token_embedding = nn.Embedding(vocab_size, transformer_width)
        self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
        self.ln_final = LayerNorm(transformer_width)

        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        self.temperature = nn.Parameter(torch.tensor(0.07))

        self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))

        print(f"text_projection shape: {self.text_projection.shape}")
        self.dtype = torch.float32

        self.initialize_parameters()
    
    def initialize_parameters(self):
        nn.init.normal_(self.token_embedding.weight, std=0.02)
        nn.init.normal_(self.positional_embedding, std=0.01)

        proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
        attn_std = self.transformer.width ** -0.5
        fc_std = (2 * self.transformer.width) ** -0.5
        for block in self.transformer.resblocks:
            nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
            nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
            nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
            nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)

        if self.text_projection is not None:
            nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
        else:
            nn.init.normal_(self.text_projection, std=self.custom_text_config['text_rep_size'] ** -0.5)

    def build_attention_mask(self):
        # lazily create causal attention mask, with full attention between the vision tokens
        # pytorch uses additive attention mask; fill with -inf
        mask = torch.empty(self.context_length, self.context_length)
        mask.fill_(float("-inf"))
        mask.triu_(1)  # zero out the lower diagonal
        return mask

    def forward(self, text):
        # print(f'text: {text}')
        x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model]

        x = x + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)
        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection

        return x

In [50]:
text_encoder = TextEncoder(embed_dim=512, context_length=77, vocab_size=49408, 
                           transformer_width=512, transformer_heads=8, transformer_layers=12)

text_projection shape: torch.Size([512, 512])


In [51]:
text_encoder.load_state_dict(model.state_dict(), strict=False)

_IncompatibleKeys(missing_keys=['temperature'], unexpected_keys=['visual.class_embedding', 'visual.positional_embedding', 'visual.proj', 'visual.conv1.weight', 'visual.ln_pre.weight', 'visual.ln_pre.bias', 'visual.transformer.resblocks.0.attn.in_proj_weight', 'visual.transformer.resblocks.0.attn.in_proj_bias', 'visual.transformer.resblocks.0.attn.out_proj.weight', 'visual.transformer.resblocks.0.attn.out_proj.bias', 'visual.transformer.resblocks.0.ln_1.weight', 'visual.transformer.resblocks.0.ln_1.bias', 'visual.transformer.resblocks.0.mlp.c_fc.weight', 'visual.transformer.resblocks.0.mlp.c_fc.bias', 'visual.transformer.resblocks.0.mlp.c_proj.weight', 'visual.transformer.resblocks.0.mlp.c_proj.bias', 'visual.transformer.resblocks.0.ln_2.weight', 'visual.transformer.resblocks.0.ln_2.bias', 'visual.transformer.resblocks.1.attn.in_proj_weight', 'visual.transformer.resblocks.1.attn.in_proj_bias', 'visual.transformer.resblocks.1.attn.out_proj.weight', 'visual.transformer.resblocks.1.attn.ou

In [52]:
import coremltools as ct

text_encoder.eval()

example_input = clip.tokenize("a diagram").to(device)
traced_model = torch.jit.trace(text_encoder, example_input)
out = traced_model(example_input)

In [53]:
text_encoder_model = ct.convert(
            traced_model,
            convert_to="mlprogram",
            minimum_deployment_target=ct.target.iOS16,
            inputs=[ct.TensorType(name="prompt",
                                 shape=example_input.shape)],
            outputs=[ct.TensorType(name="embOutput", dtype=np.float32)],
#             compute_units=ct.ComputeUnit[args.compute_unit],
            # skip_model_load=True,
        )

Converting PyTorch Frontend ==> MIL Ops:   0%|                  | 0/972 [00:00<?, ? ops/s]Saving value type of int64 into a builtin type of int32, might lose precision!
Converting PyTorch Frontend ==> MIL Ops: 100%|████▉| 971/972 [00:00<00:00, 10912.19 ops/s]
Running MIL frontend_pytorch pipeline: 100%|██████████| 5/5 [00:00<00:00, 192.92 passes/s]
Running MIL default pipeline: 100%|██████████████████| 57/57 [00:01<00:00, 28.90 passes/s]
Running MIL backend_mlprogram pipeline: 100%|███████| 10/10 [00:00<00:00, 748.90 passes/s]


In [54]:
text_encoder_model.save("TextEncoder_float32.mlpackage")

## Validate export  precision

In [55]:
import coremltools as ct

# Load the model
model = ct.models.MLModel('TextEncoder_float32.mlpackage')
text = clip.tokenize("a diagram").to(device)
predictions = model.predict({'prompt': text})

print("PyTorch TextEncoder ckpt out for \"a diagram\":\n>>>", out[0, :10])
print("\nCoreML TextEncoder ckpt out for \"a diagram\":\n>>>", predictions['embOutput'][0, :10])

PyTorch TextEncoder ckpt out for "a diagram":
>>> tensor([ 0.0547, -0.0061,  0.0495,  0.0106,  0.1107, -0.2575, -0.2108, -1.3542,
         0.4390, -0.1328], grad_fn=<SliceBackward0>)

CoreML TextEncoder ckpt out for "a diagram":
>>> [ 0.02310181  0.06121826 -0.10656738  0.07983398  0.26879883  0.09619141
 -0.07788086  0.06005859  0.03533936  0.01570129]


**You can see that there is some loss in precision, but it is still acceptable.**

# 3. Export ImageEncoder

In [56]:
import torch
import clip
import coremltools as ct
import numpy as np
from PIL import Image

In [57]:
device="cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
i = Image.open("IMG_3628.jpg")
image_orig = preprocess(i).unsqueeze(0).to(device)

In [58]:
traced_image_only = torch.jit.trace(model.visual, image_orig)
out = traced_image_only(image_orig)

In [59]:
import coremltools as ct
# Set the image scale and bias for input image preprocessing
scale = 1/(0.2685697*255.0)
bias = [- 0.48145466/(0.26862954) , - 0.4578275/(0.26130258), - 0.40821073/(0.27577711)]

# imgPIL = Image.open("4111670639918_.pic.png")

image_input_scale = ct.ImageType(name="colorImage",
                           color_layout=ct.colorlayout.RGB,
                           shape=image_orig.shape,
                           scale=scale, bias=bias)


image_encoder_model = ct.convert(
            traced_image_only,
            convert_to="mlprogram",
            minimum_deployment_target=ct.target.iOS16,
            inputs=[image_input_scale],
            outputs=[ct.TensorType(name="embOutput", dtype=np.float32)],
        )


Converting PyTorch Frontend ==> MIL Ops: 100%|████▉| 970/971 [00:00<00:00, 10550.50 ops/s]
Running MIL frontend_pytorch pipeline: 100%|██████████| 5/5 [00:00<00:00, 195.10 passes/s]
Running MIL default pipeline: 100%|██████████████████| 57/57 [00:02<00:00, 26.89 passes/s]
Running MIL backend_mlprogram pipeline: 100%|███████| 10/10 [00:00<00:00, 754.19 passes/s]


In [60]:
image_encoder_model.save("ImageEncoder_float32.mlpackage")

## Validate export

In [62]:
import coremltools as ct

# Load the model
image_encoder = ct.models.MLModel('ImageEncoder_float32.mlpackage')

from torchvision import transforms
imgPIL = Image.open("IMG_3628.jpg")
imgPIL = imgPIL.resize((224, 224), Image.BICUBIC)
predictions = image_encoder.predict({'colorImage': imgPIL})

print("PyTorch ImageEncoder ckpt out for IMG_3628.jpg:\n>>>", out[0, :10])
print("\nCoreML ImageEncoder ckpt out for IMG_3628.jpg:\n>>>", predictions['embOutput'][0, :10])

PyTorch ImageEncoder ckpt out for IMG_3628.jpg:
>>> tensor([ 5.0707e-01,  3.0469e-04,  9.7929e-03, -1.6034e-01,  2.0017e-01,
        -2.5040e-01,  1.2528e-01, -1.4575e-01,  7.5148e-01,  7.2960e-02],
       grad_fn=<SliceBackward0>)

CoreML ImageEncoder ckpt out for IMG_3628.jpg:
>>> [ 0.06161072 -0.06282079  0.37438184 -0.25163764  0.23998612 -0.24717407
  0.71466494  0.2927076   0.64987004  0.10319111]


This time <span style='color:red'> the precision error is larger.</span> This may be caused by the wrong norm. 

## What if no norm?

In [45]:
image_input_scale = ct.ImageType(name="colorImage",
                           color_layout=ct.colorlayout.RGB,
                           shape=image_orig.shape)


image_encoder_model = ct.convert(
            traced_image_only,
            convert_to="mlprogram",
            minimum_deployment_target=ct.target.iOS16,
            inputs=[image_input_scale],
            outputs=[ct.TensorType(name="embOutput", dtype=np.float32)],
        )

image_encoder_model.save("ImageEncoder_float32.mlpackage")

image_encoder = ct.models.MLModel('ImageEncoder_float32.mlpackage')

from torchvision import transforms
imgPIL = Image.open("IMG_3628.jpg")
imgPIL = imgPIL.resize((224, 224), Image.BICUBIC)
predictions = image_encoder.predict({'colorImage': imgPIL})

print("PyTorch ImageEncoder ckpt out for IMG_3628.jpg:\n>>>", out[0, :10])
print("\nCoreML ImageEncoder ckpt out for IMG_3628.jpg:\n>>>", predictions['embOutput'][0, :10])

Converting PyTorch Frontend ==> MIL Ops: 100%|████▉| 970/971 [00:00<00:00, 10112.61 ops/s]
Running MIL frontend_pytorch pipeline: 100%|██████████| 5/5 [00:00<00:00, 189.16 passes/s]
Running MIL default pipeline: 100%|██████████████████| 57/57 [00:02<00:00, 26.70 passes/s]
Running MIL backend_mlprogram pipeline: 100%|███████| 10/10 [00:00<00:00, 876.81 passes/s]


PyTorch ImageEncoder ckpt out for IMG_3628.jpg:
>>> tensor([ 5.0707e-01,  3.0469e-04,  9.7929e-03, -1.6034e-01,  2.0017e-01,
        -2.5040e-01,  1.2528e-01, -1.4575e-01,  7.5148e-01,  7.2960e-02],
       grad_fn=<SliceBackward0>)

CoreML ImageEncoder ckpt out for IMG_3628.jpg:
>>> [-0.06622314  0.02362061 -0.09057617  0.07165527  0.55322266 -0.4963379
  0.26123047  0.7270508   0.71484375  0.10101318]


**The error is even worse.**