-
Notifications
You must be signed in to change notification settings - Fork 244
Description
Export CTC Head from TDT-CTC-110M for Unified Custom Vocabulary
Summary
The Parakeet TDT-CTC-110M model contains a CTC head (used during training as auxiliary loss), but the current CoreML export only includes the TDT decoder path. By exporting the CTC logits as an additional output, we can eliminate the need for a separate CTC encoder in the custom vocabulary pipeline, reducing memory usage and simplifying the architecture.
Problem
Current custom vocabulary architecture requires two separate encoder models:
TDT Encoder (TDT-CTC-110M) ──→ Primary transcription (3.01% WER)
+
CTC Encoder (CTC 110M) ──→ Keyword spotting for vocabulary boosting
↓
Combined by VocabularyRescorer ──→ Final transcript with domain terms
Current costs:
- Total model size: ~470MB (TDT-CTC-110M: 410MB + CTC 110M: 60MB)
- Peak memory: ~130MB (both encoders loaded)
- Inference cost: 2× encoder runs (separate audio processing)
- Frame alignment: Manual synchronization needed between two encoders
Proposed Solution
Export the CTC head from the TDT-CTC-110M encoder as an additional CoreML output:
# In mobius/parakeet-tdt-ctc-110m/convert.py
encoder_coreml = ct.convert(
encoder,
inputs=[...],
outputs=[
ct.TensorType(name="encoder", ...), # TDT path (existing)
ct.TensorType(name="encoder_length", ...), # TDT path (existing)
ct.TensorType(name="ctc_logits", ...), # CTC head (NEW!)
]
)New unified architecture:
TDT-CTC-110M Encoder ──→ encoder features (TDT path)
├──→ ctc_logits (CTC path)
↓
TDT Decoder + CTC Keyword Spotter
↓
VocabularyRescorer
↓
Final transcript with domain terms
Benefits
| Metric | Current (Dual Encoder) | With CTC Head Export | Improvement |
|---|---|---|---|
| Models needed | 2 (TDT + CTC) | 1 (TDT-CTC only) | -1 model |
| Total size | ~470MB | ~410MB | -60MB (13%) |
| Encoder runs | 2 (separate) | 1 (shared) | 2× faster |
| Peak memory | ~130MB | ~90MB (est.) | -40MB (31%) |
| Frame alignment | Manual sync | Perfect (same encoder) | Guaranteed |
Implementation Steps
1. Model Conversion (in mobius repo)
File: mobius/parakeet-tdt-ctc-110m/convert.py
- Export CTC logits as additional output from encoder
- Test CTC head quality against standalone CTC 110M
- Verify frame alignment with TDT timestamps
- Upload updated model to HuggingFace
Expected output shape:
ctc_logits:[1, T, 1024]- Per-frame token log-probabilities
2. FluidAudio Code Changes
File: Sources/FluidAudio/ASR/AsrTranscription.swift
// Extract CTC logits from fused preprocessor output (if available)
if let ctcLogits = encoderOutputProvider.featureValue(for: "ctc_logits")?.multiArrayValue {
// Cache for custom vocabulary use
self.cachedCtcLogits = ctcLogits
self.cachedCtcFrameDuration = 0.04 // 40ms per frame
}File: Sources/FluidAudio/ASR/AsrManager.swift
// Add cached CTC logits storage
private var cachedCtcLogits: MLMultiArray?
private var cachedCtcFrameDuration: Double?
// Update vocabulary rescoring to use cached logits
private func applyVocabularyRescoring(result: ASRResult, audioSamples: [Float]) async -> ASRResult {
guard let rescorer = vocabularyRescorer else { return result }
guard let ctcLogits = cachedCtcLogits else {
logger.warning("No CTC logits cached, skipping vocabulary rescoring")
return result
}
// Use cached logits instead of running separate CTC encoder
return await rescorer.rescore(
tdtResult: result,
ctcLogits: ctcLogits,
frameDuration: cachedCtcFrameDuration ?? 0.04
)
}File: Sources/FluidAudio/ASR/CustomVocabulary/Rescorer/VocabularyRescorer.swift
// Add method to accept pre-computed CTC logits
public func rescore(
tdtResult: ASRResult,
ctcLogits: MLMultiArray,
frameDuration: Double
) async -> ASRResult {
// Convert MLMultiArray to [[Float]] log-probs
let logProbs = convertToLogProbs(ctcLogits)
// Run keyword spotting using cached logits
let detections = spotKeywords(logProbs: logProbs, frameDuration: frameDuration)
// Apply rescoring...
}3. Testing & Validation
- Unit tests for CTC logits extraction
- Verify frame alignment: TDT timestamps match CTC frame indices
- Benchmark WER with vocabulary rescoring (should match current dual-encoder approach)
- Memory profiling: confirm ~40MB reduction
- Inference speed: confirm 2× speedup on custom vocabulary workloads
4. Documentation Updates
- Update
Documentation/ASR/CustomVocabulary.mdto reflect unified architecture - Add example code showing simplified API (no separate CTC model loading)
- Update benchmark results with new memory/speed metrics
Technical Details
CTC Head Location in Model
The TDT-CTC-110M model architecture:
Audio → MelSpectrogram → FastConformer Encoder
├──→ TDT Path → Decoder LSTM → Joint Network
└──→ CTC Head → Token Logits [T, 1024]
During training, the CTC head provides auxiliary supervision. The head is already present in the encoder; it just needs to be exported in the CoreML conversion.
Frame Alignment
Both TDT and CTC operate on the same encoder frames:
- Frame rate: ~40ms per frame (25 frames per second)
- TDT timestamps: Frame indices when tokens emitted
- CTC logits: Per-frame probabilities at same indices
- Perfect alignment: Same encoder → guaranteed frame correspondence
Backward Compatibility
This change is fully backward compatible:
- Old models without
ctc_logitsoutput: Use separate CTC encoder (current behavior) - New models with
ctc_logitsoutput: Use unified approach (automatic)
Detection logic:
if encoderOutputProvider.featureValue(for: "ctc_logits") != nil {
// New unified approach
} else {
// Fallback to separate CTC encoder
}Example Usage (After Implementation)
Before (current):
// Load two models
let tdtModels = try await AsrModels.downloadAndLoad(version: .tdtCtc110m)
let ctcModels = try await CtcModels.downloadAndLoad(variant: .ctc110m)
let manager = AsrManager()
try await manager.initialize(models: tdtModels)
try await manager.enableVocabularyBoosting(ctcModels: ctcModels, terms: customTerms)
let result = try await manager.transcribe(audio)
// Two encoders run, manual frame alignment neededAfter (unified):
// Load single model
let models = try await AsrModels.downloadAndLoad(version: .tdtCtc110m)
let manager = AsrManager()
try await manager.initialize(models: models)
try await manager.enableVocabularyBoosting(terms: customTerms) // No CTC models needed!
let result = try await manager.transcribe(audio)
// Single encoder run, CTC logits extracted automaticallyOpen Questions
-
CTC head quality: Does the auxiliary CTC head match the quality of the dedicated CTC 110M model?
- Need to benchmark on custom vocabulary tasks
- Compare keyword spotting F1 scores
-
Model file size: Will adding CTC output significantly increase
.mlmodelcsize?- Likely negligible (head already exists, just exposing output)
-
ANE compatibility: Do CTC logits run efficiently on Apple Neural Engine?
- Need to profile with Instruments
Related Work
- PR feat: Support Parakeet-TDT-CTC-110M hybrid model #433: Add TDT-CTC-110M support
- Issue #XXX: Custom vocabulary documentation (if exists)
- NeMo paper: CTC-based Word Spotter (arXiv:2406.07096)
References
- Model: FluidInference/parakeet-tdt-ctc-110m-coreml
- CTC Model: FluidInference/parakeet-ctc-110m-coreml
- Conversion repo:
mobius/parakeet-tdt-ctc-110m/ - Custom vocabulary docs:
Documentation/ASR/CustomVocabulary.md