In [15]:
import os
import sys
import glob
import numpy as np
import torch
import utils
import logging
import argparse
import torch.nn as nn
import genotypes
import torch.utils
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn

from torch.autograd import Variable
from model import NetworkCIFAR as Network


parser = argparse.ArgumentParser("cifar")
parser.add_argument('--data', type=str, default='../data', help='location of the data corpus')
parser.add_argument('--batch_size', type=int, default=96, help='batch size')
parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--init_channels', type=int, default=36, help='num of init channels')
parser.add_argument('--layers', type=int, default=20, help='total number of layers')
parser.add_argument('--model_path', type=str, default='/data/taoliu/ENNAS_master/CNN/ENNAS_CIFAR_RESULT/cifar10_600.pt', help='path of pretrained model')
parser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower')
parser.add_argument('--cutout', action='store_true', default=False, help='use cutout')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path probability')
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--arch', type=str, default='ENNAS', help='which architecture to use')
args = parser.parse_args([])

log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
    format=log_format, datefmt='%m/%d %I:%M:%S %p')

CLASSES = 10

In [16]:
np.random.seed(args.seed)
torch.cuda.set_device(args.gpu)
cudnn.benchmark = True
torch.manual_seed(args.seed)
cudnn.enabled=True
torch.cuda.manual_seed(args.seed)
logging.info('gpu device = %d' % args.gpu)
logging.info("args = %s", args)
device = torch.device("cuda:0")



genotype = eval("genotypes.%s" % args.arch)
model = Network(args.init_channels, CLASSES, args.layers, args.auxiliary, genotype)
model = torch.nn.DataParallel(model)
model = model.cuda()
model.to(device)

04/14 10:27:08 PM gpu device = 0
04/14 10:27:08 PM args = Namespace(arch='ENNAS', auxiliary=False, batch_size=96, cutout=False, cutout_length=16, data='../data', drop_path_prob=0.2, gpu=0, init_channels=36, layers=20, model_path='/data/taoliu/ENNAS_master/CNN/ENNAS_CIFAR_RESULT/cifar10_600.pt', report_freq=50, seed=0)
108 108 36
108 144 36
144 144 36
144 144 36
144 144 36
144 144 36
144 144 72
144 288 72
288 288 72
288 288 72
288 288 72
288 288 72
288 288 72
288 288 144
288 576 144
576 576 144
576 576 144
576 576 144
576 576 144
576 576 144


