In [None]:
import os, sys;
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"
sys.path.append(os.path.join(os.getcwd(),'./'))

import json
import numpy as np
import pandas as pd
import xgboost as xgb
from copy import deepcopy
from tqdm import tqdm

from ofa.utils import list_mean
from ofa.imagenet_classification.data_providers.imagenet import ImagenetDataProvider
from ofa.imagenet_classification.run_manager import ImagenetRunConfig, RunManager
from ofa.model_zoo import ofa_net
from ofa.nas.efficiency_predictor import MBv3LatencyTable
from ofa.nas.accuracy_predictor.arch_encoder import MobileNetArchEncoder

In [None]:
__all__ = ['net_setting2id', 'net_id2setting', 'AccuracyDataset']

def net_setting2id(net_setting):
  return json.dumps(net_setting)

def net_id2setting(net_id):
  return json.loads(net_id)

class AccuracyDataset:
  def __init__(self, flops_constraint):
    self.image_size = 236
    batch_size = 100 # Adjust Test Batch Size here
    ImagenetDataProvider.DEFAULT_PATH = '/ssd1/DATASET/nasbench201/imagenet' # Set Full ImageNet Path here
    self.ofa_network = ofa_net('ofa_mbv3_d234_e346_k357_w1.2', pretrained=True)
    self.run_config = ImagenetRunConfig(train_batch_size=batch_size,
              test_batch_size=batch_size, 
              n_worker=8,
              image_size=self.image_size)
    self.run_manager = RunManager('.tmp/', self.ofa_network, self.run_config, init=False)
    self.run_manager.run_config.data_provider.assign_active_img_size(self.image_size)
    self.flops_constraint = flops_constraint

  def random_sample(self, num, net_setting_list=None):
    if net_setting_list is None:
      net_id_list = set()
      self.flops_dict = dict()
      while len(net_id_list) < num:
        net_setting = self.ofa_network.sample_active_subnet()
        net_config = self.ofa_network.get_active_net_config()
        flops = int(MBv3LatencyTable.count_flops_given_config(net_config, image_size=self.image_size))
        net_setting_full = {**net_setting, 'r': self.image_size, 'flops': flops}
        if flops >= self.flops_constraint[0] and flops <= self.flops_constraint[1]:
          net_id = net_setting2id(net_setting_full)
          net_id_list.add(net_id)
          self.flops_dict[net_id] = flops
      net_id_list = list(net_id_list)
      net_id_list.sort()
      return net_id_list
    else:
      net_id_list = set()
      self.flops_dict = dict()
      for net_setting in net_setting_list:
        net_id = net_setting2id(net_setting)
        self.ofa_network.set_active_subnet(**net_setting)
        net_config = self.ofa_network.get_active_net_config()
        flops = int(MBv3LatencyTable.count_flops_given_config(net_config, image_size=self.image_size))
        if flops >= self.flops_constraint[0] and flops <= self.flops_constraint[1]:
          self.flops_dict[net_id] = flops
          net_id_list.add(net_id)
      net_id_list = list(net_id_list)
      net_id_list.sort()
      return net_id_list

  def get_acc(self, net_id):
#     flops = self.flops_dict[net_id]
#     net_setting = net_id2setting(net_id)
#     key = net_setting2id({**net_setting, 'flops': flops})
#     key = net_id
    net_setting = net_id2setting(net_id)
    net_setting_str = ','.join(['%s_%s' % (key, '%.1f' % list_mean(val) if isinstance(val, list) else val) for key, val in net_setting.items()])
    acc_dict = dict()
    self.ofa_network.set_active_subnet(**net_setting)
    self.run_manager.reset_running_statistics(self.ofa_network)
    loss, (top1, top5) = self.run_manager.validate(is_test=True, run_str=net_setting_str, net=self.ofa_network, data_loader=None, no_logs=True)
    metric = (top1, top5)
    flops = net_id2setting(net_id)['flops']
    print(f'net: {net_setting_str}, r:{self.image_size}, flops:{flops}, top1:{metric[0]}, top5:{metric[1]}')
    acc_dict.update({
     net_id: metric
    })
    return metric[0]

In [None]:
accuracy_dataset = AccuracyDataset(flops_constraint=[550, 600])
net_id_list = accuracy_dataset.random_sample(num=10000)
acc_array = []
arch_array = []
flops_array = []
arch_raw = []
for net_id in net_id_list:
  arch_raw.append(net_id)
  key_dict = json.loads(net_id)
#   acc = accuracy_dataset.get_acc(net_id)
  arch = key_dict['ks']+key_dict['e']+key_dict['d']
  acc = -1
  flops = key_dict['flops']
  arch_array.append(arch)
  acc_array.append(acc)
  flops_array.append(flops)
  
arch_array = np.array(arch_array)
acc_array = np.array(acc_array)
flops_array = np.array(flops_array)

