-
Notifications
You must be signed in to change notification settings - Fork 0
/
ops.py
153 lines (111 loc) · 4.22 KB
/
ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import numpy as np
import dynet as dy
import nn
def cat(xs, dim=-1):
head_shape, batch_size = xs[0].dim()
if dim > len(head_shape):
raise RuntimeError('Bad dim %d for shape %s, you can '
'use -1 to indicate last dim '
% (dim, str(head_shape)))
if dim == -1:
dim = len(head_shape)
return dy.concatenate(xs, d=dim)
def expand_dims(x, dim=-1):
head_shape, batch_size = x.dim()
if dim > len(head_shape):
raise RuntimeError('Bad dim %d for shape %s, you can '
'use -1 to indicate last dim '
%(dim, str(head_shape)))
if dim == -1:
dim = len(head_shape)
ex_shape = list(head_shape)
ex_shape.insert(dim, 1)
return dy.reshape(x, tuple(ex_shape))
def layer_norm(xs):
head_shape, batch_size = xs[0].dim()
g = dy.ones(head_shape)
b = dy.zeros(head_shape)
return [dy.layer_norm(x, g, b) for x in xs]
def squeeze(x, dim=None):
head_shape, batch_size = x.dim()
if dim is None:
sq_shape = [d for d in head_shape if d != 1]
else:
if dim > len(head_shape):
raise RuntimeError('Bad dim %d for shape %s, you can '
'use -1 to indicate last dim. Hint: '
'you can not squeeze batch dim due to dynet mechanism'
% (dim, str(head_shape)))
if head_shape[dim] != 1:
raise RuntimeError('You can not squeeze dim %d for shape %s' % (dim, str(head_shape)))
sq_shape = list(head_shape)
sq_shape.pop(dim)
return dy.reshape(x , tuple(sq_shape))
def sum(x, dim=None, include_batch_dim=False):
if isinstance(x, list):
return dy.esum(x)
head_shape, batch_size = x.dim()
if dim is None:
x = dy.sum_elems(x)
if include_batch_dim and batch_size > 1:
return dy.sum_batches(x)
else:
return x
else:
if dim == -1:
dim = len(head_shape) - 1
return dy.sum_dim(x, d=[dim], b=include_batch_dim)
def mean(x, dim=None, include_batch_dim=False):
if isinstance(x, list):
return dy.average(x)
head_shape, batch_size = x.dim()
if dim is None:
# warning: dynet only implement 2 or lower dims for mean_elems
x = dy.mean_elems(x)
if include_batch_dim and batch_size > 1:
return dy.mean_batches(x)
else:
return x
else:
if dim == -1:
dim = len(head_shape) - 1
return dy.mean_dim(x, d=[dim], b=include_batch_dim)
def split(x, dim=1):
head_shape, batch_size = x.dim()
res = []
if dim == 0:
for i in range(head_shape[0]):
res.append(dy.select_rows(x, [i]))
elif dim == 1:
for i in range(head_shape[1]):
res.append(dy.select_cols(x, [i]))
return res
def pick_mat(x, row_idx, col_idx):
return x[row_idx][col_idx]
def logsumexp_dim(x, dim=0):
return dy.logsumexp_dim(x, d=dim)
# def logsumexp(x):
# return dy.logsumexp(x)
def log_sum_exp(scores, n_tags):
npval = scores.npvalue()
argmax_score = np.argmax(npval)
max_score_expr = dy.pick(scores, argmax_score)
max_score_expr_broadcast = dy.concatenate([max_score_expr] * n_tags)
return max_score_expr + dy.log(dy.sum_cols(dy.transpose(dy.exp(scores - max_score_expr_broadcast))))
def dropout_list(rep_list, dp_rate):
return [dy.dropout(rep, dp_rate) for rep in rep_list]
def dropout_dim_list(rep_list, dp_rate, dim=0):
return [dy.dropout_dim(rep, dim, dp_rate) for rep in rep_list]
def cat_list(rep_list_a, rep_list_b, dim=0):
return [dy.concatenate([rep_a, rep_b], d=dim) for rep_a, rep_b in zip(rep_list_a, rep_list_b)]
def add_list(rep_list_a, rep_list_b):
return [rep_a + rep_b for rep_a, rep_b in zip(rep_list_a, rep_list_b)]
def sum_list(rep_list_a, rep_list_b):
return [rep_a+rep_b for rep_a, rep_b in zip(rep_list_a, rep_list_b)]
def binary_cross_entropy(x, y):
max_val = nn.relu(-x)
loss = x - dy.cmult(x, y) + max_val + dy.log(dy.exp(-max_val) + dy.exp(-x - max_val))
return nn.mean(loss)
def max_np(np_vec):
np_vec = np_vec.flatten()
return np.max(np_vec), np.argmax(np_vec)