# SignBART Functional Model - QAT Experimentation

This notebook lets you experiment with:
1. Loading and inspecting the functional model
2. Selective layer annotation for QAT
3. Applying quantization with different strategies
4. Comparing model sizes and accuracy

Once you find a good strategy, we'll integrate it into the training script.


In [None]:
import os
import yaml
import numpy as np
import tensorflow as tf
from tensorflow import keras
import tensorflow_model_optimization as tfmot

from model_functional import build_signbart_functional_with_dict_inputs
from layers import Projection, ClassificationHead, PositionalEmbedding
from encoder import Encoder, EncoderLayer
from decoder import Decoder, DecoderLayer
from attention import SelfAttention, CrossAttention, CausalSelfAttention

print(f"TensorFlow version: {tf.__version__}")
print(f"Keras version: {keras.__version__}")


## 1. Load Model Configuration


In [None]:
# Load config
config_path = "configs/arabic-asl-90kpts.yaml"

with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

print("Model Configuration:")
print(f"  d_model: {config['d_model']}")
print(f"  encoder_layers: {config['encoder_layers']}")
print(f"  decoder_layers: {config['decoder_layers']}")
print(f"  num_labels: {config['num_labels']}")
print(f"  joint_idx: {len(config['joint_idx'])} keypoints")


## 2. Build Functional Model


In [None]:
# Build model with FULLY FUNCTIONAL architecture
print("Building functional model (with functional Encoder/Decoder)...")
model = build_signbart_functional_with_dict_inputs(config)

# Test with dummy input
num_keypoints = len(config['joint_idx'])
dummy_input = {
    'keypoints': tf.random.normal((1, 10, num_keypoints, 2)),
    'attention_mask': tf.ones((1, 10))
}
output = model(dummy_input, training=False)

print(f"‚úì Model built successfully")
print(f"  Output shape: {output.shape}")
print(f"\nüìä Model uses FULLY FUNCTIONAL Encoder & Decoder")
print(f"   This means each component has proper .summary() and Netron visualization!")


## 3.5 Component-Level Inspection

Now let's look at each component separately. Since Encoder and Decoder are now **Functional Models**, they have proper `.summary()` methods!


In [None]:
from encoder import build_encoder_functional, EncoderLayer
from decoder import build_decoder_functional, DecoderLayer

print("="*80)
print("COMPONENT 1: ENCODER (Functional Model)")
print("="*80)
encoder_model = build_encoder_functional(config)
encoder_model.summary(line_length=100)

print("\n" + "="*80)
print("Drilling into Encoder's EncoderLayer (Custom Layer):")
print("="*80)
# EncoderLayer is a custom Layer (not a Model), but we can still see what's inside
for layer in encoder_model.layers:
    if isinstance(layer, EncoderLayer):
        print(f"\nüì¶ {layer.name} contains:")
        print(f"  - self_attn: {type(layer.self_attn).__name__}")
        print(f"  - fc1 (Dense): {layer.fc1.units} units")
        print(f"  - fc2 (Dense): {layer.fc2.units} units") 
        print(f"  - 2 LayerNorms, 2 Dropouts")
        print(f"  ‚Üí We can access layer.fc1 and layer.fc2 for selective quantization!")
        break

print("\n\n" + "="*80)
print("COMPONENT 2: DECODER (Functional Model)")
print("="*80)
decoder_model = build_decoder_functional(config)
decoder_model.summary(line_length=100)

print("\n" + "="*80)
print("Drilling into Decoder's DecoderLayer (Custom Layer):")
print("="*80)
for layer in decoder_model.layers:
    if isinstance(layer, DecoderLayer):
        print(f"\nüì¶ {layer.name} contains:")
        print(f"  - self_attn (causal): {type(layer.self_attn).__name__}")
        print(f"  - encoder_attn (cross): {type(layer.encoder_attn).__name__}")
        print(f"  - fc1 (Dense): {layer.fc1.units} units")
        print(f"  - fc2 (Dense): {layer.fc2.units} units")
        print(f"  - 3 LayerNorms, 3 Dropouts")
        print(f"  ‚Üí We can access layer.fc1 and layer.fc2 for selective quantization!")
        break

