Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
support first-order DARTS on the NASNet search space
- Loading branch information
Showing
10 changed files
with
213 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
{ | ||
"super_type" : ["str", "nasnet-super"], | ||
"name" : ["str", "GDAS"], | ||
"C" : ["int", "16" ], | ||
"N" : ["int", "2" ], | ||
"steps" : ["int", "4" ], | ||
"multiplier" : ["int", "4" ], | ||
"stem_multiplier" : ["int", "3" ] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
{ | ||
"scheduler": ["str", "cos"], | ||
"LR" : ["float", "0.025"], | ||
"eta_min" : ["float", "0.001"], | ||
"epochs" : ["int", "50"], | ||
"warmup" : ["int", "0"], | ||
"optim" : ["str", "SGD"], | ||
"decay" : ["float", "0.0005"], | ||
"momentum" : ["float", "0.9"], | ||
"nesterov" : ["bool", "1"], | ||
"criterion": ["str", "Softmax"], | ||
"batch_size": ["int", "256"] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
#################### | ||
# DARTS, ICLR 2019 # | ||
#################### | ||
import torch | ||
import torch.nn as nn | ||
from copy import deepcopy | ||
from .search_cells import NASNetSearchCell as SearchCell | ||
from .genotypes import Structure | ||
|
||
|
||
# The macro structure is based on NASNet | ||
class NASNetworkDARTS(nn.Module): | ||
|
||
def __init__(self, C, N, steps, multiplier, stem_multiplier, num_classes, search_space, affine, track_running_stats): | ||
super(NASNetworkDARTS, self).__init__() | ||
self._C = C | ||
self._layerN = N | ||
self._steps = steps | ||
self._multiplier = multiplier | ||
self.stem = nn.Sequential( | ||
nn.Conv2d(3, C*stem_multiplier, kernel_size=3, padding=1, bias=False), | ||
nn.BatchNorm2d(C*stem_multiplier)) | ||
|
||
# config for each layer | ||
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * (N-1) + [C*4 ] + [C*4 ] * (N-1) | ||
layer_reductions = [False] * N + [True] + [False] * (N-1) + [True] + [False] * (N-1) | ||
|
||
num_edge, edge2index = None, None | ||
C_prev_prev, C_prev, C_curr, reduction_prev = C*stem_multiplier, C*stem_multiplier, C, False | ||
|
||
self.cells = nn.ModuleList() | ||
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): | ||
cell = SearchCell(search_space, steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, affine, track_running_stats) | ||
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index | ||
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) | ||
self.cells.append( cell ) | ||
C_prev_prev, C_prev, reduction_prev = C_prev, multiplier*C_curr, reduction | ||
self.op_names = deepcopy( search_space ) | ||
self._Layer = len(self.cells) | ||
self.edge2index = edge2index | ||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) | ||
self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||
self.classifier = nn.Linear(C_prev, num_classes) | ||
self.arch_normal_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) | ||
self.arch_reduce_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) | ||
|
||
def get_weights(self): | ||
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) | ||
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) | ||
xlist+= list( self.classifier.parameters() ) | ||
return xlist | ||
|
||
def get_alphas(self): | ||
return [self.arch_normal_parameters, self.arch_reduce_parameters] | ||
|
||
def show_alphas(self): | ||
with torch.no_grad(): | ||
A = 'arch-normal-parameters :\n{:}'.format( nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu() ) | ||
B = 'arch-reduce-parameters :\n{:}'.format( nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu() ) | ||
return '{:}\n{:}'.format(A, B) | ||
|
||
def get_message(self): | ||
string = self.extra_repr() | ||
for i, cell in enumerate(self.cells): | ||
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) | ||
return string | ||
|
||
def extra_repr(self): | ||
return ('{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) | ||
|
||
def genotype(self): | ||
def _parse(weights): | ||
gene = [] | ||
for i in range(self._steps): | ||
edges = [] | ||
for j in range(2+i): | ||
node_str = '{:}<-{:}'.format(i, j) | ||
ws = weights[ self.edge2index[node_str] ] | ||
for k, op_name in enumerate(self.op_names): | ||
if op_name == 'none': continue | ||
edges.append( (op_name, j, ws[k]) ) | ||
edges = sorted(edges, key=lambda x: -x[-1]) | ||
selected_edges = edges[:2] | ||
gene.append( tuple(selected_edges) ) | ||
return gene | ||
with torch.no_grad(): | ||
gene_normal = _parse(torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy()) | ||
gene_reduce = _parse(torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy()) | ||
return {'normal': gene_normal, 'normal_concat': list(range(2+self._steps-self._multiplier, self._steps+2)), | ||
'reduce': gene_reduce, 'reduce_concat': list(range(2+self._steps-self._multiplier, self._steps+2))} | ||
|
||
def forward(self, inputs): | ||
|
||
normal_w = nn.functional.softmax(self.arch_normal_parameters, dim=1) | ||
reduce_w = nn.functional.softmax(self.arch_reduce_parameters, dim=1) | ||
|
||
s0 = s1 = self.stem(inputs) | ||
for i, cell in enumerate(self.cells): | ||
if cell.reduction: ww = reduce_w | ||
else : ww = normal_w | ||
s0, s1 = s1, cell.forward_darts(s0, s1, ww) | ||
out = self.lastact(s1) | ||
out = self.global_pooling( out ) | ||
out = out.view(out.size(0), -1) | ||
logits = self.classifier(out) | ||
|
||
return out, logits |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
#!/bin/bash | ||
# bash ./scripts-search/DARTS1V-search-NASNet-space.sh cifar10 -1 | ||
echo script name: $0 | ||
echo $# arguments | ||
if [ "$#" -ne 2 ] ;then | ||
echo "Input illegal number of parameters " $# | ||
echo "Need 2 parameters for dataset, and seed" | ||
exit 1 | ||
fi | ||
if [ "$TORCH_HOME" = "" ]; then | ||
echo "Must set TORCH_HOME envoriment variable for data dir saving" | ||
exit 1 | ||
else | ||
echo "TORCH_HOME : $TORCH_HOME" | ||
fi | ||
|
||
dataset=$1 | ||
BN=1 | ||
seed=$2 | ||
channel=16 | ||
num_cells=5 | ||
max_nodes=4 | ||
space=darts | ||
|
||
if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then | ||
data_path="$TORCH_HOME/cifar.python" | ||
else | ||
data_path="$TORCH_HOME/cifar.python/ImageNet16" | ||
fi | ||
|
||
save_dir=./output/search-cell-${space}/DARTS-V1-${dataset}-BN${BN} | ||
|
||
OMP_NUM_THREADS=4 python ./exps/algos/DARTS-V1.py \ | ||
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ | ||
--dataset ${dataset} --data_path ${data_path} \ | ||
--search_space_name ${space} \ | ||
--config_path configs/search-opts/DARTS-NASNet-CIFAR.config \ | ||
--model_config configs/search-archs/GDAS-NASNet-CIFAR.config \ | ||
--track_running_stats ${BN} \ | ||
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \ | ||
--workers 4 --print_freq 200 --rand_seed ${seed} |