-
Notifications
You must be signed in to change notification settings - Fork 1
/
global_var.py
97 lines (89 loc) · 2.41 KB
/
global_var.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
GLOBAL = {
"continual_split_to_tags": None,
"continual_tag_to_split": None,
"infer_model": None,
}
def get_epoch_map(big_model: bool):
if big_model:
return {
'fewnerd': 10,
'ontonotes': 10,
'bbn': 10,
'fewrel': 10,
'tacred': 20,
'ace': 20,
'chent': 10,
}
else:
return {
'fewnerd': 10,
'ontonotes': 20,
'bbn': 20,
'fewrel': 10,
'tacred': 20,
'ace': 20,
'chent': 20,
}
def get_batch_limit(dataset_name: str):
return {
'fewnerd': 2500,
'ontonotes': 1250,
'bbn': 500,
'fewrel': 400,
'tacred': 100,
'ace': 100,
'chent': 1250,
}[dataset_name]
def get_batch_size_map(big_model: bool):
if big_model:
return {
'fewnerd': 8,
'ontonotes': 8,
'bbn': 8,
'fewrel': 8,
'tacred': 8,
'ace': 8,
'chent': 8,
}
else:
return {
'fewnerd': 16,
'ontonotes': 16,
'bbn': 16,
'fewrel': 16,
'tacred': 16,
'ace': 16,
'chent': 16,
}
def get_learning_rate(big_model: bool, dataset_name: str, split_id: int, method: str, is_adapter=False):
if big_model:
if dataset_name == 'fewnerd':
if method not in ['emr', 'emar', 'eaemr', 'our_abl', 'emr_abl']:
if is_adapter:
return 0.0001
else:
return 0.0005 if split_id <= 2 else 0.0001
else:
return 1.0e-5 if is_adapter else 2.0e-5
elif dataset_name == 'ontonotes':
return 0.0005 if split_id <= 3 else 0.0001
elif dataset_name == 'bbn':
return 0.0002
elif dataset_name == 'fewrel':
return 0.0002
elif dataset_name == 'tacred':
return 0.0002
elif dataset_name == 'ace':
return 0.0002
else:
raise NotImplementedError(f'invalid dataset_name: {dataset_name}')
else:
return {
'fewnerd': 0.00002,
'ontonotes': 0.0001,
'bbn': 0.00005,
'fewrel': 0.001,
'tacred': 0.001,
'ace': 0.002,
'chent': 0.0001,
}[dataset_name]