In [1]:
# This note book take NAS-Bench-201 search space as an example, to show how to use AdaProxy to boost the SRCC (> 0.9) between devices
# with originally low SRCC

In [2]:
# first read and process some latency data, which is stored as a Python list in .pickle file 

import pickle

# all the latency data here are measured and released by recent work

# nn-Meter: https://github.com/microsoft/nn-Meter
# actual latency of Adreno 630 GPU, Adreno 640 GPU, CortexA76 CPU, Myriad VPU
true_630 = pickle.load(open('./true_630.pickle', 'rb'))
true_640 = pickle.load(open('./true_640.pickle', 'rb'))
true_a76 = pickle.load(open('./true_a76.pickle', 'rb'))
true_vpu = pickle.load(open('./true_vpu.pickle', 'rb'))

# HW-NAS-Bench: https://github.com/RICE-EIC/HW-NAS-Bench
# actual latency of Pixel3, EdgeGPU, EdgeTPU, eyeriss, FPGA, Raspi4
rice_pixel3 = pickle.load(open('./rice_pixel3.pickle', 'rb'))
rice_edgegpu = pickle.load(open('./rice_edgegpu.pickle', 'rb'))
rice_edgetpu = pickle.load(open('./rice_edgetpu.pickle', 'rb'))
rice_eyeriss = pickle.load(open('./rice_eyeriss.pickle', 'rb'))
rice_fpga = pickle.load(open('./rice_fpga.pickle', 'rb'))
rice_raspi4 = pickle.load(open('./rice_raspi4.pickle', 'rb'))

# Eagle: https://github.com/zheng-ningxin/brp-nas
# actual latency of GTX, CPU 855, DSP 855, GPU 855, i7, Jetson, Jetson (int16 models)
eagle_gtx = pickle.load(open('./eagle_gtx.pickle', 'rb'))
eagle_cpu_855 = pickle.load(open('./eagle_cpu_855.pickle', 'rb'))
eagle_dsp_855 = pickle.load(open('./eagle_dsp_855.pickle', 'rb'))
eagle_gpu_855 = pickle.load(open('./eagle_gpu_855.pickle', 'rb'))
eagle_i7 = pickle.load(open('./eagle_i7.pickle', 'rb'))
eagle_jetson = pickle.load(open('./eagle_jetson.pickle', 'rb'))
eagle_jetson_16 = pickle.load(open('./eagle_jetson_16.pickle', 'rb'))

In [3]:
# process the latency data
# convert ms to s so that weight is smaller

# nn-Meter latency
true_vpu = [i/1000 for i in true_vpu]
true_630 = [i/1000 for i in true_630]
true_640 = [i/1000 for i in true_640]
true_a76 = [i/1000 for i in true_a76]
# eagle latency
eagle_gtx = [i/1000 for i in eagle_gtx]
eagle_cpu_855 = [i/1000 for i in eagle_cpu_855]
eagle_dsp_855 = [i/1000 for i in eagle_dsp_855]
eagle_gpu_855 = [i/1000 for i in eagle_gpu_855]
eagle_i7 = [i/1000 for i in eagle_i7]
eagle_jetson = [i/1000 for i in eagle_jetson]
eagle_jetson_16 = [i/1000 for i in eagle_jetson_16]
# rice latency
rice_pixel3 = [i/1000 for i in rice_pixel3]
rice_edgegpu = [i/1000 for i in rice_edgegpu]
rice_edgetpu = [i/1000 for i in rice_edgetpu]
rice_eyeriss = [i/1000 for i in rice_eyeriss]
rice_fpga = [i/1000 for i in rice_fpga]
rice_raspi4 = [i/1000 for i in rice_raspi4]

In [4]:
# read the architecture encoding
# for the NAS-Bench-201 models, we use simple one-hot encoding, which is 31-dim

encode = pickle.load(open('./encode.pickle', 'rb'))   # encode is a list stored in .pickle
# because nn-Meter only releases latency of 2000 models, so we actually only use part of NAS-Bench-201
print(len(encode), len(encode[0]))

1999 31


In [14]:
import numpy as np
from scipy import stats

# select proxy and target device
proxy = true_vpu
target = eagle_gtx

