# Convert TFLite model to PyTorch

This uses the model **face_detection_front.tflite** from [MediaPipe](https://github.com/google/mediapipe/tree/master/mediapipe/models).

Prerequisites:

1) Clone the MediaPipe repo:

```
git clone https://github.com/google/mediapipe.git
```

2) Install **flatbuffers**:

```
git clone https://github.com/google/flatbuffers.git
cmake -G "Unix Makefiles" -DCMAKE_BUILD_TYPE=Release
make -j

cd flatbuffers/python
python setup.py install
```

3) Clone the TensorFlow repo. We only need this to get the FlatBuffers schema files (I guess you could just download [schema.fbs](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs)).

```
git clone https://github.com/tensorflow/tensorflow.git
```

4) Convert the schema files to Python files using **flatc**:

```
./flatbuffers/flatc --python tensorflow/tensorflow/lite/schema/schema.fbs
```

Now we can use the Python FlatBuffer API to read the TFLite file!

In [1]:
# !git clone https://github.com/google/mediapipe.git
# !git clone https://github.com/google/flatbuffers.git
# !cd flatbuffers ; cmake -G "Unix Makefiles" -DCMAKE_BUILD_TYPE=Release ; make -j
# !cd flatbuffers/python ; python setup.py install
# !git clone https://github.com/tensorflow/tensorflow.git
# !./flatbuffers/flatc --python tensorflow/tensorflow/lite/schema/schema.fbs

Now restart this notebook

In [2]:
import os
import numpy as np
from collections import OrderedDict

## Get the weights from the TFLite file

Load the TFLite model using the FlatBuffers library:

In [3]:
from tflite import Model

# taken from arcore pod
data = open("../mediapipe/mediapipe/models/iris_landmark.tflite", "rb").read()
model = Model.GetRootAsModel(data, 0)

In [4]:
subgraph = model.Subgraphs(0)
subgraph.Name()

b'faceeyenbrow_iris_full_2019_09_13_v0'

In [5]:
def get_shape(tensor):
    return [tensor.Shape(i) for i in range(tensor.ShapeLength())]

List all the tensors in the graph:

In [6]:
for i in range(0, subgraph.TensorsLength()):
    tensor = subgraph.Tensors(i)
    print("%3d %30s %d %2d %s" % (i, tensor.Name(), tensor.Type(), tensor.Buffer(), 
                                  get_shape(subgraph.Tensors(i))))

  0                     b'input_1' 0  0 [1, 64, 64, 3]
  1               b'conv2d/Kernel' 0  1 [64, 3, 3, 3]
  2                 b'conv2d/Bias' 0  2 [64]
  3                      b'conv2d' 0  0 [1, 32, 32, 64]
  4               b'p_re_lu/Alpha' 0  3 [1, 1, 64]
  5                     b'p_re_lu' 0  0 [1, 32, 32, 64]
  6             b'conv2d_1/Kernel' 0  4 [32, 1, 1, 64]
  7               b'conv2d_1/Bias' 0  5 [32]
  8                    b'conv2d_1' 0  0 [1, 32, 32, 32]
  9             b'p_re_lu_1/Alpha' 0  6 [1, 1, 32]
 10                   b'p_re_lu_1' 0  0 [1, 32, 32, 32]
 11     b'depthwise_conv2d/Kernel' 0  7 [1, 3, 3, 32]
 12       b'depthwise_conv2d/Bias' 0  8 [32]
 13            b'depthwise_conv2d' 0  0 [1, 32, 32, 32]
 14             b'conv2d_2/Kernel' 0  9 [64, 1, 1, 32]
 15               b'conv2d_2/Bias' 0 10 [64]
 16                    b'conv2d_2' 0  0 [1, 32, 32, 64]
 17         b'add__xeno_compat__1' 0  0 [1, 32, 32, 64]
 18             b'p_re_lu_2/Alpha' 0 11 [1, 1, 64]
 1

Make a look-up table that lets us get the tensor index based on the tensor name:

In [7]:
tensor_dict = {(subgraph.Tensors(i).Name().decode("utf8")): i 
               for i in range(subgraph.TensorsLength())}

