In [1]:
import ai_edge_torch
import numpy as np
import torch
import torchvision
import pickle
import os
import json

I0000 00:00:1763022133.174622 1422696 cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
print(torch.cuda.is_available())

True


In [3]:
print(torch.cuda.get_device_name(0))

Quadro RTX 4000


In [4]:
import sys
from pathlib import Path
import yaml

from model import SignBart

In [5]:
# 3. Load config
config_path = "configs/arabic-asl.yaml"
with open(config_path, "r") as f:
    config = yaml.safe_load(f)

# 4. Create model
signbart = SignBart(config)
signbart.eval()

SignBart(
  (encoder): Encoder(
    (embed_positions): PositionalEmbedding(258, 144)
    (layers): ModuleList(
      (0-1): 2 x EncoderLayer(
        (self_attn): SelfAttention(
          (k_proj): Linear(in_features=144, out_features=144, bias=True)
          (v_proj): Linear(in_features=144, out_features=144, bias=True)
          (q_proj): Linear(in_features=144, out_features=144, bias=True)
          (out_proj): Linear(in_features=144, out_features=144, bias=True)
        )
        (self_attn_layer_norm): LayerNorm((144,), eps=1e-05, elementwise_affine=True)
        (activation_fn): GELU(approximate='none')
        (fc1): Linear(in_features=144, out_features=144, bias=True)
        (fc2): Linear(in_features=144, out_features=144, bias=True)
        (final_layer_norm): LayerNorm((144,), eps=1e-05, elementwise_affine=True)
      )
    )
    (layernorm_embedding): LayerNorm((144,), eps=1e-05, elementwise_affine=True)
  )
  (decoder): Decoder(
    (embed_positions): PositionalEmbedding(

In [6]:
# Load checkpoint
checkpoint_path = "checkpoints_arabic_asl_LOSO_user08/checkpoints_79_final.pth"  
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)

# The checkpoint contains: {'model': ..., 'optimizer': ..., 'epoch': ...}
# Extract just the model state dict
if 'model' in checkpoint:
    state_dict = checkpoint['model']
    print(f"‚úÖ Loaded checkpoint from epoch {checkpoint.get('epoch', 'unknown')}")
else:
    state_dict = checkpoint

# Load into model
ret = signbart.load_state_dict(state_dict, strict=False)

if ret.missing_keys:
    print(f"‚ö†Ô∏è  Missing keys: {len(ret.missing_keys)} keys")
if ret.unexpected_keys:
    print(f"‚ö†Ô∏è  Unexpected keys: {len(ret.unexpected_keys)} keys")
    
if not ret.missing_keys and not ret.unexpected_keys:
    print("‚úÖ All weights loaded successfully!")

‚úÖ Loaded checkpoint from epoch 79
‚úÖ All weights loaded successfully!


In [7]:
# ============================================================================
# 3. Load Real Test Sample
# ============================================================================
print("\nüì¶ Loading real test sample...")

# Load the actual pickle file
sample_path = ("data/arabic-asl_LOSO_user11/test/G10/user11_G10_R10.pkl")

if not os.path.exists(sample_path):
    print(f"‚ùå Sample not found: {sample_path}")
    # Try relative path
    sample_path = "data/arabic-asl_LOSO_user11/test/G10/user11_G10_R10.pkl"
    if os.path.exists(sample_path):
        print(f"‚úÖ Found at relative path: {sample_path}")
    else:
        print("Please provide correct path to the sample file")
else:
    print(f"‚úÖ Sample found: {sample_path}")

# Load the pickle file
with open(sample_path, "rb") as f:
    sample = pickle.load(f)

print(f"\nüìä Sample contents:")
print(f"   Keys: {sample.keys()}")
print(f"   Class: {sample['class']}")
print(f"   Keypoints shape: {np.array(sample['keypoints']).shape}")

# Extract keypoints and label
keypoints = np.array(sample['keypoints'])[:, :, :2]  # (T, K, 2) - x,y coords only
sample_class = sample['class']  # 'G10'


üì¶ Loading real test sample...
‚úÖ Sample found: data/arabic-asl_LOSO_user11/test/G10/user11_G10_R10.pkl

üìä Sample contents:
   Keys: dict_keys(['keypoints', 'class', 'user', 'num_keypoints', 'keypoint_structure', 'video_file', 'num_frames', 'original_path'])
   Class: G10
   Keypoints shape: (45, 100, 3)


In [8]:
# Extract keypoints and label
keypoints = np.array(sample['keypoints'])[:, :, :2]  # (T, K, 2) - x,y coords only
sample_class = sample['class']  # 'G10'

# Load label mappings from the dataset directory
dataset_root =  "data/arabic-asl_LOSO_user11"

In [9]:
with open(f"{dataset_root}/label2id.json", 'r') as f:
    label2id = json.load(f)

with open(f"{dataset_root}/id2label.json", 'r') as f:
    id2label = json.load(f)

sample_label_id = label2id[sample_class]

print(f"   Label: {sample_class} ‚Üí ID: {sample_label_id}")
print(f"   Label mapping: {id2label}")

   Label: G10 ‚Üí ID: 9
   Label mapping: {'0': 'G01', '1': 'G02', '2': 'G03', '3': 'G04', '4': 'G05', '5': 'G06', '6': 'G07', '7': 'G08', '8': 'G09', '9': 'G10'}


In [10]:
# Filter keypoints to the joints specified in config (if needed)
if 'joint_idxs' in config and config['joint_idxs'] is not None:
    flat_joint_idxs = []
    for group in config['joint_idxs']:
        flat_joint_idxs.extend(group)
    flat_joint_idxs = sorted(flat_joint_idxs)
    keypoints = keypoints[:, flat_joint_idxs, :]
    print(f"   Filtered to {len(flat_joint_idxs)} keypoints")

# Clip keypoints to [0, 1]
keypoints = np.clip(keypoints, 0, 1)

# Clip to max 64 frames (as done in dataset.py)
if keypoints.shape[0] > 64:
    # Simple uniform sampling
    indices = np.linspace(0, keypoints.shape[0] - 1, 64, dtype=int)
    keypoints = keypoints[indices]
    print(f"   Clipped from {keypoints.shape[0]} to 64 frames")

print(f"   Final keypoints shape: {keypoints.shape}")

# Normalize keypoints (grouped normalization as in dataset.py)
def normalize_keypoints(keypoints, joint_idxs):
    """Normalize keypoints by groups."""
    if joint_idxs is None:
        return keypoints
    
    flat_joint_idxs = []
    for group in joint_idxs:
        flat_joint_idxs.extend(group)
    idx_to_pos = {idx: pos for pos, idx in enumerate(sorted(flat_joint_idxs))}
    
    for i in range(keypoints.shape[0]):  # for each frame
        for group in joint_idxs:  # for each group
            filtered_positions = [idx_to_pos[idx] for idx in group if idx in idx_to_pos]
            
            if len(filtered_positions) > 0:
                group_keypoints = keypoints[i, filtered_positions, :]
                
                # Normalize the group
                x_coords = group_keypoints[:, 0]
                y_coords = group_keypoints[:, 1]
                
                min_x, min_y = np.min(x_coords), np.min(y_coords)
                max_x, max_y = np.max(x_coords), np.max(y_coords)
                
                w = max_x - min_x
                h = max_y - min_y
                
                if w > h:
                    delta_x = 0.05 * w
                    delta_y = delta_x + ((w - h) / 2)
                else:
                    delta_y = 0.05 * h
                    delta_x = delta_y + ((h - w) / 2)
                
                s_point = [max(0, min(min_x - delta_x, 1)), max(0, min(min_y - delta_y, 1))]
                e_point = [max(0, min(max_x + delta_x, 1)), max(0, min(max_y + delta_y, 1))]
                
                # Normalize
                if (e_point[0] - s_point[0]) != 0.0:
                    group_keypoints[:, 0] = (group_keypoints[:, 0] - s_point[0]) / (e_point[0] - s_point[0])
                if (e_point[1] - s_point[1]) != 0.0:
                    group_keypoints[:, 1] = (group_keypoints[:, 1] - s_point[1]) / (e_point[1] - s_point[1])
                
                keypoints[i, filtered_positions, :] = group_keypoints
    
    return keypoints

keypoints = normalize_keypoints(keypoints, config.get('joint_idxs'))
print(f"   Normalized keypoints")

# Convert to torch tensors
sample_keypoints = torch.from_numpy(keypoints).float()
sample_label = torch.tensor(sample_label_id, dtype=torch.long)

print(f"\n‚úÖ Sample prepared:")
print(f"   Keypoints: {sample_keypoints.shape} (T={sample_keypoints.shape[0]}, K={sample_keypoints.shape[1]})")
print(f"   Label: {sample_label.item()} ({sample_class})")

   Final keypoints shape: (45, 100, 2)
   Normalized keypoints

‚úÖ Sample prepared:
   Keypoints: torch.Size([45, 100, 2]) (T=45, K=100)
   Label: 9 (G10)


In [11]:
import torch
import torch.nn as nn

class SignBartInference(nn.Module):
    """
    Wrapper for SignBart that only returns logits (no loss)
    This is needed for TFLite conversion since ai_edge_torch doesn't handle None values
    """
    def __init__(self, signbart_model):
        super().__init__()
        self.model = signbart_model
    
    def forward(self, keypoints, attention_mask):
        # Call original model without labels
        loss, logits = self.model(keypoints, attention_mask, labels=None)
        # Return only logits (not loss which is None)
        return logits

In [12]:
# 5. Wrap model to return only logits
signbart_inference = SignBartInference(signbart)
signbart_inference.eval()

SignBartInference(
  (model): SignBart(
    (encoder): Encoder(
      (embed_positions): PositionalEmbedding(258, 144)
      (layers): ModuleList(
        (0-1): 2 x EncoderLayer(
          (self_attn): SelfAttention(
            (k_proj): Linear(in_features=144, out_features=144, bias=True)
            (v_proj): Linear(in_features=144, out_features=144, bias=True)
            (q_proj): Linear(in_features=144, out_features=144, bias=True)
            (out_proj): Linear(in_features=144, out_features=144, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((144,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELU(approximate='none')
          (fc1): Linear(in_features=144, out_features=144, bias=True)
          (fc2): Linear(in_features=144, out_features=144, bias=True)
          (final_layer_norm): LayerNorm((144,), eps=1e-05, elementwise_affine=True)
        )
      )
      (layernorm_embedding): LayerNorm((144,), eps=1e-05, elementwise_affine=True)
   

In [13]:
# ============================================================================
# 4. Test PyTorch Inference
# ============================================================================
print("\nüîÑ Running PyTorch inference...")

# Prepare input
keypoints_input = sample_keypoints.unsqueeze(0)  # Add batch dimension: (1, T, K, 2)
attention_mask = torch.ones(1, keypoints_input.shape[1], dtype=torch.long)  # (1, T)
labels = sample_label.unsqueeze(0)  # (1,)

with torch.no_grad():
    loss, logits = signbart(
        keypoints=keypoints_input,
        attention_mask=attention_mask,
        labels=labels
    )

# Get predictions and confidence
predicted_class = logits.argmax(dim=1).item()
predicted_label = id2label[str(predicted_class)]
true_label = id2label[str(sample_label.item())]

# Calculate confidence (softmax probabilities)
probs = torch.softmax(logits, dim=1)
confidence = probs[0, predicted_class].item() * 100  # Confidence in predicted class
true_confidence = probs[0, sample_label.item()].item() * 100  # Confidence in true class

print(f"‚úÖ PyTorch inference successful!")
print(f"   True label: {true_label} (ID: {sample_label.item()})")
print(f"   Predicted: {predicted_label} (ID: {predicted_class})")
print(f"   Confidence: {confidence:.2f}%")
if predicted_class != sample_label.item():
    print(f"   True class confidence: {true_confidence:.2f}%")
print(f"   Loss: {loss.item():.4f}")
print(f"   Correct: {'‚úì' if predicted_class == sample_label.item() else '‚úó'}")

# Show top-5 predictions
top5_probs, top5_indices = torch.topk(probs[0], min(5, len(id2label)))
print(f"\n   Top-5 predictions:")
for i, (prob, idx) in enumerate(zip(top5_probs, top5_indices), 1):
    label = id2label[str(idx.item())]
    print(f"      {i}. {label}: {prob.item()*100:.2f}%")


üîÑ Running PyTorch inference...
‚úÖ PyTorch inference successful!
   True label: G10 (ID: 9)
   Predicted: G10 (ID: 9)
   Confidence: 92.22%
   Loss: 0.0810
   Correct: ‚úì

   Top-5 predictions:
      1. G10: 92.22%
      2. G04: 4.34%
      3. G07: 2.71%
      4. G01: 0.60%
      5. G06: 0.10%


In [14]:
# Prepare sample inputs for conversion
# Note: ai_edge_torch needs fixed sequence length, so we'll pad to 114 (max length)
max_seq_len = 114

# Pad the real sample to max_seq_len
T_current = keypoints_input.shape[1]
if T_current < max_seq_len:
    padding = torch.zeros(1, max_seq_len - T_current, keypoints_input.shape[2], keypoints_input.shape[3])
    keypoints_padded = torch.cat([keypoints_input, padding], dim=1)
    mask_padding = torch.zeros(1, max_seq_len - T_current, dtype=torch.long)
    attention_mask_padded = torch.cat([attention_mask, mask_padding], dim=1)
else:
    keypoints_padded = keypoints_input[:, :max_seq_len, :, :]
    attention_mask_padded = attention_mask[:, :max_seq_len]

print(f"   Padded input shape: {keypoints_padded.shape}")
print(f"   Attention mask shape: {attention_mask_padded.shape}")

# Test wrapper with padded input
with torch.no_grad():
    torch_output = signbart_inference(keypoints_padded, attention_mask_padded)
    print(f"   ‚úì Wrapper output shape: {torch_output.shape}")
    predicted = torch_output.argmax(dim=1).item()
    print(f"   ‚úì Wrapper predicted: {id2label[str(predicted)]} (ID: {predicted})")


   Padded input shape: torch.Size([1, 114, 100, 2])
   Attention mask shape: torch.Size([1, 114])
   ‚úì Wrapper output shape: torch.Size([1, 10])
   ‚úì Wrapper predicted: G10 (ID: 9)


In [15]:
# Convert to TFLite
sample_inputs = (keypoints_padded, attention_mask_padded)

try:
    print("\n   Converting to TFLite (this may take a few minutes)...")
    edge_model = ai_edge_torch.convert(signbart_inference.eval(), sample_inputs)
    
    # Save TFLite model
    output_path = "signbart_arabic_asl_user08_epoch79.tflite"
    edge_model.export(output_path)
    
    # Get file size
    file_size_mb = os.path.getsize(output_path) / (1024 * 1024)
    print(f"\n‚úÖ TFLite conversion successful!")
    print(f"   Saved to: {output_path}")
    print(f"   File size: {file_size_mb:.2f} MB")
    
except Exception as e:
    print(f"\n‚ùå Conversion failed: {e}")
    import traceback
    traceback.print_exc()
    raise


   Converting to TFLite (this may take a few minutes)...


W0000 00:00:1763022153.378200 1422696 cuda_executor.cc:1783] GPU interconnect information not available: INTERNAL: NVML library doesn't have required functions.
W0000 00:00:1763022153.379276 1422696 cuda_executor.cc:1783] GPU interconnect information not available: INTERNAL: NVML library doesn't have required functions.
W0000 00:00:1763022153.379485 1422696 cuda_executor.cc:1783] GPU interconnect information not available: INTERNAL: NVML library doesn't have required functions.
W0000 00:00:1763022153.381264 1422696 cuda_executor.cc:1783] GPU interconnect information not available: INTERNAL: NVML library doesn't have required functions.
W0000 00:00:1763022153.381385 1422696 cuda_executor.cc:1783] GPU interconnect information not available: INTERNAL: NVML library doesn't have required functions.
W0000 00:00:1763022153.381520 1422696 cuda_executor.cc:1783] GPU interconnect information not available: INTERNAL: NVML library doesn't have required functions.
W0000 00:00:1763022153.437593 1422

INFO:tensorflow:Assets written to: /tmp/tmpbg8u4sim/assets


INFO:tensorflow:Assets written to: /tmp/tmpbg8u4sim/assets
W0000 00:00:1763022154.812164 1422696 tf_tfl_flatbuffer_helpers.cc:364] Ignored output_format.
W0000 00:00:1763022154.812181 1422696 tf_tfl_flatbuffer_helpers.cc:367] Ignored drop_control_dependency.
I0000 00:00:1763022154.812587 1422696 reader.cc:83] Reading SavedModel from: /tmp/tmpbg8u4sim
I0000 00:00:1763022154.814062 1422696 reader.cc:52] Reading meta graph with tags { serve }
I0000 00:00:1763022154.814069 1422696 reader.cc:147] Reading SavedModel debug info (if present) from: /tmp/tmpbg8u4sim
I0000 00:00:1763022154.824683 1422696 mlir_graph_optimization_pass.cc:437] MLIR V1 optimization pass is not enabled
I0000 00:00:1763022154.826349 1422696 loader.cc:236] Restoring SavedModel bundle.
I0000 00:00:1763022154.897224 1422696 loader.cc:220] Running initialization op on SavedModel bundle at path: /tmp/tmpbg8u4sim
I0000 00:00:1763022154.918064 1422696 loader.cc:471] SavedModel load for tags { serve }; Status: success: OK. Too


