In [1]:
import coremltools as ct

In [2]:
# Download class labels (from a separate file)
import urllib
label_url = 'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt'
class_labels = urllib.request.urlopen(label_url).read().splitlines()
class_labels = class_labels[1:] # remove the first class which is background
assert len(class_labels) == 1000

# make sure entries of class_labels are strings
for i, label in enumerate(class_labels):
  if isinstance(label, bytes):
    class_labels[i] = label.decode("utf8")

In [3]:
classifier_config = ct.ClassifierConfig(class_labels)

# TensorFlow (tf.keras)

In [4]:
from tensorflow.keras.applications import mobilenet_v2

In [5]:
tf_model = mobilenet_v2.MobileNetV2(include_top=True, weights='imagenet')

In [6]:
image_input = ct.ImageType(scale=2/255, bias=[-1, -1, -1])

mlmodel_tf = ct.convert(tf_model,
                        inputs=[image_input],
                        classifier_config=classifier_config)

Running TensorFlow Graph Passes: 100%|██████████| 5/5 [00:00<00:00,  7.74 passes/s]
Converting Frontend ==> MIL Ops: 100%|██████████| 428/428 [00:01<00:00, 291.40 ops/s] 
Running MIL optimization passes: 100%|██████████| 18/18 [00:01<00:00, 14.93 passes/s]
Translating MIL ==> MLModel Ops: 100%|██████████| 751/751 [00:00<00:00, 1124.17 ops/s]


In [7]:
# Set feature description (these show up as comments in Xcode)
mlmodel_tf.input_description["input_1"] = "Input image to be classified"
mlmodel_tf.output_description["classLabel"] = "Most likely image category"
mlmodel_tf.output_description["Identity"] = "Probability of each image category"
mlmodel_tf.short_description = "MobileNet v2 converted from TensorFlow"
mlmodel_tf.version = "1.0"

In [8]:
# Save model
mlmodel_tf.save("MobileNetV2_TF.mlmodel")

In [9]:
print(mlmodel_tf)

input {
  name: "input_1"
  shortDescription: "Input image to be classified"
  type {
    imageType {
      width: 224
      height: 224
      colorSpace: RGB
      imageSizeRange {
        widthRange {
          lowerBound: 224
          upperBound: 224
        }
        heightRange {
          lowerBound: 224
          upperBound: 224
        }
      }
    }
  }
}
output {
  name: "Identity"
  shortDescription: "Probability of each image category"
  type {
    dictionaryType {
      stringKeyType {
      }
    }
  }
}
output {
  name: "classLabel"
  shortDescription: "Most likely image category"
  type {
    stringType {
    }
  }
}
predictedFeatureName: "classLabel"
predictedProbabilitiesName: "Identity"
metadata {
  shortDescription: "MobileNet v2 converted from TensorFlow"
  versionString: "1.0"
  userDefined {
    key: "com.github.apple.coremltools.source"
    value: "tensorflow==2.1.0"
  }
  userDefined {
    key: "com.github.apple.coremltools.version"
    value: "4.1"
  }
}



# PyTorch

In [10]:
import torch
import torchvision

In [11]:
import torch.nn as nn

In [12]:
# Wrapper class for passing the softmax function through the output of the pre-trained model
class WrappedMobileNetV2(nn.Module):
    def __init__(self):
        super(WrappedMobileNetV2, self).__init__()
        self.model = torchvision.models.mobilenet_v2(pretrained=True).eval()
        
    def forward(self, x):
        m = nn.Softmax(dim=1)
        res = m(self.model(x))
        return res

In [13]:
traceable_model = WrappedMobileNetV2().eval()

In [14]:
example_input = torch.rand(1, 3, 224, 224) 
traced_model = torch.jit.trace(traceable_model, example_input)

In [24]:
image_input = ct.ImageType(name="input_1", shape=example_input.shape, 
                           scale=2/255, bias=[-1, -1, -1])

mlmodel_torch = ct.convert(traced_model,
                           inputs=[image_input],
                           classifier_config=classifier_config)

Converting Frontend ==> MIL Ops: 100%|█████████▉| 386/387 [00:01<00:00, 377.46 ops/s]
Running MIL optimization passes: 100%|██████████| 18/18 [00:00<00:00, 51.25 passes/s]
Translating MIL ==> MLModel Ops: 100%|██████████| 706/706 [00:00<00:00, 765.32 ops/s] 


In [25]:
print(mlmodel_torch)

input {
  name: "input_1"
  type {
    imageType {
      width: 224
      height: 224
      colorSpace: RGB
    }
  }
}
output {
  name: "649"
  type {
    dictionaryType {
      stringKeyType {
      }
    }
  }
}
output {
  name: "classLabel"
  type {
    stringType {
    }
  }
}
predictedFeatureName: "classLabel"
predictedProbabilitiesName: "649"
metadata {
  userDefined {
    key: "com.github.apple.coremltools.source"
    value: "torch==1.6.0"
  }
  userDefined {
    key: "com.github.apple.coremltools.version"
    value: "4.1"
  }
}



In [39]:
m = str(mlmodel_torch)
target = 'predictedProbabilitiesName'
m = m[m.find(target)+len(target):]
node_name = m.split("\"", 2)[1]
node_name

'649'

In [40]:
spec = mlmodel_torch.get_spec()

ct.utils.rename_feature(spec, node_name, "Identity")
mlmodel_torch = ct.models.MLModel(spec)

In [41]:
# Set feature description (these show up as comments in Xcode)
mlmodel_torch.input_description["input_1"] = "Input image to be classified"
mlmodel_torch.output_description["classLabel"] = "Most likely image category"
mlmodel_torch.output_description["Identity"] = "Probability of each image category"
mlmodel_torch.short_description = "MobileNet v2 converted from PyTorch"
mlmodel_torch.version = "1.0"

In [42]:
mlmodel_torch.save("MobileNetV2_Torch.mlmodel")

In [43]:
print(mlmodel_torch)

input {
  name: "input_1"
  shortDescription: "Input image to be classified"
  type {
    imageType {
      width: 224
      height: 224
      colorSpace: RGB
    }
  }
}
output {
  name: "Identity"
  shortDescription: "Probability of each image category"
  type {
    dictionaryType {
      stringKeyType {
      }
    }
  }
}
output {
  name: "classLabel"
  shortDescription: "Most likely image category"
  type {
    stringType {
    }
  }
}
predictedFeatureName: "classLabel"
predictedProbabilitiesName: "Identity"
metadata {
  shortDescription: "MobileNet v2 converted from PyTorch"
  versionString: "1.0"
  userDefined {
    key: "com.github.apple.coremltools.source"
    value: "torch==1.6.0"
  }
  userDefined {
    key: "com.github.apple.coremltools.version"
    value: "4.1"
  }
}

