# Phase-1 code starter template
### The below code is for your reference; please feel free to change it partially or fully.
### Please make sure it does not have any bugs or mistakes. Code authors DO NOT claim the code is bug-free. It is the student's responsibility to ensure its correctness.
## In all cases you must use a base model which consist of:
- 1 convulosion layer with 16 channels, 3x3 kernel, and a relu activation.
- Fully connected layer with 2 neurons and a relu activation.
- Fully connected layer with num_classes neurons and a softmax activation.

In [1]:
# --- Imports ---
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import fashion_mnist, cifar10
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import EarlyStopping
import time
import numpy as np
import os
from tqdm import tqdm

print(f'your tensorflow version is {tf.__version__}. It is advised to use tensorflow 2.15.0 to avoid any errors.')
assert tf.__version__=='2.15.0', 'WARNING!!! different TensorFlow version may produce an error while quantizing.\nTo proceed, comment this line.'


# --- Device Detection ---
gpus = tf.config.list_physical_devices('GPU')
device = '/GPU:0' if gpus else '/CPU:0'
dev_name = 'GPU' if gpus else 'CPU'
print(f"Using device: {dev_name}")

# --- Functions ---

def create_base_model(input_shape, num_classes):
    model = models.Sequential([
        layers.Conv2D(16, (3, 3), activation='relu', input_shape=input_shape),
        layers.MaxPooling2D((2, 2)),
        layers.Flatten(),
        layers.Dense(2, activation='relu'),
        layers.Dense(num_classes, activation='softmax')
    ])
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    return model

