In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import coremltools as ct

  from .autonotebook import tqdm as notebook_tqdm


In [71]:
## -----------------------------------------------------------------------------
## Network layers
## -----------------------------------------------------------------------------

# 3x3 convolution module
def Conv(in_channels, out_channels):
  return nn.Conv2d(in_channels, out_channels, 3, padding=1)

# ReLU function
def relu(x):
  return F.relu(x, inplace=True)

# 2x2 max pool function
def pool(x):
  return F.max_pool2d(x, 2, 2)

# 2x2 nearest-neighbor upsample function
def upsample(x):
  return F.interpolate(x, scale_factor=2, mode='nearest')

# Channel concatenation function
def concat(a, b):
  return torch.cat((a, b), 1)

## -----------------------------------------------------------------------------
## U-Net model
## -----------------------------------------------------------------------------

class UNet(nn.Module):
  def __init__(self, in_channels=3, out_channels=3):
    super(UNet, self).__init__()

    # Number of channels per layer
    ic   = in_channels
    ec1  = 32
    ec2  = 48
    ec3  = 64
    ec4  = 80
    ec5  = 96
    dc4  = 112
    dc3  = 96
    dc2  = 64
    dc1a = 64
    dc1b = 32
    oc   = out_channels

    # Convolutions
    self.enc_conv0  = Conv(ic,      ec1)
    self.enc_conv1  = Conv(ec1,     ec1)
    self.enc_conv2  = Conv(ec1,     ec2)
    self.enc_conv3  = Conv(ec2,     ec3)
    self.enc_conv4  = Conv(ec3,     ec4)
    self.enc_conv5a = Conv(ec4,     ec5)
    self.enc_conv5b = Conv(ec5,     ec5)
    self.dec_conv4a = Conv(ec5+ec3, dc4)
    self.dec_conv4b = Conv(dc4,     dc4)
    self.dec_conv3a = Conv(dc4+ec2, dc3)
    self.dec_conv3b = Conv(dc3,     dc3)
    self.dec_conv2a = Conv(dc3+ec1, dc2)
    self.dec_conv2b = Conv(dc2,     dc2)
    self.dec_conv1a = Conv(dc2+ic,  dc1a)
    self.dec_conv1b = Conv(dc1a,    dc1b)
    self.dec_conv0  = Conv(dc1b,    oc)

    # Images must be padded to multiples of the alignment
    self.alignment = 16

  def forward(self, input):
    # Encoder
    # -------------------------------------------

    x = relu(self.enc_conv0(input))  # enc_conv0

    x = relu(self.enc_conv1(x))      # enc_conv1
    x = pool1 = pool(x)              # pool1

    x = relu(self.enc_conv2(x))      # enc_conv2
    x = pool2 = pool(x)              # pool2

    x = relu(self.enc_conv3(x))      # enc_conv3
    x = pool3 = pool(x)              # pool3

    x = relu(self.enc_conv4(x))      # enc_conv4
    x = pool(x)                      # pool4

    # Bottleneck
    x = relu(self.enc_conv5a(x))     # enc_conv5a
    x = relu(self.enc_conv5b(x))     # enc_conv5b

    # Decoder
    # -------------------------------------------

    x = upsample(x)                  # upsample4
    x = concat(x, pool3)             # concat4
    x = relu(self.dec_conv4a(x))     # dec_conv4a
    x = relu(self.dec_conv4b(x))     # dec_conv4b

    x = upsample(x)                  # upsample3
    x = concat(x, pool2)             # concat3
    x = relu(self.dec_conv3a(x))     # dec_conv3a
    x = relu(self.dec_conv3b(x))     # dec_conv3b

    x = upsample(x)                  # upsample2
    x = concat(x, pool1)             # concat2
    x = relu(self.dec_conv2a(x))     # dec_conv2a
    x = relu(self.dec_conv2b(x))     # dec_conv2b

    x = upsample(x)                  # upsample1
    x = concat(x, input)             # concat1
    x = relu(self.dec_conv1a(x))     # dec_conv1a
    x = relu(self.dec_conv1b(x))     # dec_conv1b

    x = self.dec_conv0(x)            # dec_conv0

    return x

In [3]:
net = UNet()

In [26]:
example_input = torch.rand(1, 9, 512, 512)
traced_model = torch.jit.trace(net, example_input)
out = traced_model(example_input)

In [29]:
input_shape = ct.Shape(shape=(1, 9, ct.RangeDim(256,4096), ct.RangeDim(256,4096)))
model_input = ct.TensorType(shape=input_shape)

In [30]:
model = ct.convert(
    traced_model,
    inputs=[model_input]
 )