print("\n\n" + "="*80)
print("ARCHITECTURE HIERARCHY")
print("="*80)
print("‚úì Level 1: SignBART (Functional Model)")
print("  ‚úì Level 2: Encoder (Functional Model)")
print("    ‚Ä¢ Level 3: EncoderLayer (Custom Layer with Dense layers inside)")
print("  ‚úì Level 2: Decoder (Functional Model)")  
print("    ‚Ä¢ Level 3: DecoderLayer (Custom Layer with Dense layers inside)")
print("  ‚Ä¢ Level 2: Projection, ClassificationHead (Custom Layers)")
print("\nThis is STANDARD Keras pattern - Model ‚Üí Model ‚Üí Layer")
print("="*80)


## 3.7 Verifying the Encoder-Decoder Connection

Let's verify that the Encoder and Decoder are properly connected in the computational graph:


In [None]:
print("="*100)
print("ENCODER-DECODER CONNECTION IN SIGNBART")
print("="*100)

# Extract encoder and decoder from the full model
encoder = None
decoder = None

for layer in model.layers:
    if 'encoder' in layer.name.lower() and hasattr(layer, 'summary'):
        encoder = layer
    if 'decoder' in layer.name.lower() and hasattr(layer, 'summary'):
        decoder = layer

if encoder and decoder:
    print("\n‚úì Found Encoder and Decoder as functional Models within SignBART")
    print(f"\nEncoder: {encoder.name}")
    print(f"  Inputs: {[inp.name for inp in encoder.input.values()] if isinstance(encoder.input, dict) else [inp.name for inp in encoder.input]}")
    print(f"  Output: {encoder.output.name}, shape: {encoder.output.shape}")
    
    print(f"\nDecoder: {decoder.name}")
    print(f"  Inputs: {[inp.name for inp in decoder.input.values()] if isinstance(decoder.input, dict) else [inp.name for inp in decoder.input]}")
    print(f"  Output: {decoder.output.name}, shape: {decoder.output.shape}")
    
    print("\n" + "="*100)
    print("CONNECTION FLOW:")
    print("="*100)
    print("1. Input keypoints ‚Üí Projection ‚Üí x_embed, y_embed")
    print("2. x_embed ‚Üí ENCODER ‚Üí encoder_hidden_states")
    print("3. encoder_hidden_states ‚Üí DECODER (cross-attention) ‚Üê CONNECTED!")
    print("4. y_embed ‚Üí DECODER (self-attention)")
    print("5. decoder_outputs ‚Üí ExtractLastToken ‚Üí ClassificationHead ‚Üí logits")
    print("="*100)
    
    print("\n‚úì Encoder and Decoder are CONNECTED in the computational graph!")
    print("‚úì Decoder receives encoder outputs via 'encoder_hidden_states' input")
    print("‚úì This is the standard Transformer encoder-decoder architecture")
else:
    print("‚ö† Could not find encoder/decoder models")

print("\n" + "="*100)
print("HIERARCHY:")
print("="*100)
print("SignBART (Functional Model)")
print("‚îú‚îÄ‚îÄ Projection (Custom Layer)")
print("‚îú‚îÄ‚îÄ Encoder (Functional Model) ‚Üê can call .summary()")
print("‚îÇ   ‚îî‚îÄ‚îÄ encoder_layer_0, encoder_layer_1... (Custom Layers)")
print("‚îÇ       ‚îî‚îÄ‚îÄ fc1, fc2 (Dense layers) ‚Üê can annotate for QAT")
print("‚îú‚îÄ‚îÄ Decoder (Functional Model) ‚Üê can call .summary()")
print("‚îÇ   ‚îî‚îÄ‚îÄ decoder_layer_0, decoder_layer_1... (Custom Layers)")
print("‚îÇ       ‚îî‚îÄ‚îÄ fc1, fc2 (Dense layers) ‚Üê can annotate for QAT")
print("‚îú‚îÄ‚îÄ ExtractLastValidToken (Custom Layer)")
print("‚îî‚îÄ‚îÄ ClassificationHead (Custom Layer)")
print("    ‚îî‚îÄ‚îÄ out_proj (Dense layer) ‚Üê can annotate for QAT")
print("="*100)


## 3.8 Full Model Summary

Now let's see how everything connects together in the complete SignBART model:


In [None]:
print("="*80)
print("FULL SIGNBART MODEL SUMMARY")
print("="*80)
model.summary(line_length=100)

