In [1]:
# Imports
import os
import sys
import numpy as np
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F

import coremltools as ct
from coremltools.models.neural_network import quantization_utils



In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cpu'

In [3]:
restore_path = "./model_parameter.pt"
valid_dir = "./kodak/"
y_hat_dir = "./y_hat/"

In [4]:
class EncoderDecoder(nn.Module):
    def __init__(self):
        super(EncoderDecoder, self).__init__()

        # Define Decoder convolution and activation layers
        self.Conv_Decoder = torch.nn.Conv2d(in_channels=12,  out_channels=3, kernel_size=(37, 37), stride=(1, 1), padding=(18, 18), padding_mode='zeros', bias=True)

    def forward(self, y_hat):

        # Decoder: y_hat --> x_hat
        x_hat = self.Conv_Decoder(y_hat)

        return x_hat

## Points to Remember

**GT image shape:**  (1, 3, 512, 768)

**y_hat shape:**  torch.Size([1, 12, 512, 768])

**Model prediction shape:**  (1, 3, 512, 768)

In [5]:
# Restore Model from Weights
model = EncoderDecoder().to(device)
model.eval()

# Restore Paths:
print(f'Restoring weights from {restore_path}')
checkpoint = torch.load(restore_path, map_location=device)
model.load_state_dict(checkpoint['weights'])

Restoring weights from ./model_parameter.pt


<All keys matched successfully>

In [6]:
for param in model.parameters():
    print(param)

