# HD-NAS





In [None]:
import torch
import json
import numpy as np
from tqdm.notebook import tqdm
from scipy.stats import kendalltau
device = torch.device("cuda:0")

In [None]:
test_json = json.load(open('./test.json'))
train_json = json.load(open('./train.json'))

#Task information
n_test = 99500
n_train = 500
n_ranks = 8

#HDC dimension
hv_dim = 100000

In [None]:
# HDC Model

def im_gen(hv_dim):
  im_layer = 2 * torch.rand((12, hv_dim)).to(device) - 1
  im_num_heads = 2 * torch.rand((3, hv_dim)).to(device) - 1
  im_mlp_ratio = 2 * torch.rand((3, hv_dim)).to(device) - 1

  im = torch.zeros((12 * 2 * 3, hv_dim)).to(device)
  for layer in range(12):
    for config in range(2):
      for setting in range(3):
        if config == 0:
          im[setting + config * 3 + layer * 2 * 3] += im_layer[layer] * im_num_heads[setting]
        else:
          im[setting + config * 3 + layer * 2 * 3] += im_layer[layer] * im_mlp_ratio[setting]
  return im

# def im_gen(hv_dim):
#   return 2 * torch.rand((72, hv_dim)).to(device) - 1
  
# Encoding
def encoding(target_json, all_hvs, all_weights = None, n_sample = 0, mode = 'encode', am = None, scores = None, bipolar = False, lr = 1):
  arch_idx = 0
  for each_arch in tqdm(target_json):
    arch = list(target_json[each_arch]['arch'][1:]) 
    if mode == 'encode' or mode == 'evaluate':
      weights = [target_json[each_arch]['cplfw_rank'],
                target_json[each_arch]['dukemtmc_rank'],
                target_json[each_arch]['market1501_rank'],
                target_json[each_arch]['msmt17_rank'],
                target_json[each_arch]['sop_rank'],
                target_json[each_arch]['vehicleid_rank'],
                target_json[each_arch]['veri_rank'],
                target_json[each_arch]['veriwild_rank']]
      weights = [1 - 2 * item / n_sample for item in weights]
      weights = torch.from_numpy(np.asarray(weights)).float().to(device) #
    else: 
      weights = all_weights[arch_idx]
    arch_ = [int(arch[item_idx]) for item_idx in range(len(arch)) if item_idx % 3 != 2] 
    arch_onehot = []
    for item in arch_:
      if item == 3:
        arch_onehot.extend([0., 0., 1.])
      elif item == 2:
        arch_onehot.extend([0., 1., 0.])
      elif item == 1:
        arch_onehot.extend([1., 0., 0.])
      elif item == 0:
        arch_onehot.extend([0., 0., 0.])
      else:
        raise ValueError('Wrong value (expect 1, 2, or 3')
    arch_onehot = torch.from_numpy(np.asarray(arch_onehot)).float().to(device) 
    arch_hv = arch_onehot @ im 
    if bipolar == True: 
      arch_hv[arch_hv < -1] = -1
      arch_hv[arch_hv > 1] = 1
    if mode == 'encode':
      all_hvs.append(arch_hv)
      all_weights.append(weights)
    elif mode == 'train':
      for chv_idx in range(len(am)):
        am[chv_idx] += weights[chv_idx] * arch_hv
    elif mode == 'retrain':
      retrain_weights = lr * (weights - scores[arch_idx])
      for chv_idx in range(len(am)):
        am[chv_idx] += retrain_weights[chv_idx] * arch_hv
    elif mode == 'test' or mode == 'evaluate':
      score = torch.cosine_similarity(arch_hv.repeat((n_ranks, 1)), am, dim = 1)
      scores[arch_idx] = score 
    arch_idx += 1

In [None]:
# Training
im = im_gen(hv_dim)
am = torch.rand((n_ranks, hv_dim)).to(device)
scores_train = torch.zeros((n_train, n_ranks)).to(device)
scores_rank = torch.zeros((n_train, n_ranks)).to(device)
bipolar = True

def update_all_scores(scores_train, scores_rank, n_sample): #propagate error back
  scores_rank_ = scores_train.argsort(dim = 0, descending=True).argsort(dim = 0)
  scores_rank_ = 1 - 2 * scores_rank_ / n_sample
  metric_sim = torch.cosine_similarity(all_weights_.flatten(), scores_rank_.flatten(), dim = 0)
  metric_kendall, metric_kendall_p = kendalltau(all_weights_.flatten().cpu(), scores_rank_.flatten().cpu())
  print(metric_sim, metric_kendall)
  return scores_rank_, metric_sim, metric_kendall

all_hvs = []
all_weights = []
encoding(train_json, all_hvs, all_weights = all_weights, n_sample = n_train, mode = 'encode', am = am, bipolar = bipolar)

all_weights_ = [weight.unsqueeze(0) for weight in all_weights]
all_weights_ = torch.concat(all_weights_, dim = 0) #golden weights/ranks
 
encoding(train_json, all_hvs, all_weights = all_weights_, n_sample = n_train, mode = 'train', am = am, bipolar = bipolar)
encoding(train_json, all_hvs, all_weights = all_weights_, n_sample = n_train, mode = 'test', am = am, scores = scores_train, bipolar = bipolar)
scores_rank, _, _ = update_all_scores(scores_train, scores_rank, n_train)


In [None]:
# Retraining
retrain_epochs = 20
for epoch in range(retrain_epochs):
  lr = 0.6 * (retrain_epochs - epoch) / retrain_epochs # decaying lr
  # lr = 1 # constant lr
  encoding(train_json, all_hvs, all_weights = all_weights_, n_sample = n_train, mode = 'retrain', am = am, scores = scores_rank, bipolar = bipolar, lr = lr)
  encoding(train_json, all_hvs, all_weights = all_weights_, n_sample = n_train, mode = 'test', am = am, scores = scores_train, bipolar = bipolar)
  scores_rank, _, _ = update_all_scores(scores_train, scores_rank, n_train)

In [None]:
# Prediction
bipolar = True
scores_test = torch.zeros((n_test, n_ranks))
encoding(test_json, None, all_weights = None, n_sample = n_test, mode = 'evaluate', am = am, scores = scores_test, bipolar = bipolar)
sorted = scores_test.argsort(dim = 0, descending=True).argsort(dim = 0).cpu() #similarity -> rank

In [None]:
# Generating Prediction json
arch_idx = 0
for each_arch in tqdm(test_json):
  test_json[each_arch]['cplfw_rank'] = sorted[arch_idx][0].int().item()
  test_json[each_arch]['dukemtmc_rank'] = sorted[arch_idx][1].int().item()
  test_json[each_arch]['market1501_rank'] = sorted[arch_idx][2].int().item()
  test_json[each_arch]['msmt17_rank'] = sorted[arch_idx][3].int().item()
  test_json[each_arch]['sop_rank'] = sorted[arch_idx][4].int().item()
  test_json[each_arch]['vehicleid_rank'] = sorted[arch_idx][5].int().item()
  test_json[each_arch]['veri_rank'] = sorted[arch_idx][6].int().item()
  test_json[each_arch]['veriwild_rank'] = sorted[arch_idx][7].int().item()
  arch_idx += 1

with open('./prediction.json', 'w') as fp:
  json.dump(test_json, fp)