def prepare_dataset(dataset_name):
    if dataset_name == 'fashion_mnist':
        (x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
        x_train = x_train[..., np.newaxis]
        x_test = x_test[..., np.newaxis]
        input_shape = (28, 28, 1)
    elif dataset_name == 'cifar10':
        (x_train, y_train), (x_test, y_test) = cifar10.load_data()
        input_shape = (32, 32, 3)
    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")

    x_train = x_train.astype('float32') / 255.0
    x_test = x_test.astype('float32') / 255.0
    y_train = to_categorical(y_train, 10)
    y_test = to_categorical(y_test, 10)

    return (x_train, y_train), (x_test, y_test), input_shape, 10

def evaluate_model(model, x_test, y_test):
    start = time.time()
    loss, acc = model.evaluate(x_test, y_test, verbose=0)
    elapsed = time.time() - start
    return acc, elapsed

def profile_workload(model, image, iterations=30):
    print(f"Profiling {dev_name}...")
    latencies = []

    for _ in tqdm(range(10), desc="Warm-up"):
        _ = model(image, training=False)

    for _ in tqdm(range(iterations), desc="Profiling"):
        start = time.time()
        _ = model(image, training=False)
        latencies.append((time.time() - start) * 1000)

    return np.mean(latencies)

your tensorflow version is 2.15.0. It is advised to use tensorflow 2.15.0 to avoid any errors.
Using device: GPU


# Phase-3 code starter template
### The below code is for your reference; please feel free to change it partially or fully.
### Please make sure it does not have any bugs or mistakes. Code authors DO NOT claim the code is bug-free. It is the student's responsibility to ensure its correctness.

In [2]:
def profile_tflite_model(interpreter, input_tensor, iterations=30):
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    latencies = []

    for _ in range(10):
        interpreter.set_tensor(input_details[0]['index'], input_tensor)
        interpreter.invoke()

    for _ in range(iterations):
        start = time.time()
        interpreter.set_tensor(input_details[0]['index'], input_tensor)
        interpreter.invoke()
        latencies.append((time.time() - start) * 1000)

    return np.mean(latencies)

def quantize_model_to_int8(model, rep_data_gen, save_path):
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.representative_dataset = rep_data_gen
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
    converter.inference_input_type = tf.int8
    converter.inference_output_type = tf.int8
    tflite_model = converter.convert()
    with open(save_path, 'wb') as f:
        f.write(tflite_model)
    print(f"Saved INT8 model at {save_path}")

def quantize_model_to_fp16(model, save_path):
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.target_spec.supported_types = [tf.float16]
    tflite_model = converter.convert()
    with open(save_path, 'wb') as f:
        f.write(tflite_model)
    print(f"Saved FP16 model at {save_path}")

def evaluate_tflite_accuracy(interpreter, x_test, y_test, quantized=False):
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    correct = 0
    total = x_test.shape[0]

    for i in range(total):
        input_data = x_test[i:i+1]
        if quantized:
            input_data = np.round(input_data * 255).astype(np.int8)
        else:
            input_data = input_data.astype(np.float32)
        interpreter.set_tensor(input_details[0]['index'], input_data)
        interpreter.invoke()
        output = interpreter.get_tensor(output_details[0]['index'])
        if np.argmax(output) == np.argmax(y_test[i]):
            correct += 1

    return correct / total

def get_file_size(file_path):
    return os.path.getsize(file_path) / 1024  # in KB

# --- Main Loop ---

datasets = ['fashion_mnist', 'cifar10']
EPOCHS = 10
USE_PRETRAINED_MODELS = False # use the model you already trained in previous runs if set to True

for dataset in datasets:
    print(f"\nProcessing {dataset}...")

    base_path = f"{dataset}_base_model.h5"
    int8_base_path = f"{dataset}_base_int8.tflite"
    fp16_base_path = f"{dataset}_base_fp16.tflite"

    (x_train, y_train), (x_test, y_test), input_shape, num_classes = prepare_dataset(dataset)

    if os.path.exists(base_path) and USE_PRETRAINED_MODELS:
        model_base = models.load_model(base_path)
    else:
        model_base = create_base_model(input_shape, num_classes)
        model_base.fit(x_train, y_train, epochs=EPOCHS, batch_size=64, validation_split=0.2,
                       callbacks=[EarlyStopping(monitor='val_loss', patience=2)], verbose=1)
        model_base.save(base_path)

    num_params_base = model_base.count_params()

    acc_base_fp32, time_base_fp32 = evaluate_model(model_base, x_test, y_test)

    test_image = tf.convert_to_tensor(x_test[:1], dtype=tf.float32)
    latency_base_fp32 = profile_workload(model_base, test_image)

    def representative_data_gen():
        for input_value in tf.data.Dataset.from_tensor_slices(x_test).batch(1).take(100):
            yield [tf.cast(input_value, tf.float32)]

    quantize_model_to_int8(model_base, representative_data_gen, int8_base_path)
    quantize_model_to_fp16(model_base, fp16_base_path)

    int8_base_size = get_file_size(int8_base_path)
    fp16_base_size = get_file_size(fp16_base_path)

    interpreter = tf.lite.Interpreter(model_path=int8_base_path)
    interpreter.allocate_tensors()
    acc_base_int8 = evaluate_tflite_accuracy(interpreter, x_test, y_test, quantized=True)
    latency_base_int8 = profile_tflite_model(interpreter, np.round(x_test[:1] * 255).astype(np.int8))

    interpreter = tf.lite.Interpreter(model_path=fp16_base_path)
    interpreter.allocate_tensors()
    acc_base_fp16 = evaluate_tflite_accuracy(interpreter, x_test, y_test, quantized=False)
    latency_base_fp16 = profile_tflite_model(interpreter, x_test[:1].astype(np.float32))

    print("\nSummary:")
    print(f"{'Metric':<30} {'Base Model':<20} {'Enhanced Model'}")
    print(f"{'-'*80}")
    print(f"{'Parameters':<30} {num_params_base:<20} {'deleted'}")
    print(f"{'Accuracy FP32 (%)':<30} {acc_base_fp32*100:.2f}%{'':<12} {'deleted'}%")
    print(f"{'Accuracy FP16 (%)':<30} {acc_base_fp16*100:.2f}%{'':<12} {'deleted'}%")
    print(f"{'Accuracy INT8 (%)':<30} {acc_base_int8*100:.2f}%{'':<12} {'deleted'}%")
    print(f"{'Latency FP32 (ms)':<30} {latency_base_fp32:.2f}{'':<14} {'deleted'}")
    print(f"{'Latency FP16 (ms)':<30} {latency_base_fp16:.2f}{'':<14} {'deleted'}")
    print(f"{'Latency INT8 (ms)':<30} {latency_base_int8:.2f}{'':<14} {'deleted'}")
    print(f"{'Size FP32 (KB)':<30} {get_file_size(base_path):.2f}{'':<14} {'deleted'}")
    print(f"{'Size FP16 (KB)':<30} {fp16_base_size:.2f}{'':<14} {'deleted'}")
    print(f"{'Size INT8 (KB)':<30} {int8_base_size:.2f}{'':<14} {'deleted'}")


Processing fashion_mnist...
Epoch 1/10


2025-04-29 02:26:30.763792: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M2 Pro
2025-04-29 02:26:30.763815: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 16.00 GB
2025-04-29 02:26:30.763820: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 5.33 GB
2025-04-29 02:26:30.763846: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2025-04-29 02:26:30.763859: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


  8/750 [..............................] - ETA: 6s - loss: 2.2923 - accuracy: 0.1133  

2025-04-29 02:26:31.066877: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.


Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


  saving_api.save_model(


Profiling GPU...


Warm-up: 100%|██████████| 10/10 [00:00<00:00, 254.08it/s]
Profiling: 100%|██████████| 30/30 [00:00<00:00, 1059.77it/s]


INFO:tensorflow:Assets written to: /var/folders/7f/r2ps86p562s640d2zx6tvwm40000gq/T/tmprwfcgil1/assets


INFO:tensorflow:Assets written to: /var/folders/7f/r2ps86p562s640d2zx6tvwm40000gq/T/tmprwfcgil1/assets
2025-04-29 02:27:36.095500: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format.
2025-04-29 02:27:36.095518: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency.
2025-04-29 02:27:36.095736: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /var/folders/7f/r2ps86p562s640d2zx6tvwm40000gq/T/tmprwfcgil1
2025-04-29 02:27:36.096327: I tensorflow/cc/saved_model/reader.cc:51] Reading meta graph with tags { serve }
2025-04-29 02:27:36.096332: I tensorflow/cc/saved_model/reader.cc:146] Reading SavedModel debug info (if present) from: /var/folders/7f/r2ps86p562s640d2zx6tvwm40000gq/T/tmprwfcgil1
2025-04-29 02:27:36.097489: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:388] MLIR V1 optimization pass is not enabled
2025-04-29 02:27:36.098121: I tensorflow/cc/saved_model/load

