Skip to content

Commit

Permalink
Fix the potential memory leak in NAS-Bench-201 clear_param
Browse files Browse the repository at this point in the history
  • Loading branch information
D-X-Y committed Mar 21, 2020
1 parent b702ddf commit 2202588
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 38 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Expand Up @@ -121,3 +121,5 @@ lib/NAS-Bench-*-v1_0.pth
others/TF
scripts-search/l2s-algos
TEMP-L.sh

.nfs00*
4 changes: 1 addition & 3 deletions exps/NAS-Bench-201/functions.py
@@ -1,19 +1,17 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
#####################################################
import os, sys, time, torch
import time, torch
from procedures import prepare_seed, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from config_utils import dict2config
from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net



__all__ = ['evaluate_for_seed', 'pure_evaluate']



def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()):
data_time, batch_time, batch = AverageMeter(), AverageMeter(), None
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
Expand Down
2 changes: 1 addition & 1 deletion exps/NAS-Bench-201/main.py
Expand Up @@ -28,7 +28,7 @@ def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_c
for dataset, xpath, split in zip(datasets, xpaths, splits):
# train valid data
train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)
# load the configurature
# load the configuration
if dataset == 'cifar10' or dataset == 'cifar100':
if use_less: config_path = 'configs/nas-benchmark/LESS.config'
else : config_path = 'configs/nas-benchmark/CIFAR.config'
Expand Down
2 changes: 1 addition & 1 deletion exps/NAS-Bench-201/show-best.py
Expand Up @@ -3,7 +3,7 @@
################################################################################################
# python exps/NAS-Bench-201/show-best.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth #
################################################################################################
import os, sys, time, glob, random, argparse
import sys, argparse
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
Expand Down
15 changes: 8 additions & 7 deletions exps/NAS-Bench-201/test-weights.py
Expand Up @@ -6,7 +6,7 @@
# python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 --dataset cifar10-valid --use_12 1 --use_valid 1
# bash ./scripts-search/NAS-Bench-201/test-weights.sh cifar10-valid 1
###############################################################################################
import os, gc, sys, time, glob, random, argparse
import os, gc, sys, argparse, psutil
import numpy as np
import torch
from pathlib import Path
Expand All @@ -33,7 +33,7 @@ def tostr(accdict, norms):

def evaluate(api, weight_dir, data: str, use_12epochs_result: bool):
print('\nEvaluate dataset={:}'.format(data))
norms = []
norms, process = [], psutil.Process(os.getpid())
final_val_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
final_test_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
for idx in range(len(api)):
Expand All @@ -56,16 +56,17 @@ def evaluate(api, weight_dir, data: str, use_12epochs_result: bool):
with torch.no_grad():
net.load_state_dict(param)
_, summary = weight_watcher.analyze(net, alphas=False)
cur_norms.append( summary['lognorm'] )
cur_norms.append(summary['lognorm'])
norms.append( float(np.mean(cur_norms)) )
api.clear_params(idx, use_12epochs_result)
api.clear_params(idx, None)
if idx % 200 == 199 or idx + 1 == len(api):
head = '{:05d}/{:05d}'.format(idx, len(api))
stem_val = tostr(final_val_accs, norms)
stem_test = tostr(final_test_accs, norms)
print('{:} {:} {:} with {:} epochs on {:} : the correlation is {:.3f}'.format(time_string(), head, data, 12 if use_12epochs_result else 200))
print(' -->> {:} || {:}'.format(stem_val, stem_test))
torch.cuda.empty_cache() ; gc.collect()
print('{:} {:} {:} with {:} epochs ({:.2f} MB memory)'.format(time_string(), head, data, 12 if use_12epochs_result else 200, process.memory_info().rss / 1e6))
print(' [Valid] -->> {:}'.format(stem_val))
print(' [Test.] -->> {:}'.format(stem_test))
gc.collect()


