# Script to compress existing models

In [None]:
import onnx
from onnx_tf.backend import prepare

import torchvision
import torch
import os
import numpy as np
import tensorflow as tf
import json
from PIL import Image
import onnx
import matplotlib.pyplot as plt

In [None]:
# set up the files you are interested in
image_file='/bask/homes/f/fspo1218/amber/data/gbif_download_standalone/gbif_images/Noctuidae/Spodoptera/Spodoptera exigua/1211977745.jpg'

region = 'costarica'

# Label info for the species names for the uk macro moths
f = open(f"/bask/homes/f/fspo1218/amber/data/gbif_{region}/02_{region}_data_numeric_labels.json")
label_info = json.load(f)
label_info = label_info['species_list']
species_list_mila = list(label_info)
print(len(species_list_mila), " species in total")

num_classes = len(species_list_mila)


files = os.listdir("/bask/homes/f/fspo1218/amber/projects/on_device_classifier/outputs/")
PATH = os.path.join("/bask/homes/f/fspo1218/amber/projects/on_device_classifier/outputs/",
               [file for file in files if region in file and 'resnet50' in file and 'state' not in file][1])
print(PATH)

device = torch.device('cpu')

output_dir = f'/bask/homes/f/fspo1218/amber/data/compressed_models/gbif_{region}/'
os.makedirs(output_dir, exist_ok=True)

In [None]:
def pytorch_to_tflite(model, output_dir, image, output_model_prefix="model"):

    # convert the model to onnx
    print("Converting to onnx")

    onnx_path = output_dir + "/" + output_model_prefix + ".onnx"
    torch.onnx.export(
            model=model.eval(),
            args=image.unsqueeze(0),
            f=onnx_path,
            verbose=False,
            export_params=True,
            do_constant_folding=False,
            input_names=['input'],
            opset_version=12,
            output_names=['output']
    )

    # Convert to tf
    print("Converting to tensorflow...")
    tf_path = output_dir + "/tf_" + output_model_prefix
    onnx_model = onnx.load(onnx_path)
    onnx.checker.check_model(onnx_model)
    tf_rep = prepare(onnx_model, device='CPU')
    tf_rep.export_graph(tf_path)

    # Convert to tfLite
    print("Converting to tensorflowlite")
    converter = tf.lite.TFLiteConverter.from_saved_model(tf_path)
    converter.experimental_new_converter = True
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
    converter.allow_custom_ops=True
    tflite_model = converter.convert()

    print("Saving converted model")
    with open(output_dir + "/" + output_model_prefix + ".tflite", 'wb') as f:
        f.write(tflite_model)

    return tflite_model

In [None]:
image = Image.open(image_file)

# Transform
transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize((300, 300)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ]
)
img = transform(image)


## MILA species classifier

In [None]:
import torch, torchvision

import sys
sys.path.append('/bask/homes/f/fspo1218/amber/projects/species_classifier/')
sys.path.append('/bask/homes/f/fspo1218/amber/projects/species_classifier/models/')
sys.path.append('/bask/homes/f/fspo1218/amber/projects/species_classifier/data2/')
sys.path.append('/bask/homes/f/fspo1218/amber/projects/species_classifier/evaluation/')

from data2 import dataloader
import evaluation

In [None]:
if 'efficientnet' in PATH:
    model_py_mila = models.efficientnet_b0(pretrained=True)
    model_py_mila = model_py_mila.to(device)
    checkpoint = torch.load(PATH, map_location=device)
    model_py_mila.eval()

elif 'resnet' in PATH:
    model_py_mila = torchvision.models.resnet50(weights=None)
    num_ftrs = model_py_mila.fc.in_features
    model_py_mila.fc = torch.nn.Linear(num_ftrs, num_classes)
    model_py_mila = model_py_mila.to(device)
    model_py_mila = torch.load(PATH, map_location=device)
    model_py_mila.eval()

else:
    print('clarify model type')

print("loaded MILA model")

In [None]:
print(PATH)

In [None]:
# save the model state_dict
#torch.save(model_py_mila.state_dict(), PATH.replace('resnet50', 'state_resnet50'))

In [None]:
pref = 'resnet_' + region

tflite_model = pytorch_to_tflite(model_py_mila,
                  output_dir=output_dir,
                  image=img,
                  output_model_prefix=pref)

In [None]:
os.path.getsize(PATH) / 1e6

# Localisation Models

In [None]:
# load in the localizmodel
weights_path = "/bask/homes/f/fspo1218/amber/data/mila_models/v1_localizmodel_2021-08-17-12-06.pt"

