# Model Conversion (TF 2.6.2)

In [None]:
import numpy as np
import tensorflow as tf
from pathlib import Path
from PIL import Image

from utils.data import DIV2K

from utils.srgan_tools import resolve_single
from utils.srgan import generator
#from utils.model.srgan_old import generator as generator_old

from utils.tools import load_image, plot_sample

import numpy as np
import tensorflow as tf
from PIL import Image

import matplotlib.pyplot as plt 

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
# gpus = tf.config.experimental.list_physical_devices('GPU')
# tf.config.experimental.set_visible_devices(gpus[1], 'GPU')
# tf.config.experimental.set_memory_growth(gpus[1], True)

In [None]:
model_dir = Path('weights/bin')

In [None]:
IMAGE_SIZE=(80,60,3)

## Build Model

In [None]:
model = generator(scale=4,
                  num_filters=64, 
                  num_res_blocks=8, 
                  shape=(IMAGE_SIZE), 
                  batch_size=1, 
                  batch_norm=True, 
                  activation='prelu',
                  upsampling='PixelShuffle',
                  return_features=False)
#model.summary()
model.compile() 
#model.load_weights('weights/srgan/pre_generator_small.h5')
#model.load_weights('weights/srgan/pre_generator_small.h5')

#model = generator_old(num_filters=64, num_res_blocks=16, shape=IMAGE_SIZE)
model.load_weights('weights/srgan/edgesrgan.h5', by_name=True, skip_mismatch=True)

In [None]:
model.input_shape

In [None]:
model.summary()

## Convert to TFLite

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.experimental_new_converter = True
tflite_model = converter.convert()

name_model_tflite = 'edgesrgan.tflite'
tflite_model_file = model_dir.joinpath(name_model_tflite)                          
tflite_model_file.write_bytes(tflite_model)

## Test TFLite Model

In [None]:
interpreter = tf.lite.Interpreter(model_path="weights/bin/edgesrgan.tflite")
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Test the model on input data.
input_shape = input_details[0]['shape']
print(input_shape)
image = Image.open('./figures/kd.png').convert('RGB')
image = image.resize(input_shape[1:3])
arr = np.swapaxes(np.asarray(image, dtype='float32'), 0, 1)[None,...]
im_bicubic = image.resize(input_shape[1:3]*4, resample=Image.Resampling.BICUBIC)
arr_bicubic = np.swapaxes(np.asarray(im_bicubic, dtype='uint8'), 0, 1)
print(arr.shape)
interpreter.set_tensor(input_details[0]['index'], arr)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])

print(output_data.shape)
plot_sample(arr[0].astype('uint8'), output_data.astype('uint8')[0])
plot_sample(arr_bicubic, output_data.astype('uint8')[0])

im = Image.fromarray(output_data.astype('uint8')[0])
#im.save('./demo/img_sr.png')

im = Image.fromarray(arr.astype('uint8')[0])
#im.save('./demo/img_lr.png')

#im = Image.open('./demo/0829x4-crop.png')
#im.save('./demo/img_hr.png')

## Quantize to int8

In [None]:
def representative_dataset():
    div2k_valid = DIV2K(scale=4, resolution=IMAGE_SIZE[1::-1], subset='valid', downgrade='bicubic', 
                        data_dir='/home/simone/SR/sr-edge/dataset/div2k/')
    valid_ds = div2k_valid.dataset(batch_size=1, random_transform=False, repeat_count=1)
    for i in valid_ds:
        yield [tf.cast(i[1],tf.float32)]

In [None]:
# def representative_dataset():
#     for _ in range(100):
#         data = np.random.rand(IMAGE_SIZE[0], IMAGE_SIZE[1], IMAGE_SIZE[2])[None,...]*255
#         yield [data.astype(np.float32)]

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8  
converter.allow_custom_ops = True
converter.experimental_new_converter = True
converter.experimental_new_quantizer = True

tflite_quant_model = converter.convert()

name_model_tflite = 'srgan_converted_int8.tflite'
tflite_model_file = model_dir.joinpath(name_model_tflite)
tflite_model_file.write_bytes(tflite_quant_model)

## Test Quantized Model

In [None]:
#
interpreter = tf.lite.Interpreter(model_path="bin/srgan_converted_int8.tflite")
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Test the model on input data.
input_shape = input_details[0]['shape']
print(input_shape)
image = Image.open('./demo/0823x4-crop.jpeg')
image = image.resize(input_shape[1:3])

