In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras import layers

# ==========================================
# 1. Prepare Real Text Data (IMDB Dataset)
# ==========================================
vocab_size = 10000  # We will use the top 10,000 most frequent words
maxlen = 100        # We will truncate/pad reviews to 100 words

print("Downloading and processing real IMDB dataset...")
# Load the dataset (Keras handles downloading it automatically)
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=vocab_size)

# Preprocess: Ensure all sequences are exactly 100 tokens long
x_train = tf.keras.preprocessing.sequence.pad_sequences(x_train, maxlen=maxlen)
x_test = tf.keras.preprocessing.sequence.pad_sequences(x_test, maxlen=maxlen)

# We use 10,000 samples for training and 2,000 for validation to keep execution fast
x_train, y_train = x_train[:10000], y_train[:10000]
x_test, y_test = x_test[:2000], y_test[:2000]

# Batch the data for optimization
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(128)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(128)

# ==========================================
# 2. Define Dynamic W4A4 Quantization (Weights & Activations)
# ==========================================
@tf.custom_gradient
def dynamic_quant_ste(x):
    """
    FORWARD PASS: Dynamic Quantization (Token-wise)
    Calculates a unique scale dynamically, preventing massive LLM 
    outliers from being clamped and destroyed.
    """
    # Find the massive outlier spike (max absolute value) per token vector
    abs_max = tf.reduce_max(tf.abs(x), axis=-1, keepdims=True)
    
    # Stretch our 16 steps (signed 4-bit range [-8, 7]) around the outlier
    scale = tf.maximum(abs_max / 7.0, 1e-7)
    
    # Quantize to integer levels and clamp
    x_quant = tf.clip_by_value(tf.round(x / scale), -8.0, 7.0)
    
    # Dequantize back to scaled "staircase" levels (Fake Quantization)
    result = x_quant * scale

    # BACKWARD PASS: Straight-Through Estimator (STE)
    def grad(upstream_gradient):
        return upstream_gradient

    return result, grad

class QuantizedDense(layers.Layer):
    """A Dense layer that quantizes BOTH weights and activations to 4-bit discrete steps."""
    def __init__(self, units, **kwargs):
        super().__init__(**kwargs)
        self.units = units

    def build(self, input_shape):
        self.w = self.add_weight(shape=(input_shape[-1], self.units),
                                 initializer="glorot_uniform", trainable=True)
        self.b = self.add_weight(shape=(self.units,), initializer="zeros", trainable=True)

    def call(self, inputs):
        # Quantize Weights on the fly
        w_q = dynamic_quant_ste(self.w)
        # Quantize incoming Activations on the fly
        x_q = dynamic_quant_ste(inputs)
        # MatMul with quantized discrete tensors
        return tf.matmul(x_q, w_q) + self.b

# ==========================================
# 3. Build the Transformer Networks
# ==========================================
def create_standard_transformer():
    """Baseline: Massive Continuous FP32 Memory footprint."""
    inputs = layers.Input(shape=(maxlen,))
    x = layers.Embedding(vocab_size, 64)(inputs)
    
    attn = layers.MultiHeadAttention(num_heads=2, key_dim=64)(x, x)
    x = layers.LayerNormalization(epsilon=1e-6)(x + attn)
    
    # Standard Continuous FP32 Dense Blocks
    ffn = layers.Dense(128, activation='relu')(x)
    ffn = layers.Dense(64)(ffn)
    
    x = layers.LayerNormalization(epsilon=1e-6)(x + ffn)
    x = layers.GlobalAveragePooling1D()(x)
    outputs = layers.Dense(1, activation='sigmoid')(x)
    return tf.keras.Model(inputs, outputs)

def create_quantized_transformer():
    """PoC: Lean Discrete INT4 Memory footprint."""
    inputs = layers.Input(shape=(maxlen,))
    x = layers.Embedding(vocab_size, 64)(inputs)
    
    attn = layers.MultiHeadAttention(num_heads=2, key_dim=64)(x, x)
    x = layers.LayerNormalization(epsilon=1e-6)(x + attn)
    
    # Quantized W4A4 Dense Blocks
    ffn = QuantizedDense(128)(x)
    ffn = layers.ReLU()(ffn) # Standard ReLU acts on top of discrete steps
    ffn = QuantizedDense(64)(ffn)
    
    x = layers.LayerNormalization(epsilon=1e-6)(x + ffn)
    x = layers.GlobalAveragePooling1D()(x)
    outputs = layers.Dense(1, activation='sigmoid')(x)
    return tf.keras.Model(inputs, outputs)

# ==========================================
# 4. Compile and Train Side-by-Side
# ==========================================
model_std = create_standard_transformer()
model_qnt = create_quantized_transformer()

model_std.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
model_qnt.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

epochs = 10 # 10 epochs is usually enough for IMDB to converge

print("--- Training Standard Transformer (Continuous FP32) ---")
history_std = model_std.fit(train_ds, validation_data=test_ds, epochs=epochs)

print("\n--- Training Quantized Transformer (Dynamic W4A4) ---")
history_qnt = model_qnt.fit(train_ds, validation_data=test_ds, epochs=epochs)

# ==========================================
# 5. Visualize Results
# ==========================================
print("\nGenerating performance plots...")
plt.figure(figsize=(10, 5))
plt.plot(range(1, epochs + 1), history_std.history['val_accuracy'], label="Standard (FP32)", linewidth=2, color='blue')
plt.plot(range(1, epochs + 1), history_qnt.history['val_accuracy'], label="Quantized (W4A4)", linewidth=2, color='red', linestyle='--')
plt.title("IMDB Accuracy: Continuous vs. 4-bit Discrete (W4A4)")
plt.xlabel("Epoch")
plt.ylabel("Validation Accuracy")
plt.legend()
plt.grid(True, linestyle=':', alpha=0.7)
plt.tight_layout()
plt.savefig("imdb_tf_accuracy.png", dpi=150)
print("Saved plot -> imdb_tf_accuracy.png")

Downloading and processing real IMDB dataset...
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb.npz
[1m17464789/17464789[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 0us/step


2026-02-20 14:21:52.393111: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M3
2026-02-20 14:21:52.393137: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 16.00 GB
2026-02-20 14:21:52.393144: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 5.92 GB
2026-02-20 14:21:52.393157: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2026-02-20 14:21:52.393164: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


--- Training Standard Transformer (Continuous FP32) ---
Epoch 1/10


2026-02-20 14:21:53.694931: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.


[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 75ms/step - accuracy: 0.7227 - loss: 0.5358 - val_accuracy: 0.8010 - val_loss: 0.4133
Epoch 2/10
[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 67ms/step - accuracy: 0.8861 - loss: 0.2781 - val_accuracy: 0.8180 - val_loss: 0.4357
Epoch 3/10
[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 68ms/step - accuracy: 0.9257 - loss: 0.1885 - val_accuracy: 0.7595 - val_loss: 0.7048
Epoch 4/10
[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 68ms/step - accuracy: 0.9427 - loss: 0.1446 - val_accuracy: 0.8050 - val_loss: 0.5227
Epoch 5/10
[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 66ms/step - accuracy: 0.9479 - loss: 0.1311 - val_accuracy: 0.7975 - val_loss: 0.7285
Epoch 6/10
[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 69ms/step - accuracy: 0.9690 - loss: 0.0817 - val_accuracy: 0.7875 - val_loss: 0.9921
Epoch 7/10
[1m79/79[0m [32m━━━━━━━━━━━━━━