Grab only the tensors that represent weights and biases.

In [8]:
parameters = {}
for i in range(subgraph.TensorsLength()):
    tensor = subgraph.Tensors(i)
    if tensor.Buffer() > 0:
        name = tensor.Name().decode("utf8")
        parameters[name] = tensor.Buffer()

len(parameters)

216

The buffers are simply arrays of bytes. As the docs say,

> The data_buffer itself is an opaque container, with the assumption that the
> target device is little-endian. In addition, all builtin operators assume
> the memory is ordered such that if `shape` is [4, 3, 2], then index
> [i, j, k] maps to `data_buffer[i*3*2 + j*2 + k]`.

For weights and biases, we need to interpret every 4 bytes as being as float. On my machine, the native byte ordering is already little-endian so we don't need to do anything special for that.

In [9]:
def get_weights(tensor_name):
    i = tensor_dict[tensor_name]
    tensor = subgraph.Tensors(i)
    buffer = tensor.Buffer()
    shape = get_shape(tensor)
    assert(tensor.Type() == 0)  # FLOAT32
    # tensor types are here: https://github.com/jackwish/tflite/blob/master/tflite/TensorType.py
    
    W = model.Buffers(buffer).DataAsNumpy()
    W = W.view(dtype=np.float32)
    W = W.reshape(shape)
    return W

In [10]:
W = get_weights("conv2d_1/Kernel")
b = get_weights("conv2d_1/Bias")
W.shape, b.shape

((32, 1, 1, 64), (32,))

Now we can get the weights for all the layers and copy them into our PyTorch model.

## Convert the weights to PyTorch format

In [11]:
import torch
import torch.nn as nn
from irislandmarks import IrisLandmarks

In [12]:
net = IrisLandmarks()

In [13]:
net