‚úÖ TFLite conversion successful!
   Saved to: signbart_arabic_asl_user08_epoch79.tflite
   File size: 2.94 MB


In [17]:
# ============================================================================
# 6. Test TFLite Model with Real Sample
# ============================================================================
print("\n" + "="*80)
print("Testing TFLite Model")
print("="*80)

import tensorflow as tf

# Load TFLite model
interpreter = tf.lite.Interpreter(model_path=output_path)
interpreter.allocate_tensors()

# Get input/output details
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

print("\nüìã TFLite Model Details:")
for i, detail in enumerate(input_details):
    print(f"   Input {i}: shape={detail['shape']}, dtype={detail['dtype']}, name={detail['name']}")
for i, detail in enumerate(output_details):
    print(f"   Output {i}: shape={detail['shape']}, dtype={detail['dtype']}, name={detail['name']}")

# Run inference with the same padded sample
print(f"\nüîÑ Running TFLite inference on real sample...")
interpreter.set_tensor(input_details[0]['index'], keypoints_padded.numpy().astype(np.float32))
interpreter.set_tensor(input_details[1]['index'], attention_mask_padded.numpy().astype(np.int64))
interpreter.invoke()

# Get output
tflite_logits = interpreter.get_tensor(output_details[0]['index'])[0]  # Remove batch dim

