-
Notifications
You must be signed in to change notification settings - Fork 3
/
utils.py
123 lines (109 loc) · 4.52 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
import numpy as np
import logging
import os
def splitdata(length,fold,index):
fold_length = length // fold
index_list = np.arange(length)
if index == 1:
val = index_list[:fold_length]
test = index_list[fold_length * (fold - 1):]
train = index_list[fold_length : fold_length * (fold - 1)]
elif index == fold:
val = index_list[fold_length * (fold - 1):]
test = index_list[fold_length * (fold - 2) : fold_length * (fold - 1)]
train = index_list[:fold_length * (fold - 2)]
else:
val = index_list[fold_length * (index - 1) : fold_length * index]
test = index_list[fold_length * (index - 2) : fold_length * (index - 1)]
train = np.concatenate([index_list[:fold_length * (index - 2)],index_list[fold_length * index:]])
return train,val,test
def printParams(model_params, logger=None):
print("=========== Parameters ==========")
for k,v in model_params.items():
print(f'{k} : {v}')
print("=================================")
print()
if logger:
for k,v in model_params.items():
logger.info(f'{k} : {v}')
def applyIndexOnList(lis,idx):
ans = []
for _ in idx:
ans.append(lis[_])
return ans
def set_seed(seed):
torch.manual_seed(seed) # set seed for cpu
torch.cuda.manual_seed(seed) # set seed for gpu
torch.backends.cudnn.deterministic = True # cudnn
torch.backends.cudnn.benchmark = False
np.random.seed(seed) # numpy
def get_logger(save_dir):
logger = logging.getLogger(__name__)
logger.setLevel(level = logging.INFO)
handler = logging.FileHandler(save_dir + "/log.txt")
handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
class CheckpointHandler(object):
def __init__(self, save_dir, max_save=5):
self.save_dir = save_dir
self.max_save = max_save
self.init_info()
def init_info(self):
os.makedirs(self.save_dir, exist_ok=True)
self.metric_dic = {}
if os.path.exists(self.save_dir+'/eval_log.txt'):
with open(self.save_dir+'/eval_log.txt','r') as f:
ls = f.readlines()
for l in ls:
l = l.strip().split(':')
assert len(l) == 2
self.metric_dic[l[0]] = float(l[1])
def save_model(self, model, model_params, epoch, eval_metric):
max_in_dic = max(self.metric_dic.values()) if len(self.metric_dic) else 1e9
if eval_metric > max_in_dic:
return
if len(self.metric_dic) == self.max_save:
self.remove_last()
self.metric_dic['model-'+str(epoch)+'.pt'] = eval_metric
state = {"params":model_params, "epoch":epoch, "model":model.state_dict()}
torch.save(state, self.save_dir + '/' + 'model-'+str(epoch)+'.pt')
log_str = '\n'.join(['{}:{:.7f}'.format(k,v) for k,v in self.metric_dic.items()])
with open(self.save_dir+'/eval_log.txt','w') as f:
f.write(log_str)
def remove_last(self):
last_model = sorted(list(self.metric_dic.keys()),key = lambda x:self.metric_dic[x])[-1]
if os.path.exists(self.save_dir+'/'+last_model):
os.remove(self.save_dir+'/'+last_model)
self.metric_dic.pop(last_model)
def checkpoint_best(self, use_cuda=True):
best_model = sorted(list(self.metric_dic.keys()),key = lambda x:self.metric_dic[x])[0]
if use_cuda:
state = torch.load(self.save_dir + '/' + best_model)
else:
state = torch.load(self.save_dir + '/' + best_model,map_location='cpu')
return state
def checkpoint_avg(self, use_cuda=True):
return_dic = None
model_num = 0
tmp_model_params = None
for ckpt in os.listdir(self.save_dir):
if not ckpt.endswith('.pt'):
continue
model_num += 1
if use_cuda:
state = torch.load(self.save_dir + '/' + ckpt)
else:
state = torch.load(self.save_dir + '/' + ckpt,map_location='cpu')
model,tmp_model_params = state['model'], state['params']
if not return_dic:
return_dic = model
else:
for k in return_dic:
return_dic[k] += model[k]
for k in return_dic:
return_dic[k] = return_dic[k]/model_num
return {'params':tmp_model_params, 'model':return_dic}