Skip to content

Commit

Permalink
support first-order DARTS on the NASNet search space
Browse files Browse the repository at this point in the history
  • Loading branch information
D-X-Y committed Jan 17, 2020
1 parent 56f2161 commit db2760c
Show file tree
Hide file tree
Showing 10 changed files with 213 additions and 13 deletions.
9 changes: 9 additions & 0 deletions configs/search-archs/DARTS-NASNet-CIFAR.config
@@ -0,0 +1,9 @@
{
"super_type" : ["str", "nasnet-super"],
"name" : ["str", "GDAS"],
"C" : ["int", "16" ],
"N" : ["int", "2" ],
"steps" : ["int", "4" ],
"multiplier" : ["int", "4" ],
"stem_multiplier" : ["int", "3" ]
}
13 changes: 13 additions & 0 deletions configs/search-opts/DARTS-NASNet-CIFAR.config
@@ -0,0 +1,13 @@
{
"scheduler": ["str", "cos"],
"LR" : ["float", "0.025"],
"eta_min" : ["float", "0.001"],
"epochs" : ["int", "50"],
"warmup" : ["int", "0"],
"optim" : ["str", "SGD"],
"decay" : ["float", "0.0005"],
"momentum" : ["float", "0.9"],
"nesterov" : ["bool", "1"],
"criterion": ["str", "Softmax"],
"batch_size": ["int", "256"]
}
6 changes: 3 additions & 3 deletions docs/CVPR-2019-GDAS.md
Expand Up @@ -46,13 +46,13 @@ If you want to train the searched architecture found by the above scripts, you n
### Searching on a small search space (NAS-Bench-201)
The GDAS searching codes on a small search space:
```
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/GDAS.sh cifar10 -1
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/GDAS.sh cifar10 1 -1
```

The baseline searching codes are DARTS:
```
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V1.sh cifar10 -1
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V2.sh cifar10 -1
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V1.sh cifar10 1 -1
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V2.sh cifar10 1 -1
```

**After searching**, if you want to train the searched architecture found by the above scripts, please use the following codes:
Expand Down
2 changes: 1 addition & 1 deletion docs/ICCV-2019-SETN.md
Expand Up @@ -32,7 +32,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k SETN

The searching codes of SETN on a small search space (NAS-Bench-201).
```
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/SETN.sh cifar10 -1
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/SETN.sh cifar10 1 -1
```


Expand Down
5 changes: 5 additions & 0 deletions docs/ICLR-2019-DARTS.md
Expand Up @@ -10,6 +10,11 @@ CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V2.sh cifar10 1 -1
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/GDAS.sh cifar10 1 -1
```

**Run the first-order DARTS on the NASNet search space**:
```
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/DARTS1V-search-NASNet-space.sh cifar10 -1
```

# Citation

```
Expand Down
15 changes: 10 additions & 5 deletions exps/algos/DARTS-V1.py
Expand Up @@ -112,10 +112,14 @@ def main(xargs):
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))