# Get predictions
tflite_predicted_class = tflite_logits.argmax()
tflite_predicted_label = id2label[str(tflite_predicted_class)]

# Calculate confidence
tflite_probs = np.exp(tflite_logits) / np.exp(tflite_logits).sum()
tflite_confidence = tflite_probs[tflite_predicted_class] * 100
tflite_true_confidence = tflite_probs[sample_label.item()] * 100

print(f"‚úÖ TFLite inference successful!")
print(f"   True label: {true_label} (ID: {sample_label.item()})")
print(f"   Predicted: {tflite_predicted_label} (ID: {tflite_predicted_class})")
print(f"   Confidence: {tflite_confidence:.2f}%")
if tflite_predicted_class != sample_label.item():
    print(f"   True class confidence: {tflite_true_confidence:.2f}%")
print(f"   Correct: {'‚úì' if tflite_predicted_class == sample_label.item() else '‚úó'}")

# Show top-5 for TFLite
top5_indices = np.argsort(tflite_probs)[-5:][::-1]
print(f"\n   Top-5 predictions:")
for i, idx in enumerate(top5_indices, 1):
    label = id2label[str(idx)]
    print(f"      {i}. {label}: {tflite_probs[idx]*100:.2f}%")


Testing TFLite Model

üìã TFLite Model Details:
   Input 0: shape=[  1 114 100   2], dtype=<class 'numpy.float32'>, name=serving_default_args_0:0
   Input 1: shape=[  1 114], dtype=<class 'numpy.int64'>, name=serving_default_args_1:0
   Output 0: shape=[ 1 10], dtype=<class 'numpy.float32'>, name=StatefulPartitionedCall:0

