# Imports

In [None]:
from __future__ import print_function
import argparse
from tensorflow.keras.layers import Input
import scipy.misc

import tensorflow_model_optimization as tfmot

import numpy as np
import os

import PIL
import tensorflow as tf
import random
import re
from tensorflow.python.framework.ops import enable_eager_execution
enable_eager_execution()

import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.applications.mobilenet import MobileNet
from tensorflow.keras.applications.densenet import DenseNet121
import tensorflow_datasets as tfds

## Environment variables

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"]= "5"
os.environ['TF_DETERMINISTIC_OPS'] = '1'

# Load Data

In [None]:
def preprocess_image(features):
    """Preprocesses the given image.
       will convert the images from RGB to BGR, then will zero-center each color channel with respect to the ImageNet dataset, without scaling.
       mean = [103.939, 116.779, 123.68]
       std = None
  """
    image = features["image"]
    image = tf.image.resize(image,[224,224])
    features["image"] = image
    return features

In [None]:
BATCH_SIZE = 50

In [None]:
tfds_dataset2, tfds_info  = tfds.load(name='imagenet2012_subset', split='validation[-60%:]', with_info=True,
                                     data_dir='../../datasets/ImageNet') # use the last 20% of images among 50000 validation images for testing

In [None]:
figs = tfds.show_examples(tfds_dataset2, tfds_info)

In [None]:
val_ds = tfds_dataset2.map(preprocess_image).batch(BATCH_SIZE)

## Load Models

In [None]:
def normalize(x):
    # utility function to normalize a tensor by its L2 norm
    return x / (K.sqrt(K.mean(K.square(x))) + 1e-5)

In [None]:
# input image dimensions
img_rows, img_cols = 224 ,224
input_shape = (img_rows, img_cols, 3)

### ResNet Model

In [None]:
#ResNet Model

model_ = ResNet50(input_shape=input_shape)
q_model = tfmot.quantization.keras.quantize_model(model_)
model = ResNet50(input_tensor = q_model.input)
model.load_weights("../../weights/fp_model_resnet50.h5")
q_model.load_weights("../../weights/q_model_resnet50.h5")
model.trainable = False
q_model.trainable = False
sb_model = ResNet50(input_tensor = q_model.input)
sb_model.load_weights("../../weights/d_model_resnet50.h5")
sb_model.trainable = False
print("ResNet Done")

In [None]:
model.compile()
q_model.compile()
sb_model.compile()

### Mobilenet Model

In [None]:
#MobileNet Model

mob_model_ = MobileNet(input_shape=input_shape)
mob_q_model = tfmot.quantization.keras.quantize_model(mob_model_)
mob_model = MobileNet(input_tensor = mob_q_model.input)
mob_model.load_weights("../../weights/fp_model_mobilenet.h5")
mob_q_model.load_weights("../../weights/q_model_mobilenet.h5")
mob_model.trainable = False
mob_q_model.trainable = False
sb_mob_model = MobileNet(input_tensor = mob_q_model.input)
sb_mob_model.load_weights("../../weights/d_model_mobilenet.h5")
sb_mob_model.trainable = False
print("MobileNet Done")

In [None]:
mob_model.compile()
mob_q_model.compile()
sb_mob_model.compile()

### DenseNet Model

In [None]:
#Generate Custom DenseNet layers to support quantization

class DefaultBNQuantizeConfig(tfmot.quantization.keras.QuantizeConfig):

    def get_weights_and_quantizers(self, layer):
        return []

    def get_activations_and_quantizers(self, layer):
        return []

    def set_quantize_weights(self, layer, quantize_weights):
        pass
    def set_quantize_activations(self, layer, quantize_activations):
        pass
    def get_output_quantizers(self, layer):
        return [tfmot.quantization.keras.quantizers.MovingAverageQuantizer(num_bits=8, per_axis=False, symmetric=False, narrow_range=False)]

    def get_config(self):
        return {}
    
    
class NoOpQuantizeConfig(tfmot.quantization.keras.QuantizeConfig):
    """Use this config object if the layer has nothing to be quantized for 
    quantization aware training."""

    def get_weights_and_quantizers(self, layer):
        return []

    def get_activations_and_quantizers(self, layer):
        return []

    def set_quantize_weights(self, layer, quantize_weights):
        pass

    def set_quantize_activations(self, layer, quantize_activations):
        pass

    def get_output_quantizers(self, layer):
        # Does not quantize output, since we return an empty list.
        return []

    def get_config(self):
        return {}
    
    
def apply_quantization(layer):
    if 'bn'  in layer.name:
        return tfmot.quantization.keras.quantize_annotate_layer(layer,DefaultBNQuantizeConfig())
    elif 'concat' in layer.name:
        return tfmot.quantization.keras.quantize_annotate_layer(layer,NoOpQuantizeConfig())
    else:
        return tfmot.quantization.keras.quantize_annotate_layer(layer)

In [None]:
dense_model_ = tf.keras.applications.DenseNet121(input_shape=(img_rows, img_cols,3))
# Create a base model
base_model = dense_model_
# Helper function uses `quantize_annotate_layer` to annotate that only the 
# Dense layers should be quantized.

LastValueQuantizer = tfmot.quantization.keras.quantizers.LastValueQuantizer
MovingAverageQuantizer = tfmot.quantization.keras.quantizers.MovingAverageQuantizer

# Use `tf.keras.models.clone_model` to apply `apply_quantization_to_dense` 
# to the layers of the model.
annotated_model = tf.keras.models.clone_model(
    base_model,
    clone_function=apply_quantization,
)