#arr = tf.convert_to_tensor(np.asarray(image, dtype='uint8'))
input_data = np.swapaxes(np.array(image), 0, 1)[None,...]

input_scale, input_zero_point = input_details[0]['quantization']
test_image_int = input_data / input_scale + input_zero_point
test_image_int=test_image_int.astype(input_details[0]['dtype'])

interpreter.set_tensor(input_details[0]['index'], test_image_int)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])

scale, zero_point = output_details[0]['quantization']
output_data = output_data.astype(np.float32)
output_data = (output_data- zero_point)* scale

print(input_scale)
#print(output_data)
print(input_details[0]["quantization"])

print(output_data.shape)
plot_sample(input_data[0].astype('uint8'), output_data.astype('uint8')[0])

im = Image.fromarray(output_data.astype('uint8')[0])
im.save('./demo/img_sr_quant.png')

In [None]:
#
interpreter = tf.lite.Interpreter(model_path="bin/srgan_converted_int8.tflite")
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Test the model on input data.
input_shape = input_details[0]['shape']
print(input_shape)
image = Image.open('./demo/0823x4-crop.jpeg')
image = image.resize(input_shape[1:3])

#arr = tf.convert_to_tensor(np.asarray(image, dtype='uint8'))
input_data = np.swapaxes(np.array(image), 0, 1)[None,...]

input_scale, input_zero_point = input_details[0]['quantization']
test_image_int = input_data / input_scale + input_zero_point
test_image_int=test_image_int.astype(input_details[0]['dtype'])

interpreter.set_tensor(input_details[0]['index'], test_image_int)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])

scale, zero_point = output_details[0]['quantization']
output_data = output_data.astype(np.float32)
output_data = (output_data- zero_point)* scale

print(input_scale)
#print(output_data)
print(input_details[0]["quantization"])

print(output_data.shape)
plot_sample(input_data[0].astype('uint8'), output_data.astype('uint8')[0])

im = Image.fromarray(output_data.astype('uint8')[0])
im.save('./demo/img_sr_quant.png')

In [None]:
input_details

## Compile for EdgeTPU

In [None]:
!edgetpu_compiler bin/srgan_converted_int8.tflite -o ./bin/ -sad

# Test TFLite float32

In [None]:
import numpy as np
import tensorflow as tf
import time

# Load the TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path='bin/srgan.tflite')
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Test the model on random input data.
input_shape = input_details[0]['shape']
image = Image.open('./demo/0823x4-crop.jpeg')
image = image.resize(input_shape[1:3])
input_data = np.swapaxes(np.array(image, dtype=np.float32), 0, 1)[None,...]
interpreter.set_tensor(input_details[0]['index'], input_data)

# Run inference
print('----INFERENCE TIME----')
lat = []
for _ in range(100):
    start = time.perf_counter()
    interpreter.invoke()
    inference_time = time.perf_counter() - start
    lat.append(inference_time)
    print('%.1fms' % (inference_time * 1000))
_ = lat.pop()
print(f'Average Speed: {1/np.mean(np.array(lat))} fps')

# The function `get_tensor()` returns a copy of the tensor data.
# Use `tensor()` in order to get a pointer to the tensor.
output_data = interpreter.get_tensor(output_details[0]['index'])
#print(output_data)


# LR Scheduler

In [None]:
import tensorflow as tf
from tensorflow.keras.optimizers import schedules
import matplotlib.pyplot as plt

In [None]:
N = 1000

scheds = [schedules.ExponentialDecay(3e-4, N, 0.01),
          schedules.InverseTimeDecay(1e-3, N, 0.1),
          schedules.PolynomialDecay(1e-3, N, 1e-5, power=2),
          schedules.PiecewiseConstantDecay([0],[1e-3,1e-3])
         ]

for sched in scheds:
    lr = []
    for i in range (N):
        lr.append(sched(i))
    plt.plot(lr)
    plt.show()
    print(lr[-1])

# Random Crop

In [None]:
import numpy as np
import tensorflow as tf
from pathlib import Path
from PIL import Image
    
from utils.srgan_tools import resolve_single
from utils.srgan import generator
from utils.tools import load_image, plot_sample

import numpy as np
import tensorflow as tf
from PIL import Image

import matplotlib.pyplot as plt 

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
from utils.data import *