def main(meta_file: str, weight_dir, save_dir, xdata, use_12epochs_result):
Expand Down
2 changes: 1 addition & 1 deletion exps/NAS-Bench-201/visualize.py
Expand Up @@ -3,7 +3,7 @@
#####################################################
# python exps/NAS-Bench-201/visualize.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth
#####################################################
import os, sys, time, argparse, collections
import sys, argparse
from tqdm import tqdm
from collections import OrderedDict
import numpy as np
Expand Down
10 changes: 5 additions & 5 deletions exps/NAS-Bench-201/xshapes.py
Expand Up @@ -24,19 +24,19 @@ def evaluate_all_datasets(channels: Text, datasets: List[Text], xpaths: List[Tex
machine_info = get_machine_info()
all_infos = {'info': machine_info}
all_dataset_keys = []
# look all the datasets
# look all the dataset
for dataset, xpath, split in zip(datasets, xpaths, splits):
# train valid data
# the train and valid data
train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)
# load the configurature
# load the configuration
if dataset == 'cifar10' or dataset == 'cifar100':
split_info = load_config('configs/nas-benchmark/cifar-split.txt', None, None)
elif dataset.startswith('ImageNet16'):
split_info = load_config('configs/nas-benchmark/{:}-split.txt'.format(dataset), None, None)
else:
raise ValueError('invalid dataset : {:}'.format(dataset))
config = load_config(config_path, dict(class_num=class_num, xshape=xshape), logger)
# check whether use splited validation set
# check whether use the splitted validation set
if bool(split):
assert dataset == 'cifar10'
ValLoaders = {'ori-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)}
Expand Down Expand Up @@ -92,7 +92,7 @@ def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text],

log_dir = save_dir / 'logs'
log_dir.mkdir(parents=True, exist_ok=True)
logger = Logger(str(log_dir), 0, False)
logger = Logger(str(log_dir), os.getpid(), False)

logger.log('xargs : seeds = {:}'.format(seeds))
logger.log('xargs : cover_mode = {:}'.format(cover_mode))
Expand Down
39 changes: 20 additions & 19 deletions lib/nas_201_api/api.py
Expand Up @@ -114,15 +114,27 @@ def reload(self, archive_root: Text, index: int):
assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path)
xdata = torch.load(xfile_path, map_location='cpu')
assert isinstance(xdata, dict) and 'full' in xdata and 'less' in xdata, 'invalid format of data in {:}'.format(xfile_path)
if index in self.arch2infos_less: del self.arch2infos_less[index]
if index in self.arch2infos_full: del self.arch2infos_full[index]
self.arch2infos_less[index] = ArchResults.create_from_state_dict( xdata['less'] )
self.arch2infos_full[index] = ArchResults.create_from_state_dict( xdata['full'] )

def clear_params(self, index: int, use_12epochs_result: bool):
"""Remove the architecture's weights to save memory."""
if use_12epochs_result: arch2infos = self.arch2infos_less
else : arch2infos = self.arch2infos_full
archresult = arch2infos[index]
archresult.clear_params()
def clear_params(self, index: int, use_12epochs_result: Union[bool, None]):
"""Remove the architecture's weights to save memory.
:arg
index: the index of the target architecture
use_12epochs_result: a flag to controll how to clear the parameters.
-- None: clear all the weights in both `less` and `full`, which indicates the training hyper-parameters.
-- True: clear all the weights in arch2infos_less, which by default is 12-epoch-training result.
-- False: clear all the weights in arch2infos_full, which by default is 200-epoch-training result.
"""
if use_12epochs_result is None:
self.arch2infos_less[index].clear_params()
self.arch2infos_full[index].clear_params()
else:
if use_12epochs_result: arch2infos = self.arch2infos_less
else : arch2infos = self.arch2infos_full
arch2infos[index].clear_params()

# This function is used to query the information of a specific archiitecture
# 'arch' can be an architecture index or an architecture string
Expand Down Expand Up @@ -193,7 +205,6 @@ def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None, use_1
best_index, highest_accuracy = idx, accuracy
return best_index, highest_accuracy


def arch(self, index: int):
"""Return the topology structure of the `index`-th architecture."""
assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs))
Expand All @@ -213,7 +224,6 @@ def get_net_param(self, index, dataset, seed, use_12epochs_result=False):
else: arch2infos = self.arch2infos_full
arch_result = arch2infos[index]
return arch_result.get_net_param(dataset, seed)