# convert list to numpy array
target = np.array(target).reshape(-1, 1)
proxy = np.array(proxy).reshape(-1, 1)

# check the original SRCC between proxy and target
print("="*100)
srcc = stats.spearmanr(proxy, target)[0]
print("SRCC between target and proxy: ", srcc)

if srcc >= 0.9:
    raise Exception("SRCC between target and proxy meets the threshold 0.9, there is no need to do proxy adaptation!")
else:
    print("SRCC between target and proxy does not meet the threshold 0.9, proxy adaptation is needed")
    
    
print("="*100)

# convert list to numpy array
encode = pickle.load(open('./encode.pickle', 'rb'))
encode = np.array(encode)
print(encode.shape)

SRCC between target and proxy:  0.690526859861756
SRCC between target and proxy does not meet the threshold 0.9, proxy adaptation is needed!
(1999, 31)


In [11]:
from sklearn.metrics import mean_squared_error

# directly use linear relationship to get the weight on proxy
weight_linear = np.linalg.pinv(encode).dot(proxy)

# check the linear fitting effect
print("="*100)
print("MSE for linear fitting on proxy:", mean_squared_error(encode.dot(weight_linear), proxy))
print("SRCC between pred and true for linear:", stats.spearmanr(encode.dot(weight_linear), proxy))
print("="*100)

MSE for linear fitting on proxy: 8.778622360177611e-08
SRCC between pred and true for linear: SpearmanrResult(correlation=0.9676540304712693, pvalue=0.0)


In [6]:
# transfer on target
import random
import copy
from functions.solver import Solver   # the AdaProxy algorithm is in Solver class

# number of models used on target
n_val_target = 10
n_train_target = 60

# random selec models
random.seed(1)
index = random.sample(range(encode.shape[0]), n_train_target+n_val_target)

x_train_target = np.zeros(shape=(n_train_target, encode.shape[1]))
y_train_target = np.zeros(shape=(n_train_target, 1))
for i in range(n_train_target):
	x_train_target[i, :] = encode[index[i]]
	y_train_target[i, :] = target[index[i]]

x_val_target = np.zeros(shape=(n_val_target, encode.shape[1]))
y_val_target = np.zeros(shape=(n_val_target, 1))
for i in range(n_val_target):
	x_val_target[i, :] = encode[index[i+n_train_target]]
	y_val_target[i, :] = target[index[i+n_train_target]]
	
x_all = np.vstack((x_train_target, x_val_target))
y_all = np.vstack((y_train_target, y_val_target))

# x and y are all the architecture encoding and target latency
x = copy.deepcopy(encode)
y = copy.deepcopy(target)

# initial weight on proxy
w = copy.deepcopy(weight_linear)

solver = Solver(w, x_train_target, y_train_target, x_val_target, y_val_target, x, y, n_dim=len(encode[0]))

# check the SRCC between training models, validation models, and training & val together
print("="*100)
print("SRCC between y_train: ", stats.spearmanr(x_train_target.dot(w), y_train_target))
print("SRCC between y_val: ", stats.spearmanr(x_val_target.dot(w), y_val_target))
print("SRCC between y_all: ", stats.spearmanr(x_all.dot(w), y_all))
print("="*100)

SRCC between y_train:  SpearmanrResult(correlation=0.6601833842734095, pvalue=9.558357117232713e-09)
SRCC between y_val:  SpearmanrResult(correlation=0.7939393939393938, pvalue=0.0060999233136969115)
SRCC between y_all:  SpearmanrResult(correlation=0.6680255445717784, pvalue=2.656765492968506e-10)


In [7]:
# we use l2 norm
norm = 2

lamb_dic = dict()
max_srcc_val = float('-inf')
max_srcc_all = None
max_srcc = None
max_lamb = None
MSE_val = []
MSE_all = []
MSE = []
SRCC_val = []
SRCC_all = []
SRCC = []
final_lat_val = None
final_lat_all = None
final_lat = None
# lamb is hyperparameter that can be tuned
#lamb_range = np.arange(0, 10.01, 0.01)
lamb_range = np.arange(1, 100000, 0.001)