üîÑ Running TFLite inference on real sample...
‚úÖ TFLite inference successful!
   True label: G10 (ID: 9)
   Predicted: G10 (ID: 9)
   Confidence: 92.22%
   Correct: ‚úì

   Top-5 predictions:
      1. G10: 92.22%
      2. G04: 4.34%
      3. G07: 2.71%
      4. G01: 0.60%
      5. G06: 0.10%


    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.


In [18]:
# ============================================================================
# 7. Compare PyTorch vs TFLite
# ============================================================================
print("\n" + "="*80)
print("PyTorch vs TFLite Comparison")
print("="*80)

# Get PyTorch output for padded input (for fair comparison)
with torch.no_grad():
    pytorch_logits_padded = signbart_inference(keypoints_padded, attention_mask_padded)[0]
    pytorch_probs_padded = torch.softmax(pytorch_logits_padded, dim=0).numpy()

# Compare logits
max_logit_diff = np.abs(pytorch_logits_padded.numpy() - tflite_logits).max()
mean_logit_diff = np.abs(pytorch_logits_padded.numpy() - tflite_logits).mean()

# Compare probabilities
max_prob_diff = np.abs(pytorch_probs_padded - tflite_probs).max()
mean_prob_diff = np.abs(pytorch_probs_padded - tflite_probs).mean()

