Skip to content
This repository has been archived by the owner on May 6, 2023. It is now read-only.

Commit

Permalink
refactor: declare epoch_train as static method
Browse files Browse the repository at this point in the history
  • Loading branch information
AFAgarap committed May 28, 2020
1 parent ce3d970 commit 8cd86ad
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions cnn_svm/models/cnn.py
Expand Up @@ -60,19 +60,19 @@ def call(self, features):
def fit(self, data_loader, epochs):
train_loss = []
for epoch in range(epochs):
epoch_loss = epoch_train(self, data_loader)
epoch_loss = self.epoch_train(self, data_loader)
train_loss.append(epoch_loss)
print(f"epoch {epoch + 1}/{epochs} : mean loss = {train_loss[-1]:.6f}")


def epoch_train(model, data_loader):
epoch_loss = 0
for batch_features, batch_labels in data_loader:
with tf.GradientTape() as tape:
outputs = model(batch_features)
train_loss = model.loss_fn(batch_labels, outputs)
gradients = tape.gradient(train_loss, model.trainable_variables)
model.optimizer.apply_gradients(zip(gradients, model.trainable_variables))
epoch_loss += train_loss
epoch_loss = tf.reduce_mean(epoch_loss)
return epoch_loss
@staticmethod
def epoch_train(model, data_loader):
epoch_loss = 0
for batch_features, batch_labels in data_loader:
with tf.GradientTape() as tape:
outputs = model(batch_features)
train_loss = model.loss_fn(batch_labels, outputs)
gradients = tape.gradient(train_loss, model.trainable_variables)
model.optimizer.apply_gradients(zip(gradients, model.trainable_variables))
epoch_loss += train_loss
epoch_loss = tf.reduce_mean(epoch_loss)
return epoch_loss

0 comments on commit 8cd86ad

Please sign in to comment.