with tfmot.quantization.keras.quantize_scope({'DefaultBNQuantizeConfig': DefaultBNQuantizeConfig, 'NoOpQuantizeConfig': NoOpQuantizeConfig}):
    dense_q_model = tfmot.quantization.keras.quantize_apply(annotated_model)

In [None]:
dense_model = tf.keras.applications.DenseNet121(input_tensor = q_model.input)
dense_model.load_weights("../../weights/fp_model_densenet.h5")
dense_q_model.load_weights("../../weights/q_model_densenet.h5")
dense_model.trainable = False
dense_q_model.trainable = False
sb_dense_model = DenseNet121(input_tensor = dense_q_model.input)
sb_dense_model.load_weights("../../weights/d_model_densenet.h5")
sb_dense_model.trainable = False
print("Dense Done")

In [None]:
dense_model.compile()
dense_q_model.compile()
sb_dense_model.compile()

## Generate New Data

In [None]:
label_map = {} #Mapping from label images to count
number = 4 #Number of images per class group

In [None]:
def check(label_map, number):
    progress = 0
    for i in label_map.keys():
        progress += len(label_map[i])
    
    print(progress/(number*1000))   

In [None]:
# Task of creating images that all agree on the same label
def work(image,file,label):
    
    image_copy = np.copy(image)
    
    res_image_ =  np.expand_dims(tf.keras.applications.resnet.preprocess_input(image_copy), axis=0)
        
    orig_logist = model.predict(res_image_)
    q_logist = q_model.predict(res_image_)
    orig_logists = sb_model.predict(res_image_)
    label1 = np.argmax(orig_logist)
    label2 = np.argmax(q_logist)
    label3 = np.argmax(orig_logists)
    
    image_copy = np.copy(image)
    
    mob_image_ =  np.expand_dims(tf.keras.applications.mobilenet.preprocess_input(image_copy), axis=0)
        
    orig_logist = mob_model.predict(mob_image_)
    q_logist = mob_q_model.predict(mob_image_)
    orig_logists = sb_mob_model.predict(mob_image_)
    label4 = np.argmax(orig_logist)
    label5 = np.argmax(q_logist)
    label6 = np.argmax(orig_logists)
    
    image_copy = np.copy(image)
    
    den_image_ =  np.expand_dims(tf.keras.applications.densenet.preprocess_input(image_copy), axis=0)
        
    orig_logist = dense_model.predict(den_image_)
    q_logist = dense_q_model.predict(den_image_)
    orig_logists = sb_dense_model.predict(den_image_)
    label7 = np.argmax(orig_logist)
    label8 = np.argmax(q_logist)
    label9 = np.argmax(orig_logists)
    
    # We generate all the labels and compare them in order to create a suitable dataset
    all_labels = set([label1,label2,label3,label4,label5,label6,label7,label8,label9,label])
    
    if len(all_labels) != 1:
        print("Res",[label1,label2,label3])
        print("Mob",[label4,label5,label6])
        print("Dense",[label7,label8,label9])
        print("Correct", label)
        return True
    
    return False
    

In [None]:
def generate_new_data():
    for i,images in enumerate(val_ds):
    
    print("% OF IMAGES SEEN: "+str(i/600))
    
    for j, image in enumerate(images['image']):

        image = images['image'][j].numpy()
        file = images['file_name'][j].numpy()
        label = images['label'][j].numpy()
        
        if label not in label_map.keys():
            
            if work(image,file,label):
                continue
            print("found:" + str(label))
            label_map[label] = [(image,file,label)]
            
        elif len(label_map[label]) < number:
            
            if work(image,file,label):
                continue
                
            label_map[label] = label_map[label] + [(image,file,label)]
            print("count:" + str(label)+","+str(len(label_map[label])))
    
    check(label_map,number)
    
    file_data = []
    image_data = []
    label_data = []
    
    for i in range(0,1000):
        try:
            a = [s[0] for s in label_map[i]]
            b = [s[1] for s in label_map[i]]
            c = [s[2] for s in label_map[i]]
            image_data = image_data + a[:3]
            file_data = file_data + b[:3]
            label_data = label_data + c[:3]
        except:
            print(i)
        
    for i in range(0,1000):
        try:
            a = [s[0] for s in label_map[i]]
            b = [s[1] for s in label_map[i]]
            c = [s[2] for s in label_map[i]]
            image_data = image_data + [a[3]]
            file_data = file_data + [b[3]]
            label_data = label_data + [c[3]]
        except:
            print(i)

        if len(image_data) == 3000:
            print("DONE")
            break
        
        print(len(image_data))

    file_data_ = np.array(file_data)
    image_data_ = np.array(image_data)
    label_data_ = np.array(label_data)
    
    KImagePerClass = tf.data.Dataset.from_tensor_slices({"file_name":file_data_,"image":image_data_, "label":label_data_})
    tf.data.experimental.save(KImagePerClass, "../../datasets/Imagenet/quantisation/3KImagePerClass", compression=None, shard_func=None)
    print(KImagePerClass.element_spec)
     

In [None]:
generate_new_data()

## Test loading new data

In [None]:
es = {'file_name': tf.TensorSpec(shape=(), dtype=tf.string, name=None),
 'image': tf.TensorSpec(shape=(224, 224, 3), dtype=tf.float32, name=None),
 'label': tf.TensorSpec(shape=(), dtype=tf.int64, name=None)}
mydataset = tf.data.experimental.load("../../datasets/Imagenet/quantisation/3KImagePerClass",es).batch(50)