print(f"\nüìä Numerical Accuracy:")
print(f"   Max logit difference: {max_logit_diff:.6f}")
print(f"   Mean logit difference: {mean_logit_diff:.6f}")
print(f"   Max probability difference: {max_prob_diff:.6f}")
print(f"   Mean probability difference: {mean_prob_diff:.6f}")

print(f"\nüéØ Prediction Accuracy:")
pytorch_pred = pytorch_logits_padded.argmax().item()
print(f"   PyTorch predicted: {id2label[str(pytorch_pred)]} (ID: {pytorch_pred})")
print(f"   TFLite predicted: {tflite_predicted_label} (ID: {tflite_predicted_class})")
print(f"   Same prediction: {'‚úì' if pytorch_pred == tflite_predicted_class else '‚úó'}")

# Overall assessment
print(f"\nüìà Conversion Quality:")
if max_prob_diff < 0.001:
    print("   ‚úÖ Excellent! Nearly identical outputs.")
elif max_prob_diff < 0.01:
    print("   ‚úÖ Very good! Minimal differences.")
elif max_prob_diff < 0.05:
    print("   ‚ö†Ô∏è  Acceptable. Small differences detected.")
else:
    print("   ‚ö†Ô∏è  Significant differences. May need investigation.")

print("\n" + "="*80)
print(f"‚úÖ TFLite model ready for deployment: {output_path}")
print(f"   Input: (1, 114, 100, 2) keypoints + (1, 114) attention_mask")
print(f"   Output: (1, {len(id2label)}) class logits")
print(f"   File size: {file_size_mb:.2f} MB")
print("="*80)