print("\n" + "="*80)
print("KEY OBSERVATIONS")
print("="*80)
print("‚úì Encoder and Decoder are now MODELS (not just Layers)")
print("‚úì Each component can be inspected independently")
print("‚úì Perfect for Netron visualization")
print("‚úì Better for selective quantization strategies")
print("‚úì More standard Keras architecture")
print("="*80)


## 4. Inspect Nested Layers


In [None]:
# Get nested layers
projection = model.get_layer('projection')
encoder = model.get_layer('encoder')
decoder = model.get_layer('decoder')
clf_head = model.get_layer('classification_head')

def print_layer_info(layer, name):
    """Print layer information - works for both Model and Layer."""
    print("\n" + "="*80)
    print(name)
    print("="*80)
    
    if hasattr(layer, 'summary'):
        # It's a Model, can call summary
        layer.summary()
    else:
        # It's a Layer, print manually
        print(f"Type: {layer.__class__.__name__}")
        if hasattr(layer, 'trainable_variables'):
            total_params = sum([tf.size(w).numpy() for w in layer.trainable_variables])
            print(f"Trainable parameters: {total_params:,}")
            print("\nWeights:")
            for w in layer.trainable_variables:
                print(f"  - {w.name}: {w.shape} ({tf.size(w).numpy():,} params)")

print_layer_info(projection, "PROJECTION LAYER")
print_layer_info(encoder, "ENCODER")
print_layer_info(decoder, "DECODER")
print_layer_info(clf_head, "CLASSIFICATION HEAD")


## 5. Find All Dense Layers (What Can Be Quantized)


## 6. QAT Approaches - Finding What Works

Let's test different QAT approaches to see which one works with our nested architecture.


In [None]:
print("="*100)
print("APPROACH 1: Using quantize_model() directly (simplest)")
print("="*100)

try:
    # Try the simplest approach - let tfmot handle everything
    qat_model_simple = tfmot.quantization.keras.quantize_model(model)
    print("‚úì SUCCESS! quantize_model() works directly")
    print(f"  Total params: {qat_model_simple.count_params():,}")
    
    # Test forward pass
    test_output = qat_model_simple(dummy_input, training=False)
    print(f"‚úì Forward pass works, output shape: {test_output.shape}")
    
except Exception as e:
    print(f"‚ùå FAILED: {type(e).__name__}: {e}")
    qat_model_simple = None

print("\n" + "="*100)


In [None]:
print("="*100)
print("APPROACH 2: Using quantize_annotate_model() + quantize_apply()")
print("="*100)

try:
    # Annotate all quantizable layers automatically
    annotated_model = tfmot.quantization.keras.quantize_annotate_model(model)
    print("‚úì Model annotated")
    
    # Apply quantization
    qat_model_annotate = tfmot.quantization.keras.quantize_apply(annotated_model)
    print("‚úì SUCCESS! quantize_annotate_model() + quantize_apply() works")
    print(f"  Total params: {qat_model_annotate.count_params():,}")
    
    # Test forward pass
    test_output = qat_model_annotate(dummy_input, training=False)
    print(f"‚úì Forward pass works, output shape: {test_output.shape}")
    
except Exception as e:
    print(f"‚ùå FAILED: {type(e).__name__}: {e}")
    qat_model_annotate = None

print("\n" + "="*100)


In [None]:
print("="*100)
print("APPROACH 3: Manual recursive annotation (for selective quantization)")
print("="*100)
print("This is needed when you want to quantize ONLY specific layers")
print("(e.g., only Dense layers in FFN, not in attention)")

try:
    def annotate_layer_recursive(layer):
        """
        Recursively annotate layers.
        For custom layers with nested Dense layers, we need to traverse into them.
        """
        # If it's a Dense layer, annotate it
        if isinstance(layer, keras.layers.Dense):
            print(f"  Annotating: {layer.name}")
            return tfmot.quantization.keras.quantize_annotate_layer(layer)
        
        # If it's a Model (functional), clone it recursively
        if isinstance(layer, keras.Model):
            print(f"  Traversing Model: {layer.name}")
            return keras.models.clone_model(
                layer,
                clone_function=annotate_layer_recursive
            )
        
        # For custom Layer subclasses (like EncoderLayer, Projection), 
        # tfmot will NOT automatically traverse into them with clone_model
        # We return them as-is, but tfmot's quantize_annotate_model DOES traverse
        return layer
    
    print("\nStep 1: Annotating model...")
    # Use quantize_annotate_model which handles nested layers better
    annotated = tfmot.quantization.keras.quantize_annotate_model(model)
    
    print("\nStep 2: Applying quantization...")
    qat_model_manual = tfmot.quantization.keras.quantize_apply(annotated)
    
    print("‚úì SUCCESS! Manual approach works")
    print(f"  Total params: {qat_model_manual.count_params():,}")
    
    # Test forward pass
    test_output = qat_model_manual(dummy_input, training=False)
    print(f"‚úì Forward pass works, output shape: {test_output.shape}")
    