for lamb in lamb_range:
	estimate_target_train, estimate_target_val, estimate_target_all, estimate_target, srcc_train, srcc_val, srcc_all, srcc = solver.solve(lamb, norm)

	if srcc_val > max_srcc_val:
		max_srcc_val = srcc_val
		max_srcc_all = srcc_all
		max_srcc = srcc
		max_lamb = lamb
		final_lat_val = estimate_target_val
		final_lat_all = estimate_target_all
		final_lat = estimate_target
	print("lamb:", lamb, "srcc_val:", srcc_val, "srcc_all:", srcc_all, "srcc_train:",  srcc_train, "srcc", srcc)
	
	mse_val = mean_squared_error(y_val_target, estimate_target_val)
	MSE_val.append(mse_val)
	mse_all = mean_squared_error(y_all, estimate_target_all)
	MSE_val.append(mse_all)
	mse = mean_squared_error(target, estimate_target)
	MSE.append(mse)
	
	SRCC_val.append(srcc_val)
	SRCC_all.append(srcc_all)
	SRCC.append(srcc)
	
	print("Max SRCC val:", max_srcc_val, "lambda:", max_lamb, "MSE train:",  mean_squared_error(y_train_target, estimate_target_train), "SRCC train:", srcc_train, "MSE val:", mse_val, "SRCC all:", max_srcc_all, "MSE all:", mse_all, "SRCC 2000 models:", max_srcc, "MSE 2000 models:", mse)
	print()

1.0
lamb: 1.0 srcc_val: 0.6969696969696969 srcc_all: 0.946076458752515 srcc_train: 0.9565434843011951 srcc 0.8688365729411253
Max SRCC val: 0.6969696969696969 lambda: 1.0 MSE train: 1.58742488281726e-07 SRCC train: 0.9565434843011951 MSE val: 1.577527844062006e-07 SRCC all: 0.946076458752515 MSE all: 1.586011020137938e-07 SRCC 15625: 0.8688365729411253 MSE 15625: 4.7706686492197e-07

1.001
lamb: 1.001 srcc_val: 0.6969696969696969 srcc_all: 0.946076458752515 srcc_train: 0.9565434843011951 srcc 0.8688365729411253
Max SRCC val: 0.6969696969696969 lambda: 1.0 MSE train: 1.5874249227460823e-07 SRCC train: 0.9565434843011951 MSE val: 1.577528676365699e-07 SRCC all: 0.946076458752515 MSE all: 1.5860111732631707e-07 SRCC 15625: 0.8688365729411253 MSE 15625: 4.770671401059421e-07

1.0019999999999998
lamb: 1.0019999999999998 srcc_val: 0.6969696969696969 srcc_all: 0.946076458752515 srcc_train: 0.9565434843011951 srcc 0.8688365729411253
Max SRCC val: 0.6969696969696969 lambda: 1.0 MSE train: 1.587

1.0219999999999976
lamb: 1.0219999999999976 srcc_val: 0.6969696969696969 srcc_all: 0.946076458752515 srcc_train: 0.9565434843011951 srcc 0.8688357091458643
Max SRCC val: 0.6969696969696969 lambda: 1.0 MSE train: 1.5874232510581095e-07 SRCC train: 0.9565434843011951 MSE val: 1.5775552222621832e-07 SRCC all: 0.946076458752515 MSE all: 1.5860135326586913e-07 SRCC 15625: 0.8688365729411253 MSE 15625: 4.770727307873882e-07

1.0229999999999975
lamb: 1.0229999999999975 srcc_val: 0.6969696969696969 srcc_all: 0.946076458752515 srcc_train: 0.9565434843011951 srcc 0.8688357091458643
Max SRCC val: 0.6969696969696969 lambda: 1.0 MSE train: 1.5874232070190995e-07 SRCC train: 0.9565434843011951 MSE val: 1.5775572015428836e-07 SRCC all: 0.946076458752515 MSE all: 1.5860137776653546e-07 SRCC 15625: 0.8688365729411253 MSE 15625: 4.770730327734665e-07

1.0239999999999974
lamb: 1.0239999999999974 srcc_val: 0.6969696969696969 srcc_all: 0.946076458752515 srcc_train: 0.9565434843011951 srcc 0.868835709145864