def get_net_config(self, index: int, dataset: Text):
"""
Expand All @@ -235,15 +245,13 @@ def get_net_config(self, index: int, dataset: Text):
#print ('SEED [{:}] : {:}'.format(seed, result))
raise ValueError('Impossible to reach here!')


def get_cost_info(self, index: int, dataset: Text, use_12epochs_result: bool = False) -> Dict[Text, float]:
"""To obtain the cost metric for the `index`-th architecture on a dataset."""
if use_12epochs_result: arch2infos = self.arch2infos_less
else: arch2infos = self.arch2infos_full
arch_result = arch2infos[index]
return arch_result.get_compute_costs(dataset)


def get_latency(self, index: int, dataset: Text, use_12epochs_result: bool = False) -> float:
"""
To obtain the latency of the network (by default it will return the latency with the batch size of 256).
Expand All @@ -254,7 +262,6 @@ def get_latency(self, index: int, dataset: Text, use_12epochs_result: bool = Fal
cost_dict = self.get_cost_info(index, dataset, use_12epochs_result)
return cost_dict['latency']


# obtain the metric for the `index`-th architecture
# `dataset` indicates the dataset:
# 'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set
Expand Down Expand Up @@ -388,7 +395,6 @@ def get_more_info(self, index: int, dataset, iepoch=None, use_12epochs_result=Fa
return xifo
"""


def show(self, index: int = -1) -> None:
"""
This function will print the information of a specific (or all) architecture(s).
Expand Down Expand Up @@ -423,7 +429,6 @@ def show(self, index: int = -1) -> None:
else:
print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs)))


def statistics(self, dataset: Text, use_12epochs_result: bool) -> Dict[int, int]:
"""
This function will count the number of total trials.
Expand All @@ -443,7 +448,6 @@ def statistics(self, dataset: Text, use_12epochs_result: bool) -> Dict[int, int]
nums[len(dataset_seed[dataset])] += 1
return dict(nums)


@staticmethod
def str2lists(arch_str: Text) -> List[tuple]:
"""
Expand Down Expand Up @@ -471,7 +475,6 @@ def str2lists(arch_str: Text) -> List[tuple]:
genotypes.append( input_infos )
return genotypes


@staticmethod
def str2matrix(arch_str: Text,
search_space: List[Text] = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']) -> np.ndarray:
Expand Down Expand Up @@ -511,7 +514,6 @@ def str2matrix(arch_str: Text,
return matrix



class ArchResults(object):

def __init__(self, arch_index, arch_str):
Expand Down Expand Up @@ -752,7 +754,6 @@ def debug_test(self):

def __repr__(self):
return ('{name}(arch-index={index}, arch={arch}, {num} runs, clear={clear})'.format(name=self.__class__.__name__, index=self.arch_index, arch=self.arch_str, num=len(self.all_results), clear=self.clear_net_done))



"""
Expand Down Expand Up @@ -872,8 +873,8 @@ def get_train(self, iepoch=None):
'cur_time': xtime,
'all_time': atime}

# get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument).
def get_eval(self, name, iepoch=None):
"""Get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument)."""
if iepoch is None: iepoch = self.epochs-1
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
if isinstance(self.eval_times,dict) and len(self.eval_times) > 0:
Expand All @@ -890,8 +891,8 @@ def get_net_param(self, clone=False):
if clone: return copy.deepcopy(self.net_state_dict)
else: return self.net_state_dict

# This function is used to obtain the config dict for this architecture.
def get_config(self, str2structure):
"""This function is used to obtain the config dict for this architecture."""
if str2structure is None:
return {'name': 'infer.tiny', 'C': self.arch_config['channel'],
'N' : self.arch_config['num_cells'],
Expand Down
2 changes: 1 addition & 1 deletion scripts-search/NAS-Bench-201/test-weights.sh
Expand Up @@ -15,7 +15,7 @@ else
echo "TORCH_HOME : $TORCH_HOME"
fi

OMP_NUM_THREADS=4 python exps/NAS-Bench-201/test-weights.py \
CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/NAS-Bench-201/test-weights.py \
--base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 \
--dataset $1 \
--use_12 $2

0 comments on commit 2202588

Please sign in to comment.