"""
MNIST Optimization Lab (TensorFlow + TF-MOT + TFLite)
====================================================


A reproducible experimental harness to compare:
- Baselines: ConvNet & Tiny DS-CNN
- Pruning (magnitude, polynomial schedule)
- Weight sharing (clustering)
- Knowledge distillation (teacher: ConvNet, student: Tiny DS-CNN)
- Quantization: PTQ (dynamic & full-int8) and QAT


It logs accuracy, size, params, FLOPs, sparsity, and latency (TFLite) for each run.
It also generates tables, plots, and a Markdown report.


REQUIREMENTS
------------
- tensorflow >= 2.10
- tensorflow-model-optimization >= 0.7
- numpy, pandas, matplotlib, scikit-learn

In [1]:
#imports
import os
import io
import re
import gc
import time
import math
import json
import random
import argparse
import statistics as stats
from pathlib import Path


import numpy as np

os.environ["TF_USE_LEGACY_KERAS"] = "1"
import tensorflow as tf
assert tf.__version__ >= '2.10', f"TensorFlow >=2.10 required, found {tf.__version__}"


try:
    import tensorflow_model_optimization as tfmot
except Exception as e:
    raise RuntimeError("tensorflow-model-optimization (tfmot) is required. pip install tensorflow-model-optimization") from e


import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
from adjustText import adjust_text




In [2]:
# ------------------------------------------------------------
# Global config
# ------------------------------------------------------------
OUTDIR = Path("runs_mnist_opt")
(OUTDIR / "models").mkdir(parents=True, exist_ok=True)
(OUTDIR / "tflite").mkdir(parents=True, exist_ok=True)
(OUTDIR / "figures").mkdir(parents=True, exist_ok=True)


DEFAULT_EPOCHS = 6
DEFAULT_BATCH = 128
DEFAULT_SEEDS = 1
WARMUP_RUNS = 5
TIMED_RUNS = 100

In [None]:
# ------------------------------------------------------------
# Helpers
# ------------------------------------------------------------

def set_global_seed(seed: int):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    try:
        tf.config.experimental.enable_op_determinism()
    except Exception:
        pass

def ensure_compiled(m):
    if not getattr(m, "_is_compiled", False):
        m.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return m

def tflite_accuracy(tflite_bytes, x_test, y_test):
    interpreter = tf.lite.Interpreter(model_content=tflite_bytes)
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()[0]
    output_details = interpreter.get_output_details()[0]

    correct = 0
    total = x_test.shape[0]

    for i in range(total):
        sample = x_test[i:i+1]  # (1, 28, 28, 1)
        # set input
        idx = input_details['index']
        if input_details['dtype'] == np.float32:
            interpreter.set_tensor(idx, sample.astype(np.float32))
        elif input_details['dtype'] == np.int8:
            scale, zero_point = input_details['quantization']
            q = (sample / scale + zero_point).astype(np.int8)
            interpreter.set_tensor(idx, q)
        elif input_details['dtype'] == np.uint8:
            scale, zero_point = input_details['quantization']
            q = (sample / scale + zero_point).astype(np.uint8)
            interpreter.set_tensor(idx, q)
        else:
            raise RuntimeError(f"Unsupported input dtype: {input_details['dtype']}")

        interpreter.invoke()
        logits = interpreter.get_tensor(output_details['index'])
        pred = int(np.argmax(logits, axis=1)[0])
        correct += (pred == int(y_test[i]))

    return correct / total

In [None]:
# ------------------------------------------------------------
# Data: MNIST tf.data pipelines
# ------------------------------------------------------------

def load_mnist(batch_size=DEFAULT_BATCH):
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
    x_train = x_train.astype(np.float32) / 255.0
    x_test = x_test.astype(np.float32) / 255.0


    # train/val split
    val_count = 5000
    x_val, y_val = x_train[-val_count:], y_train[-val_count:]
    x_train, y_train = x_train[:-val_count], y_train[:-val_count]


    def _prep(x, y):
        x = np.expand_dims(x, -1) # (H,W,1)
        return x, y.astype(np.int64)


    x_train, y_train = _prep(x_train, y_train)
    x_val, y_val = _prep(x_val, y_val)
    x_test, y_test = _prep(x_test, y_test)


    ds_train = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    ds_test = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size).prefetch(tf.data.AUTOTUNE)


    return ds_train, ds_val, ds_test, (x_train, y_train, x_val, y_val, x_test, y_test)

In [None]:
# ------------------------------------------------------------
# Model builders
# ------------------------------------------------------------

# Convolutional model
def build_conv_model(dropout=0.5):
    m = tf.keras.Sequential([
        tf.keras.layers.InputLayer(input_shape=(28,28,1)),
        tf.keras.layers.Conv2D(32, 3, activation='relu'),
        tf.keras.layers.MaxPooling2D(2),
        tf.keras.layers.Conv2D(64, 3, activation='relu'),
        tf.keras.layers.MaxPooling2D(2),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dropout(dropout),
        tf.keras.layers.Dense(10, activation='softmax'),
    ], name='convnet')
    return m


# DSCNN model
def build_tiny_dscnn():
    inputs = tf.keras.Input(shape=(28, 28, 1))
    x = tf.keras.layers.Conv2D(12, 3, strides=2, padding='same', use_bias=False)(inputs) # 14x14
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    # block 1
    x = tf.keras.layers.DepthwiseConv2D(3, padding='same', use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    x = tf.keras.layers.Conv2D(16, 1, padding='same', use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    # block 2
    x = tf.keras.layers.DepthwiseConv2D(3, strides=2, padding='same', use_bias=False)(x) # 7x7
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    x = tf.keras.layers.Conv2D(24, 1, padding='same', use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    outputs = tf.keras.layers.Dense(10, activation='softmax')(x)
    return tf.keras.Model(inputs, outputs, name='tiny_dscnn')

In [None]:
# ------------------------------------------------------------
# Training, evaluation, and metrics
# ------------------------------------------------------------


def compile_model(model):
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model

def train_model(model, ds_train, ds_val, epochs=DEFAULT_EPOCHS, seed=0, outprefix="exp"):
    callbacks = [
    tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=3, restore_best_weights=True),
    tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, min_lr=1e-5),
    tf.keras.callbacks.CSVLogger(str(OUTDIR / f"{outprefix}_history.csv"), append=False),
    ]
    history = model.fit(ds_train, validation_data=ds_val, epochs=epochs, verbose=2, callbacks=callbacks)
    return history

def evaluate_model(model, ds_test, x_test, y_test, prefix="exp"):
    test_loss, test_acc = model.evaluate(ds_test, verbose=0)
    # Confusion matrix
    y_prob = model.predict(ds_test, verbose=0)
    y_pred = np.argmax(y_prob, axis=1)
    cm = confusion_matrix(y_test, y_pred)
    report = classification_report(y_test, y_pred, digits=4)
    np.save(OUTDIR / f"{prefix}_cm.npy", cm)
    with open(OUTDIR / f"{prefix}_cls_report.txt", 'w') as f:
        f.write(report)
    return float(test_acc), cm, report

In [None]:
# ------------------------------------------------------------
# FLOPs, params, sparsity
# ------------------------------------------------------------


def get_model_params(model):
    return int(model.count_params())

def get_model_flops(model):
    try:
        concrete = tf.function(model).get_concrete_function(tf.TensorSpec([1, 28, 28, 1], tf.float32))
        frozen_func = tf.graph_util.convert_variables_to_constants_v2(concrete)
        graph_def = frozen_func.graph.as_graph_def()
        with tf.Graph().as_default() as graph:
            tf.graph_util.import_graph_def(graph_def, name='')
            run_meta = tf.compat.v1.RunMetadata()
            opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
            flops = tf.compat.v1.profiler.profile(graph=graph, run_meta=run_meta, cmd='op', options=opts)
            return int(flops.total_float_ops)
    except Exception:
        try:
            from tensorflow.keras.utils import get_flops
            return int(get_flops(model, batch_size=1))
        except Exception:
            return None

def estimate_sparsity(model):
    total = 0
    zeros = 0
    for w in model.get_weights():
        total += w.size
        zeros += np.count_nonzero(w == 0)
    return zeros / total if total else 0.0

In [None]:
# ------------------------------------------------------------
# TFLite conversion + latency benchmarking
# ------------------------------------------------------------


def representative_data_gen(x_train, num_samples=100):
    for i in range(min(num_samples, x_train.shape[0])):
        yield [x_train[i:i+1].astype(np.float32)]


def to_tflite(model, kind: str, x_train=None, int8_io=True):
    if kind == 'float':
        converter = tf.lite.TFLiteConverter.from_keras_model(model)
        tflite = converter.convert()
    elif kind == 'dynamic':
        converter = tf.lite.TFLiteConverter.from_keras_model(model)
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
        tflite = converter.convert()
    elif kind == 'int8':
        assert x_train is not None, "x_train required for full-int8 PTQ"
        converter = tf.lite.TFLiteConverter.from_keras_model(model)
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
        converter.representative_dataset = lambda: representative_data_gen(x_train)
        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
        if int8_io:
            converter.inference_input_type = tf.int8
            converter.inference_output_type = tf.int8
        tflite = converter.convert()
    else:
        raise ValueError(f"Unknown tflite kind: {kind}")
    return tflite


def _set_input(interpreter, sample):
    input_details = interpreter.get_input_details()[0]
    tensor_index = input_details['index']
    if input_details['dtype'] == np.float32:
        interpreter.set_tensor(tensor_index, sample.astype(np.float32))
    elif input_details['dtype'] == np.int8:
        scale, zero_point = input_details['quantization']
        quantized = (sample / scale + zero_point).astype(np.int8)
        interpreter.set_tensor(tensor_index, quantized)
    elif input_details['dtype'] == np.uint8:
        scale, zero_point = input_details['quantization']
        quantized = (sample / scale + zero_point).astype(np.uint8)
        interpreter.set_tensor(tensor_index, quantized)
    else:
        raise RuntimeError(f"Unsupported input dtype: {input_details['dtype']}")
    

def tflite_latency(tflite_bytes, x_test, warmup=WARMUP_RUNS, runs=TIMED_RUNS):
    interpreter = tf.lite.Interpreter(model_content=tflite_bytes)
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()


    # prepare one sample repeatedly
    sample = x_test[0:1]


    # warmup
    for _ in range(warmup):
        _set_input(interpreter, sample)
    interpreter.invoke()


    times = []
    for _ in range(runs):
        _set_input(interpreter, sample)
    t0 = time.perf_counter_ns()
    interpreter.invoke()
    t1 = time.perf_counter_ns()
    times.append((t1 - t0) / 1e6) # ms


    return {
    'mean_ms': float(np.mean(times)),
    'median_ms': float(np.median(times)),
    'p90_ms': float(np.percentile(times, 90)),
    'p99_ms': float(np.percentile(times, 99)),
    'min_ms': float(np.min(times)),
    'max_ms': float(np.max(times)),
    'runs': len(times),
    }

In [9]:
# ------------------------------------------------------------
# Pruning (TF-MOT)
# ------------------------------------------------------------

def apply_pruning(model, ds_train, epochs, target_sparsity=0.8):
    """Wrap model with pruning and fine-tune for 'epochs'."""
    steps_per_epoch = int(tf.data.experimental.cardinality(ds_train))
    schedule = tfmot.sparsity.keras.PolynomialDecay(
        initial_sparsity=0.0, final_sparsity=target_sparsity,
        begin_step=0, end_step=epochs * steps_per_epoch)

    pruned = tfmot.sparsity.keras.prune_low_magnitude(model, pruning_schedule=schedule)
    pruned = compile_model(pruned)

    callbacks = [
        tfmot.sparsity.keras.UpdatePruningStep(),
        tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=2, restore_best_weights=True),
    ]
    return pruned, callbacks
          

In [None]:
# ------------------------------------------------------------
# Weight clustering
# ------------------------------------------------------------


def apply_clustering(model, clusters=8):
    try:
        clustering = tfmot.clustering.keras
        centroids_init = clustering.CentroidInitialization.LINEAR
        clustered = clustering.cluster_weights(model, number_of_clusters=clusters, cluster_centroids_init=centroids_init)
        clustered.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
        return clustered
    except Exception as e:
        print(f"Clustering unavailable or failed: {e}. Proceeding without clustering.")
        clone = tf.keras.models.clone_model(model)
        clone.build((None, 28, 28, 1))
        clone.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
        return clone

In [11]:
# ------------------------------------------------------------
# Quantization-aware training (QAT)
# ------------------------------------------------------------


def apply_qat(model):
    try:
        quantize_model = tfmot.quantization.keras.quantize_model
        # Ensure we pass a standard Keras model type
        if not isinstance(model, (tf.keras.Sequential, tf.keras.Model)):
            model = tf.keras.models.clone_model(model)
        qat = quantize_model(model)
        qat.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
        return qat
    except Exception as e:
        print(f"QAT unavailable or failed: {e}. Proceeding without QAT.")
        clone = tf.keras.models.clone_model(model)
        clone.build((None, 28, 28, 1))
        clone.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
        return clone

In [None]:
# ------------------------------------------------------------
# Knowledge Distillation (KD)
# ------------------------------------------------------------

class Distiller(tf.keras.Model):
    def __init__(self, student, teacher, temperature=2.0, alpha=0.1):
        super().__init__()
        self.student = student
        self.teacher = teacher
        self.temperature = temperature
        self.alpha = alpha
        self.student_loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
        self.distill_loss_fn = tf.keras.losses.KLDivergence()
        self.metrics_tracker = [tf.keras.metrics.SparseCategoricalAccuracy(name='acc')]


    def compile(self, optimizer):
        super().compile()
        self.optimizer = optimizer


    @property
    def metrics(self):
        return self.metrics_tracker


    def train_step(self, data):
        x, y = data
        teacher_probs = tf.nn.softmax(self.teacher(x, training=False) / self.temperature)
        with tf.GradientTape() as tape:
            student_logits = self.student(x, training=True)
            student_loss = self.student_loss_fn(y, student_logits)
            distill_loss = self.distill_loss_fn(teacher_probs, tf.nn.softmax(student_logits / self.temperature))
            loss = self.alpha * student_loss + (1 - self.alpha) * distill_loss
        grads = tape.gradient(loss, self.student.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.student.trainable_variables))
        for m in self.metrics:
            m.update_state(y, student_logits)
        return {"loss": loss, **{m.name: m.result() for m in self.metrics}}


    def test_step(self, data):
        x, y = data
        y_pred = self.student(x, training=False)
        loss = self.student_loss_fn(y, y_pred)
        for m in self.metrics:
            m.update_state(y, y_pred)
        return {"loss": loss, **{m.name: m.result() for m in self.metrics}}

In [None]:
# ------------------------------------------------------------
# Orchestrator
# ------------------------------------------------------------


def save_tflite(path, tflite_bytes):
    Path(path).write_bytes(tflite_bytes)
    return Path(path).stat().st_size

def run_one_experiment(kind: str, base: str, seed: int, epochs: int, batch: int, ds_pack, arrays_pack):
    ds_train, ds_val, ds_test = ds_pack
    x_train, y_train, x_val, y_val, x_test, y_test = arrays_pack


    set_global_seed(seed)
    if base == 'conv':
        base_model = build_conv_model()
    elif base == 'dscnn':
        base_model = build_tiny_dscnn()
    else:
        raise ValueError(f"Unknown base model: {base}")


    exp_id = f"{base}_{kind}_s{seed}"
    print(f"\n=== Running {exp_id} ===")

    # Build variant
    if kind == 'baseline':
        model = compile_model(base_model)
        history = train_model(model, ds_train, ds_val, epochs=epochs, seed=seed, outprefix=exp_id)
    elif kind.startswith('prune'):
        target = float(kind.split('-')[1]) if '-' in kind else 0.8
        pruned, callbacks = apply_pruning(base_model, ds_train, epochs, target_sparsity=target)
        history = pruned.fit(ds_train, validation_data=ds_val, epochs=epochs, verbose=2, callbacks=callbacks)
        model = tfmot.sparsity.keras.strip_pruning(pruned)
        model.compile(optimizer='adam',
                    loss='sparse_categorical_crossentropy',
                    metrics=['accuracy'])
    elif kind.startswith('cluster'):
        clusters = int(kind.split('-')[1]) if '-' in kind else 8
        clustered = apply_clustering(base_model, clusters=clusters)
        history = clustered.fit(ds_train, validation_data=ds_val,
                                epochs=max(2, epochs // 2), verbose=2)
        model = tfmot.clustering.keras.strip_clustering(clustered)
        model.compile(optimizer='adam',
                    loss='sparse_categorical_crossentropy',
                    metrics=['accuracy'])
    elif kind == 'qat':
        qat = apply_qat(base_model)
        history = qat.fit(ds_train, validation_data=ds_val, epochs=epochs, verbose=2)
        model = qat
    elif kind == 'kd':
        # teacher: trained ConvNet
        teacher = build_conv_model()
        teacher = compile_model(teacher)
        teacher.fit(ds_train, validation_data=ds_val, epochs=epochs, verbose=0)

        student = base_model
        distiller = Distiller(student=student, teacher=teacher, temperature=2.0, alpha=0.2)
        distiller.compile(optimizer=tf.keras.optimizers.Adam())

        history = distiller.fit(ds_train, validation_data=ds_val, epochs=epochs, verbose=2)
        student.compile(optimizer='adam',
                        loss='sparse_categorical_crossentropy',
                        metrics=['accuracy'])
        model = student

    # Evaluate FP32 Keras
    model = ensure_compiled(model)
    acc, cm, cls_rep = evaluate_model(model, ds_test, x_test, y_test, prefix=exp_id)

    # FLOPs / params / sparsity
    params = get_model_params(model)
    flops = get_model_flops(model)
    sparsity_est = estimate_sparsity(model)

    # TFLite variants
    sizes = {}
    latencies = {}

    # float32 tflite
    tfl_float = to_tflite(model, kind='float')
    sizes['tflite_float_bytes'] = save_tflite(OUTDIR / 'tflite' / f'{exp_id}_float.tflite', tfl_float)
    latencies['tflite_float'] = tflite_latency(tfl_float, x_test)

    # dynamic range PTQ
    tfl_dyn = to_tflite(model, kind='dynamic')
    sizes['tflite_dynamic_bytes'] = save_tflite(OUTDIR / 'tflite' / f'{exp_id}_dynamic.tflite', tfl_dyn)
    latencies['tflite_dynamic'] = tflite_latency(tfl_dyn, x_test)

    # full int8 PTQ
    tfl_int8 = to_tflite(model, kind='int8', x_train=x_train)
    sizes['tflite_int8_bytes'] = save_tflite(OUTDIR / 'tflite' / f'{exp_id}_int8.tflite', tfl_int8)
    latencies['tflite_int8'] = tflite_latency(tfl_int8, x_test)

    # TFLite accuracies
    acc_float   = tflite_accuracy(tfl_float, x_test, y_test)
    acc_dynamic = tflite_accuracy(tfl_dyn,   x_test, y_test)
    acc_int8    = tflite_accuracy(tfl_int8,  x_test, y_test)

    # Save Keras model
    model_path = OUTDIR / 'models' / f'{exp_id}.keras'
    model.save(model_path)
    keras_bytes = sum(p.stat().st_size for p in model_path.rglob('*')) if model_path.is_dir() else model_path.stat().st_size

    row = {
        'exp_id': exp_id,
        'base': base,
        'kind': kind,
        'seed': seed,
        'epochs': epochs,
        'batch': batch,
        'test_acc': acc,                 
        'params': params,
        'flops': flops,
        'sparsity_est': sparsity_est,
        'keras_bytes': keras_bytes,

        # TFLite sizes
        'tflite_float_bytes':  sizes['tflite_float_bytes'],
        'tflite_dynamic_bytes': sizes['tflite_dynamic_bytes'],
        'tflite_int8_bytes':   sizes['tflite_int8_bytes'],

        # TFLite latencies
        'tflite_float_median_ms':   latencies['tflite_float']['median_ms'],
        'tflite_dynamic_median_ms': latencies['tflite_dynamic']['median_ms'],
        'tflite_int8_median_ms':    latencies['tflite_int8']['median_ms'],

        # TFLite accuracies
        'tflite_float_acc':   acc_float,
        'tflite_dynamic_acc': acc_dynamic,
        'tflite_int8_acc':    acc_int8,

        # Comparison deltas/ratios
        'acc_drop_int8':  acc - acc_int8,  # positive means int8 lost acc
        'size_ratio_int8': keras_bytes / max(1, sizes['tflite_int8_bytes']),
    }

    row.update(sizes)
    for k, v in latencies.items():
        for mname, mval in v.items():
            row[f'{k}_{mname}'] = mval


    # Persist confusion matrix
    return row


def run_suite(which: str, seeds: int, epochs: int, batch: int):
    ds_train, ds_val, ds_test, arrays = load_mnist(batch)


    # Define experiment recipes
    recipes = []
    if which in ('baseline', 'all'):
        recipes += [('baseline', 'conv'), ('baseline', 'dscnn')]
    if which in ('prune', 'all'):
        recipes += [(f'prune-0.8', 'conv'), (f'prune-0.8', 'dscnn')]
    if which in ('cluster', 'all'):
         recipes += [(f'cluster-8', 'conv'), (f'cluster-8', 'dscnn')]
    if which in ('qat', 'all'):
        recipes += [('qat', 'conv'), ('qat', 'dscnn')]
    if which in ('kd', 'all'):
        recipes += [('kd', 'dscnn')]

    rows = []
    for (kind, base) in recipes:
        for s in range(seeds):
            row = run_one_experiment(kind=kind, base=base, seed=s, epochs=epochs, batch=batch, ds_pack=(ds_train, ds_val, ds_test), arrays_pack=arrays)
            rows.append(row)
            df = pd.DataFrame(rows)
            df.to_csv(OUTDIR / 'results.csv', index=False)
            gc.collect()

    df = pd.DataFrame(rows)
    df.to_csv(OUTDIR / 'results.csv', index=False)
    return df

In [None]:
# ------------------------------------------------------------
# Reporting: figures + Markdown
# ------------------------------------------------------------


def plot_scatter(df, x, y, label='exp_id', fname='scatter.png'):
    plt.figure()
    plt.scatter(df[x].values, df[y].values)
    texts = []
    if label:
        for _, row in df.iterrows():
            texts.append(
                plt.text(row[x], row[y], str(row[label]), fontsize=8)
            )
        try:
            adjust_text(texts, arrowprops=dict(arrowstyle="->", color='gray', lw=0.5))
        except Exception:
            pass
    plt.xlabel(x)
    plt.ylabel(y)
    plt.tight_layout()
    plt.savefig(OUTDIR / 'figures' / fname, dpi=180)
    plt.close()


def pareto_front(df, x_col, y_col):
    pts = df[[x_col, y_col, 'exp_id']].values
    pts = sorted(pts, key=lambda r: (r[0], -r[1]))
    front = []
    best_y = -1
    for x, y, name in pts:
        if y > best_y:
            front.append((x, y, name))
            best_y = y
    return front

def build_report(csv_path=OUTDIR / 'results.csv'):
    df = pd.read_csv(csv_path)
    if 'tflite_int8_median_ms' in df.columns:
        plot_scatter(df, 'tflite_int8_median_ms', 'test_acc', label='exp_id', fname='acc_vs_latency_int8.png')
    plot_scatter(df, 'params', 'test_acc', label='exp_id', fname='acc_vs_params.png')
    if 'tflite_int8_bytes' in df.columns:
        plot_scatter(df, 'tflite_int8_bytes', 'test_acc', label='exp_id', fname='acc_vs_size_int8.png')


    # Pareto summary (latency vs acc using int8 median)
    pareto = []
    if 'tflite_int8_median_ms' in df.columns:
        pf = pareto_front(df, 'tflite_int8_median_ms', 'test_acc')
        pareto = [{'exp_id': n, 'latency_ms': x, 'acc': y} for (x, y, n) in pf]


    # Markdown
    md = io.StringIO()
    md.write("# MNIST Optimization Report\n\n")
    md.write("## Summary\n\n")
    md.write(f"Experiments: {len(df)}\n\n")
    if pareto:
        md.write("### Pareto Frontier (latency vs accuracy, int8)\n\n")
        for p in pareto:
            md.write(f"- {p['exp_id']}: {p['acc']:.4f} acc @ {p['latency_ms']:.3f} ms\n")
        md.write("\n")


    md.write("## Results Table (top 10 by accuracy)\n\n")
    top = df.sort_values('test_acc', ascending=False).head(10)
    md.write(top.to_markdown(index=False))
    md.write("\n\n## Figures\n\n")
    for fig in ['acc_vs_latency_int8.png', 'acc_vs_params.png', 'acc_vs_size_int8.png']:
        fig_path = OUTDIR / 'figures' / fig
        if fig_path.exists():
            md.write(f"![{fig}]({fig_path.as_posix()})\n\n")


    # --- TFLite vs Keras comparison table ---
    cols = [
        'exp_id','kind','test_acc',
        'tflite_float_acc','tflite_dynamic_acc','tflite_int8_acc',
        'acc_drop_int8',
        'keras_bytes','tflite_float_bytes','tflite_int8_bytes',
        'tflite_float_median_ms','tflite_int8_median_ms',
        'size_ratio_int8'
    ]
    have = [c for c in cols if c in df.columns]
    comp = df[have].sort_values('test_acc', ascending=False).head(12)

    md.write("## Keras vs TFLite comparison\n\n")
    try:
        md.write(comp.to_markdown(index=False))
    except Exception:
        md.write(comp.to_csv(index=False))
    md.write("\n\n")

    # --- Figure: grouped latency bars (float vs int8) ---
    try:
        sub = df.sort_values('test_acc', ascending=False).head(12)
        labels = sub['exp_id'].tolist()
        x = np.arange(len(labels))
        width = 0.35

        plt.figure(figsize=(max(8, len(labels)*0.6), 4.5))
        plt.bar(x - width/2, sub['tflite_float_median_ms'], width, label='TFLite float median ms')
        plt.bar(x + width/2, sub['tflite_int8_median_ms'],  width, label='TFLite int8 median ms')
        plt.xticks(x, labels, rotation=45, ha='right')
        plt.ylabel('median latency (ms)')
        plt.legend()
        plt.tight_layout()
        plt.savefig(OUTDIR / 'figures' / 'tflite_latency_bar.png', dpi=180)
        plt.close()

        md.write("### Latency: float vs int8 (median)\n\n")
        md.write("![tflite_latency_bar](figures/tflite_latency_bar.png)\n\n")
    except Exception as e:
        md.write(f"Latency bar plot skipped: {e}\n\n")

    # --- Figure: accuracy drop vs size ratio (int8) ---
    if set(['acc_drop_int8','size_ratio_int8']).issubset(df.columns):
        try:
            plt.figure(figsize=(6.5, 4.5))
            plt.scatter(df['size_ratio_int8'], df['acc_drop_int8'])
            texts = []
            for _, r in df.iterrows():
                texts.append(plt.text(r['size_ratio_int8'], r['acc_drop_int8'], r['exp_id'], fontsize=8))
            try:
                adjust_text(texts, arrowprops=dict(arrowstyle="->", color='gray', lw=0.5))
            except Exception:
                pass
            plt.xlabel('size ratio (keras_bytes / tflite_int8_bytes)')
            plt.ylabel('accuracy drop (Keras acc - TFLite int8 acc)')
            plt.tight_layout()
            plt.savefig(OUTDIR / 'figures' / 'tflite_accdrop_vs_sizeratio.png', dpi=180)
            plt.close()

            md.write("### Accuracy drop vs size ratio (int8)\n\n")
            md.write("![tflite_accdrop_vs_sizeratio](figures/tflite_accdrop_vs_sizeratio.png)\n\n")
        except Exception as e:
            md.write(f"Acc vs size plot skipped: {e}\n\n")

    
    md.write("## Training accuracy curves\n\n")
    for _, r in df.sort_values('test_acc', ascending=False).iterrows():
        figp = OUTDIR / 'figures' / f"{r['exp_id']}_acc.png"
        if figp.exists():
            md.write(f"**{r['exp_id']}**\n\n")
            md.write(f"![{r['exp_id']} accuracy](figures/{figp.name})\n\n")


    report_path = OUTDIR / 'MNIST_Optimization_Report.md'
    with open(report_path, 'w') as f:
        f.write(md.getvalue())
    print(f"Report written to {report_path}")

In [None]:
# ------------------------------------------------------------
# CLI
# ------------------------------------------------------------

def list_recipes():
    print("Available experiment groups (use with --run):")
    print(" baseline : ConvNet + Tiny DS-CNN")
    print(" prune : Magnitude pruning to 80% sparsity")
    print(" cluster : Weight clustering to 8 centroids")
    print(" qat : Quantization-aware training")
    print(" kd : Distill ConvNet -> Tiny DS-CNN")
    print(" all : Everything above")

def main():
    p = argparse.ArgumentParser()
    p.add_argument('--list', action='store_true', help='List available experiment groups and exit')
    p.add_argument('--run', type=str, default='baseline', help='Which group to run: baseline|prune|cluster|qat|kd|all (comma-separated allowed)')
    p.add_argument('--epochs', type=int, default=DEFAULT_EPOCHS)
    p.add_argument('--batch', type=int, default=DEFAULT_BATCH)
    p.add_argument('--seeds', type=int, default=DEFAULT_SEEDS)
    p.add_argument('--report', action='store_true', help='Build figures + Markdown from existing CSV')
    args = p.parse_args(args=[])

    if args.list:
        list_recipes()
        return

    if args.report:
        build_report()
        return

    which_list = [w.strip() for w in args.run.split(',') if w.strip()]
    all_rows = []
    for which in which_list:
        # df = run_suite(which=which, seeds=args.seeds, epochs=args.epochs, batch=args.batch)
        df = run_suite(which='all', seeds=1, epochs=10, batch=128)
        all_rows.append(df)
    if all_rows:
        df_all = pd.concat(all_rows).drop_duplicates(subset=['exp_id'])
        df_all.to_csv(OUTDIR / 'results.csv', index=False)
        print("Saved:", OUTDIR / 'results.csv')
        build_report()


if __name__ == '__main__':
    main()




=== Running conv_baseline_s0 ===

Epoch 1/10


430/430 - 4s - loss: 0.3643 - accuracy: 0.8902 - val_loss: 0.0798 - val_accuracy: 0.9776 - lr: 0.0010 - 4s/epoch - 8ms/step
Epoch 2/10
430/430 - 3s - loss: 0.1081 - accuracy: 0.9666 - val_loss: 0.0525 - val_accuracy: 0.9850 - lr: 0.0010 - 3s/epoch - 6ms/step
Epoch 3/10
430/430 - 3s - loss: 0.0833 - accuracy: 0.9742 - val_loss: 0.0471 - val_accuracy: 0.9872 - lr: 0.0010 - 3s/epoch - 6ms/step
Epoch 4/10
430/430 - 3s - loss: 0.0700 - accuracy: 0.9786 - val_loss: 0.0407 - val_accuracy: 0.9884 - lr: 0.0010 - 3s/epoch - 6ms/step
Epoch 5/10
430/430 - 3s - loss: 0.0604 - accuracy: 0.9815 - val_loss: 0.0377 - val_accuracy: 0.9892 - lr: 0.0010 - 3s/epoch - 6ms/step
Epoch 6/10
430/430 - 3s - loss: 0.0536 - accuracy: 0.9830 - val_loss: 0.0326 - val_accuracy: 0.9908 - lr: 0.0010 - 3s/epoch - 6ms/step
Epoch 7/10
430/430 - 3s - loss: 0.0517 - accuracy: 0.9835 - val_loss: 0.0361 - val_accuracy: 0.9900 - lr: 0.0010 - 3s/epoch - 6ms/step
Epoch 8/10
430/4

INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmp547wcie3\assets
    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmpxo7dbebo\assets


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmpxo7dbebo\assets
    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmpq8sw_hji\assets


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmpq8sw_hji\assets
    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    



=== Running dscnn_baseline_s0 ===
Epoch 1/10
430/430 - 3s - loss: 1.8223 - accuracy: 0.4022 - val_loss: 2.3739 - val_accuracy: 0.1140 - lr: 0.0010 - 3s/epoch - 8ms/step
Epoch 2/10
430/430 - 2s - loss: 1.0152 - accuracy: 0.7338 - val_loss: 0.9305 - val_accuracy: 0.6786 - lr: 0.0010 - 2s/epoch - 5ms/step
Epoch 3/10
430/430 - 2s - loss: 0.5757 - accuracy: 0.8711 - val_loss: 0.4988 - val_accuracy: 0.8682 - lr: 0.0010 - 2s/epoch - 5ms/step
Epoch 4/10
430/430 - 2s - loss: 0.3890 - accuracy: 0.9084 - val_loss: 0.3920 - val_accuracy: 0.8904 - lr: 0.0010 - 2s/epoch - 5ms/step
Epoch 5/10
430/430 - 2s - loss: 0.3094 - accuracy: 0.9224 - val_loss: 0.3673 - val_accuracy: 0.8834 - lr: 0.0010 - 2s/epoch - 5ms/step
Epoch 6/10
430/430 - 2s - loss: 0.2665 - accuracy: 0.9291 - val_loss: 0.2177 - val_accuracy: 0.9464 - lr: 0.0010 - 2s/epoch - 5ms/step
Epoch 7/10
430/430 - 2s - loss: 0.2400 - accuracy: 0.9352 - val_loss: 0.2167 - val_accuracy: 0.9418 - lr: 0.0010 - 2s/epoch - 5ms/step
Epoch 8/10
430/430 -

INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmp0atxwdnn\assets
    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmpyinex82j\assets


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmpyinex82j\assets
    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmpfkhrrg2t\assets


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmpfkhrrg2t\assets
    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    



=== Running conv_prune-0.8_s0 ===
Epoch 1/10
430/430 - 4s - loss: 0.3674 - accuracy: 0.8897 - val_loss: 0.0815 - val_accuracy: 0.9768 - 4s/epoch - 10ms/step
Epoch 2/10
430/430 - 3s - loss: 0.1171 - accuracy: 0.9643 - val_loss: 0.0590 - val_accuracy: 0.9826 - 3s/epoch - 7ms/step
Epoch 3/10
430/430 - 3s - loss: 0.0944 - accuracy: 0.9715 - val_loss: 0.0531 - val_accuracy: 0.9856 - 3s/epoch - 6ms/step
Epoch 4/10
430/430 - 3s - loss: 0.0821 - accuracy: 0.9753 - val_loss: 0.0489 - val_accuracy: 0.9880 - 3s/epoch - 7ms/step
Epoch 5/10
430/430 - 3s - loss: 0.0769 - accuracy: 0.9767 - val_loss: 0.0453 - val_accuracy: 0.9888 - 3s/epoch - 6ms/step
Epoch 6/10
430/430 - 3s - loss: 0.0751 - accuracy: 0.9770 - val_loss: 0.0438 - val_accuracy: 0.9894 - 3s/epoch - 6ms/step
Epoch 7/10
430/430 - 3s - loss: 0.0733 - accuracy: 0.9786 - val_loss: 0.0455 - val_accuracy: 0.9892 - 3s/epoch - 6ms/step
Epoch 8/10
430/430 - 3s - loss: 0.0702 - accuracy: 0.9786 - val_loss: 0.0410 - val_accuracy: 0.9900 - 3s/epoch

INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmp1qhtwj9w\assets
    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmp__j3vzhf\assets


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmp__j3vzhf\assets
    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmpgff5ww7x\assets


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmpgff5ww7x\assets
    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    



=== Running dscnn_prune-0.8_s0 ===
Epoch 1/10
430/430 - 4s - loss: 1.8512 - accuracy: 0.3751 - val_loss: 2.1970 - val_accuracy: 0.1148 - 4s/epoch - 10ms/step
Epoch 2/10
430/430 - 2s - loss: 1.1678 - accuracy: 0.6534 - val_loss: 1.1423 - val_accuracy: 0.6308 - 2s/epoch - 5ms/step
Epoch 3/10
430/430 - 2s - loss: 0.7757 - accuracy: 0.8021 - val_loss: 0.8077 - val_accuracy: 0.7378 - 2s/epoch - 5ms/step
Epoch 4/10
430/430 - 2s - loss: 0.6257 - accuracy: 0.8325 - val_loss: 1.0352 - val_accuracy: 0.6308 - 2s/epoch - 5ms/step
Epoch 5/10
430/430 - 2s - loss: 0.5871 - accuracy: 0.8414 - val_loss: 0.8202 - val_accuracy: 0.7052 - 2s/epoch - 5ms/step
INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmp6ky_dbc5\assets


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmp6ky_dbc5\assets
    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmpj77_18g4\assets


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmpj77_18g4\assets
    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmpwdkas6u9\assets


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmpwdkas6u9\assets
    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    



=== Running conv_cluster-8_s0 ===
Epoch 1/5
430/430 - 3s - loss: 0.3062 - accuracy: 0.9145 - val_loss: 0.0917 - val_accuracy: 0.9750 - 3s/epoch - 7ms/step
Epoch 2/5
430/430 - 3s - loss: 0.0866 - accuracy: 0.9735 - val_loss: 0.0647 - val_accuracy: 0.9820 - 3s/epoch - 6ms/step
Epoch 3/5
430/430 - 3s - loss: 0.0666 - accuracy: 0.9794 - val_loss: 0.0559 - val_accuracy: 0.9858 - 3s/epoch - 7ms/step
Epoch 4/5
430/430 - 3s - loss: 0.0560 - accuracy: 0.9825 - val_loss: 0.0551 - val_accuracy: 0.9858 - 3s/epoch - 7ms/step
Epoch 5/5
430/430 - 3s - loss: 0.0498 - accuracy: 0.9845 - val_loss: 0.0480 - val_accuracy: 0.9864 - 3s/epoch - 7ms/step
INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmp6k47ytiy\assets


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmp6k47ytiy\assets
    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmppteakjh1\assets


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmppteakjh1\assets
    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmpek6yd40z\assets


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmpek6yd40z\assets
    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    



=== Running dscnn_cluster-8_s0 ===
Epoch 1/5
430/430 - 3s - loss: 2.0778 - accuracy: 0.2245 - val_loss: 1.6313 - val_accuracy: 0.3520 - 3s/epoch - 6ms/step
Epoch 2/5
430/430 - 2s - loss: 1.4077 - accuracy: 0.4960 - val_loss: 1.1571 - val_accuracy: 0.5216 - 2s/epoch - 4ms/step
Epoch 3/5
430/430 - 2s - loss: 1.1112 - accuracy: 0.6145 - val_loss: 1.0240 - val_accuracy: 0.6634 - 2s/epoch - 4ms/step
Epoch 4/5
430/430 - 2s - loss: 0.8855 - accuracy: 0.7104 - val_loss: 0.8823 - val_accuracy: 0.6886 - 2s/epoch - 4ms/step
Epoch 5/5
430/430 - 2s - loss: 0.8192 - accuracy: 0.7306 - val_loss: 0.6273 - val_accuracy: 0.8148 - 2s/epoch - 4ms/step
INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmplr940wfq\assets


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmplr940wfq\assets
    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmp7qeh48jn\assets


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmp7qeh48jn\assets
    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmpnf43gho8\assets


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmpnf43gho8\assets
    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    



=== Running conv_qat_s0 ===
Epoch 1/10
430/430 - 4s - loss: 0.3746 - accuracy: 0.8857 - val_loss: 0.0799 - val_accuracy: 0.9780 - 4s/epoch - 10ms/step
Epoch 2/10
430/430 - 4s - loss: 0.1138 - accuracy: 0.9658 - val_loss: 0.0600 - val_accuracy: 0.9844 - 4s/epoch - 9ms/step
Epoch 3/10
430/430 - 4s - loss: 0.0864 - accuracy: 0.9741 - val_loss: 0.0464 - val_accuracy: 0.9882 - 4s/epoch - 10ms/step
Epoch 4/10
430/430 - 4s - loss: 0.0723 - accuracy: 0.9781 - val_loss: 0.0437 - val_accuracy: 0.9880 - 4s/epoch - 10ms/step
Epoch 5/10
430/430 - 4s - loss: 0.0636 - accuracy: 0.9800 - val_loss: 0.0378 - val_accuracy: 0.9884 - 4s/epoch - 10ms/step
Epoch 6/10
430/430 - 4s - loss: 0.0565 - accuracy: 0.9825 - val_loss: 0.0340 - val_accuracy: 0.9898 - 4s/epoch - 9ms/step
Epoch 7/10
430/430 - 4s - loss: 0.0521 - accuracy: 0.9839 - val_loss: 0.0362 - val_accuracy: 0.9902 - 4s/epoch - 10ms/step
Epoch 8/10
430/430 - 4s - loss: 0.0478 - accuracy: 0.9854 - val_loss: 0.0318 - val_accuracy: 0.9912 - 4s/epoch -

INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmpgoue8w94\assets
    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmpntltd6ep\assets


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmpntltd6ep\assets
    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmp7aqx3se7\assets


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmp7aqx3se7\assets
    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    



=== Running dscnn_qat_s0 ===
Epoch 1/10
430/430 - 4s - loss: 1.8155 - accuracy: 0.3985 - val_loss: 1.6702 - val_accuracy: 0.4696 - 4s/epoch - 9ms/step
Epoch 2/10
430/430 - 3s - loss: 1.0596 - accuracy: 0.7185 - val_loss: 1.5637 - val_accuracy: 0.3660 - 3s/epoch - 7ms/step
Epoch 3/10
430/430 - 3s - loss: 0.6326 - accuracy: 0.8588 - val_loss: 0.5269 - val_accuracy: 0.8610 - 3s/epoch - 7ms/step
Epoch 4/10
430/430 - 3s - loss: 0.4411 - accuracy: 0.8951 - val_loss: 0.3837 - val_accuracy: 0.9124 - 3s/epoch - 7ms/step
Epoch 5/10
430/430 - 3s - loss: 0.3481 - accuracy: 0.9131 - val_loss: 0.3732 - val_accuracy: 0.9030 - 3s/epoch - 7ms/step
Epoch 6/10
430/430 - 3s - loss: 0.2976 - accuracy: 0.9210 - val_loss: 0.2893 - val_accuracy: 0.9236 - 3s/epoch - 7ms/step
Epoch 7/10
430/430 - 3s - loss: 0.2625 - accuracy: 0.9287 - val_loss: 0.3073 - val_accuracy: 0.9122 - 3s/epoch - 7ms/step
Epoch 8/10
430/430 - 3s - loss: 0.2397 - accuracy: 0.9348 - val_loss: 0.2110 - val_accuracy: 0.9422 - 3s/epoch - 7ms

INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmpj3h_zqyq\assets
    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmppw1k5qb1\assets


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmppw1k5qb1\assets
    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmpqv5gjtg7\assets


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmpqv5gjtg7\assets
    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    



=== Running dscnn_kd_s0 ===
Epoch 1/10
430/430 - 4s - loss: 0.2811 - acc: 0.4032 - val_loss: 2.5530 - val_acc: 0.1300 - 4s/epoch - 9ms/step
Epoch 2/10
430/430 - 3s - loss: 0.1723 - acc: 0.7232 - val_loss: 0.9087 - val_acc: 0.5826 - 3s/epoch - 6ms/step
Epoch 3/10
430/430 - 3s - loss: 0.1299 - acc: 0.8537 - val_loss: 0.4626 - val_acc: 0.8364 - 3s/epoch - 6ms/step
Epoch 4/10
430/430 - 3s - loss: 0.0739 - acc: 0.8924 - val_loss: 0.3191 - val_acc: 0.9088 - 3s/epoch - 6ms/step
Epoch 5/10
430/430 - 3s - loss: 0.0619 - acc: 0.9190 - val_loss: 0.1898 - val_acc: 0.8950 - 3s/epoch - 7ms/step
Epoch 6/10
430/430 - 3s - loss: 0.0814 - acc: 0.9295 - val_loss: 0.2084 - val_acc: 0.8736 - 3s/epoch - 7ms/step
Epoch 7/10
430/430 - 3s - loss: 0.0438 - acc: 0.9362 - val_loss: 0.0641 - val_acc: 0.9532 - 3s/epoch - 7ms/step
Epoch 8/10
430/430 - 3s - loss: 0.0332 - acc: 0.9407 - val_loss: 0.0667 - val_acc: 0.9360 - 3s/epoch - 7ms/step
Epoch 9/10
430/430 - 3s - loss: 0.0467 - acc: 0.9440 - val_loss: 0.0578 - v

INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmpyyo1nfp5\assets
    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmptwdrfnby\assets


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmptwdrfnby\assets
    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmpg12yijdw\assets


INFO:tensorflow:Assets written to: C:\Users\ceulea\AppData\Local\Temp\tmpg12yijdw\assets
    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    


Saved: runs_mnist_opt\results.csv
Report written to runs_mnist_opt\MNIST_Optimization_Report.md