except Exception as e:
    print(f"‚ùå FAILED: {type(e).__name__}: {e}")
    import traceback
    traceback.print_exc()
    qat_model_manual = None

print("\n" + "="*100)


In [None]:
print("="*100)
print("SUMMARY: Which QAT Approach Works?")
print("="*100)

approaches = [
    ("Approach 1: quantize_model()", qat_model_simple),
    ("Approach 2: quantize_annotate_model() + quantize_apply()", qat_model_annotate),
    ("Approach 3: Manual recursive", qat_model_manual),
]

for name, model_obj in approaches:
    status = "‚úì WORKS" if model_obj is not None else "‚ùå FAILED"
    print(f"{status}: {name}")

print("\n" + "="*100)
print("RECOMMENDATION:")
print("="*100)

if qat_model_simple is not None:
    print("‚úì Use Approach 1: tfmot.quantization.keras.quantize_model(model)")
    print("  - Simplest (one line!)")
    print("  - Automatically handles nested layers")
    print("  - Quantizes all quantizable layers (Dense, Conv2D, etc.)")
    selected_qat_model = qat_model_simple
elif qat_model_annotate is not None:
    print("‚úì Use Approach 2: quantize_annotate_model() + quantize_apply()")
    print("  - Simple (two lines)")
    print("  - Good for when quantize_model() doesn't work")
    selected_qat_model = qat_model_annotate
else:
    print("‚ö† Approaches 1 & 2 failed - you'll need custom logic")
    print("  This means tfmot doesn't recognize your layer structure")
    print("  You may need to refactor custom layers or use a custom QuantizeConfig")
    selected_qat_model = None

print("="*100)


## 6.1 Why `clone_model()` Alone Doesn't Work

**The Problem:**
```python
# This FAILS ‚ùå
def annotate_dense(layer):
    if isinstance(layer, keras.layers.Dense):
        return tfmot.quantization.keras.quantize_annotate_layer(layer)
    return layer

annotated = keras.models.clone_model(model, clone_function=annotate_dense)
qat_model = tfmot.quantization.keras.quantize_apply(annotated)  # Error: no layers annotated!
```

**Why it fails:**
- `clone_model()` with `clone_function` only applies to **top-level layers**
- Dense layers are **inside** custom layers (EncoderLayer, Projection, etc.)
- `clone_function` never sees the nested Dense layers!

**The Solution:**
- Use `tfmot.quantization.keras.quantize_model()` - it traverses nested layers automatically
- Or use `quantize_annotate_model()` - also handles nesting

**That's why the test_export.py code was so complicated** - it was trying to manually solve this nesting problem. But we don't need to - tfmot's built-in functions handle it!


## ‚ö†Ô∏è CRITICAL FINDING: Model‚ÜíModel Breaks QAT!

**The Error:**
```
ValueError: Quantizing a keras Model inside another keras Model is not supported.
```

**What Happened:**
- We made Encoder and Decoder into **Functional Models** for better visualization
- This created: SignBART (Model) ‚Üí Encoder (Model) ‚Üí Decoder (Model)
- **TensorFlow Model Optimization does NOT support nested Models!**

**The Fix:**
- ‚úÖ **REVERTED** Encoder and Decoder back to `layers.Layer` subclasses
- Now: SignBART (Model) ‚Üí Encoder (Layer) ‚Üí Decoder (Layer) ‚úì

**Trade-offs:**
- ‚ùå Lost: Individual `.summary()` for Encoder/Decoder
- ‚ùå Lost: Better Netron visualization
- ‚úÖ Gained: **QAT compatibility** (most important!)
- ‚úÖ Gained: Can still access nested Dense layers for selective quantization

