To run and test this you will need to: 

- source the species_classifier kernel
- download the following files: 
    - `01_uk_macro_data_numeric_labels.json`: for labeling the data
    - `turing-macro_v01_efficientnetv2-b3_*.pt`: the model files
    - `test-500-{000000..0000*}.tar` : the test data files

In [2]:
! pip install torchvision

Defaulting to user installation because normal site-packages is not writeable
Collecting torchvision
  Downloading torchvision-0.15.2-cp310-cp310-manylinux1_x86_64.whl (6.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.0/6.0 MB[0m [31m35.7 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting torch==2.0.1
  Downloading torch-2.0.1-cp310-cp310-manylinux1_x86_64.whl (619.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m619.9/619.9 MB[0m [31m12.3 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting numpy
  Downloading numpy-1.25.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.6/17.6 MB[0m [31m76.2 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting pillow!=8.3.*,>=5.3.0
  Downloading Pillow-10.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (3.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.4/3.4 MB[0m 

In [3]:
from torchvision import transforms
import torch
import os
import numpy as np
import tensorflow as tf
import csv
import json
from PIL import Image
import PIL
import onnx
from typing import Literal
from typing_extensions import Literal
import matplotlib.pyplot as plt
import timm

The history saving thread hit an unexpected error (OperationalError('disk I/O error')).History will not be written to the database.


ModuleNotFoundError: No module named 'torchvision'

In [None]:
import onnx
from onnx_tf.backend import prepare
from data2 import dataloader

In [None]:
# Load in the pytorch model
model_py = torch.load("./outputs/turing-macro_v01_efficientnetv2-b3_2023-06-27-10-45.pt", map_location='cpu')

In [None]:
# Load in the test data
test_dataloader = dataloader.build_webdataset_pipeline(
        sharedurl="./outputs/test-500-{000000..000013}.tar",
        input_size=1000,
        batch_size=64,
        is_training=False,
        num_workers=4,
        preprocess_mode="tf",
    )
print("images loaded")

image_dummy, label_dummy = next(iter(test_dataloader))
image_batch = image_dummy.to("cpu", non_blocking=True)
label_batch = label_dummy.to("cpu", non_blocking=True)



In [None]:
label = label_batch[0]
image = image_batch[0]
print(int(label))

In [None]:
# Label info for the species names
f = open("./outputs/01_uk_macro_data_numeric_labels.json")
label_info = json.load(f)
species_list = label_info["species_list"]
print(len(species_list), " species in total")

In [None]:
# convert to onnx
torch.onnx.export(
            model=model_py.eval(),
            args=image.unsqueeze(0),
            f="./outputs/onnx_file.onnx",
            verbose=False,
            export_params=True,
            do_constant_folding=False,
            input_names=['input'],
            opset_version=12,
            output_names=['output']
)

In [None]:
# Convert to tf
onnx_model = onnx.load("./outputs/onnx_file.onnx")
onnx.checker.check_model(onnx_model)
tf_rep = prepare(onnx_model, device='CPU')
tf_rep.export_graph("./outputs/tf_file")

We have a working tensorflow model. Lets test one image

In [None]:
def plot_predictions(image_index, ax):
    image = image_batch[image_index]
    image = image.unsqueeze(0)

    label = label_batch[image_index]

    output = tf_rep.run(image)

    true_str = "True: " + species_list[int(label)]
    pred_str = "Pred: " + species_list[np.argmax(output)]
    text = "Image " + str(image_index)

    ax.imshow(image_batch[image_index].permute(1, 2, 0))
    ax.tick_params(axis='both', which='both', bottom=False, 
                top=False, left=False, right=False, 
                labelbottom=False, labelleft=False) 

    # add annotation label to the plot
    ax.annotate(text, (50, 50), color='white')
    ax.annotate(true_str, (50, 850), color='white')
    ax.annotate(pred_str, (50, 900), color='white')

In [None]:
# matplotlib subplots 3x3
fig, axs = plt.subplots(3, 3, figsize=(15, 15))

plot_predictions(10, axs[0, 0])
plot_predictions(11, axs[0, 1])
plot_predictions(12, axs[0, 2])
plot_predictions(13, axs[1, 0])
plot_predictions(14, axs[1, 1])
plot_predictions(15, axs[1, 2])
plot_predictions(16, axs[2, 0])
plot_predictions(17, axs[2, 1])
plot_predictions(18, axs[2, 2])

Now lets convert this to s TF Lite model

In [None]:
# Convert to tfLite
converter = tf.lite.TFLiteConverter.from_saved_model("./outputs/tf_file")
converter.experimental_new_converter = True
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
converter.allow_custom_ops=True
tflite_model = converter.convert()

In [None]:
with open("./outputs/compressed_model.tflite", 'wb') as f:
    f.write(tflite_model)

# Testing 

Load the model

In [None]:
wandb.init(
    project="gbif",
    entity="kg-test", 
    tags="tflite"
)

wandb.init(settings=wandb.Settings(start_method="fork"))

In [None]:
model = tf.lite.Interpreter(model_path="./outputs/compressed_model.tflite")
model.allocate_tensors()

Load the species label information

In [None]:
device="cpu"

# Load the TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="./outputs/compressed_model.tflite")

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

interpreter.allocate_tensors()
print("tflite model loaded")

In [None]:
headers = ['True_label', 'Pytorch_prediction', 'TF_prediction', 'TFLite_prediction']

f = open('myfile.csv', 'w', newline="")

# create the csv writer
writer = csv.writer(f, delimiter=';')
writer.writerow(headers)

🚨🚨 warning: the next cell will take a while to run 🚨🚨

In [None]:
for image_batch, label_batch in test_dataloader:


    image_batch, label_batch = image_batch.to(
        device, non_blocking=True
    ), label_batch.to(device, non_blocking=True)
    
    for i in range(len(image_batch)):
        s_time = time.time()
        image = image_batch[i]


        # For pytorch model
        outputs_py = model_py(image.unsqueeze(0))
        prediction_py = int(torch.max(outputs_py.data, 1)[1].numpy())
        
        interpreter.set_tensor(input_details[0]['index'], image.unsqueeze(0))
        interpreter.invoke()
        outputs_tf = interpreter.get_tensor(output_details[0]['index'])
        
        prediction_tf = np.squeeze(outputs_tf)
        prediction_tf = int(prediction_tf.argsort()[-1:][::-1])
        
        true_label = int(label_batch[i].numpy())
        print("true: ", true_label, species_list[true_label], ", "
            ", py: ", prediction_py, species_list[prediction_py], ", "
            ", tf: ", prediction_tf, species_list[prediction_tf])
        line = [str(int(true_label)), str(prediction_py), str(prediction_tf)]
        writer.writerow(line)
        
        wandb.log(
            {"training loss": 0, "validation loss": 0, "epoch": i}
        )
        
        wandb.log(
            {
                "train_micro_species_top1": 100,
                "train_micro_genus_top1": 100,
                "train_micro_family_top1": 100,
                "val_micro_species_top1": 100,
                "val_micro_genus_top1": 100,
                "val_micro_family_top1": 100,
                "epoch": i,
            }
        )
        e_time = (time.time() - s_time) / 60  # time taken in minutes
        wandb.log({"time per epoch": e_time, "epoch": i})

wandb.log_artifact("~/Desktop/wandblog", name=mod_name, type="models")

wandb.log({"final micro accuracy": 100})
wandb.log({"final macro accuracy": 100})
wandb.log({"configuration": ""})
wandb.log({"tax accuracy": 100})

wandb.finish()

f.close()