In [1]:
import sys

from numpy import isin
# from torch._C import device
import utils
import argparse
import torch.nn as nn
import torch.utils
import torchvision.datasets as dset
import torch
if torch.cuda.is_available():
    import torch.backends.cudnn as cudnn
from collections import namedtuple
from model import NetworkCIFAR as Network
from operations import Conv2d, NSTPConv2d, NConv2d, NLinear
from utils import *
from torch.utils.data.dataset import Subset
import logging
from nasnet_set import *
from tqdm.notebook import tqdm

net = eval('[2, 2, 0, 2, 1, 2, 0, 2, 2, 3, 2, 1, 2, 0, 0, 1, 1, 1, 2, 1, 1, 0, 3, 4, 3, 0, 3, 1]')
# print(net)
code = gen_code_from_list(net, node_num=int((len(net) / 4)))
genotype = translator([code, code], max_node=int((len(net) / 4)))
# print(genotype)

In [3]:
device  = torch.device(f"cuda:{0}" if torch.cuda.is_available() else "cpu")
# torch.cuda.set_device(args.gpu)
if device != torch.device("cpu"):
    cudnn.benchmark = True
    cudnn.enabled = True

# logging.info('gpu device = %d' % args.gpu)
# logging.info("args = %s", args)

model = Network(128, 10, 24, True, genotype)
logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

checkpoint = torch.load("./lanas_128_99.03" + '/top1.pt', map_location="cpu")

In [4]:
state_dict = checkpoint['model_state_dict']
new_state_dict = model.state_dict()
for key in state_dict.keys():
    here = key.split(".")
    this = ""
    for i in here[:-1]:
        this += (i + ".")
    this += ("op." + here[-1])
    if this in new_state_dict:
        new_state_dict[this] = state_dict[key]
    else:
        new_state_dict[key] = state_dict[key]

In [5]:
model.load_state_dict(new_state_dict)
model = model.to(device)
criterion = nn.CrossEntropyLoss().to(device)

In [6]:
def infer(valid_queue, model, criterion, device):

    objs = utils.AverageMeter()
    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()
    # model.eval()

    for step, (x, target) in enumerate(tqdm(valid_queue)):
        x = x.to(device)
        target = target.to(device)

        with torch.no_grad():
            logits, _ = model(x)
            loss = criterion(logits, target)

            prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
            n = x.size(0)
            objs.update(loss.item(), n)
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)

    return top1.avg, objs.avg

In [7]:

CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]

valid_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])

valid_queue = torch.utils.data.DataLoader(
        dset.CIFAR10(root="~/Private/data", train=False, transform=valid_transform),
        batch_size=128, shuffle=True, num_workers=2, pin_memory=True)


model.eval()
model.clear_noise()
# to_save = {}
# state_dict = model.state_dict()
# to_save["model_state_dict"] = state_dict
# torch.save(to_save, "top1")

valid_acc, valid_obj = infer(valid_queue, model, criterion, device)
print(valid_acc)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))




99.04