model_loc = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=None)
num_classes = 2  # 1 class (object) + background
in_features = model_loc.roi_heads.box_predictor.cls_score.in_features
model_loc.roi_heads.box_predictor = (
    torchvision.models.detection.faster_rcnn.FastRCNNPredictor(
        in_features, num_classes
    )
)

checkpoint = torch.load(weights_path, map_location=device)
state_dict = checkpoint.get("model_state_dict") or checkpoint
model_loc.load_state_dict(state_dict)
model_loc = model_loc.to(device)
#model_loc.eval()

In [None]:
output_dir2 = output_dir.split('gbif_', 1)[0]

print(output_dir2)

In [None]:
import tensorflow as tf
import onnx
from onnx_tf.backend_rep import TensorflowRep

In [None]:
pref = 'localisation_' + region

model = model_loc
image=img
output_model_prefix=pref

# convert the model to onnx
print("Converting to onnx")


model.eval()

# Example input tensor (replace this with your own input tensor)
input_tensor = img.unsqueeze(0)

# Display the shape of the input tensor
print("Input tensor shape:", input_tensor.shape)

# # Export the PyTorch model to ONNX format
onnx_path = output_dir2 + "faster_rcnn.onnx"
torch.onnx.export(model, input_tensor, onnx_path, verbose=True)

onnx_model = onnx.load(onnx_path)

# # Prepare the ONNX model for conversion to TensorFlow
# tf_rep = prepare(onnx_model)

# # Convert the ONNX model to TensorFlow Lite
# tflite_path = "faster_rcnn.tflite"
# converter = tf.lite.TFLiteConverter.from_concrete_functions([tf_rep.graph.as_graph_def(add_shapes=True)])
# tflite_model = converter.convert()

# # Save the TFLite model to a file
# with open(tflite_path, 'wb') as f:
#     f.write(tflite_model)


### Example Inference

In [None]:
fig, axs = plt.subplots(1, 1, figsize=(5, 5))

axs.imshow(img.permute(1, 2, 0))
axs.axis('off')

In [None]:
import datetime

def pytorch_inference(image, test_model, print_time=False):
    a = datetime.datetime.now()
    output = test_model(image.unsqueeze(0))
    predictions = torch.nn.functional.softmax(output, dim=1)
    predictions = predictions.detach().numpy()

    categories = predictions.argmax(axis=1)
    #print(categories)
    b = datetime.datetime.now()
    c = b - a
    if print_time: print(str(c.microseconds) + "\u03bcs")
    return(categories[0])

def tflite_inference(image, interpreter, print_time=False):
    a = datetime.datetime.now()
    interpreter.set_tensor(input_details[0]['index'], image.unsqueeze(0))
    interpreter.invoke()
    outputs_tf = interpreter.get_tensor(output_details[0]['index'])
    prediction_tf = np.squeeze(outputs_tf)
    prediction_tf = prediction_tf.argsort()[::-1][0]
    #print(prediction_tf)
    b = datetime.datetime.now()
    c = b - a
    if print_time: print(str(c.microseconds) + "\u03bcs")
    return(prediction_tf)


In [None]:
pytorch_inf = pytorch_inference(img, model_py_mila)

In [None]:
# Load the TFLite model and allocate tensors.
pref2 = pref + '.tflite'
interpreter = tf.lite.Interpreter(model_path=os.path.join(output_dir, pref2))

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

interpreter.allocate_tensors()

In [None]:
tflite_inf = tflite_inference(img, interpreter)

In [None]:
print('TFlite says', species_list_mila[tflite_inf])
print('Pytorch says', species_list_mila[pytorch_inf])
print('Truth says', os.path.basename(os.path.dirname(image_file)))

# Inference on Test Data

In [None]:
config_file = f'/bask/homes/f/fspo1218/amber/projects/on_device_classifier/configs/01_{region}_data_config.json'
f = open(config_file)
config_data = json.load(f)

In [None]:
len_test = os.listdir(f'/bask/homes/f/fspo1218/amber/data/gbif_{region}/test/')

# subset to only those of format 'test-*.tar'
len_test = [file for file in len_test if 'test-' in file and '.tar' in file]
len_test = len(len_test) - 1

# padd the number to 6 digits
len_test = str(len_test).zfill(6)

print(len_test)

In [None]:
image_resize = config_data["training"]["image_resize"]
batch_size = config_data["training"]["batch_size"]
label_list = config_data["dataset"]["label_info"]
preprocess_mode = config_data["model"]["preprocess_mode"]
taxon_hierar = config_data["dataset"]["taxon_hierarchy"]
label_info= config_data["dataset"]["label_info"]