DataParallel(
  (module): NetworkCIFAR(
    (stem): Sequential(
      (0): Conv2d(3, 108, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(108, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (cells): ModuleList(
      (0): Cell(
        (preprocess0): ReLUConvBN(
          (op): Sequential(
            (0): ReLU()
            (1): Conv2d(108, 36, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (2): BatchNorm2d(36, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (preprocess1): ReLUConvBN(
          (op): Sequential(
            (0): ReLU()
            (1): Conv2d(108, 36, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (2): BatchNorm2d(36, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (_ops): ModuleList(
          (0): SepConv(
            (op): Sequential(
              (0): ReLU()
              (1): Conv2d(3

In [17]:
torch.save(model.state_dict(), 'test_model.pt')

In [6]:
model.load_state_dict(torch.load('test_model.pt'))

<All keys matched successfully>

In [18]:
new_model = Network(args.init_channels, 2, args.layers, args.auxiliary, genotype)
new_model = torch.nn.DataParallel(new_model)
new_model = new_model.cuda()
new_model.to(device)

108 108 36
108 144 36
144 144 36
144 144 36
144 144 36
144 144 36
144 144 72
144 288 72
288 288 72
288 288 72
288 288 72
288 288 72
288 288 72
288 288 144
288 576 144
576 576 144
576 576 144
576 576 144
576 576 144
576 576 144


DataParallel(
  (module): NetworkCIFAR(
    (stem): Sequential(
      (0): Conv2d(3, 108, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(108, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (cells): ModuleList(
      (0): Cell(
        (preprocess0): ReLUConvBN(
          (op): Sequential(
            (0): ReLU()
            (1): Conv2d(108, 36, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (2): BatchNorm2d(36, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (preprocess1): ReLUConvBN(
          (op): Sequential(
            (0): ReLU()
            (1): Conv2d(108, 36, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (2): BatchNorm2d(36, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (_ops): ModuleList(
          (0): SepConv(
            (op): Sequential(
              (0): ReLU()
              (1): Conv2d(3

In [33]:
pretrained_dict

{'module.stem.0.weight': tensor([[[[-1.4408e-03,  1.0324e-01, -1.5840e-01],
           [-1.4163e-01, -7.4123e-02,  5.1607e-02],
           [-3.8130e-03,  1.5259e-01, -1.7079e-02]],
 
          [[ 5.0925e-02, -5.8161e-02, -3.7829e-02],
           [-1.8386e-01, -1.2746e-01, -7.9332e-02],
           [ 7.1290e-03,  7.6082e-02,  1.1547e-01]],
 
          [[-1.3047e-01, -8.3805e-02,  6.9901e-02],
           [ 1.5981e-01, -3.9606e-02,  1.4401e-01],
           [-3.1020e-02,  2.0364e-02,  1.7426e-01]]],
 
 
         [[[-1.7853e-01, -1.2115e-01, -4.8722e-02],
           [-7.5017e-02,  1.6628e-01, -1.2474e-01],
           [-8.8591e-02, -1.3445e-01, -1.8024e-01]],
 
          [[-1.1234e-01,  1.6543e-01,  8.5875e-02],
           [ 9.3275e-02,  1.0121e-02, -9.8666e-02],
           [ 3.2560e-02, -1.7969e-01, -1.3906e-01]],
 
          [[-9.9214e-02,  1.2142e-01,  1.1284e-01],
           [-8.5351e-02, -6.9441e-03,  1.2308e-01],
           [ 1.9132e-01,  7.6380e-02,  2.5999e-02]]],
 
 
         [[[ 1.2

In [30]:
pretrained_dict=torch.load('test_model.pt')

In [51]:
model_dict=new_model.state_dict()
new_list=list(model_dict.keys())
pre_list=list(pretrained_dict.keys())
for i in range(1340):
    model_dict[new_list[i]]=pretrained_dict[pre_list[i]]
#pretrained_dict={k: v for k, v in model.state_dict().items() if k < 1340}
#model_dict.update(pretrained_dict)
new_model.load_state_dict(model_dict)

<All keys matched successfully>

In [49]:
dict_name=list(model_dict)
for i,p in enumerate(dict_name):
    print(i,p)

0 module.stem.0.weight
1 module.stem.1.weight
2 module.stem.1.bias
3 module.stem.1.running_mean
4 module.stem.1.running_var
5 module.stem.1.num_batches_tracked
6 module.cells.0.preprocess0.op.1.weight
7 module.cells.0.preprocess0.op.2.weight
8 module.cells.0.preprocess0.op.2.bias
9 module.cells.0.preprocess0.op.2.running_mean
10 module.cells.0.preprocess0.op.2.running_var
11 module.cells.0.preprocess0.op.2.num_batches_tracked
12 module.cells.0.preprocess1.op.1.weight
13 module.cells.0.preprocess1.op.2.weight
14 module.cells.0.preprocess1.op.2.bias
15 module.cells.0.preprocess1.op.2.running_mean
16 module.cells.0.preprocess1.op.2.running_var
17 module.cells.0.preprocess1.op.2.num_batches_tracked
18 module.cells.0._ops.0.op.1.weight
19 module.cells.0._ops.0.op.2.weight
20 module.cells.0._ops.0.op.3.weight
21 module.cells.0._ops.0.op.3.bias
22 module.cells.0._ops.0.op.3.running_mean
23 module.cells.0._ops.0.op.3.running_var
24 module.cells.0._ops.0.op.3.num_batches_tracked
25 module.cells

In [41]:
dict_name=list(model.state_dict())
for i,p in enumerate(pretrained_dict):
    print(i,p)

0 module.stem.0.weight
1 module.stem.1.weight
2 module.stem.1.bias
3 module.stem.1.running_mean
4 module.stem.1.running_var
5 module.stem.1.num_batches_tracked
6 module.cells.0.preprocess0.op.1.weight
7 module.cells.0.preprocess0.op.2.weight
8 module.cells.0.preprocess0.op.2.bias
9 module.cells.0.preprocess0.op.2.running_mean
10 module.cells.0.preprocess0.op.2.running_var
11 module.cells.0.preprocess0.op.2.num_batches_tracked
12 module.cells.0.preprocess1.op.1.weight
13 module.cells.0.preprocess1.op.2.weight
14 module.cells.0.preprocess1.op.2.bias
15 module.cells.0.preprocess1.op.2.running_mean
16 module.cells.0.preprocess1.op.2.running_var
17 module.cells.0.preprocess1.op.2.num_batches_tracked
18 module.cells.0._ops.0.op.1.weight
19 module.cells.0._ops.0.op.2.weight
20 module.cells.0._ops.0.op.3.weight
21 module.cells.0._ops.0.op.3.bias
22 module.cells.0._ops.0.op.3.running_mean
23 module.cells.0._ops.0.op.3.running_var
24 module.cells.0._ops.0.op.3.num_batches_tracked
25 module.cells

RuntimeError: Error(s) in loading state_dict for DataParallel:
	size mismatch for module.classifier.weight: copying a param with shape torch.Size([10, 576]) from checkpoint, the shape in current model is torch.Size([2, 576]).
	size mismatch for module.classifier.bias: copying a param with shape torch.Size([10]) from checkpoint, the shape in current model is torch.Size([2]).