diff --git a/.gitignore b/.gitignore index 944b9cc..6a8b3e5 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,4 @@ **/log/** *events* *.txt - +*.idea/ diff --git a/experiments/MNIST/YOPO-5-10/config.py b/experiments/MNIST/YOPO-5-10/config.py index 8e933c3..82bdffe 100644 --- a/experiments/MNIST/YOPO-5-10/config.py +++ b/experiments/MNIST/YOPO-5-10/config.py @@ -1,29 +1,31 @@ -from easydict import EasyDict import sys import os import argparse import numpy as np import torch -from loss import CrossEntropyWithWeightPenlty + def add_path(path): if path not in sys.path: print('Adding {}'.format(path)) sys.path.append(path) + abs_current_path = os.path.realpath('./') root_path = os.path.join('/', *abs_current_path.split(os.path.sep)[:-3]) lib_dir = os.path.join(root_path, 'lib') add_path(lib_dir) + from training.config import TrainingConfigBase, SGDOptimizerMaker, \ PieceWiseConstantLrSchedulerMaker, IPGDAttackMethodMaker + class TrainingConfing(TrainingConfigBase): lib_dir = lib_dir - num_epochs = 20 + num_epochs = 40 val_interval = 1 weight_decay = 5e-4 @@ -32,15 +34,11 @@ class TrainingConfing(TrainingConfigBase): sigma = 0.01 eps = 0.3 - create_optimizer = SGDOptimizerMaker(lr =1e-1 / K, momentum = 0.9, weight_decay = 5e-4) - create_lr_scheduler = PieceWiseConstantLrSchedulerMaker(milestones = [15, 19], gamma = 0.1) + create_optimizer = SGDOptimizerMaker(lr =1e-1 / K, momentum = 0.9, weight_decay = weight_decay) + create_lr_scheduler = PieceWiseConstantLrSchedulerMaker(milestones = [30, 35, 39], gamma = 0.1) create_loss_function = None#torch.nn.CrossEntropyLoss - #create_attack_method = \ - # IPGDAttackMethodMaker(eps = 8/255.0, sigma = 2/255.0, nb_iters = 10, norm = np.inf, - # mean = torch.tensor(np.array([0]).astype(np.float32)[np.newaxis, :, np.newaxis, np.newaxis]), - # std = torch.tensor(np.array([1]).astype(np.float32)[np.newaxis, :, np.newaxis, np.newaxis])) create_attack_method = None create_evaluation_attack_method = \ @@ -53,10 +51,6 @@ class TrainingConfing(TrainingConfigBase): config = TrainingConfing() -# About data -# C.inp_chn = 1 -# C.num_class = 10 - parser = argparse.ArgumentParser() parser.add_argument('--resume', default=None, type=str, metavar='PATH',