In [None]:
def random_crop(lr_img, hr_img, hr_crop_size=(96,128), scale=4):
    lr_crop_size = tuple(s // scale for s in hr_crop_size)
    lr_img_shape = tf.shape(lr_img)[:2]
    
    print(lr_crop_size)
    print(lr_img_shape)
    
    lr_w = tf.random.uniform(shape=(), maxval=lr_img_shape[1] - lr_crop_size[1] + 1, dtype=tf.int32)
    lr_h = tf.random.uniform(shape=(), maxval=lr_img_shape[0] - lr_crop_size[0] + 1, dtype=tf.int32)

    print(lr_w.numpy(), lr_h.numpy())
    
    hr_w = lr_w * scale
    hr_h = lr_h * scale

    lr_img_cropped = lr_img[lr_h:lr_h + lr_crop_size[0], lr_w:lr_w + lr_crop_size[1]]
    hr_img_cropped = hr_img[hr_h:hr_h + hr_crop_size[0], hr_w:hr_w + hr_crop_size[1]]

    return lr_img_cropped, hr_img_cropped

In [None]:
hr = Image.open('../../super_resolution/div2k/images/DIV2K_valid_LR_bicubic/X4/0898x4.png')
lr = hr.resize(tuple(dim // 4 for dim in hr.size))
display(hr, lr)
print(hr.size, lr.size)

plt.imshow(tf.io.decode_jpeg(tf.io.encode_jpeg(hr)))
plt.show()
plt.imshow(hr)
plt.show()
plt.imshow(hr - tf.io.decode_jpeg(tf.io.encode_jpeg(hr)))
plt.show()

In [None]:
hr_t = tf.keras.preprocessing.image.img_to_array(hr)
lr_t = tf.keras.preprocessing.image.img_to_array(lr)

# [24 32  3] [ 96 128   3]

for i in range(100):
    crop_lr, crop_hr = random_crop(lr_t, hr_t)
    if (tf.shape(crop_lr).numpy()[0] != 24 or 
        tf.shape(crop_lr).numpy()[1] != 32 or 
        tf.shape(crop_hr).numpy()[0] != 96 or
        tf.shape(crop_hr).numpy()[1] != 128):
        print(tf.shape(crop_lr).numpy(), tf.shape(crop_hr).numpy())
        raise

In [None]:
plt.imshow(crop_hr.astype('uint8'))
plt.show()
plt.imshow(crop_lr.astype('uint8')) 
plt.show()

In [None]:
a = DIV2K(scale=4, resolution=(160*4,120*4), downgrade='bicubic', subset='valid',
          images_dir='/home/simone/SR/sr-edge/dataset/div2k/images',
          caches_dir='/home/simone/SR/sr-edge/dataset/div2k/caches')

In [None]:
ds = a.dataset(random_transform=True, batch_size=16, repeat_count=1)

In [None]:
len(ds)

In [None]:
for i in ds.take(20):
    print(i[0].shape, i[1].shape)

In [None]:
for i in ds:
    print(i[0][0].shape)
    plt.imshow(i[0][0])
    break

# Shit

In [None]:
from utils.tools import read_yaml
import pprint

In [None]:
config = read_yaml('utils/config.yaml')

In [None]:
def pretty(d, indent=0):
    for key, value in d.items():
        print('\t' * indent + str(key))
        if isinstance(value, dict):
            pretty(value, indent+1)
        else:
            print('\t' * (indent+1) + str(value))

In [None]:
pp = pprint.PrettyPrinter(depth=1)

In [None]:
k = [config['MODE']]

In [None]:
pp.pprint(config[config['MODE']])

In [None]:
pp.pformat(config[config['MODE']])

In [None]:
type(config[config['MODE']]['BATCH_SIZE'])

In [None]:
tuple(config['PATCH_SIZE'])

In [None]:
float(config[config['MODE']]['LR'])

In [None]:
from utils.srgan import generator
import tensorflow as tf
d = generator(num_filters=32, num_res_blocks=4, shape=(96,96,3))

In [None]:
d.summary()

In [None]:
o, f = d(np.random.rand(16,96,96,3))

In [None]:
f.shape

In [None]:
ff = tf.gather(f,[1,2,4])

In [None]:
ff.shape

In [None]:
def should_save(m):
    if not 'best_metr' in globals():
        return True
    elif m > best_metr:
        return 'PSNR' == 'PSNR'
    else:
        return 'PSNR' == 'NIQE'

In [None]:
should_save(11)

In [None]:
best_metr = 10