Converting Frontend ==> MIL Ops: 100%|███▉| 246/247 [00:00<00:00, 4733.78 ops/s]
Running MIL Common passes: 100%|█████████| 34/34 [00:00<00:00, 1019.51 passes/s]
Running MIL Clean up passes: 100%|██████████| 9/9 [00:00<00:00, 414.27 passes/s]
Translating MIL ==> NeuralNetwork Ops: 100%|█| 191/191 [00:00<00:00, 311.63 ops/


In [31]:
model.save("unet.mlmodel")

In [4]:
from coremltools.models.neural_network import flexible_shape_utils
spec = ct.utils.load_spec('unet.mlmodel')

In [5]:
input_name = spec.description.input[0].name

In [8]:
flexible_shape_utils.set_multiarray_ndshape_range(spec, 
                                 feature_name=input_name, 
                                 lower_bounds=[1,3,256,256], 
                                 upper_bounds=[1,3,-1,-1])

In [4]:
from training import tza

In [23]:
rt_hdr_alb_nrm = tza.Reader("weights/rt_ldr_calb_cnrm.tza")

In [69]:
rt_hdr_alb_nrm._table

((32,), 'x', numpy.float32, 10432)

In [59]:
net.get_submodule("enc_conv0").bias

Parameter containing:
tensor([ 1.2497e-02, -3.9386e-01,  1.4464e-04,  3.6471e-04, -5.2961e-01,
        -5.7039e-01,  4.1093e-04,  1.0836e-01,  1.7407e-02, -1.4313e-02,
         2.4525e-02,  2.9364e-02,  1.7879e-02,  1.5135e-02,  3.4147e-02,
         6.8869e-02,  9.8098e-02,  1.9367e-01,  1.6307e-02,  1.6447e-02,
         2.3181e-02,  6.5963e-04,  4.2232e-02,  3.1785e-02,  3.0503e-02,
         8.6258e-04,  2.5121e-02,  9.0851e-04,  8.1616e-04,  1.2892e-02,
         1.7959e-01,  8.9735e-02], requires_grad=True)

In [60]:
rt_hdr_alb_nrm["enc_conv0.bias"]

(array([ 1.24973757e-02, -3.93858135e-01,  1.44641715e-04,  3.64705629e-04,
        -5.29607296e-01, -5.70392549e-01,  4.10931942e-04,  1.08356625e-01,
         1.74072701e-02, -1.43134808e-02,  2.45251376e-02,  2.93642748e-02,
         1.78788304e-02,  1.51347220e-02,  3.41467597e-02,  6.88686296e-02,
         9.80980322e-02,  1.93671897e-01,  1.63065661e-02,  1.64473429e-02,
         2.31808424e-02,  6.59631507e-04,  4.22322974e-02,  3.17853317e-02,
         3.05032786e-02,  8.62584915e-04,  2.51205117e-02,  9.08514368e-04,
         8.16164771e-04,  1.28915962e-02,  1.79587409e-01,  8.97354633e-02],
       dtype=float32),
 'x')

In [58]:
net.get_submodule("enc_conv0").bias = torch.nn.Parameter(torch.tensor(rt_hdr_alb_nrm["enc_conv0.bias"][0]))

In [113]:
import torch
from training import tza

def reload_unet(filepath: str) -> nn.Module:
    data = tza.Reader(filepath)
    in_channels = data["enc_conv0.weight"][0].shape[1]
    out_channels = data["dec_conv0.weight"][0].shape[0]
    net = UNet(in_channels, out_channels)
    for key in data._table.keys():
        layer, param = key.split(".")
        submodule = net.get_submodule(layer)
        if param == "weight":
            submodule.weight = torch.nn.Parameter(torch.tensor(data[key][0]))
        else:
            submodule.bias = torch.nn.Parameter(torch.tensor(data[key][0]))
    return net

In [152]:
net = reload_unet("weights/rt_hdr_calb_cnrm.tza")

In [159]:
import coremltools as ct
import torch

def save_to_mlmodel(filepath: str, net: nn.Module):
    in_channels = net.enc_conv0.weight.shape[1]
    example_input = torch.rand(1, in_channels, 512, 512)
    traced_model = torch.jit.trace(net, example_input)
    out = traced_model(example_input)
    
    input_shape = ct.Shape(shape=(1, in_channels, ct.RangeDim(256,4096), ct.RangeDim(256,4096)))
    model_input = ct.TensorType(shape=input_shape)
    
    model = ct.convert(
        traced_model,
        inputs=[model_input]
     )
    
    model.save(filepath)

In [160]:
save_to_mlmodel("rt_hdr_calb_cnrm.mlmodel", net)

Converting Frontend ==> MIL Ops: 100%|████████████████████████████████████████████▊| 246/247 [00:00<00:00, 4783.22 ops/s]
Running MIL Common passes: 100%|███████████████████████████████████████████████████| 34/34 [00:00<00:00, 972.04 passes/s]
Running MIL Clean up passes: 100%|███████████████████████████████████████████████████| 9/9 [00:00<00:00, 412.90 passes/s]
Translating MIL ==> NeuralNetwork Ops: 100%|████████████████████████████████████████| 191/191 [00:00<00:00, 314.50 ops/s]


CHANGELOG.md                    [1m[36mmkl-dnn[m[m
CMakeLists.txt                  readme.pdf
LICENSE.txt                     requirements.txt
README.md                       [1m[36mscripts[m[m
SECURITY.md                     third-party-programs-oneDNN.txt
[1m[36mapps[m[m                            third-party-programs-oneTBB.txt
[1m[36mcmake[m[m                           third-party-programs.txt
[1m[36mcmake-build-debug[m[m               torchtocoreml.ipynb
[1m[36mcommon[m[m                          [1m[36mtraining[m[m
[1m[36mcore[m[m                            unet.mlmodel
[1m[36mdoc[m[m                             [1m[36mweights[m[m
[1m[36minclude[m[m