In [None]:
def run_traj(accuracy_dataset, ARCH_RAW, ARCH, LABEL, seed,
             num_sample_train_list, num_sample_method,
             top_list_acc, keep_old='all', n_trees=1000, max_depth=20):
  
  num_iteration = len(num_sample_train_list)
  arch_dict_data = {}
  if num_sample_method == 'uniform':
    wgt_list = [1 / num_iteration for i in range(num_iteration)]
  else:
    assert ValueError
  wgt_list_sum = sum(wgt_list)
  wgt_list = [w / float(wgt_list_sum) for w in wgt_list]
  all_index = np.arange(len(LABEL))
  all_index_selected = deepcopy(all_index)
  keep_index_train = np.array([]).astype(np.int)
  df_dict_all = pd.DataFrame()
  for i in range(arch_array.shape[-1]):
    df_dict_all[f'arch_{i + 1}'] = [float(data[i]) for data in ARCH]
  sample_list = []
  sample_space_list = []
  sample_space_exlude_list = []
  predict_best_list = []
  acc_pred_list = []
  acc_gt_list = []
  acc_best_list = []
  train_acc_dict = {}
  for z, (num_sample_train, top_acc) in enumerate(zip(num_sample_train_list, top_list_acc)):
    acc_gt_list.append(LABEL)
    if len(all_index_selected) == 0:
      break  
    if z == 0:
      all_index_sample = all_index_selected
      np.random.seed(seed+z)
      train_index = np.random.choice(all_index_sample, size=min(num_sample_train, len(all_index_sample)), replace=False)
      sample_space_list.append(all_index_selected)
      sample_space_exlude_list.append(all_index_selected)
      sample_list.append(train_index)
    else:
      sample_space_list.append(all_index_selected)
      if keep_index_train.size != 0:
        all_index_sample = all_index_selected[~np.isin(all_index_selected, keep_index_train)]
      else:
        all_index_sample = all_index_selected
#       print(f'{z} sample space: {len(all_index_selected)}, sample space actual: {len(all_index_sample)}')
      if num_sample_train == 0:
        break
      np.random.seed(seed+z+100)
      train_index_sample = np.random.choice(all_index_sample, size=min(num_sample_train, len(all_index_sample)),
                                            replace=False)
      if keep_index_train.size != 0:
        train_index = np.concatenate((keep_index_train, train_index_sample))
      else:
        train_index = train_index_sample
        
      sample_space_exlude_list.append(all_index_sample)
      sample_list.append(train_index_sample)
    
    assert len(train_index) != 0
    test_index = np.setdiff1d(all_index, train_index)
    assert len(test_index) + len(train_index) == len(all_index)
    assert len(test_index) != 0
    params = {'booster': 'gbtree',
              'max_depth': max_depth,
              'objective': 'reg:squarederror'}
    label_list = []
    for index in train_index:
      if index in train_acc_dict.keys():
        label_list.append(train_acc_dict[index])
      else:
        net_id = ARCH_RAW[index]
        acc = accuracy_dataset.get_acc(net_id)
        label_list.append(acc)
        train_acc_dict[index] = acc
    dtrain = xgb.DMatrix(data=df_dict_all.iloc[train_index], label=label_list)
    dall = xgb.DMatrix(data=df_dict_all, label=LABEL)
    bst = xgb.train(params=params, dtrain=dtrain, num_boost_round=n_trees)
    pred_all = bst.predict(dall)
    index_optimal_pred = np.argmax(pred_all)
    acc_pred_list.append(deepcopy(pred_all))
    predict_best_list.append(index_optimal_pred)
    all_index_by_acc = np.array([i for (i, _) in sorted(zip(np.arange(len(pred_all)), pred_all), key=lambda pair: pair[-1], reverse=True)]).astype(np.int)
    dist_to_label = np.abs(pred_all-LABEL)
    all_index_selected = all_index_by_acc[:top_acc]
    if keep_old == 'none':
      keep_index_train = np.array([]).astype(np.int)
    elif keep_old == 'top':
      keep_index_train = np.array([i for i in all_index_selected if i in train_index]).astype(np.int)
    elif keep_old == 'all':
      keep_index_train = train_index.astype(np.int)
    del bst
    acc_pred_best = LABEL[index_optimal_pred]
    acc_best_list.append(acc_pred_best)
  if index_optimal_pred in train_acc_dict.keys():
    best_acc = train_acc_dict[index]
  else:
    net_id = ARCH_RAW[index_optimal_pred]
    best_acc = accuracy_dataset.get_acc(net_id)
  return best_acc, ARCH_RAW[index_optimal_pred]


In [None]:
ARCH = arch_array
LABEL = acc_array
ARCH_RAW = arch_raw
num_sample_each_iterations = 100 # Set Number of Samples each Iterations
num_iterations = 10 # Set Number of Iterations, 800 Queries: Set to 8, 1000 Queries: Set to 10
repeat = 5
top = 1000
num_sample_train_list = [num_sample_each_iterations]*num_iterations
top_list_acc = [top]*num_iterations
seed = 0
acc_list = []
best_net_id_list = []
for seed in tqdm(range(repeat)):
  acc, best_net_id = run_traj(accuracy_dataset, ARCH_RAW, ARCH, LABEL, seed,
           num_sample_train_list=num_sample_train_list, num_sample_method='uniform',
           top_list_acc=top_list_acc, keep_old='all', n_trees=1000, max_depth=20)
  acc_list.append(acc)
  best_net_id_list.append(best_net_id)
  net_setting = net_id2setting(best_net_id)
#     net_setting_str = ','.join(['%s_%s' % (key, '%.1f' % list_mean(val) if isinstance(val, list) else val) for key, val in net_setting.items()])
  print(f'Queries: {np.sum(num_sample_train_list)}, Best Arch: {net_setting}, Best SuperNet Acc: {acc}')
acc_mean = np.mean(acc_list)
acc_std = np.std(acc_list)
print(f'Top {num_sample_each_iterations}/{top}, SuperNet Acc: {acc_mean:0.4f} +_({acc_std:0.4f})')