-
Notifications
You must be signed in to change notification settings - Fork 2
/
run_ewc.py
executable file
·47 lines (36 loc) · 1.26 KB
/
run_ewc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
'''
implementation of Elastic Weight Consolidation (Kirkpatrick et al, 2017)
Timo Flesch, 2021
'''
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from ewc_lib.model import Nnet
from ewc_lib.trainer import train_nnet
from ewc_lib.visualise import disp_results
# ----------------------------------------------------------------------------------------
# parameters
# ----------------------------------------------------------------------------------------
# define a few variables
params = {}
params['n_inputs'] = 784
params['n_classes'] = 10
params['n_hidden'] = 100
params['weight_init'] = 1e-5
params['n_iters'] = 5000
params['lrate'] = 1e-1
params['do_ewc'] = True
params['ewc_lambda'] = 15
params['fim_samples'] = 1000
params['mbatch_size'] = 250
params['disp_n_steps'] = 100
params['verbose'] = True
params['task'] = 'permutedMNIST' # permutedMNIST or splitMNIST
params['device'] = 'CPU'
if params['device']=='CPU':
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
# ----------------------------------------------------------------------------------------
# main experiment
# ----------------------------------------------------------------------------------------
if __name__ == "__main__":
results = train_nnet(params)
disp_results(results,params['do_ewc'])