Saved INT8 model at fashion_mnist_base_int8.tflite
INFO:tensorflow:Assets written to: /var/folders/7f/r2ps86p562s640d2zx6tvwm40000gq/T/tmpznlw05_c/assets


INFO:tensorflow:Assets written to: /var/folders/7f/r2ps86p562s640d2zx6tvwm40000gq/T/tmpznlw05_c/assets
2025-04-29 02:27:36.518261: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format.
2025-04-29 02:27:36.518271: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency.
2025-04-29 02:27:36.518377: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /var/folders/7f/r2ps86p562s640d2zx6tvwm40000gq/T/tmpznlw05_c
2025-04-29 02:27:36.518982: I tensorflow/cc/saved_model/reader.cc:51] Reading meta graph with tags { serve }
2025-04-29 02:27:36.518987: I tensorflow/cc/saved_model/reader.cc:146] Reading SavedModel debug info (if present) from: /var/folders/7f/r2ps86p562s640d2zx6tvwm40000gq/T/tmpznlw05_c
2025-04-29 02:27:36.520575: I tensorflow/cc/saved_model/loader.cc:233] Restoring SavedModel bundle.
2025-04-29 02:27:36.545729: I tensorflow/cc/saved_model/loader.cc:217] Running initialization

Saved FP16 model at fashion_mnist_base_fp16.tflite