PyTorch vs TFLite Comparison

üìä Numerical Accuracy:
   Max logit difference: 0.000001
   Mean logit difference: 0.000000
   Max probability difference: 0.000000
   Mean probability difference: 0.000000

üéØ Prediction Accuracy:
   PyTorch predicted: G10 (ID: 9)
   TFLite predicted: G10 (ID: 9)
   Same prediction: ‚úì

üìà Conversion Quality:
   ‚úÖ Excellent! Nearly identical outputs.

‚úÖ TFLite model ready for deployment: signbart_arabic_asl_user08_epoch79.tflite
   Input: (1, 114, 100, 2) keypoints + (1, 114) attention_mask
   Output: (1, 10) class logits
   File size: 2.94 MB


In [19]:
# ============================================================================
# Correct TFLite Parameter Analysis
# ============================================================================
print("\n" + "="*80)
print("TFLite Model Parameter Analysis (Corrected)")
print("="*80)

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path=output_path)
interpreter.allocate_tensors()

# Get tensor details
tensor_details = interpreter.get_tensor_details()

# Only count actual model weights (stored in the flatbuffer)
# These are the tensors that are initialized from the model file
total_params = 0
weight_tensors = []

for tensor in tensor_details:
    # Check if this tensor is a constant (stored in the model file)
    # Activation tensors are allocated at runtime and don't contribute to model size
    quantization = tensor.get('quantization_parameters', {})
    
    # A tensor is a model parameter if it's used as input to ops but not an input/output
    # and has data stored in the model file
    shape = tensor['shape']
    name = tensor['name']
    
    # Skip input/output tensors
    if 'serving_default' in name:
        continue
    
    # Check if it's a weight by trying to read it
    try:
        # Try to get the tensor data - only works for constant tensors (weights)
        tensor_idx = tensor['index']
        tensor_data = interpreter.tensor(tensor_idx)()
        
        # If we can read it and it's not empty, it's a weight
        if tensor_data is not None and tensor_data.size > 0:
            num_elements = tensor_data.size
            total_params += num_elements
            weight_tensors.append({
                'name': name,
                'shape': shape,
                'params': num_elements,
                'dtype': tensor['dtype']
            })
    except:
        # If we can't read it, it's an intermediate tensor, not a weight
        pass