IrisLandmarks(
  (backbone): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2))
    (1): PReLU(num_parameters=64)
    (2): IrisBlock(
      (convAct): Sequential(
        (0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
        (1): PReLU(num_parameters=32)
      )
      (dwConvConv): Sequential(
        (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)
        (1): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
      )
      (act): PReLU(num_parameters=64)
    )
    (3): IrisBlock(
      (convAct): Sequential(
        (0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
        (1): PReLU(num_parameters=32)
      )
      (dwConvConv): Sequential(
        (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)
        (1): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
      )
      (act): PReLU(num_parameters=64)
    )
    (4): IrisBlock(
      (convAct): Sequential(
        (0): Conv2d(64, 32, 

In [14]:
net(torch.randn(2,3,64,64))[0].shape

torch.Size([2, 213])

Make a lookup table that maps the layer names between the two models. We're going to assume here that the tensors will be in the same order in both models. If not, we should get an error because shapes don't match.

In [15]:
probable_names = []
for i in range(0, subgraph.TensorsLength()):
    tensor = subgraph.Tensors(i)
    if tensor.Buffer() > 0 and tensor.Type() == 0:
        probable_names.append(tensor.Name().decode("utf-8"))
        
probable_names[:5]

['conv2d/Kernel',
 'conv2d/Bias',
 'p_re_lu/Alpha',
 'conv2d_1/Kernel',
 'conv2d_1/Bias']

In [16]:
len(probable_names)

215

In [17]:
from pprint import pprint

In [18]:
pprint(list(zip(probable_names, net.state_dict())))

[('conv2d/Kernel', 'backbone.0.weight'),
 ('conv2d/Bias', 'backbone.0.bias'),
 ('p_re_lu/Alpha', 'backbone.1.weight'),
 ('conv2d_1/Kernel', 'backbone.2.convAct.0.weight'),
 ('conv2d_1/Bias', 'backbone.2.convAct.0.bias'),
 ('p_re_lu_1/Alpha', 'backbone.2.convAct.1.weight'),
 ('depthwise_conv2d/Kernel', 'backbone.2.dwConvConv.0.weight'),
 ('depthwise_conv2d/Bias', 'backbone.2.dwConvConv.0.bias'),
 ('conv2d_2/Kernel', 'backbone.2.dwConvConv.1.weight'),
 ('conv2d_2/Bias', 'backbone.2.dwConvConv.1.bias'),
 ('p_re_lu_2/Alpha', 'backbone.2.act.weight'),
 ('conv2d_3/Kernel', 'backbone.3.convAct.0.weight'),
 ('conv2d_3/Bias', 'backbone.3.convAct.0.bias'),
 ('p_re_lu_3/Alpha', 'backbone.3.convAct.1.weight'),
 ('depthwise_conv2d_1/Kernel', 'backbone.3.dwConvConv.0.weight'),
 ('depthwise_conv2d_1/Bias', 'backbone.3.dwConvConv.0.bias'),
 ('conv2d_4/Kernel', 'backbone.3.dwConvConv.1.weight'),
 ('conv2d_4/Bias', 'backbone.3.dwConvConv.1.bias'),
 ('p_re_lu_4/Alpha', 'backbone.3.act.weight'),
 ('conv2d

In [19]:
len(net.state_dict()), len(probable_names)

(215, 215)

In [20]:
convert = {}
i = 0
for name, params in net.state_dict().items():
    if i < 85:
        convert[name] = probable_names[i]
        i += 1

In [21]:
import ast
manual_mapping = ast.literal_eval(open("conversion_dict.txt", "r").read())

convert.update(manual_mapping)

Copy the weights into the layers.

Note that the ordering of the weights is different between PyTorch and TFLite, so we need to transpose them.

Convolution weights:

    TFLite:  (out_channels, kernel_height, kernel_width, in_channels)
    PyTorch: (out_channels, in_channels, kernel_height, kernel_width)

Depthwise convolution weights:

    TFLite:  (1, kernel_height, kernel_width, channels)
    PyTorch: (channels, 1, kernel_height, kernel_width)
    
PReLU:

    TFLite:  (1, 1, num_channels)
    PyTorch: (num_channels, )


In [22]:
new_state_dict = OrderedDict()

for dst, src in convert.items():
    W = get_weights(src)
    print(dst, src, W.shape, net.state_dict()[dst].shape)

    if W.ndim == 4:
        if W.shape[0] == 1: # no conv2d with out_channel == 1 in this net
            W = W.transpose((3, 0, 1, 2))  # depthwise conv
        else:
            W = W.transpose((0, 3, 1, 2))  # regular conv
    elif W.ndim == 3:
        W = W.reshape(-1)
    
    new_state_dict[dst] = torch.from_numpy(W)

backbone.0.weight conv2d/Kernel (64, 3, 3, 3) torch.Size([64, 3, 3, 3])
backbone.0.bias conv2d/Bias (64,) torch.Size([64])
backbone.1.weight p_re_lu/Alpha (1, 1, 64) torch.Size([64])
backbone.2.convAct.0.weight conv2d_1/Kernel (32, 1, 1, 64) torch.Size([32, 64, 1, 1])
backbone.2.convAct.0.bias conv2d_1/Bias (32,) torch.Size([32])
backbone.2.convAct.1.weight p_re_lu_1/Alpha (1, 1, 32) torch.Size([32])
backbone.2.dwConvConv.0.weight depthwise_conv2d/Kernel (1, 3, 3, 32) torch.Size([32, 1, 3, 3])
backbone.2.dwConvConv.0.bias depthwise_conv2d/Bias (32,) torch.Size([32])
backbone.2.dwConvConv.1.weight conv2d_2/Kernel (64, 1, 1, 32) torch.Size([64, 32, 1, 1])
backbone.2.dwConvConv.1.bias conv2d_2/Bias (64,) torch.Size([64])
backbone.2.act.weight p_re_lu_2/Alpha (1, 1, 64) torch.Size([64])
backbone.3.convAct.0.weight conv2d_3/Kernel (32, 1, 1, 64) torch.Size([32, 64, 1, 1])
backbone.3.convAct.0.bias conv2d_3/Bias (32,) torch.Size([32])
backbone.3.convAct.1.weight p_re_lu_3/Alpha (1, 1, 32) to

In [23]:
net.load_state_dict(new_state_dict, strict=True)

<All keys matched successfully>

No errors? Then the conversion was successful!

## Save the checkpoint

In [24]:
torch.save(net.state_dict(), "irislandmarks.pth")