1.045999999999995
lamb: 1.045999999999995 srcc_val: 0.6969696969696969 srcc_all: 0.946076458752515 srcc_train: 0.9565434843011951 srcc 0.8688353425962232
Max SRCC val: 0.6969696969696969 lambda: 1.0 MSE train: 1.5874217224141812e-07 SRCC train: 0.9565434843011951 MSE val: 1.5775816251603146e-07 SRCC all: 0.946076458752515 MSE all: 1.5860159942350568e-07 SRCC 15625: 0.8688365729411253 MSE 15625: 4.770781608484377e-07

1.0469999999999948
lamb: 1.0469999999999948 srcc_val: 0.6969696969696969 srcc_all: 0.946076458752515 srcc_train: 0.9565434843011951 srcc 0.8688353425962232
Max SRCC val: 0.6969696969696969 lambda: 1.0 MSE train: 1.5874216057260095e-07 SRCC train: 0.9565434843011951 MSE val: 1.5775829919234432e-07 SRCC all: 0.946076458752515 MSE all: 1.5860160894685002e-07 SRCC 15625: 0.8688365729411253 MSE 15625: 4.770783989072288e-07

1.0479999999999947
lamb: 1.0479999999999947 srcc_val: 0.6969696969696969 srcc_all: 0.946076458752515 srcc_train: 0.9565434843011951 srcc 0.8688353425962232


1.0689999999999924
lamb: 1.0689999999999924 srcc_val: 0.6969696969696969 srcc_all: 0.946076458752515 srcc_train: 0.9565434843011951 srcc 0.8688353365872126
Max SRCC val: 0.6969696969696969 lambda: 1.0 MSE train: 1.587420588920461e-07 SRCC train: 0.9565434843011951 MSE val: 1.5776018151735759e-07 SRCC all: 0.946076458752515 MSE all: 1.5860179069566203e-07 SRCC 15625: 0.8688365729411253 MSE 15625: 4.770821749244082e-07

1.0699999999999923
lamb: 1.0699999999999923 srcc_val: 0.6969696969696969 srcc_all: 0.946076458752515 srcc_train: 0.9565434843011951 srcc 0.8688353365872126
Max SRCC val: 0.6969696969696969 lambda: 1.0 MSE train: 1.5874206485757803e-07 SRCC train: 0.9565434843011951 MSE val: 1.5776027844150127e-07 SRCC all: 0.946076458752515 MSE all: 1.5860180965528129e-07 SRCC 15625: 0.8688365729411253 MSE 15625: 4.770823367763694e-07

1.0709999999999922
lamb: 1.0709999999999922 srcc_val: 0.6969696969696969 srcc_all: 0.946076458752515 srcc_train: 0.9565434843011951 srcc 0.8688353365872126

1.0919999999999899
lamb: 1.0919999999999899 srcc_val: 0.6969696969696969 srcc_all: 0.946076458752515 srcc_train: 0.9565434843011951 srcc 0.8688356340332329
Max SRCC val: 0.6969696969696969 lambda: 1.0 MSE train: 1.5874198560214175e-07 SRCC train: 0.9565434843011951 MSE val: 1.5776172893983782e-07 SRCC all: 0.946076458752515 MSE all: 1.5860194893609833e-07 SRCC 15625: 0.8688365729411253 MSE 15625: 4.770848930082891e-07

1.0929999999999898
lamb: 1.0929999999999898 srcc_val: 0.6969696969696969 srcc_all: 0.946076458752515 srcc_train: 0.9565434843011951 srcc 0.8688356340332329
Max SRCC val: 0.6969696969696969 lambda: 1.0 MSE train: 1.5874198804301735e-07 SRCC train: 0.9565434843011951 MSE val: 1.577617715054464e-07 SRCC all: 0.946076458752515 MSE all: 1.5860195710907866e-07 SRCC 15625: 0.8688365729411253 MSE 15625: 4.770849841458664e-07

1.0939999999999896
lamb: 1.0939999999999896 srcc_val: 0.6969696969696969 srcc_all: 0.946076458752515 srcc_train: 0.9565434843011951 srcc 0.8688356340332329

KeyboardInterrupt: 