In [10]:
import os
import random
import requests
import torch
from torchvision import models
from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
import timeit
import numpy as np
import tvm
from tvm import relay, autotvm
import tvm.relay.testing
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
import tvm.contrib.graph_executor as runtime
from tvm.contrib import graph_executor
import tvm.runtime
import pickle
import sqlite3
from pathlib import Path
from torchvision import transforms
import onnx
from transformers import AutoImageProcessor, FlaxResNetForImageClassification


In [2]:
transform = transforms.Compose([            
     transforms.Resize(256),                    
     transforms.CenterCrop(224),                
     transforms.ToTensor(),                     
     transforms.Normalize(                      
     mean=[0.485, 0.456, 0.406],                
     std=[0.229, 0.224, 0.225]                  
)])

In [3]:
def display_images_with_labels(imgs, labels):
    if len(imgs.shape) == 3:  # If there's only one image, reshape it to add a batch dimension
        imgs = imgs[np.newaxis, :]
        labels = [labels]

    num_images = len(imgs)
    fig, axes = plt.subplots(1, num_images, figsize=(12, 4))

    # If there's only one image, `axes` is not a list. Convert it to a list for consistency.
    if num_images == 1:
        axes = [axes]

    for i, (img, label) in enumerate(zip(imgs, labels)):
        img = img.squeeze(0)  # Remove the batch dimension if it exists
        img = np.transpose(img, (1, 2, 0))  # Change the dimension order from CxHxW to HxWxC
        img = img - img.min()
        img = img / img.max()

        axes[i].imshow(img)
        axes[i].axis('off')
        axes[i].set_title(label, fontsize=10, pad=5)  # Display label on top of the image

    plt.show()
    
def load_random_images(batch_size):
    directory = "/home1/public/misampson/dataset/ILSVRC2015/Data/DET/test"
    files = os.listdir(directory)
    image_files = [f for f in files if f.endswith('.JPEG')]

    if not image_files:
        print("No image files found in the directory.")
        return None
    
    imgs = []
    chosen_image_files = []
    for _ in range(batch_size):
        random_image = random.choice(image_files)
        img_path = os.path.join(directory, random_image)
        chosen_image_files.append(img_path)  # Append the chosen image file path
        img = Image.open(img_path).convert("RGB")  # Convert to RGB format
        img_reshape = img.resize((224, 224))
        img_t = transform(img_reshape)
        imgs.append(img_t)
    
    imgs = torch.stack(imgs)
    
    with open("image_files.txt", "w") as f:
        f.write("\n".join(chosen_image_files))
    
    return imgs

def get_images():
    directory = "/home1/public/misampson/dataset/ILSVRC2015/Data/DET/test"
    file_path = "image_files.txt"  # Changed to the relative path of image_files.txt
    with open(file_path, "r") as f:
        image_files = f.read().splitlines()
    
    imgs = []
    for image_file in image_files:
        img = Image.open(image_file).convert("RGB")  # Load the image using the file path
        img_reshape = img.resize((224, 224))
        img_t = transform(img_reshape)
        imgs.append(img_t)
    
    imgs = torch.stack(imgs)
    return imgs

In [4]:
def prediction_to_class(predictions):
    with open('imagenet_classes.txt') as f:
        classes = [line.strip() for line in f.readlines()]

    synsets_to_names = {}
    with open('imagenet_synsets.txt') as f:
        for line in f:
            parts = line.strip().split(' ', 1)
            synsets_to_names[parts[0]] = parts[1]

    batch_classes = []
    for prediction in predictions:
        class_name = synsets_to_names[classes[prediction]]
        batch_classes.append(class_name)

    return batch_classes