**Architecture that works with QAT:**
```
SignBART (Functional Model)
‚îú‚îÄ‚îÄ Projection (Layer)
‚îú‚îÄ‚îÄ Encoder (Layer)           ‚Üê Layer, not Model!
‚îÇ   ‚îî‚îÄ‚îÄ EncoderLayer (Layer)
‚îÇ       ‚îú‚îÄ‚îÄ fc1 (Dense)
‚îÇ       ‚îî‚îÄ‚îÄ fc2 (Dense)
‚îú‚îÄ‚îÄ Decoder (Layer)           ‚Üê Layer, not Model!
‚îÇ   ‚îî‚îÄ‚îÄ DecoderLayer (Layer)
‚îÇ       ‚îú‚îÄ‚îÄ fc1 (Dense)
‚îÇ       ‚îî‚îÄ‚îÄ fc2 (Dense)
‚îî‚îÄ‚îÄ ClassificationHead (Layer)
```


In [None]:
print("="*100)
print("REBUILDING MODEL WITH LAYER-BASED ENCODER/DECODER")
print("="*100)

# Rebuild the model (it will now use Layer-based Encoder/Decoder)
from importlib import reload
import model_functional
reload(model_functional)

model = model_functional.build_signbart_functional_with_dict_inputs(config)

# Test it still works
output = model(dummy_input, training=False)
print(f"‚úì Model rebuilt successfully")
print(f"  Output shape: {output.shape}")

# Check the architecture
print("\n‚úì Verifying architecture:")
for layer in model.layers:
    layer_type = "Model" if isinstance(layer, keras.Model) else "Layer"
    print(f"  - {layer.name:30s} ({layer_type})")
    
print("\n‚úì No nested Models - QAT should work now!")
print("="*100)


In [None]:
print("="*100)
print("TESTING QAT AGAIN (with Layer-based Encoder/Decoder)")
print("="*100)

print("\nApproach 1: quantize_model() - ONE LINE!")
try:
    qat_model = tfmot.quantization.keras.quantize_model(model)
    print("‚úì SUCCESS!")
    print(f"  Total params: {qat_model.count_params():,}")
    
    # Test forward pass
    test_output = qat_model(dummy_input, training=False)
    print(f"‚úì Forward pass works, output shape: {test_output.shape}")
    
    print("\nüéâ QAT WORKS with Layer-based architecture!")
    
except Exception as e:
    print(f"‚ùå FAILED: {type(e).__name__}: {str(e)[:200]}")
    qat_model = None

print("="*100)


## üìö Lessons Learned

### ‚úÖ What Works for QAT:
1. **Architecture:** Model ‚Üí Layer ‚Üí Layer (NO nested Models!)
2. **Method:** `tfmot.quantization.keras.quantize_model(model)` - one line!
3. **Simplicity:** tfmot handles nested layers automatically

### ‚ùå What Breaks QAT:
1. **Nested Models:** Model ‚Üí Model ‚Üí Layer
2. **Manual clone_model():** Doesn't traverse nested custom layers
3. **Complex workarounds:** Not needed!

### üéØ Final Answer to "Is test_export.py code correct?"

**Answer:** No - it's **too complicated** and uses an outdated approach.

**The simple version:**
```python
# That's it! One line!
qat_model = tfmot.quantization.keras.quantize_model(model)
```

**What it does automatically:**
- Finds all Dense layers (even nested ones)
- Annotates them for quantization
- Skips non-quantizable layers
- Handles custom Layer subclasses properly

