In [1]:
def get_arch_str_from_arch_vector(arch_vector):
    _opname_to_index = {
        'none': 0,
        'skip_connect': 1,
        'nor_conv_1x1': 2,
        'nor_conv_3x3': 3,
        'avg_pool_3x3': 4,
        'input': 5,
        'output': 6,
        'global': 7
    }
    _opindex_to_name = {value: key for key, value in _opname_to_index.items()}
    ops = [_opindex_to_name[opindex] for opindex in arch_vector]
    return '|{}~0|+|{}~0|{}~1|+|{}~0|{}~1|{}~2|'.format(*ops)

# arch_str = get_arch_str_from_arch_vector([1, 2, 0, 2, 4, 1])             # 架构向量转字符串

# index = api.query_index_by_arch(arch_str)
# cost_info = api.get_cost_info(index, dataset='cifar10-valid')  # 或 'cifar10', 'cifar100'

# flops = cost_info['flops']      # 单位: M
# params = cost_info['params']    # 单位: MB
# flops, params

In [None]:
import pickle
import numpy as np
import os
from collections import defaultdict

# -------------------------------
# 1. 读取 pkl 文件
with open('/home/rtx4090/code/python/current/LightNAS/data/nasbench201/pkl/desktop-gpu-gtx-1080ti-large.pkl', 'rb') as f:
    df = pickle.load(f)

from nas_201_api import NASBench201API as API
api = API('NAS-Bench-201-v1_0-e61699.pth', verbose=False)
# -------------------------------
# 2. 扫描所有架构，收集 flops 和 params
raw_records = []     # [(key, flops, params)]

for i in range(len(df)):
    key = df[i, :-1].astype(np.int32)
    try:
        arch_str = get_arch_str_from_arch_vector(key)
        index = api.query_index_by_arch(arch_str)
        cost_info = api.get_cost_info(index, dataset='cifar10-valid')
        flops = cost_info['flops']
        params = cost_info['params']
        acc_info = api.get_more_info(
            index, 
            dataset='cifar10-valid', 
            iepoch=None, 
            hp='200', 
            is_random=False
        )
        # 提取验证集准确率（根据数据集不同，key 可能不同，需对应调整）
        test_accuracy = acc_info['test-accuracy']
        raw_records.append((tuple(key), flops, params))
    except Exception as e:
        print(f"[Warning] Skipped index {i} due to: {e}")

# -------------------------------
# 3. 标准化（z-score）
all_flops = np.array([x[1] for x in raw_records])
all_params = np.array([x[2] for x in raw_records])

flops_mean, flops_std = all_flops.mean(), all_flops.std()
params_mean, params_std = all_params.mean(), all_params.std()

# -------------------------------
# 4. 构建归一化字典
norm_dict = {}

for key, flops, params in raw_records:
    flops_norm = (flops - flops_mean) / flops_std
    params_norm = (params - params_mean) / params_std
    norm_dict[key] = [flops_norm, params_norm]

# -------------------------------
# 5. 示例查看
print(f"标准化后样例：\n{list(norm_dict.items())[:3]}")

  file_path_or_dict = torch.load(file_path_or_dict, map_location='cpu')


标准化后样例：
[((0, 0, 0, 2, 2, 4), [-1.1737815449198306, -1.1711545212919863]), ((1, 0, 0, 2, 2, 4), [-1.1737815449198306, -1.1711545212919863]), ((2, 0, 0, 2, 2, 4), [-1.0578910710722402, -1.0509160147744498])]


In [None]:
import os

current_folder = os.getcwd()                      # 当前工作目录
parent_folder = os.path.dirname(current_folder)   # 父目录
parent_folder = os.path.dirname(parent_folder)   # 爷目录

print("当前目录:", current_folder)
print("父目录:", parent_folder)
with open(f'{parent_folder}/nasbench201_all_flops_parameter.pkl', 'wb') as f:
    pickle.dump(norm_dict, f)

当前目录: /home/rtx4090/code/python/current/LightNAS/data/nasbench201/others
父目录: /home/rtx4090/code/python/current/LightNAS/datasets


In [7]:
index = api.query_info_str_by_arch(arch_str)
index

"|avg_pool_3x3~0|+|avg_pool_3x3~0|avg_pool_3x3~1|+|none~0|nor_conv_3x3~1|none~2|\ndatasets : ['cifar10-valid'], extra-info : arch-index=14445\ncifar10-valid  FLOP= 43.17 M, Params=0.316 MB, latency=None ms.\ncifar10-valid  train : [loss = 1.272, top1 = 53.60%], valid : [loss = 1.314, top1 = 52.27%]"