search_space = get_search_spaces('cell', xargs.search_space_name)
model_config = dict2config({'name': 'DARTS-V1', 'C': xargs.channel, 'N': xargs.num_cells,
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
'space' : search_space,
'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None)
if xargs.model_config is None:
model_config = dict2config({'name': 'DARTS-V1', 'C': xargs.channel, 'N': xargs.num_cells,
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
'space' : search_space,
'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None)
else:
model_config = load_config(xargs.model_config, {'num_classes': class_num, 'space' : search_space,
'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None)
search_model = get_cell_based_tiny_net(model_config)
logger.log('search-model :\n{:}'.format(search_model))

Expand Down Expand Up @@ -213,12 +217,13 @@ def main(xargs):
parser.add_argument('--data_path', type=str, help='Path to dataset')
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
# channels and number-of-cells
parser.add_argument('--config_path', type=str, help='The config path.')
parser.add_argument('--search_space_name', type=str, help='The search space name.')
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
parser.add_argument('--channel', type=int, help='The number of channels.')
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
parser.add_argument('--track_running_stats',type=int, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.')
parser.add_argument('--config_path', type=str, help='The config path.')
parser.add_argument('--model_config', type=str, help='The path of the model configuration. When this arg is set, it will cover max_nodes / channels / num_cells.')
# architecture leraning rate
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')
Expand Down
4 changes: 3 additions & 1 deletion lib/models/cell_searchs/__init__.py
Expand Up @@ -10,6 +10,7 @@
from .genotypes import Structure as CellStructure, architectures as CellArchitectures
# NASNet-based macro structure
from .search_model_gdas_nasnet import NASNetworkGDAS
from .search_model_darts_nasnet import NASNetworkDARTS


nas201_super_nets = {'DARTS-V1': TinyNetworkDarts,
Expand All @@ -19,4 +20,5 @@
'ENAS' : TinyNetworkENAS,
'RANDOM' : TinyNetworkRANDOM}

nasnet_super_nets = {'GDAS' : NASNetworkGDAS}
nasnet_super_nets = {'GDAS' : NASNetworkGDAS,
'DARTS': NASNetworkDARTS}
24 changes: 21 additions & 3 deletions lib/models/cell_searchs/search_cells.py
Expand Up @@ -131,10 +131,12 @@ def __init__(self, space, C, stride, affine, track_running_stats):
op = OPS[primitive](C, C, stride, affine, track_running_stats)
self._ops.append(op)

def forward(self, x, weights, index):
#return sum(w * op(x) for w, op in zip(weights, self._ops))
def forward_gdas(self, x, weights, index):
return self._ops[index](x) * weights[index]

def forward_darts(self, x, weights):
return sum(w * op(x) for w, op in zip(weights, self._ops))


# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
class NASNetSearchCell(nn.Module):
Expand Down Expand Up @@ -173,7 +175,23 @@ def forward_gdas(self, s0, s1, weightss, indexs):
op = self.edges[ node_str ]
weights = weightss[ self.edge2index[node_str] ]
index = indexs[ self.edge2index[node_str] ].item()
clist.append( op(h, weights, index) )
clist.append( op.forward_gdas(h, weights, index) )
states.append( sum(clist) )

return torch.cat(states[-self._multiplier:], dim=1)

def forward_darts(self, s0, s1, weightss):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)

states = [s0, s1]
for i in range(self._steps):
clist = []
for j, h in enumerate(states):
node_str = '{:}<-{:}'.format(i, j)
op = self.edges[ node_str ]
weights = weightss[ self.edge2index[node_str] ]
clist.append( op.forward_darts(h, weights) )
states.append( sum(clist) )

return torch.cat(states[-self._multiplier:], dim=1)
107 changes: 107 additions & 0 deletions lib/models/cell_searchs/search_model_darts_nasnet.py
@@ -0,0 +1,107 @@
####################
# DARTS, ICLR 2019 #
####################
import torch
import torch.nn as nn
from copy import deepcopy
from .search_cells import NASNetSearchCell as SearchCell
from .genotypes import Structure


# The macro structure is based on NASNet
class NASNetworkDARTS(nn.Module):

def __init__(self, C, N, steps, multiplier, stem_multiplier, num_classes, search_space, affine, track_running_stats):
super(NASNetworkDARTS, self).__init__()
self._C = C
self._layerN = N
self._steps = steps
self._multiplier = multiplier
self.stem = nn.Sequential(
nn.Conv2d(3, C*stem_multiplier, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C*stem_multiplier))

# config for each layer
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * (N-1) + [C*4 ] + [C*4 ] * (N-1)
layer_reductions = [False] * N + [True] + [False] * (N-1) + [True] + [False] * (N-1)

num_edge, edge2index = None, None
C_prev_prev, C_prev, C_curr, reduction_prev = C*stem_multiplier, C*stem_multiplier, C, False

self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
cell = SearchCell(search_space, steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, affine, track_running_stats)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
self.cells.append( cell )
C_prev_prev, C_prev, reduction_prev = C_prev, multiplier*C_curr, reduction
self.op_names = deepcopy( search_space )
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_normal_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
self.arch_reduce_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )

def get_weights(self):
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
xlist+= list( self.classifier.parameters() )
return xlist

def get_alphas(self):
return [self.arch_normal_parameters, self.arch_reduce_parameters]

def show_alphas(self):
with torch.no_grad():
A = 'arch-normal-parameters :\n{:}'.format( nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu() )
B = 'arch-reduce-parameters :\n{:}'.format( nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu() )
return '{:}\n{:}'.format(A, B)

def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string

def extra_repr(self):
return ('{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))

def genotype(self):
def _parse(weights):
gene = []
for i in range(self._steps):
edges = []
for j in range(2+i):
node_str = '{:}<-{:}'.format(i, j)
ws = weights[ self.edge2index[node_str] ]
for k, op_name in enumerate(self.op_names):
if op_name == 'none': continue
edges.append( (op_name, j, ws[k]) )
edges = sorted(edges, key=lambda x: -x[-1])
selected_edges = edges[:2]
gene.append( tuple(selected_edges) )
return gene
with torch.no_grad():
gene_normal = _parse(torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy())
gene_reduce = _parse(torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy())
return {'normal': gene_normal, 'normal_concat': list(range(2+self._steps-self._multiplier, self._steps+2)),
'reduce': gene_reduce, 'reduce_concat': list(range(2+self._steps-self._multiplier, self._steps+2))}

def forward(self, inputs):

normal_w = nn.functional.softmax(self.arch_normal_parameters, dim=1)
reduce_w = nn.functional.softmax(self.arch_reduce_parameters, dim=1)

s0 = s1 = self.stem(inputs)
for i, cell in enumerate(self.cells):
if cell.reduction: ww = reduce_w
else : ww = normal_w
s0, s1 = s1, cell.forward_darts(s0, s1, ww)
out = self.lastact(s1)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)

return out, logits
41 changes: 41 additions & 0 deletions scripts-search/DARTS1V-search-NASNet-space.sh
@@ -0,0 +1,41 @@
#!/bin/bash
# bash ./scripts-search/DARTS1V-search-NASNet-space.sh cifar10 -1
echo script name: $0
echo $# arguments
if [ "$#" -ne 2 ] ;then
echo "Input illegal number of parameters " $#
echo "Need 2 parameters for dataset, and seed"
exit 1
fi
if [ "$TORCH_HOME" = "" ]; then
echo "Must set TORCH_HOME envoriment variable for data dir saving"
exit 1
else
echo "TORCH_HOME : $TORCH_HOME"
fi

dataset=$1
BN=1
seed=$2
channel=16
num_cells=5
max_nodes=4
space=darts

if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
data_path="$TORCH_HOME/cifar.python"
else
data_path="$TORCH_HOME/cifar.python/ImageNet16"
fi

save_dir=./output/search-cell-${space}/DARTS-V1-${dataset}-BN${BN}

OMP_NUM_THREADS=4 python ./exps/algos/DARTS-V1.py \
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
--dataset ${dataset} --data_path ${data_path} \
--search_space_name ${space} \
--config_path configs/search-opts/DARTS-NASNet-CIFAR.config \
--model_config configs/search-archs/GDAS-NASNet-CIFAR.config \
--track_running_stats ${BN} \
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
--workers 4 --print_freq 200 --rand_seed ${seed}

0 comments on commit db2760c

Please sign in to comment.