print(f"\nüìä Model Statistics:")
print(f"   Total model parameters: {total_params:,}")
print(f"   Weight tensors: {len(weight_tensors)}")

# Model size analysis
file_size_bytes = os.path.getsize(output_path)
file_size_mb = file_size_bytes / (1024 * 1024)
print(f"   TFLite file size: {file_size_mb:.2f} MB")

# Actual parameter size (most are quantized/compressed in TFLite)
actual_param_bytes = sum(w['params'] * np.dtype(w['dtype']).itemsize for w in weight_tensors)
actual_param_mb = actual_param_bytes / (1024 * 1024)
print(f"   Actual parameter size: {actual_param_mb:.2f} MB")

# Show largest actual weight tensors
print(f"\nüìã Top 20 Largest Weight Tensors:")
weight_tensors_sorted = sorted(weight_tensors, key=lambda x: x['params'], reverse=True)
for i, tensor in enumerate(weight_tensors_sorted[:20], 1):
    size_kb = tensor['params'] * np.dtype(tensor['dtype']).itemsize / 1024
    dtype_name = str(tensor['dtype']).split('.')[-1]
    print(f"   {i:2d}. {str(tensor['shape']):<20} {tensor['params']:>10,} params ({size_kb:>8.2f} KB) [{dtype_name}]")
    if len(tensor['name']) < 80:
        print(f"       {tensor['name']}")

# Compare with PyTorch
print(f"\nüîÑ PyTorch vs TFLite Comparison:")
pytorch_total = sum(p.numel() for p in signbart.parameters())
pytorch_size_mb = pytorch_total * 4 / (1024 * 1024)  # float32

print(f"   PyTorch parameters: {pytorch_total:,}")
print(f"   TFLite parameters: {total_params:,}")
print(f"   PyTorch size (float32): {pytorch_size_mb:.2f} MB")
print(f"   TFLite file size: {file_size_mb:.2f} MB")

