From: https://keras.io/guides/writing_a_custom_training_loop_in_torch/

In [1]:
import os

import torch
import keras
import numpy as np

In [2]:
os.environ["KERAS_BACKEND"]

'torch'

In [11]:
# Set default devcice
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# Let's consider a simple MNIST model
def get_model():
    inputs = keras.Input(shape=(784,), name="digits")
    x1 = keras.layers.Dense(64, activation="relu")(inputs)
    x2 = keras.layers.Dense(64, activation="relu")(x1)
    outputs = keras.layers.Dense(10, name="predictions")(x2)
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model

In [4]:
# Create load up the MNIST dataset and put it in a torch DataLoader
# Prepare the training dataset.
batch_size = 32
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784)).astype("float32")
x_test = np.reshape(x_test, (-1, 784)).astype("float32")
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 0us/step


In [5]:
# Reserve 10,000 samples for validation.
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]

In [27]:
# Create torch Datasets
train_dataset = torch.utils.data.TensorDataset(
    torch.from_numpy(x_train).float().to(device),
    torch.from_numpy(y_train).float().to(device)
)
val_dataset = torch.utils.data.TensorDataset(
    torch.from_numpy(x_val).float().to(device),
    torch.from_numpy(y_val).float().to(device)
)

In [28]:
# Create DataLoaders for the Datasets
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True
)
val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False
)

In [29]:
# Instantiate a torch optimizer
model = get_model()
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Instantiate a torch loss function
loss_fn = torch.nn.CrossEntropyLoss()

In [35]:
epochs = 3
for epoch in range(epochs):
    for step, (inputs, targets) in enumerate(train_dataloader):
        
        # Move inputs and targets to the same device as the model
        inputs, targets = inputs.to(device), targets.to(device)
        
        # Forward pass
        logits = model(inputs)
        loss = loss_fn(logits, targets)

        # Backward pass
        model.zero_grad()
        loss.backward()

        # Optimizer variable updates
        optimizer.step()

        # Log every 100 batches.
        if step % 100 == 0:
            print(
                f"Training loss (for 1 batch) at step {step}: {loss.detach().cpu().numpy():.4f}"
            )
            print(f"Seen so far: {(step + 1) * batch_size} samples")

Training loss (for 1 batch) at step 0: 0.0440
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 0.2010
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 0.1639
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 0.5320
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.4841
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.0258
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.0401
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.1123
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.2103
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.0310
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.1509
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.1254
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.3798
Seen so far: 38432 samples
Training loss (for 1 batch) at

In [37]:
# Keras version
model = get_model()
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)

for epoch in range(epochs):
    print(f"\nStart of epoch {epoch}")
    for step, (inputs, targets) in enumerate(train_dataloader):
        # Forward pass
        logits = model(inputs)
        loss = loss_fn(targets, logits)

        # Backward pass
        model.zero_grad()
        trainable_weights = [v for v in model.trainable_weights]

        # Call torch.Tensor.backward() on the loss to compute gradients
        # for the weights.
        loss.backward()
        gradients = [v.value.grad for v in trainable_weights]

        # Update weights
        with torch.no_grad():
            optimizer.apply(gradients, trainable_weights)

        # Log every 100 batches.
        if step % 100 == 0:
            print(
                f"Training loss (for 1 batch) at step {step}: {loss.detach().cpu().numpy():.4f}"
            )
            print(f"Seen so far: {(step + 1) * batch_size} samples")


Start of epoch 0
Training loss (for 1 batch) at step 0: 226.0276
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 3.1136
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 4.1769
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 1.9400
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 1.6995
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.5839
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 1.0641
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 1.4072
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.2439
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 1.5329
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.2775
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.2846
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.3568
Seen so far: 38432 samples
Training l