**When you need more control:**
Use `DefaultN8QuantizeConfig` to customize what gets quantized (we'll explore this next).


## 6.2 Working Strategy: Quantize Only Dense Layers

We‚Äôll annotate only the `Dense` layers (fc1/fc2, classification head, etc.) while everything else (Projection, Attention, LayerNorm, etc.) stays FP32.


In [None]:
print("="*100)
print("SELECTIVE QAT: Annotate ONLY Dense layers")
print("="*100)

# Import custom layers for clone_model scope
from layers import Projection, ClassificationHead, PositionalEmbedding
from encoder import Encoder, EncoderLayer
from decoder import Decoder, DecoderLayer
from attention import SelfAttention, CrossAttention, CausalSelfAttention
from model_functional import ExtractLastValidToken

custom_objects = {
    'Projection': Projection,
    'ClassificationHead': ClassificationHead,
    'PositionalEmbedding': PositionalEmbedding,
    'Encoder': Encoder,
    'EncoderLayer': EncoderLayer,
    'Decoder': Decoder,
    'DecoderLayer': DecoderLayer,
    'SelfAttention': SelfAttention,
    'CrossAttention': CrossAttention,
    'CausalSelfAttention': CausalSelfAttention,
    'ExtractLastValidToken': ExtractLastValidToken,
}

# Clone function that only annotates Dense layers

def annotate_only_dense(layer):
    if isinstance(layer, keras.layers.Dense):
        print(f"  Annotating Dense layer: {layer.name}")
        return tfmot.quantization.keras.quantize_annotate_layer(layer)
    return layer

print("\nStep 1: Annotating Dense layers (others untouched)...")
with keras.utils.custom_object_scope(custom_objects):
    annotated_dense_model = keras.models.clone_model(
        model,
        clone_function=annotate_only_dense
    )
print("‚úì Annotation complete")

print("\nStep 2: Applying quantization...")
with keras.utils.custom_object_scope({**custom_objects}):
    qat_dense_model = tfmot.quantization.keras.quantize_apply(annotated_dense_model)
print("‚úì QAT model created with Dense-only quantization")

# Quick sanity check
output = qat_dense_model(dummy_input, training=False)
print(f"\n‚úì Forward pass OK, output shape: {output.shape}")

print("="*100)
print("RESULT")
print("="*100)
print("- Dense layers (fc1/fc2, classification head) are quantized")
print("- Projection, attention, layernorm, etc. remain FP32")
print("- No custom QuantizeConfig needed")
print("="*100)


In [None]:
def find_all_dense_layers(layer, path="", indent=0):
    """Recursively find all Dense layers in a model or layer."""
    prefix = "  " * indent
    
    if isinstance(layer, keras.layers.Dense):
        params = sum([tf.size(w).numpy() for w in layer.trainable_variables])
        print(f"{prefix}‚úì Dense: {path}/{layer.name} ({params:,} params)")
        return 1
    
    count = 0
    
    # First check if layer has .layers attribute (standard Keras)
    if hasattr(layer, 'layers') and layer.layers:
        print(f"{prefix}üì¶ {layer.__class__.__name__}: {path}/{layer.name}")
        for sublayer in layer.layers:
            sublayer_path = f"{path}/{layer.name}" if path else layer.name
            count += find_all_dense_layers(sublayer, sublayer_path, indent + 1)
    
    # Also check for Dense layers stored as attributes (custom layers)
    # This is needed for Projection, ClassificationHead, etc.
    else:
        print(f"{prefix}üì¶ {layer.__class__.__name__}: {path}/{layer.name}")
        for attr_name in dir(layer):
            if attr_name.startswith('_'):
                continue
            try:
                attr = getattr(layer, attr_name)
                if isinstance(attr, keras.layers.Layer):
                    sublayer_path = f"{path}/{layer.name}" if path else layer.name
                    count += find_all_dense_layers(attr, sublayer_path, indent + 1)
            except:
                pass
    
    return count

print("\nFinding all Dense layers in the model...")
print("="*80)
total_dense = find_all_dense_layers(model)
print("="*80)
print(f"\nTotal Dense layers found: {total_dense}")

# Also show what layers can be quantized
print("\n" + "="*80)
print("QUANTIZABLE LAYERS SUMMARY")
print("="*80)
print("These Dense layers can be selectively quantized:")
print("  ‚Ä¢ proj_x1, proj_y1 - Projection layers (keypoint embedding)")
print("  ‚Ä¢ q_proj, k_proj, v_proj, out_proj - Attention projections")
print("  ‚Ä¢ fc1, fc2 - Feed-forward network (FFN)")
print("  ‚Ä¢ out_proj (in clf_head) - Classification output")


## 6. QAT Strategy 1: Quantize ALL Dense Layers


In [None]:
print("\n" + "="*80)
print("STRATEGY 1: Quantize ALL Dense Layers")
print("="*80)

def annotate_all_dense(layer):
    """Annotate all Dense layers for quantization."""
    if isinstance(layer, keras.layers.Dense):
        print(f"  ‚úì Annotating: {layer.name}")
        return tfmot.quantization.keras.quantize_annotate_layer(layer)
    return layer

print("\nAnnotating Dense layers...")
annotated_model_1 = keras.models.clone_model(
    model,
    clone_function=annotate_all_dense
)

print("\nApplying quantization...")
qat_model_1 = tfmot.quantization.keras.quantize_apply(annotated_model_1)

print("\n‚úì QAT Model 1 created")
print(f"  Total params: {qat_model_1.count_params():,}")

# Test inference
qat_output_1 = qat_model_1(dummy_input, training=False)
print(f"  ‚úì Inference works! Output shape: {qat_output_1.shape}")


## 7. QAT Strategy 2: Quantize Only FFN (fc1/fc2)


In [None]:
print("\n" + "="*80)
print("STRATEGY 2: Quantize Only FFN (fc1/fc2) Dense Layers")
print("="*80)

def annotate_ffn_only(layer):
    """Annotate only FFN Dense layers (fc1, fc2)."""
    if isinstance(layer, keras.layers.Dense):
        # Only quantize fc1 and fc2 (feed-forward network)
        if 'fc1' in layer.name or 'fc2' in layer.name:
            print(f"  ‚úì Annotating: {layer.name}")
            return tfmot.quantization.keras.quantize_annotate_layer(layer)
        else:
            print(f"  ‚äó Skipping: {layer.name} (not FFN)")
    return layer

print("\nAnnotating FFN Dense layers only...")
annotated_model_2 = keras.models.clone_model(
    model,
    clone_function=annotate_ffn_only
)

print("\nApplying quantization...")
qat_model_2 = tfmot.quantization.keras.quantize_apply(annotated_model_2)

print("\n‚úì QAT Model 2 created")
print(f"  Total params: {qat_model_2.count_params():,}")

# Test inference
qat_output_2 = qat_model_2(dummy_input, training=False)
print(f"  ‚úì Inference works! Output shape: {qat_output_2.shape}")


## 8. Compare Outputs - Which Strategy Maintains Accuracy?


In [None]:
print("\n" + "="*80)
print("OUTPUT COMPARISON - Numerical Equivalence Check")
print("="*80)

# Generate test input
test_input = {
    'keypoints': tf.random.normal((2, 15, num_keypoints, 2)),
    'attention_mask': tf.ones((2, 15))
}

# Get outputs
base_output = model(test_input, training=False).numpy()
qat1_output = qat_model_1(test_input, training=False).numpy()
qat2_output = qat_model_2(test_input, training=False).numpy()

# Compare
print(f"\n{'Comparison':<25} {'Max Diff':<15} {'Mean Diff':<15} {'Status'}")
print("-" * 80)

def compare_outputs(name, qat_out, base_out):
    max_diff = np.abs(qat_out - base_out).max()
    mean_diff = np.abs(qat_out - base_out).mean()
    status = "‚úì Good" if max_diff < 1e-3 else "‚ö† Large diff"
    print(f"{name:<25} {max_diff:<15.6e} {mean_diff:<15.6e} {status}")

compare_outputs("QAT-1 vs Base", qat1_output, base_output)
compare_outputs("QAT-2 vs Base", qat2_output, base_output)

print("\n‚ö† Note: Small differences are expected due to fake quantization")
print("   These simulate INT8 quantization effects during training")


## 9.5 IMPORTANT: Correct TFLite Conversion

**‚ö†Ô∏è Common Mistake:**
```python
# WRONG ‚ùå - Don't do this!
converter = tf.lite.TFLiteConverter.from_keras_model(model)
```

**Why it's wrong:**
1. Doesn't handle dict inputs properly
2. Can't handle dynamic shapes (None, None, ...)
3. TFLite needs FIXED input shapes

**‚úÖ Correct approach (from main_functional.py):**


In [None]:
print("="*100)
print("CORRECT TFLITE CONVERSION")
print("="*100)

# Step 1: Define FIXED input shape
MAX_SEQ_LEN = 64  # Fixed sequence length for TFLite (covers 99%+ of data)
num_keypoints = len(config['joint_idx'])

print(f"\nFixed input shape for TFLite:")
print(f"  keypoints: [1, {MAX_SEQ_LEN}, {num_keypoints}, 2]")
print(f"  attention_mask: [1, {MAX_SEQ_LEN}]")

# Step 2: Create concrete function with input signature
@tf.function(input_signature=[
    {
        'keypoints': tf.TensorSpec(shape=[1, MAX_SEQ_LEN, num_keypoints, 2], dtype=tf.float32),
        'attention_mask': tf.TensorSpec(shape=[1, MAX_SEQ_LEN], dtype=tf.float32)
    }
])
def model_predict(inputs):
    return model(inputs, training=False)

print("\n‚úì Created concrete function with fixed input signature")

# Step 3: Convert from concrete function (NOT from_keras_model!)
converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [model_predict.get_concrete_function()]
)
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,
    tf.lite.OpsSet.SELECT_TF_OPS  # Allows select TF ops for unsupported operations
]