if abs(pytorch_total - total_params) / pytorch_total < 0.01:
    print(f"   ‚úÖ Parameter counts match!")
else:
    diff_pct = abs(pytorch_total - total_params) / pytorch_total * 100
    print(f"   ‚ö†Ô∏è  Difference: {diff_pct:.1f}%")

compression_ratio = pytorch_size_mb / file_size_mb
print(f"   Compression: {compression_ratio:.2f}x")

print("="*80)


TFLite Model Parameter Analysis (Corrected)

üìä Model Statistics:
   Total model parameters: 1,494,289
   Weight tensors: 160
   TFLite file size: 2.94 MB
   Actual parameter size: 5.63 MB

üìã Top 20 Largest Weight Tensors:
    1. [  8 114 114]           103,968 params (  406.12 KB) [float32'>]
    2. [  1   8 114 114]       103,968 params (  406.12 KB) [float32'>]
    3. [  1   8 114 114]       103,968 params (  406.12 KB) [float32'>]
    4. [  1   8 114 114]       103,968 params (  406.12 KB) [float32'>]
    5. [  1   8 114 114]       103,968 params (  406.12 KB) [float32'>]
    6. [144 144]                20,736 params (   81.00 KB) [float32'>]
    7. [144 144]                20,736 params (   81.00 KB) [float32'>]
    8. [144 144]                20,736 params (   81.00 KB) [float32'>]
    9. [144 144]                20,736 params (   81.00 KB) [float32'>]
   10. [144 144]                20,736 params (   81.00 KB) [float32'>]
   11. [144 144]                20,736 params (   8

In [20]:
# ============================================================================
# Correct Parameter Count - Direct from File
# ============================================================================
print("\n" + "="*80)
print("TFLite Model Analysis (File-Based)")
print("="*80)

# The most reliable way: file size divided by bytes per parameter
file_size_bytes = os.path.getsize(output_path)
file_size_mb = file_size_bytes / (1024 * 1024)

# TFLite has overhead (model structure, metadata), but most is weights
# For float32 models: roughly file_size / 4 = number of parameters
estimated_params_from_size = file_size_bytes // 4

print(f"\nüìä File-Based Analysis:")
print(f"   TFLite file size: {file_size_mb:.2f} MB ({file_size_bytes:,} bytes)")
print(f"   Estimated parameters (file_size / 4): ~{estimated_params_from_size:,}")

# Compare with PyTorch
pytorch_total = sum(p.numel() for p in signbart.parameters())
pytorch_size_mb = pytorch_total * 4 / (1024 * 1024)

print(f"\nüîÑ Comparison:")
print(f"   PyTorch parameters: {pytorch_total:,}")
print(f"   PyTorch size (float32): {pytorch_size_mb:.2f} MB")
print(f"   TFLite file size: {file_size_mb:.2f} MB")

# The file sizes match almost exactly, so the parameter counts should too
size_ratio = file_size_mb / pytorch_size_mb
print(f"   Size ratio: {size_ratio:.3f}")

if 0.95 <= size_ratio <= 1.05:
    print(f"   ‚úÖ Sizes match! TFLite likely has ~{pytorch_total:,} parameters")
    print(f"   (The extra {file_size_bytes - pytorch_total*4:,} bytes are model metadata/structure)")
else:
    print(f"   ‚ö†Ô∏è  Size mismatch - investigating...")

print("="*80)


TFLite Model Analysis (File-Based)

üìä File-Based Analysis:
   TFLite file size: 2.94 MB (3,080,628 bytes)
   Estimated parameters (file_size / 4): ~770,157

üîÑ Comparison:
   PyTorch parameters: 776,458
   PyTorch size (float32): 2.96 MB
   TFLite file size: 2.94 MB
   Size ratio: 0.992
   ‚úÖ Sizes match! TFLite likely has ~776,458 parameters
   (The extra -25,204 bytes are model metadata/structure)
