-
Notifications
You must be signed in to change notification settings - Fork 38
/
cyclegan.py
315 lines (256 loc) · 12.7 KB
/
cyclegan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
# -*- coding: utf-8 -*-
import logging
from functools import partial
import keras
import numpy as np
import pandas as pd
from keras import backend as K
from keras.layers import Input
from keras.layers.merge import _Merge
from keras.models import Model
from scipy import integrate, stats
from mlprimitives.adapters.keras import build_layer
from mlprimitives.utils import import_object
LOGGER = logging.getLogger(__name__)
class RandomWeightedAverage(_Merge):
def _merge_function(self, inputs):
alpha = K.random_uniform((64, 1, 1))
return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])
class CycleGAN():
"""CycleGAN class"""
def _build_model(self, hyperparameters, layers, input_shape):
x = Input(shape=input_shape)
model = keras.models.Sequential()
for layer in layers:
built_layer = build_layer(layer, hyperparameters)
model.add(built_layer)
return Model(x, model(x))
def _wasserstein_loss(self, y_true, y_pred):
return K.mean(y_true * y_pred)
def _gradient_penalty_loss(self, y_true, y_pred, averaged_samples):
gradients = K.gradients(y_pred, averaged_samples)[0]
gradients_sqr = K.square(gradients)
gradients_sqr_sum = K.sum(gradients_sqr, axis=np.arange(1, len(gradients_sqr.shape)))
gradient_l2_norm = K.sqrt(gradients_sqr_sum)
gradient_penalty = K.square(1 - gradient_l2_norm)
return K.mean(gradient_penalty)
def __init__(self, shape, encoder_input_shape, generator_input_shape, critic_x_input_shape,
critic_z_input_shape, layers_encoder, layers_generator, layers_critic_x,
layers_critic_z, optimizer, learning_rate=0.0005, epochs=2000, latent_dim=20,
batch_size=64, iterations_critic=5, **hyperparameters):
"""Initialize the ARIMA object.
Args:
shape (tuple):
Tuple denoting the shape of an input sample.
encoder_input_shape (tuple):
Shape of encoder input.
generator_input_shape (tuple):
Shape of generator input.
critic_x_input_shape (tuple):
Shape of critic_x input.
critic_z_input_shape (tuple):
Shape of critic_z input.
layers_encoder (list):
List containing layers of encoder.
layers_generator (list):
List containing layers of generator.
layers_critic_x (list):
List containing layers of critic_x.
layers_critic_z (list):
List containing layers of critic_z.
optimizer (str):
String denoting the keras optimizer.
learning_rate (float):
Optional. Float denoting the learning rate of the optimizer. Default 0.005.
epochs (int):
Optional. Integer denoting the number of epochs. Default 2000.
latent_dim (int):
Optional. Integer denoting dimension of latent space. Default 20.
batch_size (int):
Integer denoting the batch size. Default 64.
iterations_critic (int):
Optional. Integer denoting the number of critic training steps per one
Generator/Encoder training step. Default 5.
hyperparameters (dictionary):
Optional. Dictionary containing any additional inputs.
"""
self.shape = shape
self.latent_dim = latent_dim
self.batch_size = batch_size
self.iterations_critic = iterations_critic
self.epochs = epochs
self.hyperparameters = hyperparameters
self.encoder_input_shape = encoder_input_shape
self.generator_input_shape = generator_input_shape
self.critic_x_input_shape = critic_x_input_shape
self.critic_z_input_shape = critic_z_input_shape
self.layers_encoder, self.layers_generator = layers_encoder, layers_generator
self.layers_critic_x, self.layers_critic_z = layers_critic_x, layers_critic_z
self.optimizer = import_object(optimizer)(learning_rate)
def _build_cyclegan(self, **kwargs):
hyperparameters = self.hyperparameters.copy()
hyperparameters.update(kwargs)
self.encoder = self._build_model(hyperparameters, self.layers_encoder,
self.encoder_input_shape)
self.generator = self._build_model(hyperparameters, self.layers_generator,
self.generator_input_shape)
self.critic_x = self._build_model(hyperparameters, self.layers_critic_x,
self.critic_x_input_shape)
self.critic_z = self._build_model(hyperparameters, self.layers_critic_z,
self.critic_z_input_shape)
self.generator.trainable = False
self.encoder.trainable = False
z = Input(shape=(self.latent_dim, 1))
x = Input(shape=self.shape)
x_ = self.generator(z)
z_ = self.encoder(x)
fake_x = self.critic_x(x_)
valid_x = self.critic_x(x)
interpolated_x = RandomWeightedAverage()([x, x_])
validity_interpolated_x = self.critic_x(interpolated_x)
partial_gp_loss_x = partial(self._gradient_penalty_loss, averaged_samples=interpolated_x)
partial_gp_loss_x.__name__ = 'gradient_penalty'
self.critic_x_model = Model(inputs=[x, z], outputs=[valid_x, fake_x,
validity_interpolated_x])
self.critic_x_model.compile(loss=[self._wasserstein_loss, self._wasserstein_loss,
partial_gp_loss_x], optimizer=self.optimizer,
loss_weights=[1, 1, 5])
fake_z = self.critic_z(z_)
valid_z = self.critic_z(z)
interpolated_z = RandomWeightedAverage()([z, z_])
validity_interpolated_z = self.critic_z(interpolated_z)
partial_gp_loss_z = partial(self._gradient_penalty_loss, averaged_samples=interpolated_z)
partial_gp_loss_z.__name__ = 'gradient_penalty'
self.critic_z_model = Model(inputs=[x, z], outputs=[valid_z, fake_z,
validity_interpolated_z])
self.critic_z_model.compile(loss=[self._wasserstein_loss, self._wasserstein_loss,
partial_gp_loss_z], optimizer=self.optimizer,
loss_weights=[1, 1, 10])
self.critic_x.trainable = False
self.critic_z.trainable = False
self.generator.trainable = True
self.encoder.trainable = True
z_gen = Input(shape=(self.latent_dim, 1))
x_gen_ = self.generator(z_gen)
x_gen = Input(shape=self.shape)
z_gen_ = self.encoder(x_gen)
x_gen_rec = self.generator(z_gen_)
fake_gen_x = self.critic_x(x_gen_)
fake_gen_z = self.critic_z(z_gen_)
self.encoder_generator_model = Model([x_gen, z_gen], [fake_gen_x, fake_gen_z, x_gen_rec])
self.encoder_generator_model.compile(loss=[self._wasserstein_loss, self._wasserstein_loss,
'mse'], optimizer=self.optimizer,
loss_weights=[1, 1, 50])
def _fit(self, X):
fake = np.ones((self.batch_size, 1))
valid = -np.ones((self.batch_size, 1))
delta = np.ones((self.batch_size, 1)) * 10
for epoch in range(self.epochs):
for _ in range(self.iterations_critic):
idx = np.random.randint(0, X.shape[0], self.batch_size)
x = X[idx]
z = np.random.normal(size=(self.batch_size, self.latent_dim, 1))
cx_loss = self.critic_x_model.train_on_batch([x, z], [valid, fake, delta])
cz_loss = self.critic_z_model.train_on_batch([x, z], [valid, fake, delta])
g_loss = self.encoder_generator_model.train_on_batch([x, z], [valid, valid, x])
if epoch % 100 == 0:
print('Epoch: {}, [Dx loss: {}] [Dz loss: {}] [G loss: {}]'.format(
epoch, cx_loss, cz_loss, g_loss))
def fit(self, X, **kwargs):
"""Fit the CycleGAN.
Args:
X (ndarray):
N-dimensional array containing the input training sequences for the model.
"""
self._build_cyclegan(**kwargs)
X = X.reshape((-1, self.shape[0], 1))
self._fit(X)
def predict(self, X):
"""Predict values using the initialized object.
Args:
X (ndarray):
N-dimensional array containing the input sequences for the model.
Returns:
ndarray:
N-dimensional array containing the reconstructions for each input sequence.
ndarray:
N-dimensional array containing the critic scores for each input sequence.
"""
X = X.reshape((-1, self.shape[0], 1))
z_ = self.encoder.predict(X)
y_hat = self.generator.predict(z_)
critic = self.critic_x.predict(X)
return y_hat, critic
def score_anomalies(y, y_hat, critic, score_window=10, smooth_window=200):
"""Compute an array of anomaly scores.
Anomaly scores are calculated using a combination of reconstruction error and critic score.
Args:
y (ndarray):
Ground truth.
y_hat (ndarray):
Predicted values. Each timestamp has multiple predictions.
critic (ndarray):
Critic score. Each timestamp has multiple critic scores.
score_window (int):
Optional. Size of the window over which the scores are calculated.
If not given, 10 is used.
smooth_window (int):
Optional. Size of window over which smoothing is applied.
If not given, 200 is used.
Returns:
ndarray:
Array of anomaly scores.
"""
true = [item[0] for item in y.reshape((y.shape[0], -1))]
for item in y[-1][1:]:
true.extend(item)
critic_extended = list()
for c in critic:
critic_extended = critic_extended + np.repeat(c, y_hat.shape[1]).tolist()
predictions = []
critic_kde_max = []
pred_length = y_hat.shape[1]
num_errors = y_hat.shape[1] + (y_hat.shape[0] - 1)
y_hat = np.asarray(y_hat)
critic_extended = np.asarray(critic_extended).reshape((-1, y_hat.shape[1]))
for i in range(num_errors):
intermediate = []
critic_intermediate = []
for j in range(max(0, i - num_errors + pred_length), min(i + 1, pred_length)):
intermediate.append(y_hat[i - j, j])
critic_intermediate.append(critic_extended[i - j, j])
if intermediate:
predictions.append(np.median(np.asarray(intermediate)))
if len(critic_intermediate) > 1:
discr_intermediate = np.asarray(critic_intermediate)
try:
critic_kde_max.append(discr_intermediate[np.argmax(
stats.gaussian_kde(discr_intermediate)(critic_intermediate))])
except np.linalg.LinAlgError:
critic_kde_max.append(np.median(discr_intermediate))
else:
critic_kde_max.append(np.median(np.asarray(critic_intermediate)))
predictions = np.asarray(predictions)
score_window_min = score_window // 2
pd_true = pd.Series(np.asarray(true).flatten())
pd_pred = pd.Series(np.asarray(predictions).flatten())
score_measure_true = pd_true.rolling(score_window, center=True, min_periods=score_window_min)\
.apply(integrate.trapz)
score_measure_pred = pd_pred.rolling(score_window, center=True, min_periods=score_window_min)\
.apply(integrate.trapz)
scores = abs(score_measure_true - score_measure_pred)
scores_smoothed = pd.Series(scores).rolling(smooth_window, center=True,
min_periods=smooth_window // 2,
win_type='triang').mean().values
z_score_scores = stats.zscore(scores_smoothed)
z_score_scores_clip = np.clip(z_score_scores, a_min=0, a_max=None) + 1
critic_kde_max = np.asarray(critic_kde_max)
l_quantile = np.quantile(critic_kde_max, 0.25)
u_quantile = np.quantile(critic_kde_max, 0.75)
in_range = np.logical_and(critic_kde_max >= l_quantile, critic_kde_max <= u_quantile)
critic_mean = np.mean(critic_kde_max[in_range])
critic_std = np.std(critic_kde_max)
z_score_critic = np.absolute((np.asarray(critic_kde_max) - critic_mean) / critic_std) + 1
z_score_critic = pd.Series(z_score_critic).rolling(
100, center=True, min_periods=50).mean().values
return np.multiply(z_score_scores_clip, z_score_critic)