https://www.tensorflow.org/datasets/keras_example

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from tensorflow.keras.callbacks import ModelCheckpoint

# Get train and test data
(train_X,train_Y),(test_X,test_Y) = tf.keras.datasets.mnist.load_data()
train_X = train_X.reshape(-1, 28 * 28) / 255.0
test_X  = test_X.reshape(-1, 28 * 28) / 255.0

# Form model
model = tf.keras.models.Sequential([
  tf.keras.layers.Dense(128,input_shape=(784,),activation='relu'),
  tf.keras.layers.Dense(10)
])

# Compile model
model.compile(
    optimizer=tf.keras.optimizers.SGD(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

# Callback to chheckpoint models
checkpoint_callback = ModelCheckpoint(
    filepath="model/mnist/model{epoch}.hdf5",
    save_weights_only=True,
    save_best_only=False)

model.fit(
    train_X, 
    train_Y,
    epochs=20,
    callbacks = checkpoint_callback
)

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from tensorflow.keras.callbacks import ModelCheckpoint

# Get train and test data
(train_X,train_Y),(test_X,test_Y) = tf.keras.datasets.mnist.load_data()
train_X = train_X.reshape(-1, 28 * 28) / 255.0
test_X  = test_X.reshape(-1, 28 * 28) / 255.0

# Form model
model = tf.keras.models.Sequential([
  tf.keras.layers.Dense(128,input_shape=(784,),activation='relu'),
  tf.keras.layers.Dense(10)
])

In [3]:
# Load checkpoint
epoch_iter = 1
model.load_weights(f"model/mnist/model{epoch_iter}.hdf5")

# Try to compute gradient of loss
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

images = train_X[0:1]
labels = train_Y[0:1]
with tf.GradientTape() as tape:
    logits   = model(images)
    loss_val = loss_fn(labels, logits)

grads = tape.gradient(loss_val, model.trainable_variables)

In [5]:
[grad.shape for grad in grads]

[TensorShape([784, 128]),
 TensorShape([128]),
 TensorShape([128, 10]),
 TensorShape([10])]