pass_str = '/bask/homes/f/fspo1218/amber/data/gbif_' + region + '/test/test-500-{000000..' + len_test + '}.tar'

# Load in the test data
test_dataloader = dataloader.build_webdataset_pipeline(
        sharedurl=pass_str,
        input_size=image_resize,
        batch_size=batch_size,
        is_training=False,
        num_workers=4,
        preprocess_mode=preprocess_mode,
    )
print("images loaded")

In [None]:
interpreter = tf.lite.Interpreter(model_path=os.path.join(output_dir, pref2))

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

interpreter.allocate_tensors()

In [None]:
# Define your preprocess_for_tflite and postprocess_for_tflite functions accordingly
def preprocess_for_tflite(image_batch):
    return np.array(image_batch)

def postprocess_for_tflite(output_data):
    return torch.tensor(output_data)

In [None]:
from evaluation import micro_accuracy_batch
from evaluation import macro_accuracy_batch
from evaluation import confusion_data_conversion
from evaluation import confusion_matrix_data

In [None]:
no_iterations = sum(1 for _ in enumerate(test_dataloader))

print('There are ', no_iterations, ' iterations')

PSA: the next cell takes around 20 mins for 94 batches (64 images each)

In [None]:
py_global_microacc_data = None
py_global_macroacc_data = None
py_global_confusion_data_sp = None
py_global_confusion_data_g = None
py_global_confusion_data_f = None

tf_global_microacc_data = None
tf_global_macroacc_data = None
tf_global_confusion_data_sp = None
tf_global_confusion_data_g = None
tf_global_confusion_data_f = None
i=1

for image_batch, label_batch in test_dataloader:
    print(i, '/', no_iterations)
    i = i + 1

    image_batch, label_batch = image_batch.to(device), label_batch.to(device)
    py_predictions = model_py_mila(image_batch)

    # Preprocess the input image_batch for TensorFlow Lite model
    # You need to replace this preprocessing logic based on your TensorFlow Lite model requirements
    input_data = preprocess_for_tflite(image_batch)

    # Run inference using TensorFlow Lite model for each image in the batch
    predictions_tflite_batch = []
    for single_input_data in input_data:
        single_input_data = np.expand_dims(single_input_data, axis=0)  # Add batch dimension
        interpreter.set_tensor(interpreter.get_input_details()[0]['index'], single_input_data)
        interpreter.invoke()
        output_data = interpreter.get_tensor(interpreter.get_output_details()[0]['index'])
        predictions_tflite_batch.append(output_data)

    # Stack predictions for the entire batch
    predictions_tflite_batch = np.vstack(predictions_tflite_batch)

    # Assuming `postprocess_for_tflite` is a function to postprocess the output_data
    # You need to replace this postprocessing logic based on your TensorFlow Lite model requirements
    tf_predictions = postprocess_for_tflite(predictions_tflite_batch)

    #predictions_tf = predictions_tflite_batch#.argmax(axis=1)



    # Pytorch Metrics
    # micro-accuracy calculation
    py_micro_accuracy = micro_accuracy_batch.MicroAccuracyBatch(
        py_predictions, label_batch, label_info, taxon_hierar
    ).batch_accuracy()
    py_global_microacc_data = micro_accuracy_batch.add_batch_microacc(
        py_global_microacc_data, py_micro_accuracy
    )
    # macro-accuracy calculation
    py_macro_accuracy = macro_accuracy_batch.MacroAccuracyBatch(
        py_predictions, label_batch, label_info, taxon_hierar
    ).batch_accuracy()
    py_global_macroacc_data = macro_accuracy_batch.add_batch_macroacc(
        py_global_macroacc_data, py_macro_accuracy
    )

    # confusion matrix
    (
        py_sp_label_batch,
        py_sp_predictions,
        py_g_label_batch,
        py_g_predictions,
        py_f_label_batch,
        py_f_predictions,
    ) = confusion_data_conversion.ConfusionDataConvert(
        py_predictions, label_batch, label_info, taxon_hierar
    ).converted_data()

    py_global_confusion_data_sp = confusion_matrix_data.confusion_matrix_data(
        py_global_confusion_data_sp, [py_sp_label_batch, py_sp_predictions]
    )
    py_global_confusion_data_g = confusion_matrix_data.confusion_matrix_data(
        py_global_confusion_data_g, [py_g_label_batch, py_g_predictions]
    )
    py_global_confusion_data_f = confusion_matrix_data.confusion_matrix_data(
        py_global_confusion_data_f, [py_f_label_batch, py_f_predictions]
    )

    # TFLite Metrics
    # micro-accuracy calculation
    tf_micro_accuracy = micro_accuracy_batch.MicroAccuracyBatch(
        tf_predictions, label_batch, label_info, taxon_hierar
    ).batch_accuracy()
    tf_global_microacc_data = micro_accuracy_batch.add_batch_microacc(
        tf_global_microacc_data, tf_micro_accuracy
    )
    # macro-accuracy calculation
    tf_macro_accuracy = macro_accuracy_batch.MacroAccuracyBatch(
        tf_predictions, label_batch, label_info, taxon_hierar
    ).batch_accuracy()
    tf_global_macroacc_data = macro_accuracy_batch.add_batch_macroacc(
        tf_global_macroacc_data, tf_macro_accuracy
    )

    # confusion matrix
    (
        tf_sp_label_batch,
        tf_sp_predictions,
        tf_g_label_batch,
        tf_g_predictions,
        tf_f_label_batch,
        tf_f_predictions,
    ) = confusion_data_conversion.ConfusionDataConvert(
        tf_predictions, label_batch, label_info, taxon_hierar
    ).converted_data()

    tf_global_confusion_data_sp = confusion_matrix_data.confusion_matrix_data(
        tf_global_confusion_data_sp, [tf_sp_label_batch, tf_sp_predictions]
    )
    tf_global_confusion_data_g = confusion_matrix_data.confusion_matrix_data(
        tf_global_confusion_data_g, [tf_g_label_batch, tf_g_predictions]
    )
    tf_global_confusion_data_f = confusion_matrix_data.confusion_matrix_data(
        tf_global_confusion_data_f, [tf_f_label_batch, tf_f_predictions]
    )

