In [None]:
import numpy as np
import tensorflow as tf

In [2]:
fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
train_images = (train_images / 255.0).astype(np.float32)
test_images = (test_images / 255.0).astype(np.float32)
train_labels = tf.keras.utils.to_categorical(train_labels)
test_labels = tf.keras.utils.to_categorical(test_labels)

In [5]:
IMG_SIZE = 28

class Model(tf.Module):

  def __init__(self):
    self.model = tf.keras.Sequential([
        tf.keras.layers.Flatten(input_shape=(IMG_SIZE, IMG_SIZE), name='flatten'),
        tf.keras.layers.Dense(128, activation='relu', name='dense_1'),
        tf.keras.layers.Dense(10, name='dense_2')
    ])

    self.model.compile(
        optimizer='adam',
        loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True))

  # The `train` function takes a batch of input images and labels.
  @tf.function(input_signature=[
      tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32),
      tf.TensorSpec([None, 10], tf.float32),
  ])
  def train(self, x, y):
    with tf.GradientTape() as tape:
      prediction = self.model(x)
      loss = self.model.loss(y, prediction)
    gradients = tape.gradient(loss, self.model.trainable_variables)
    self.model.optimizer.apply_gradients(
        zip(gradients, self.model.trainable_variables))
    result = {"loss": loss}
    return result

  @tf.function(input_signature=[
      tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32),
  ])
  def infer(self, x):
    logits = self.model(x)
    probabilities = tf.nn.softmax(logits, axis=-1)
    return {
        "output": probabilities,
        "logits": logits
    }
  
  
  @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
  def save(self, checkpoint_path):
    tensor_names = [weight.name for weight in self.model.weights]
    tensors_to_save = [weight.read_value() for weight in self.model.weights]
    tf.raw_ops.Save(
        filename=checkpoint_path, tensor_names=tensor_names,
        data=tensors_to_save, name='save')
    return {
        "checkpoint_path": checkpoint_path
    }

  @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
  def restore(self, checkpoint_path):
    restored_tensors = {}
    for var in self.model.weights:
      restored = tf.raw_ops.Restore(
          file_pattern=checkpoint_path, tensor_name=var.name, dt=var.dtype,
          name='restore')
      var.assign(restored)
      restored_tensors[var.name] = restored
    return restored_tensors


In [7]:
def evaluate_interpreter (infer_fn):
    correct_predictions = 0
    true_labels = np.argmax(test_labels, axis=1)
    for i in range(len(test_images)):
        input_data = test_images[i].reshape(1, 28, 28) 
        output_data=infer_fn(x=input_data)
        predicted_label = np.argmax(output_data["output"])
        if predicted_label == true_labels[i]:
            correct_predictions += 1
    accuracy = correct_predictions / len(test_images)
    return accuracy

In [9]:
def evaluate_model (model):
    result = model.infer(test_images)
    predictions = result["output"].numpy() 
    predicted_labels = np.argmax(predictions, axis=1)
    true_labels = np.argmax(test_labels, axis=1)
    accuracy = np.mean(predicted_labels == true_labels)
    return accuracy

To restore tflite model fully:
1. load tflite into interpreter (if tflite file is made from keras model)
2. load tflite + checkpoint into interpreter (checkpoint cannot be from an interpreter)

In [8]:
#load interpreter from tflite
interpreter = tf.lite.Interpreter(model_path="models/orig_model_epochs_30_batch_100.tflite")
interpreter.allocate_tensors()
initial_accuracy=evaluate_interpreter(interpreter.get_signature_runner("infer"))
initial_accuracy

0.8792

In [None]:
#how to save and load from checkpoint (optional)
# model.save('tmp/model.ckpt')
# restore = interpreter.get_signature_runner("restore")
# restore(checkpoint_path=np.array("tmp/model.ckpt", dtype=np.string_))

In [11]:
#create new model with random weights
model = Model()
print(evaluate_model(model))

0.1513


In [None]:
#assign loaded interpreter tensors with model
tensor_details = interpreter.get_tensor_details()
trainable_indices = [idx for idx,d in enumerate(tensor_details) if "variable" in d["name"].lower()]
weights = [interpreter.get_tensor(i) for i in trainable_indices]

model.model.layers[1].set_weights([weights[1], weights[0]]) 
model.model.layers[2].set_weights([weights[3], weights[2]])  
print(evaluate_model(model))

0.8772


In [13]:
#checkout weight dimensions
for idx, w in enumerate(weights):
    print(f"TFLite weight {idx} shape: {w.shape}")
for layer in [model.model.layers[1], model.model.layers[2]]:
    layer_w = layer.get_weights()
    print(f"Layer {layer.name} weights shapes: {[w.shape for w in layer_w]}")

TFLite weight 0 shape: (128,)
TFLite weight 1 shape: (784, 128)
TFLite weight 2 shape: (10,)
TFLite weight 3 shape: (128, 10)
Layer dense_1 weights shapes: [(784, 128), (128,)]
Layer dense_2 weights shapes: [(128, 10), (10,)]


In [14]:
NUM_EPOCHS = 50
BATCH_SIZE = 100
epochs = np.arange(1, NUM_EPOCHS + 1, 1)
losses = np.zeros([NUM_EPOCHS])

train_ds = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
train_ds = train_ds.batch(BATCH_SIZE)

for i in range(NUM_EPOCHS):
  for x,y in train_ds:
    result = model.train(x, y)

  losses[i] = result['loss']
  if (i + 1) % 10 == 0:
    print(f"Finished {i+1} epochs")
    print(f"  loss: {losses[i]:.3f}")
print(evaluate_model(model))

Finished 10 epochs
  loss: 0.127
Finished 20 epochs
  loss: 0.075
Finished 30 epochs
  loss: 0.089
Finished 40 epochs
  loss: 0.076
Finished 50 epochs
  loss: 0.054
0.8752


In [15]:
SAVED_MODEL_DIR = "saved_model"

tf.saved_model.save(
    model,
    SAVED_MODEL_DIR,
    signatures={
        'train':
            model.train.get_concrete_function(),
        'infer':
            model.infer.get_concrete_function(),
        'save':
            model.save.get_concrete_function(),
        'restore':
            model.restore.get_concrete_function(),
    })

# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_DIR)
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TensorFlow Lite ops.
    tf.lite.OpsSet.SELECT_TF_OPS  # enable TensorFlow ops.
]
converter.experimental_enable_resource_variables = True
tflite_model = converter.convert()

INFO:tensorflow:Assets written to: saved_model\assets




In [16]:
tflite_model_path = f"models/finetuned_model_epochs_{NUM_EPOCHS}_batch_{BATCH_SIZE}.tflite"
with open(tflite_model_path, "wb") as f:
    f.write(tflite_model)
print(f"✅ Fine-tuned TFLite model saved as: {tflite_model_path}")

✅ Fine-tuned TFLite model saved as: models/finetuned_model_epochs_50_batch_100.tflite