print("\n‚úì Converter configured")
print("  Supported ops: TFLITE_BUILTINS + SELECT_TF_OPS")

# Step 4: Convert
print("\nConverting to TFLite...")
tflite_model = converter.convert()

# Step 5: Save
fp32_tflite_path = "test_fp32.tflite"
with open(fp32_tflite_path, 'wb') as f:
    f.write(tflite_model)

tflite_size_mb = len(tflite_model) / (1024**2)
print(f"‚úì TFLite model saved: {fp32_tflite_path}")
print(f"  Size: {tflite_size_mb:.2f} MB")

print("\n" + "="*100)
print("KEY DIFFERENCES:")
print("="*100)
print("‚ùå from_keras_model(model)        ‚Üí Fails with dict inputs + dynamic shapes")
print("‚úÖ from_concrete_functions([...]) ‚Üí Works with fixed signature")
print("\n‚ùå Dynamic shape: (None, None, 90, 2)  ‚Üí TFLite can't handle")
print(f"‚úÖ Fixed shape: (1, {MAX_SEQ_LEN}, {num_keypoints}, 2)      ‚Üí TFLite compatible")
print("="*100)


## 9. Test Training - Can We Actually Train With QAT?


In [None]:
print("\n" + "="*80)
print("TESTING QAT MODEL TRAINING")
print("="*80)