In [None]:
py_global_confusion_data_f

In [None]:
label_read = json.load(open(label_list))

In [None]:
import pandas as pd

py_final_micro_accuracy = micro_accuracy_batch.final_microacc(py_global_microacc_data)
py_final_macro_accuracy, py_taxon_acc = macro_accuracy_batch.final_macroacc(py_global_macroacc_data)

tf_final_micro_accuracy = micro_accuracy_batch.final_microacc(tf_global_microacc_data)
tf_final_macro_accuracy, tf_taxon_acc = macro_accuracy_batch.final_macroacc(tf_global_macroacc_data)

tf_tax_accuracy = macro_accuracy_batch.taxon_accuracy(tf_taxon_acc, label_read)
py_tax_accuracy = macro_accuracy_batch.taxon_accuracy(py_taxon_acc, label_read)

print(py_final_micro_accuracy, py_final_macro_accuracy)
print(tf_final_micro_accuracy, tf_final_macro_accuracy)

# saving evaluation data to file
confdata_pd_f = pd.DataFrame(
    {
        "F_Truth": py_global_confusion_data_f[0].reshape(-1),
        "F_Py_Prediction": py_global_confusion_data_f[1].reshape(-1),
        "F_Tf_Prediction": tf_global_confusion_data_f[1].reshape(-1),
    }
)
confdata_pd_g = pd.DataFrame(
    {
        "G_Truth": py_global_confusion_data_g[0].reshape(-1),
        "G_Py_Prediction": py_global_confusion_data_g[1].reshape(-1),
        "G_Tf_Prediction": tf_global_confusion_data_g[1].reshape(-1),
    }
)
confdata_pd_sp = pd.DataFrame(
    {
        "S_Truth": py_global_confusion_data_sp[0].reshape(-1),
        "S_Py_Prediction": py_global_confusion_data_sp[1].reshape(-1),
        "S_Tf_Prediction": tf_global_confusion_data_sp[1].reshape(-1),
    }
)
confdata_pd = pd.concat([confdata_pd_f, confdata_pd_g, confdata_pd_sp], axis=1)


In [None]:
# save the outputs

confdata_pd.to_csv('./outputs/' + region + '_resnet' + "_v2.0" + "_confusion-data.csv", index=False)

with open(
    './outputs/' + region + '_resnet' + "_v2.0" + "_micro-accuracy.json", "w"
) as outfile:
   json.dump( {'Pytorch': py_final_micro_accuracy, 'TFLite': tf_final_micro_accuracy}, outfile)

with open(
    './outputs/' + region + '_resnet' + "_v2.0" + "_macro-accuracy.json", "w"
) as outfile:
   json.dump( {'Pytorch': py_final_macro_accuracy, 'TFLite': tf_final_macro_accuracy}, outfile)

with open(
    './outputs/' + region + '_resnet' + "_v2.0" + "_taxon-accuracy.json", "w"
) as outfile:
    json.dump( {'Pytorch': py_tax_accuracy, 'TFLite': tf_tax_accuracy}, outfile)