In [5]:
def timit(func, *args, **kwargs):
    timing_number = 10
    timing_repeat = 10
    
    warmup_results = timeit.repeat(lambda: func(*args, **kwargs), repeat=timing_repeat, number=timing_number)
    timing_results = timeit.repeat(lambda: func(*args, **kwargs), repeat=timing_repeat, number=timing_number)
    
    timing_summary = {
        "mean": sum(timing_results) / len(timing_results),
        "median": sorted(timing_results)[len(timing_results)//2],
        "std": np.std(timing_results),
    }
    
    print("Timing Summary:")
    print(timing_summary)
    return timing_summary


In [6]:
def model_exists_in_db(device, network, batch_size, db_path):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    
    query = '''
    SELECT COUNT(*) FROM device_models
    WHERE device = ? AND model = ? AND batch_size = ?
    '''
    
    cursor.execute(query, (device, network, batch_size))
    result = cursor.fetchone()
    
    conn.close()
    
    return result[0] > 0

In [7]:
def fetch_lib():
    filepath = f'/home1/public/misampson/resnet-50/git/ITE-Forth-CARV/tvm_report/automated_database/{device}/{network}/{batch_size}'
    print("Load lib from database...")
    so_files = list(Path(filepath).glob('*.so'))
    if not so_files:
        raise ValueError(f"No .so files found in {filepath}")
    path_lib = str(so_files[0])
    lib = tvm.runtime.load_module(path_lib)
    
    return lib

In [13]:
def get_model_params(name, batch_size, input_data):
    input_shape = (batch_size, 3, 224, 224)
    dtype = "float32"

    if "resnet" in name:
        
        shape_list = [('data', input_shape)]
        torch_model = models.resnet18(weights='ResNet18_Weights.IMAGENET1K_V1').eval()
        scripted_model = torch.jit.trace(torch_model, input_data).eval()
        mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
    elif name == "vit":
        model_path = "/home1/public/misampson/resnet-50/git/ITE-Forth-CARV/tvm_report/model.onnx"
        onnx_model = onnx.load(model_path)
        input_names = [input.name for input in onnx_model.graph.input]
        print("Input names in ONNX model:", input_names)

        shape_dict = {'pixel_values': input_shape}  # Corrected line
        mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)

        return mod, params
    else:
        raise ValueError("Unsupported network: " + name)

    return mod, params

In [14]:
def run_module(mod):
    mod.run()
    return mod

from tvm.ir import IRModule
def create_module(tuned_lib, imgs, device, network, target):
    dtype = "float32"
    
    if(device=="cuda"):
        dev = tvm.cuda(0)
    elif(device=="llvm"):
        dev = tvm.cpu()
    mod,params = get_model_params( network, batch_size, imgs)
    
    with tvm.transform.PassContext(opt_level=3):
        if isinstance(tuned_lib, (IRModule, _function.Function)):
            lib = relay.build_module.build(tuned_lib, target=target, params=params)
    
    module = graph_executor.GraphModule(lib["default"](dev))

    images = np.array(imgs).astype(dtype)
    input_name="data"
    if(network == "vit"):
        input_name = "pixel_values"
    module.set_input(input_name, tvm.nd.array(images.astype("float32")), **params)
    # module.set_input(input_name, tvm.nd.array(images.astype("float32")))
    
    mod = run_module(module)
    output = mod.get_output(0).asnumpy()

    print("Max values per output:", np.max(output, axis=1))
    prediction = np.argmax(output, axis=1)
    print("Predictions:", prediction)

    classes = prediction_to_class(prediction)
    return classes, module



In [15]:
if __name__ == "__main__":
    # Load random images
    batch_size = 2
    target = tvm.target.Target("cuda")
    device = "cuda"
    # network = "resnet-18"
    network = "vit"
    # network = "vgg-11"
    # network = "inception_v3"
    #network = "squeezenet_v1.1"
    
    db_path = '/home1/public/misampson/resnet-50/git/ITE-Forth-CARV/tvm_report/automate_tvm.db'
    
    # Check if model exists in the database
    if model_exists_in_db(device, network, batch_size, db_path):
        lib= fetch_lib()

        imgs = load_random_images(batch_size)
        
        classes,module = create_module(lib, imgs, device, network, target)
        print(lib)

        display_images_with_labels(imgs, classes)

        timing_summary = timit(run_module, module)
    
    else:
        print("Model does not exist in the database or path is missing.")

Load lib from database...
Input names in ONNX model: ['pixel_values']


NameError: name '_function' is not defined