# Pick strategy 2 (conservative)
test_qat_model = qat_model_2

# Compile
print("\nCompiling QAT model...")
test_qat_model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=2e-4),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)
print("‚úì Model compiled")

# Create fake training data
print("\nCreating fake training batch...")
fake_batch_size = 4
fake_seq_len = 20

fake_inputs = {
    'keypoints': tf.random.normal((fake_batch_size, fake_seq_len, num_keypoints, 2)),
    'attention_mask': tf.ones((fake_batch_size, fake_seq_len))
}
fake_labels = tf.random.uniform((fake_batch_size,), minval=0, maxval=config['num_labels'], dtype=tf.int32)

print("\nTesting training step...")
history = test_qat_model.fit(
    fake_inputs,
    fake_labels,
    batch_size=fake_batch_size,
    epochs=2,
    verbose=1
)

print("\n‚úì‚úì‚úì QAT MODEL TRAINING WORKS! ‚úì‚úì‚úì")
print("\nYou can now integrate this strategy into main_functional.py")


## 10. Code Template - Copy This Into main_functional.py

Once you've chosen your strategy, use this template to integrate QAT into training:


In [None]:
print("""
CODE TO ADD TO main_functional.py:
================================================================================

# 1. Add import at top:
import tensorflow_model_optimization as tfmot

# 2. Add argument to parser:
parser.add_argument("--qat", action="store_true",
                    help="Enable Quantization-Aware Training")

# 3. After building model, before compiling (around line 110):
if args.qat:
    print("\\nApplying QAT...")
    
    def annotate_for_qat(layer):
        # Strategy 2: FFN only (conservative, best accuracy)
        if isinstance(layer, keras.layers.Dense):
            if 'fc1' in layer.name or 'fc2' in layer.name:
                return tfmot.quantization.keras.quantize_annotate_layer(layer)
        return layer
    
    annotated_model = keras.models.clone_model(
        model,
        clone_function=annotate_for_qat
    )
    model = tfmot.quantization.keras.quantize_apply(annotated_model)
    print("‚úì QAT applied (FFN layers quantized)")

# Then continue with model.compile() as usual...

================================================================================

USAGE:
python train_loso_functional.py \\
    --config_path configs/arabic-asl-90kpts.yaml \\
    --base_data_path ~/signbart_tf/data/arabic-asl-90kpts \\
    --holdout_only user01 \\
    --epochs 2 \\
    --lr 2e-4 \\
    --qat  # <-- Add this flag!
""")
