<a href="https://colab.research.google.com/github/JacopoMangiavacchi/animegan2-coreml/blob/master/convert_coreml.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import cv2
import matplotlib.pyplot as plt
import torch
import random
import numpy as np

In [2]:
!git clone https://github.com/jacopomangiavacchi/animegan2-coreml

Cloning into 'animegan2-coreml'...
remote: Enumerating objects: 46, done.[K
remote: Counting objects: 100% (46/46), done.[K
remote: Compressing objects: 100% (40/40), done.[K
remote: Total 46 (delta 11), reused 34 (delta 5), pack-reused 0[K
Unpacking objects: 100% (46/46), done.


In [3]:
os.chdir(f'./animegan2-coreml')

In [4]:
from model import Generator

def load_image(path, size=None):
    image = image2tensor(cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB))

    w, h = image.shape[-2:]
    if w != h:
        crop_size = min(w, h)
        left = (w - crop_size)//2
        right = left + crop_size
        top = (h - crop_size)//2
        bottom = top + crop_size
        image = image[:,:,left:right, top:bottom]

    if size is not None and image.shape[-1] != size:
        image = torch.nn.functional.interpolate(image, (size, size), mode="bilinear", align_corners=True)
    
    return image

def image2tensor(image):
    image = torch.FloatTensor(image).permute(2,0,1).unsqueeze(0)/255.
    return (image-0.5)/0.5

def tensor2image(tensor):
    tensor = tensor.clamp_(-1., 1.).detach().squeeze().permute(1,2,0).cpu().numpy()
    return tensor*0.5 + 0.5

def imshow(img, size=5, cmap='jet'):
    plt.figure(figsize=(size,size))
    plt.imshow(img, cmap=cmap)
    plt.axis('off')
    plt.show()

In [5]:
device = 'cpu'
torch.set_grad_enabled(False)
image_size = 512

model = Generator().eval().to(device)

In [19]:
ckpt = torch.load(f"weights/celeba_distill.pt", map_location=device)
# ckpt = torch.load(f"weights/face_paint_512_v2.pt", map_location=device)
# ckpt = torch.load(f"weights/paprika.pt", map_location=device)
model.load_state_dict(ckpt)

<All keys matched successfully>

In [7]:
!pip install coremltools

Collecting coremltools
  Downloading coremltools-5.1.0-cp37-none-manylinux1_x86_64.whl (1.6 MB)
[?25l[K     |▏                               | 10 kB 21.1 MB/s eta 0:00:01[K     |▍                               | 20 kB 25.5 MB/s eta 0:00:01[K     |▋                               | 30 kB 18.9 MB/s eta 0:00:01[K     |▉                               | 40 kB 15.5 MB/s eta 0:00:01[K     |█                               | 51 kB 5.5 MB/s eta 0:00:01[K     |█▎                              | 61 kB 5.9 MB/s eta 0:00:01[K     |█▌                              | 71 kB 5.5 MB/s eta 0:00:01[K     |█▊                              | 81 kB 6.1 MB/s eta 0:00:01[K     |██                              | 92 kB 6.0 MB/s eta 0:00:01[K     |██                              | 102 kB 5.3 MB/s eta 0:00:01[K     |██▎                             | 112 kB 5.3 MB/s eta 0:00:01[K     |██▌                             | 122 kB 5.3 MB/s eta 0:00:01[K     |██▊                             | 133 kB 

In [21]:
image = torch.randn(1, 3, image_size, image_size)
output = model(image.to(device))

In [9]:
# scripted_model = torch.jit.script(model)
traced = torch.jit.trace(model, image)

  _verify_batch_size([input.size(0) * input.size(1) // num_groups, num_groups] + list(input.size()[2:]))


In [10]:
import coremltools
scale = 1/(0.226*255.0)
bias = [- 0.485/(0.229) , - 0.456/(0.224), - 0.406/(0.225)]

image_input = coremltools.ImageType(name="input_1",
                           shape=(1, 3, 512, 512),
                           scale=scale, bias=bias)



In [11]:
mlmodel = coremltools.converters.convert(
  traced,
  inputs=[image_input],
)

Converting Frontend ==> MIL Ops: 100%|█████████▉| 295/296 [00:00<00:00, 464.07 ops/s]
Running MIL Common passes: 100%|██████████| 34/34 [00:00<00:00, 56.21 passes/s]
Running MIL Clean up passes: 100%|██████████| 9/9 [00:00<00:00, 69.52 passes/s]
Translating MIL ==> NeuralNetwork Ops: 100%|██████████| 802/802 [00:00<00:00, 2251.57 ops/s]


In [12]:
mlmodel.save('animegan2-celeba-distil-512.mlmodel')

In [13]:
spec = coremltools.utils.load_spec("animegan2-celeba-distil-512.mlmodel")

In [14]:
print(spec.description.input)

[name: "input_1"
type {
  imageType {
    width: 512
    height: 512
    colorSpace: RGB
  }
}
]


In [15]:
print(spec.description.output)

[name: "var_457"
type {
  multiArrayType {
    dataType: FLOAT32
  }
}
]


In [16]:
import coremltools.proto.FeatureTypes_pb2 as ft

output = spec.description.output[0]

output.type.imageType.colorSpace = ft.ImageFeatureType.RGB
output.type.imageType.height = 512
output.type.imageType.width = 512

coremltools.utils.save_spec(spec, "animegan2-celeba-distil-512.mlmodel")

In [17]:
spec = coremltools.utils.load_spec("animegan2-celeba-distil-512.mlmodel")

In [18]:
print(spec.description.output)

[name: "var_457"
type {
  imageType {
    width: 512
    height: 512
    colorSpace: RGB
  }
}
]
