In [None]:
import time
import copy
import random
from params import *
import torch_geometric
import utils.model_utils as m_util
from model_src.demo_functions import *
from utils.misc_utils import RunningStatMeter
from model_src.model_helpers import BookKeeper
from model_src.comp_graph.tf_comp_graph import OP2I
from model_src.comp_graph.tf_comp_graph_models import make_cg_regressor, make_embedding_model
from model_src.predictor.gpi_family_data_manager import FamilyDataManager
from model_src.comp_graph.tf_comp_graph_dataloaders import CGRegressDataLoader
from utils.model_utils import set_random_seed, device, add_weight_decay, get_activ_by_name
from model_src.predictor.model_perf_predictor import train_predictor, run_predictor_demo, train_embedding_model

In [None]:
def prepare_local_params(parser, ext_args=None):
    parser.add_argument("-model_name", required=False, type=str,
                        default="CL_dropout_encoder_model")
    parser.add_argument("-family_train", required=False, type=str,
                        default="nb101"
                        )
    parser.add_argument('-family_test', required=False, type=str,
                        default="nb201c10#50"
                                "+nb301#50"
                                "+ofa_pn#50"
                                "+ofa_mbv3#50"
                                "+ofa_resnet#50"
                                "+hiaml#50"
                                "+inception#50"
                                "+two_path#50")
    parser.add_argument("-dev_ratio", required=False, type=float,
                        default=0.1)
    parser.add_argument("-test_ratio", required=False, type=float,
                        default=0.1)
    parser.add_argument("-epochs", required=False, type=int,
                        default=60)
    parser.add_argument("-fine_tune_epochs", required=False, type=int,
                        default=100)
    parser.add_argument("-batch_size", required=False, type=int,
                        default=64)
    parser.add_argument("-initial_lr", required=False, type=float,
                        default=0.0001)
    parser.add_argument("-in_channels", help="", type=int,
                        default=128, required=False)
    parser.add_argument("-hidden_size", help="", type=int,
                        default=128, required=False)
    parser.add_argument("-out_channels", help="", type=int,
                        default=128, required=False)
    parser.add_argument("-num_layers", help="", type=int,
                        default=6, required=False)
    parser.add_argument("-dropout_prob", help="", type=float,
                        default=0.4, required=False)
    parser.add_argument("-aggr_method", required=False, type=str,
                        default="mean")
    parser.add_argument("-gnn_activ", required=False, type=str,
                        default="relu")
    parser.add_argument("-reg_activ", required=False, type=str,
                        default=None)
    parser.add_argument('-gnn_type', required=False,
                        default="GINConv")
    parser.add_argument("-normalize_HW_per_family", required=False, action="store_true",
                        default=False)
    parser.add_argument('-e_chk', type=str, default=None, required=False)
    return parser.parse_args(ext_args)

In [None]:
def get_family_train_size_dict(args):
    if args is None:
        return {}
    rv = {}
    for arg in args:
        if "#" in arg:
            fam, size = arg.split("#")
        else:
            fam = arg
            size = 0
        rv[fam] = int(float(size))
    return rv

In [None]:
_parser = prepare_global_params()


In [None]:
def _batch_fwd_func(_model, _batch):
        # Define how a batch is handled by the model
        regular_node_inds = _batch[DK_BATCH_CG_REGULAR_IDX]
        regular_node_shapes = _batch[DK_BATCH_CG_REGULAR_SHAPES]
        weighted_node_inds = _batch[DK_BATCH_CG_WEIGHTED_IDX]
        weighted_node_shapes = _batch[DK_BATCH_CG_WEIGHTED_SHAPES]
        weighted_node_kernels = _batch[DK_BATCH_CG_WEIGHTED_KERNELS]
        weighted_node_bias = _batch[DK_BATCH_CG_WEIGHTED_BIAS]
        edge_tsr_list = _batch[DK_BATCH_EDGE_TSR_LIST]
        batch_last_node_idx_list = _batch[DK_BATCH_LAST_NODE_IDX_LIST]
        return _model(regular_node_inds, regular_node_shapes,
                      weighted_node_inds, weighted_node_shapes, weighted_node_kernels, weighted_node_bias,
                      edge_tsr_list, batch_last_node_idx_list)

In [None]:
def gnn_constructor(in_channels, out_channels):
            nn = torch.nn.Sequential(torch.nn.Linear(in_channels, in_channels),
                                     torch.nn.Linear(in_channels, out_channels),
                                     )
            return torch_geometric.nn.GINConv(nn=nn)

In [None]:
model = make_embedding_model(n_unique_labels=len(OP2I().build_from_file()), out_embed_size=128,
                              shape_embed_size=8, kernel_embed_size=8, n_unique_kernels=8, n_shape_vals=6,
                              hidden_size=128, out_channels=128,
                              gnn_constructor=gnn_constructor,
                              gnn_activ=get_activ_by_name("relu"), n_gnn_layers=6,
                              dropout_prob=0.4, aggr_method="mean",
                              regressor_activ=get_activ_by_name(None)).to(device())

In [None]:
model

In [None]:
from utils.model_utils import model_load
 

checkpoint_file = "/home/ec2-user/nas-rec-engine/saved_models/gpi_acc_predictor_CL_dropout_encoder_model_seed262_best.pt"
strict = True
if os.path.isfile(checkpoint_file):
    print("Found checkpoint: {}, loading".format(checkpoint_file))
    sd = model_load(checkpoint_file)
    try:
        model.load_state_dict(sd['model'], strict=strict)
    except Exception:
        # Handles the thop bug
        state_dict = []
        for n, p in sd['model'].items():
            if "total_ops" not in n and "total_params" not in n:
                state_dict.append((n, p))
        model.load_state_dict(dict(state_dict), strict=strict)
    print("Found best_eval_perf: {}, best_eval_iter: {}".format(sd[CHKPT_BEST_EVAL_RESULT],
                                                                    sd[CHKPT_BEST_EVAL_ITERATION]))