Parameter containing:
tensor([[[[ 1.9412e-09,  2.1639e-08,  2.3721e-09,  ...,  7.2213e-08,
            3.2567e-08, -6.1179e-10],
          [ 6.4599e-10,  9.9631e-08,  1.3858e-07,  ..., -1.3488e-07,
            1.1892e-09,  1.3725e-08],
          [-3.7552e-09, -1.3304e-08, -6.8318e-08,  ..., -8.3163e-07,
           -1.7339e-07,  2.1125e-08],
          ...,
          [ 3.4972e-08, -4.4061e-07, -1.1244e-07,  ...,  5.1398e-07,
           -1.4637e-07, -3.0399e-08],
          [ 9.2846e-09, -5.7243e-08, -4.6809e-08,  ..., -1.0012e-08,
           -1.3338e-08,  1.1598e-08],
          [ 2.6709e-09, -3.6229e-09,  1.1782e-08,  ..., -4.9326e-08,
            1.9643e-08,  7.2026e-09]],

         [[ 1.0537e-09,  7.2584e-09,  5.9777e-08,  ...,  1.0114e-07,
            2.9283e-08,  3.7069e-09],
          [ 7.3881e-09, -4.2303e-08, -3.8457e-09,  ..., -1.4971e-07,
           -4.3012e-08,  5.2830e-09],
          [ 9.4045e-08, -4.7817e-08, -4.2437e-07,  ..., -8.5943e-07,
           -1.0420e-07,  4.3132e-09]

In [7]:
# Create dummy input
dummy_input = torch.rand(1, 12, 512, 768)

# Define input / output names
input_names = ["y_hat"]
output_names = ["pred"]

# Convert the PyTorch model to ONNX
torch.onnx.export(model,
                  dummy_input,
                  "test02_network_op11.onnx",
                  verbose=True,
                  input_names=input_names,
                  output_names=output_names,
                  opset_version= 11)

graph(%y_hat : Float(1, 12, 512, 768),
      %Conv_Decoder.weight : Float(3, 12, 37, 37),
      %Conv_Decoder.bias : Float(3)):
  %pred : Float(1, 3, 512, 768) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[37, 37], pads=[18, 18, 18, 18], strides=[1, 1]](%y_hat, %Conv_Decoder.weight, %Conv_Decoder.bias) # /Users/anujdutt/miniconda3/envs/PyTorch/lib/python3.8/site-packages/torch/nn/modules/conv.py:345:0
  return (%pred)



In [8]:
import coremltools as ct

# Convert from ONNX to Core ML
coreml_model  = ct.converters.onnx.convert(model='test02_network_op11.onnx', minimum_ios_deployment_target='13')

1/1: Converting Node Type Conv
Translation to CoreML spec completed. Now compiling the CoreML model.
Model Compilation done.


In [9]:
coreml_model.input_description['y_hat'] = 'Input Image'
coreml_model.output_description['pred'] = 'Compressed Image Output'
coreml_model.short_description = 'FP-32 model.'

In [10]:
coreml_model

input {
  name: "y_hat"
  shortDescription: "Input Image"
  type {
    multiArrayType {
      shape: 1
      shape: 12
      shape: 512
      shape: 768
      dataType: FLOAT32
    }
  }
}
output {
  name: "pred"
  shortDescription: "Compressed Image Output"
  type {
    multiArrayType {
      shape: 1
      shape: 3
      shape: 512
      shape: 768
      dataType: FLOAT32
    }
  }
}
metadata {
  shortDescription: "FP-32 model."
  userDefined {
    key: "com.github.apple.coremltools.source"
    value: "onnx==1.7.0"
  }
  userDefined {
    key: "com.github.apple.coremltools.version"
    value: "4.0b1"
  }
}

In [11]:
coreml_model.save('./test02_model.mlmodel')

## Model Quantization

Quantize to: 16 and 8 bits

In [12]:
# Function to Quantize and Save the Models
def quantize_model(base_mode=None, quantization_bits=None):
    quantized_model = quantization_utils.quantize_weights(base_mode, nbits=quantization_bits)
    save_path = './test02_model_' + str(quantization_bits) + 'bit.mlmodel'
    quantized_model.short_description = str(quantization_bits) + ' bit quantized model.'
    quantized_model.save(save_path)

In [13]:
quantize_model(base_mode=coreml_model, quantization_bits=8)

Quantizing using linear quantization
Optimizing Neural Network before Quantization:
Finished optimizing network. Quantizing neural network..
Quantizing layer Conv_0


In [14]:
# Quantize the Model
quant_bits = [16, 8]

for bit in quant_bits:
    quantize_model(coreml_model, bit)

Quantizing using linear quantization
Quantizing layer Conv_0
Quantizing using linear quantization
Optimizing Neural Network before Quantization:
Finished optimizing network. Quantizing neural network..
Quantizing layer Conv_0


## CoreML Model Testing

1. Load the CoreML Model
2. Loop through the data and test for MSE

**GT image shape:**  (1, 3, 512, 768)

**y_hat shape:**  torch.Size([1, 12, 512, 768])

**Model prediction shape:**  (1, 3, 512, 768)

In [20]:
def test_coreml_model(quantization=32):
    if quantization == 32:
        ml_model = ct.models.MLModel('./test02_model.mlmodel')
    elif quantization == 16:
        ml_model = ct.models.MLModel('./test02_model_16bit.mlmodel')
    elif quantization == 8:
        ml_model = ct.models.MLModel('./test02_model_8bit.mlmodel')
    elif quantization == 4:
        ml_model = ct.models.MLModel('./test02_model_4bit.mlmodel')
    elif quantization == 2:
        ml_model = ct.models.MLModel('./test02_model_2bit.mlmodel')
    elif quantization == 1:
        ml_model = ct.models.MLModel('./test02_model_1bit.mlmodel')
    
    print(ml_model)

    tmp = []
    
    filelist_valid = np.sort([file for file in os.listdir(valid_dir) if file.endswith('.png')])

    for j in range(0, len(filelist_valid)):
        image = cv2.imread(valid_dir + filelist_valid[j]).astype(np.float32) / 255.0
        image = np.expand_dims(np.transpose(image, [2,0,1]), axis=0)
        print("GT image shape: ", image.shape)

        y_hat = np.load(y_hat_dir + filelist_valid[j][:-4] + ".npy")
        print("y_hat shape: ", y_hat.shape)

        pred = ml_model.predict({'y_hat': y_hat})

        model_prediction = np.asarray(pred['pred'])
        print("Model prediction shape: ", model_prediction.shape)

        mse = np.mean((image - model_prediction) ** 2) * 255.0 ** 2
        print(f"Image: {filelist_valid[j]}, MSE: {mse}")
        tmp.append(mse)
        
        print("\n\n")
    
    print("Model Quantization: ", quantization)
    print("MSE Values: ",tmp)

In [21]:
# Test all Models
test_coreml_model(quantization = 32)

input {
  name: "y_hat"
  shortDescription: "Input Image"
  type {
    multiArrayType {
      shape: 1
      shape: 12
      shape: 512
      shape: 768
      dataType: FLOAT32
    }
  }
}
output {
  name: "pred"
  shortDescription: "Compressed Image Output"
  type {
    multiArrayType {
      shape: 1
      shape: 3
      shape: 512
      shape: 768
      dataType: FLOAT32
    }
  }
}
metadata {
  shortDescription: "FP-32 model."
  userDefined {
    key: "com.github.apple.coremltools.source"
    value: "onnx==1.7.0"
  }
  userDefined {
    key: "com.github.apple.coremltools.version"
    value: "4.0b1"
  }
}

GT image shape:  (1, 3, 512, 768)
y_hat shape:  (1, 12, 512, 768)
Model prediction shape:  (1, 3, 512, 768)
Model Predictions:  [[[[0.32592773 0.34472656 0.34277344 ... 0.11071777 0.12188721
    0.13989258]
   [0.35375977 0.37426758 0.3425293  ... 0.11474609 0.12390137
    0.1340332 ]
   [0.38793945 0.37695312 0.34204102 ... 0.12298584 0.12878418
    0.1315918 ]
   ...
   [0.15246

Model prediction shape:  (1, 3, 512, 768)
Model Predictions:  [[[[0.32739258 0.3371582  0.32177734 ... 0.32861328 0.32617188
    0.3474121 ]
   [0.3647461  0.4243164  0.37524414 ... 0.35498047 0.3623047
    0.3684082 ]
   [0.34594727 0.39526367 0.3371582  ... 0.37670898 0.37817383
    0.3461914 ]
   ...
   [0.32836914 0.33081055 0.3635254  ... 0.35913086 0.32080078
    0.2692871 ]
   [0.25024414 0.23303223 0.22729492 ... 0.2475586  0.26586914
    0.2512207 ]
   [0.22167969 0.19970703 0.19445801 ... 0.19726562 0.19506836
    0.18786621]]

  [[0.33520508 0.33911133 0.32128906 ... 0.32202148 0.32617188
    0.3569336 ]
   [0.375      0.43115234 0.3930664  ... 0.35131836 0.3684082
    0.3881836 ]
   [0.3564453  0.40576172 0.36376953 ... 0.38891602 0.39990234
    0.38354492]
   ...
   [0.3305664  0.3544922  0.40014648 ... 0.3942871  0.36132812
    0.31445312]
   [0.26367188 0.26904297 0.27612305 ... 0.2919922  0.31298828
    0.29882812]
   [0.23803711 0.23791504 0.24133301 ... 0.24243164 0.2

Model prediction shape:  (1, 3, 512, 768)
Model Predictions:  [[[[0.32104492 0.3581543  0.35351562 ... 0.38305664 0.3605957
    0.32836914]
   [0.35473633 0.40283203 0.3737793  ... 0.37329102 0.3544922
    0.33276367]
   [0.3737793  0.39086914 0.35888672 ... 0.3815918  0.40185547
    0.41479492]
   ...
   [0.35351562 0.35180664 0.36914062 ... 0.33447266 0.35351562
    0.36010742]
   [0.2709961  0.25854492 0.25219727 ... 0.2397461  0.2553711
    0.25561523]
   [0.24414062 0.19787598 0.18334961 ... 0.19189453 0.2019043
    0.21069336]]

  [[0.32250977 0.35766602 0.35742188 ... 0.38305664 0.3540039
    0.3227539 ]
   [0.3618164  0.4025879  0.3786621  ... 0.38110352 0.35351562
    0.328125  ]
   [0.38598633 0.3955078  0.36694336 ... 0.3857422  0.3942871
    0.39990234]
   ...
   [0.3347168  0.34472656 0.3798828  ... 0.35205078 0.38305664
    0.39794922]
   [0.26367188 0.2668457  0.28051758 ... 0.27026367 0.29589844
    0.30395508]
   [0.24536133 0.2199707  0.22680664 ... 0.23059082 0.24694

Model prediction shape:  (1, 3, 512, 768)
Model Predictions:  [[[[0.3256836  0.37280273 0.37695312 ... 0.35205078 0.34643555
    0.35253906]
   [0.3564453  0.39672852 0.36645508 ... 0.35498047 0.36132812
    0.359375  ]
   [0.3798828  0.39501953 0.35888672 ... 0.37963867 0.39941406
    0.38232422]
   ...
   [0.33496094 0.3227539  0.3527832  ... 0.3779297  0.3359375
    0.27172852]
   [0.24890137 0.2310791  0.22033691 ... 0.26245117 0.27026367
    0.25317383]
   [0.21289062 0.19750977 0.20153809 ... 0.19152832 0.17370605
    0.18286133]]

  [[0.3256836  0.3774414  0.39233398 ... 0.34936523 0.34179688
    0.35058594]
   [0.36157227 0.40014648 0.3798828  ... 0.36010742 0.36645508
    0.36743164]
   [0.39013672 0.40185547 0.37402344 ... 0.39013672 0.41015625
    0.39526367]
   ...
   [0.34448242 0.359375   0.39746094 ... 0.42211914 0.3864746
    0.3293457 ]
   [0.2697754  0.27783203 0.27539062 ... 0.31713867 0.3269043
    0.31054688]
   [0.23498535 0.24072266 0.24597168 ... 0.25       0.23

Model prediction shape:  (1, 3, 512, 768)
Model Predictions:  [[[[0.16955566 0.15222168 0.13635254 ... 0.14782715 0.1665039
    0.19824219]
   [0.29003906 0.33007812 0.30615234 ... 0.3100586  0.27416992
    0.23461914]
   [0.35717773 0.41577148 0.4074707  ... 0.39868164 0.35424805
    0.30541992]
   ...
   [0.3713379  0.37719727 0.3684082  ... 0.37451172 0.38891602
    0.35302734]
   [0.35864258 0.36669922 0.3581543  ... 0.34936523 0.375
    0.3544922 ]
   [0.3322754  0.33398438 0.36108398 ... 0.3564453  0.35083008
    0.31689453]]

  [[0.1652832  0.1463623  0.12683105 ... 0.15686035 0.1932373
    0.23156738]
   [0.29223633 0.32836914 0.30395508 ... 0.32885742 0.3154297
    0.28588867]
   [0.3725586  0.42749023 0.4189453  ... 0.42089844 0.39648438
    0.359375  ]
   ...
   [0.36108398 0.37036133 0.36621094 ... 0.38354492 0.39648438
    0.36987305]
   [0.35620117 0.36669922 0.3635254  ... 0.3623047  0.38720703
    0.37402344]
   [0.33862305 0.3466797  0.37817383 ... 0.37475586 0.3669433