forked from keras-team/keras
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtorch_custom_fit_test.py
52 lines (42 loc) · 1.58 KB
/
torch_custom_fit_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import numpy as np
import torch
import keras
def test_custom_fit():
class CustomModel(keras.Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.loss_tracker = keras.metrics.Mean(name="loss")
self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae")
self.loss_fn = keras.losses.MeanSquaredError()
def train_step(self, data):
x, y = data
self.zero_grad()
y_pred = self(x, training=True)
loss = self.loss_fn(y, y_pred)
loss.backward()
trainable_weights = [v for v in self.trainable_weights]
gradients = [v.value.grad for v in trainable_weights]
with torch.no_grad():
self.optimizer.apply(gradients, trainable_weights)
self.loss_tracker.update_state(loss)
self.mae_metric.update_state(y, y_pred)
return {
"loss": self.loss_tracker.result(),
"mae": self.mae_metric.result(),
}
@property
def metrics(self):
return [self.loss_tracker, self.mae_metric]
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam")
x = np.random.random((64, 32))
y = np.random.random((64, 1))
history = model.fit(x, y, epochs=1)
assert "loss" in history.history
assert "mae" in history.history
print("History:")
print(history.history)
if __name__ == "__main__":
test_custom_fit()