Summary:
Metric                         Base Model           Enhanced Model
--------------------------------------------------------------------------------
Parameters                     5600                 deleted
Accuracy FP32 (%)              65.06%             deleted%
Accuracy FP16 (%)              21.22%             deleted%
Accuracy INT8 (%)              13.28%             deleted%
Latency FP32 (ms)              0.94               deleted
Latency FP16 (ms)              0.01               deleted
Latency INT8 (ms)              0.01               deleted
Size FP32 (KB)                 100.20               deleted
Size FP16 (KB)                 14.45               deleted
Size INT8 (KB)                 9.02               deleted

Processing cifar10...
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


  saving_api.save_model(


Profiling GPU...


Warm-up: 100%|██████████| 10/10 [00:00<00:00, 498.01it/s]
Profiling: 100%|██████████| 30/30 [00:00<00:00, 1044.42it/s]


INFO:tensorflow:Assets written to: /var/folders/7f/r2ps86p562s640d2zx6tvwm40000gq/T/tmpswwv3wbd/assets


INFO:tensorflow:Assets written to: /var/folders/7f/r2ps86p562s640d2zx6tvwm40000gq/T/tmpswwv3wbd/assets
2025-04-29 02:28:37.088910: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format.
2025-04-29 02:28:37.088922: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency.
2025-04-29 02:28:37.089111: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /var/folders/7f/r2ps86p562s640d2zx6tvwm40000gq/T/tmpswwv3wbd
2025-04-29 02:28:37.089716: I tensorflow/cc/saved_model/reader.cc:51] Reading meta graph with tags { serve }
2025-04-29 02:28:37.089721: I tensorflow/cc/saved_model/reader.cc:146] Reading SavedModel debug info (if present) from: /var/folders/7f/r2ps86p562s640d2zx6tvwm40000gq/T/tmpswwv3wbd
2025-04-29 02:28:37.091419: I tensorflow/cc/saved_model/loader.cc:233] Restoring SavedModel bundle.
2025-04-29 02:28:37.116019: I tensorflow/cc/saved_model/loader.cc:217] Running initialization

Saved INT8 model at cifar10_base_int8.tflite
INFO:tensorflow:Assets written to: /var/folders/7f/r2ps86p562s640d2zx6tvwm40000gq/T/tmpv_qb4yow/assets


INFO:tensorflow:Assets written to: /var/folders/7f/r2ps86p562s640d2zx6tvwm40000gq/T/tmpv_qb4yow/assets


Saved FP16 model at cifar10_base_fp16.tflite


2025-04-29 02:28:37.754970: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format.
2025-04-29 02:28:37.754981: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency.
2025-04-29 02:28:37.755099: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /var/folders/7f/r2ps86p562s640d2zx6tvwm40000gq/T/tmpv_qb4yow
2025-04-29 02:28:37.755667: I tensorflow/cc/saved_model/reader.cc:51] Reading meta graph with tags { serve }
2025-04-29 02:28:37.755672: I tensorflow/cc/saved_model/reader.cc:146] Reading SavedModel debug info (if present) from: /var/folders/7f/r2ps86p562s640d2zx6tvwm40000gq/T/tmpv_qb4yow
2025-04-29 02:28:37.757296: I tensorflow/cc/saved_model/loader.cc:233] Restoring SavedModel bundle.
2025-04-29 02:28:37.782877: I tensorflow/cc/saved_model/loader.cc:217] Running initialization op on SavedModel bundle at path: /var/folders/7f/r2ps86p562s640d2zx6tvwm40000gq/T/tmpv_qb4yow
2025-04-


Summary:
Metric                         Base Model           Enhanced Model
--------------------------------------------------------------------------------
Parameters                     7680                 deleted
Accuracy FP32 (%)              15.98%             deleted%
Accuracy FP16 (%)              10.00%             deleted%
Accuracy INT8 (%)              10.00%             deleted%
Latency FP32 (ms)              0.95               deleted
Latency FP16 (ms)              0.02               deleted
Latency INT8 (ms)              0.04               deleted
Size FP32 (KB)                 123.20               deleted
Size FP16 (KB)                 18.60               deleted
Size INT8 (KB)                 11.10               deleted
