# VibeVoice Model - Complete Layer-by-Layer Walkthrough

This notebook walks through **every layer** of the VibeVoice model, showing tensor dimensions, annotations, and purpose.

**Flow**: Input Embeddings ‚Üí Language Model (28 layers) ‚Üí Tokenizers ‚Üí Connectors ‚Üí Prediction Head


In [25]:
# Setup: Load model
import sys
import os
project_root = os.path.abspath('../')
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference

model = VibeVoiceForConditionalGenerationInference.from_pretrained(
    "vibevoice/VibeVoice-1.5B",
    torch_dtype="auto",
    device_map="auto",
)


Loading checkpoint shards: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:00<00:00,  5.37it/s]


## Part 1: Language Model - Input Embedding Layer


In [26]:
# 1. Language Model Embedding Layer
layer = model.model.language_model.embed_tokens
print(layer)
print(f"\nWeight tensor:")
print(f"  layer.weight.shape = {layer.weight.shape}")
print(f"  layer.weight.dtype = {layer.weight.dtype}")

print(f"\nTensor shape breakdown:")
print(f"  shape[0] = {layer.weight.shape[0]} ‚Üí vocab_size (number of tokens in vocabulary)")
print(f"  shape[1] = {layer.weight.shape[1]} ‚Üí hidden_size (embedding dimension)")

print(f"\nPurpose: Maps token IDs to dense embeddings")
print(f"  Input: token_id (integer) ‚Üí Output: embedding vector [B, {layer.weight.shape[1]}]")
print(f"  When used: input_ids [B, L] ‚Üí embeddings [B, L, {layer.weight.shape[1]}]")
print(f"\nThis is the FIRST layer - converts discrete tokens to continuous embeddings")


Embedding(151936, 1536)

Weight tensor:
  layer.weight.shape = torch.Size([151936, 1536])
  layer.weight.dtype = torch.bfloat16

Tensor shape breakdown:
  shape[0] = 151936 ‚Üí vocab_size (number of tokens in vocabulary)
  shape[1] = 1536 ‚Üí hidden_size (embedding dimension)

Purpose: Maps token IDs to dense embeddings
  Input: token_id (integer) ‚Üí Output: embedding vector [B, 1536]
  When used: input_ids [B, L] ‚Üí embeddings [B, L, 1536]

This is the FIRST layer - converts discrete tokens to continuous embeddings


## Part 2: Language Model - Transformer Layers (28 Layers)

The model has **28 identical transformer decoder layers**. We'll examine the first layer in detail, then note that this pattern repeats 27 more times.


In [27]:
# 2. Language Model Layer 0 - Self-Attention (Q, K, V, O projections)
layer = model.model.language_model.layers[0]
print(f"=== Qwen2DecoderLayer (Layer 0 of {len(model.model.language_model.layers)}) ===")
print(f"\nThis layer pattern repeats for all {len(model.model.language_model.layers)} layers")
print("\n--- Self-Attention Projections ---")

# Q projection
q_proj = layer.self_attn.q_proj
print(f"\nq_proj: {q_proj}")
print(f"  q_proj.weight.shape = {q_proj.weight.shape}")
print(f"  q_proj.weight.dtype = {q_proj.weight.dtype}")
print(f"  q_proj.bias.shape = {q_proj.bias.shape if q_proj.bias is not None else None}")
print(f"  Tensor shape: [out_dim={q_proj.weight.shape[0]}, in_dim={q_proj.weight.shape[1]}]")
print(f"  Purpose: Projects hidden states [B, L, {q_proj.weight.shape[1]}] ‚Üí queries [B, L, {q_proj.weight.shape[0]}]")

# K projection (GQA - smaller dimension)
k_proj = layer.self_attn.k_proj
print(f"\nk_proj: {k_proj}")
print(f"  k_proj.weight.shape = {k_proj.weight.shape}")
print(f"  k_proj.weight.dtype = {k_proj.weight.dtype}")
print(f"  k_proj.bias.shape = {k_proj.bias.shape if k_proj.bias is not None else None}")
print(f"  Tensor shape: [key_dim={k_proj.weight.shape[0]}, in_dim={k_proj.weight.shape[1]}] (smaller for GQA)")
print(f"  Purpose: Projects hidden states [B, L, {k_proj.weight.shape[1]}] ‚Üí keys [B, L, {k_proj.weight.shape[0]}]")

# V projection (GQA - smaller dimension)
v_proj = layer.self_attn.v_proj
print(f"\nv_proj: {v_proj}")
print(f"  v_proj.weight.shape = {v_proj.weight.shape}")
print(f"  v_proj.weight.dtype = {v_proj.weight.dtype}")
print(f"  v_proj.bias.shape = {v_proj.bias.shape if v_proj.bias is not None else None}")
print(f"  Tensor shape: [value_dim={v_proj.weight.shape[0]}, in_dim={v_proj.weight.shape[1]}] (smaller for GQA)")
print(f"  Purpose: Projects hidden states [B, L, {v_proj.weight.shape[1]}] ‚Üí values [B, L, {v_proj.weight.shape[0]}]")

# O projection
o_proj = layer.self_attn.o_proj
print(f"\no_proj: {o_proj}")
print(f"  o_proj.weight.shape = {o_proj.weight.shape}")
print(f"  o_proj.weight.dtype = {o_proj.weight.dtype}")
print(f"  o_proj.bias = {o_proj.bias is not None}")
print(f"  Tensor shape: [out_dim={o_proj.weight.shape[0]}, in_dim={o_proj.weight.shape[1]}]")
print(f"  Purpose: Projects attention output [B, L, {o_proj.weight.shape[1]}] ‚Üí [B, L, {o_proj.weight.shape[0]}]")

print(f"\n‚ö†Ô∏è  This attention block (q_proj, k_proj, v_proj, o_proj) repeats in all {len(model.model.language_model.layers)} layers")


=== Qwen2DecoderLayer (Layer 0 of 28) ===

This layer pattern repeats for all 28 layers

--- Self-Attention Projections ---

q_proj: Linear(in_features=1536, out_features=1536, bias=True)
  q_proj.weight.shape = torch.Size([1536, 1536])
  q_proj.weight.dtype = torch.bfloat16
  q_proj.bias.shape = torch.Size([1536])
  Tensor shape: [out_dim=1536, in_dim=1536]
  Purpose: Projects hidden states [B, L, 1536] ‚Üí queries [B, L, 1536]

k_proj: Linear(in_features=1536, out_features=256, bias=True)
  k_proj.weight.shape = torch.Size([256, 1536])
  k_proj.weight.dtype = torch.bfloat16
  k_proj.bias.shape = torch.Size([256])
  Tensor shape: [key_dim=256, in_dim=1536] (smaller for GQA)
  Purpose: Projects hidden states [B, L, 1536] ‚Üí keys [B, L, 256]

v_proj: Linear(in_features=1536, out_features=256, bias=True)
  v_proj.weight.shape = torch.Size([256, 1536])
  v_proj.weight.dtype = torch.bfloat16
  v_proj.bias.shape = torch.Size([256])
  Tensor shape: [value_dim=256, in_dim=1536] (smaller for 

In [28]:
# 2b. Language Model Layer 0 - MLP (Feed-Forward Network)
layer = model.model.language_model.layers[0]

print("--- MLP (Feed-Forward Network) ---")

# Gate projection
gate_proj = layer.mlp.gate_proj
print(f"\ngate_proj: {gate_proj}")
print(f"  gate_proj.weight.shape = {gate_proj.weight.shape}")
print(f"  gate_proj.weight.dtype = {gate_proj.weight.dtype}")
print(f"  Tensor shape: [out={gate_proj.weight.shape[0]}, in={gate_proj.weight.shape[1]}]")
print(f"  Purpose: Projects [B, L, {gate_proj.weight.shape[1]}] ‚Üí gate [B, L, {gate_proj.weight.shape[0]}] (GLU gate signal)")

# Up projection
up_proj = layer.mlp.up_proj
print(f"\nup_proj: {up_proj}")
print(f"  up_proj.weight.shape = {up_proj.weight.shape}")
print(f"  up_proj.weight.dtype = {up_proj.weight.dtype}")
print(f"  Tensor shape: [out={up_proj.weight.shape[0]}, in={up_proj.weight.shape[1]}]")
print(f"  Purpose: Projects [B, L, {up_proj.weight.shape[1]}] ‚Üí value [B, L, {up_proj.weight.shape[0]}] (GLU value signal)")

# Down projection
down_proj = layer.mlp.down_proj
print(f"\ndown_proj: {down_proj}")
print(f"  down_proj.weight.shape = {down_proj.weight.shape}")
print(f"  down_proj.weight.dtype = {down_proj.weight.dtype}")
print(f"  Tensor shape: [out={down_proj.weight.shape[0]}, in={down_proj.weight.shape[1]}]")
print(f"  Purpose: Projects GLU output [B, L, {down_proj.weight.shape[1]}] ‚Üí [B, L, {down_proj.weight.shape[0]}]")

print(f"\nMLP Flow: hidden_size [{gate_proj.weight.shape[1]}] ‚Üí ffn_dim [{gate_proj.weight.shape[0]}] ‚Üí hidden_size [{down_proj.weight.shape[0]}]")
print(f"‚ö†Ô∏è  This MLP block (gate_proj, up_proj, down_proj) repeats in all {len(model.model.language_model.layers)} layers")


--- MLP (Feed-Forward Network) ---

gate_proj: Linear(in_features=1536, out_features=8960, bias=False)
  gate_proj.weight.shape = torch.Size([8960, 1536])
  gate_proj.weight.dtype = torch.bfloat16
  Tensor shape: [out=8960, in=1536]
  Purpose: Projects [B, L, 1536] ‚Üí gate [B, L, 8960] (GLU gate signal)

up_proj: Linear(in_features=1536, out_features=8960, bias=False)
  up_proj.weight.shape = torch.Size([8960, 1536])
  up_proj.weight.dtype = torch.bfloat16
  Tensor shape: [out=8960, in=1536]
  Purpose: Projects [B, L, 1536] ‚Üí value [B, L, 8960] (GLU value signal)

down_proj: Linear(in_features=8960, out_features=1536, bias=False)
  down_proj.weight.shape = torch.Size([1536, 8960])
  down_proj.weight.dtype = torch.bfloat16
  Tensor shape: [out=1536, in=8960]
  Purpose: Projects GLU output [B, L, 8960] ‚Üí [B, L, 1536]

MLP Flow: hidden_size [1536] ‚Üí ffn_dim [8960] ‚Üí hidden_size [1536]
‚ö†Ô∏è  This MLP block (gate_proj, up_proj, down_proj) repeats in all 28 layers


In [29]:
# 2c. Language Model Layer 0 - Layer Normalizations
layer = model.model.language_model.layers[0]

print("--- Layer Normalizations ---")

# Input layer norm
input_norm = layer.input_layernorm
print(f"\ninput_layernorm: {input_norm}")
print(f"  input_norm.weight.shape = {input_norm.weight.shape}")
print(f"  input_norm.weight.dtype = {input_norm.weight.dtype}")
print(f"  Tensor shape: [{input_norm.weight.shape[0]}] (1D tensor, hidden_size)")
print(f"  Purpose: Normalizes input [B, L, {input_norm.weight.shape[0]}] before self-attention (pre-norm)")

# Post-attention layer norm
post_norm = layer.post_attention_layernorm
print(f"\npost_attention_layernorm: {post_norm}")
print(f"  post_norm.weight.shape = {post_norm.weight.shape}")
print(f"  post_norm.weight.dtype = {post_norm.weight.dtype}")
print(f"  Tensor shape: [{post_norm.weight.shape[0]}] (1D tensor, hidden_size)")
print(f"  Purpose: Normalizes [B, L, {post_norm.weight.shape[0]}] after attention, before MLP")

print(f"\n‚ö†Ô∏è  These normalization layers repeat in all {len(model.model.language_model.layers)} transformer layers")
print(f"\nüìä Summary: Each of the {len(model.model.language_model.layers)} layers contains:")
print(f"   - input_layernorm ‚Üí self_attn (q/k/v/o) ‚Üí post_attention_layernorm ‚Üí MLP (gate/up/down)")


--- Layer Normalizations ---

input_layernorm: Qwen2RMSNorm((1536,), eps=1e-06)
  input_norm.weight.shape = torch.Size([1536])
  input_norm.weight.dtype = torch.bfloat16
  Tensor shape: [1536] (1D tensor, hidden_size)
  Purpose: Normalizes input [B, L, 1536] before self-attention (pre-norm)

post_attention_layernorm: Qwen2RMSNorm((1536,), eps=1e-06)
  post_norm.weight.shape = torch.Size([1536])
  post_norm.weight.dtype = torch.bfloat16
  Tensor shape: [1536] (1D tensor, hidden_size)
  Purpose: Normalizes [B, L, 1536] after attention, before MLP

‚ö†Ô∏è  These normalization layers repeat in all 28 transformer layers

üìä Summary: Each of the 28 layers contains:
   - input_layernorm ‚Üí self_attn (q/k/v/o) ‚Üí post_attention_layernorm ‚Üí MLP (gate/up/down)


In [30]:
# 3. Language Model Final Norm
layer = model.model.language_model.norm
print(layer)
print(f"\n  norm.weight.shape = {layer.weight.shape}")
print(f"  norm.weight.dtype = {layer.weight.dtype}")
print(f"  Tensor shape: [{layer.weight.shape[0]}] (1D tensor)")
print(f"\nPurpose: Final normalization layer after all {len(model.model.language_model.layers)} transformer layers")
print(f"  Input: [B, L, {layer.weight.shape[0]}] ‚Üí Output: [B, L, {layer.weight.shape[0]}] (normalized)")
print(f"\n‚úÖ After this, we have processed embeddings ready for connectors or prediction head")


Qwen2RMSNorm((1536,), eps=1e-06)

  norm.weight.shape = torch.Size([1536])
  norm.weight.dtype = torch.bfloat16
  Tensor shape: [1536] (1D tensor)

Purpose: Final normalization layer after all 28 transformer layers
  Input: [B, L, 1536] ‚Üí Output: [B, L, 1536] (normalized)

‚úÖ After this, we have processed embeddings ready for connectors or prediction head


## Part 3: Acoustic Tokenizer - Encoder (Audio ‚Üí 64D Latents)

The encoder has **7 downsample stages** with convolutional layers. We'll examine each stage.


In [None]:
# 4. Acoustic Tokenizer Encoder - Overview
encoder = model.model.acoustic_tokenizer.encoder
print("=== TokenizerEncoder ===")
print(f"Number of do  wnsample stages: {len(encoder.downsample_layers)}")
print(f"Number of processing stages: {len(encoder.stages)}")
print(f"\nPurpose: Encodes audio waveform ‚Üí 64D acoustic latents")
print(f"Flow: Audio [1 channel] ‚Üí 7 downsample stages ‚Üí Processing blocks ‚Üí 64D output")
print(f"\nWe'll examine each downsample layer and stage block below")


=== TokenizerEncoder ===
Number of downsample stages: 7
Number of processing stages: 7

Purpose: Encodes audio waveform ‚Üí 64D acoustic latents
Flow: Audio [1 channel] ‚Üí 7 downsample stages ‚Üí Processing blocks ‚Üí 64D output

We'll examine each downsample layer and stage block below


In [32]:
# 4a. Acoustic Tokenizer Encoder - All Downsample Layers
print(f"=== All {len(encoder.downsample_layers)} Downsample Layers ===")
print(f"\nThese layers progressively downsample and increase channels:")
print(f"\nStage | Weight Shape | In‚ÜíOut Channels | Kernel | Stride | Output Shape")
print(f"------|--------------|-----------------|--------|--------|-------------")

for i, downsample_layer in enumerate(encoder.downsample_layers):
    conv = downsample_layer[0].conv.conv
    out_ch = conv.weight.shape[0]
    in_ch = conv.weight.shape[1]
    kernel = conv.weight.shape[2]
    stride = conv.stride[0] if hasattr(conv, 'stride') else 1
    
    print(f"  {i}    | {conv.weight.shape} | {in_ch:2}‚Üí{out_ch:2}          | {kernel:6} | {stride:6} | [B, {out_ch}, T/{stride}]")

print(f"\n‚ö†Ô∏è  These {len(encoder.downsample_layers)} downsample layers progressively reduce temporal resolution")
print(f"   and increase channel depth from 1 ‚Üí 32 ‚Üí 64 ‚Üí 128 ‚Üí 256 ‚Üí 512 ‚Üí 1024 ‚Üí 2048")
print(f"   Where B=batch, T=time samples")


=== All 7 Downsample Layers ===

These layers progressively downsample and increase channels:

Stage | Weight Shape | In‚ÜíOut Channels | Kernel | Stride | Output Shape
------|--------------|-----------------|--------|--------|-------------
  0    | torch.Size([32, 1, 7]) |  1‚Üí32          |      7 |      1 | [B, 32, T/1]
  1    | torch.Size([64, 32, 4]) | 32‚Üí64          |      4 |      2 | [B, 64, T/2]
  2    | torch.Size([128, 64, 4]) | 64‚Üí128          |      4 |      2 | [B, 128, T/2]
  3    | torch.Size([256, 128, 8]) | 128‚Üí256          |      8 |      4 | [B, 256, T/4]
  4    | torch.Size([512, 256, 10]) | 256‚Üí512          |     10 |      5 | [B, 512, T/5]
  5    | torch.Size([1024, 512, 10]) | 512‚Üí1024          |     10 |      5 | [B, 1024, T/5]
  6    | torch.Size([2048, 1024, 16]) | 1024‚Üí2048          |     16 |      8 | [B, 2048, T/8]

‚ö†Ô∏è  These 7 downsample layers progressively reduce temporal resolution
   and increase channel depth from 1 ‚Üí 32 ‚Üí 64 ‚Üí 

In [33]:
# 4b. Acoustic Tokenizer Encoder - Downsample Layer 0 (Detailed)
downsample_0 = encoder.downsample_layers[0][0]
conv_0 = downsample_0.conv.conv
print(f"=== Downsample Stage 0 (First of {len(encoder.downsample_layers)}) ===")
print(f"Layer: {downsample_0}")
print(f"\nConv1d:")
print(f"  conv.weight.shape = {conv_0.weight.shape}")
print(f"  conv.weight.dtype = {conv_0.weight.dtype}")
print(f"  Tensor shape: [out_ch={conv_0.weight.shape[0]}, in_ch={conv_0.weight.shape[1]}, kernel={conv_0.weight.shape[2]}]")
print(f"  Purpose: Converts [B, {conv_0.weight.shape[1]}, T] ‚Üí [B, {conv_0.weight.shape[0]}, T']")
print(f"  Where B=batch, T=time samples, T'=downsampled time")
print(f"\n‚ö†Ô∏è  Similar structure for all {len(encoder.downsample_layers)} downsample layers (different channels/kernels)")


=== Downsample Stage 0 (First of 7) ===
Layer: SConv1d(
  (conv): NormConv1d(
    (conv): Conv1d(1, 32, kernel_size=(7,), stride=(1,))
    (norm): Identity()
  )
)

Conv1d:
  conv.weight.shape = torch.Size([32, 1, 7])
  conv.weight.dtype = torch.bfloat16
  Tensor shape: [out_ch=32, in_ch=1, kernel=7]
  Purpose: Converts [B, 1, T] ‚Üí [B, 32, T']
  Where B=batch, T=time samples, T'=downsampled time

‚ö†Ô∏è  Similar structure for all 7 downsample layers (different channels/kernels)


In [34]:
# 4c. Acoustic Tokenizer Encoder - Processing Stages Overview
print(f"=== All {len(encoder.stages)} Processing Stages ===")
print(f"\nEach stage contains Block1D layers that process the downsampled features:")
print(f"\nStage | Block1D Count | Channel Dim | Purpose")
print(f"------|---------------|-------------|---------")

for i, stage in enumerate(encoder.stages):
    num_blocks = len(stage)
    # Get channel dim from first block's norm
    if len(stage) > 0:
        channel_dim = stage[0].norm.weight.shape[0]
        print(f"  {i}    | {num_blocks:13} | {channel_dim:11} | Feature processing at this resolution")
    
print(f"\n‚ö†Ô∏è  Each stage has multiple Block1D layers for feature extraction")
print(f"   Total Block1D layers across all stages: {sum(len(stage) for stage in encoder.stages)}")


=== All 7 Processing Stages ===

Each stage contains Block1D layers that process the downsampled features:

Stage | Block1D Count | Channel Dim | Purpose
------|---------------|-------------|---------
  0    |             3 |          32 | Feature processing at this resolution
  1    |             3 |          64 | Feature processing at this resolution
  2    |             3 |         128 | Feature processing at this resolution
  3    |             3 |         256 | Feature processing at this resolution
  4    |             3 |         512 | Feature processing at this resolution
  5    |             3 |        1024 | Feature processing at this resolution
  6    |             8 |        2048 | Feature processing at this resolution

‚ö†Ô∏è  Each stage has multiple Block1D layers for feature extraction
   Total Block1D layers across all stages: 26


In [35]:
# 4d. Acoustic Tokenizer Encoder - Block1D Structure (Stage 0, Block 0)
stage_0 = encoder.stages[0]
block_0 = stage_0[0]
print(f"=== Block1D Structure (Stage 0, Block 0 of {len(stage_0)}) ===")
print(f"\nEach Block1D contains:")
print(f"  - ConvRMSNorm (normalization)")
print(f"  - Convlayer (depthwise conv)")
print(f"  - FFN (feed-forward network)")
print(f"\n‚ö†Ô∏è  This Block1D structure repeats in all stages")

# Norm
norm = block_0.norm
print(f"\n--- norm (ConvRMSNorm) ---")
print(f"  norm.weight.shape = {norm.weight.shape}")
print(f"  norm.weight.dtype = {norm.weight.dtype}")
print(f"  Tensor shape: [{norm.weight.shape[0]}] (1D tensor, channel_dim)")
print(f"  Purpose: Normalizes [B, {norm.weight.shape[0]}, T] before convolution")

# Conv in mixer (NormConv1d wraps Conv1d in .conv)
conv_mixer = block_0.mixer.conv.conv
conv_weight = conv_mixer.conv.weight  # Access the actual Conv1d weight
print(f"\n--- mixer.conv (Depthwise Conv1d) ---")
print(f"  conv.weight.shape = {conv_weight.shape}")
print(f"  conv.weight.dtype = {conv_weight.dtype}")
print(f"  Tensor shape: [out_ch={conv_weight.shape[0]}, in_ch={conv_weight.shape[1]}, kernel={conv_weight.shape[2]}]")
print(f"  Groups: {conv_mixer.conv.groups} (depthwise - processes each channel separately)")
print(f"  Purpose: Spatial feature extraction [B, {conv_weight.shape[1]}, T] ‚Üí [B, {conv_weight.shape[0]}, T]")

# FFN
ffn = block_0.ffn
print(f"\n--- ffn (Feed-Forward Network) ---")
print(f"  ffn.linear1.weight.shape = {ffn.linear1.weight.shape}, dtype: {ffn.linear1.weight.dtype}")
print(f"  ffn.linear2.weight.shape = {ffn.linear2.weight.shape}, dtype: {ffn.linear2.weight.dtype}")
print(f"  Tensor shapes: [{ffn.linear1.weight.shape[1]}] ‚Üí [{ffn.linear1.weight.shape[0]}] ‚Üí [{ffn.linear2.weight.shape[0]}]")
print(f"  Purpose: Channel-wise transformation [B, T, {ffn.linear1.weight.shape[1]}] ‚Üí [B, T, {ffn.linear2.weight.shape[0]}]")


=== Block1D Structure (Stage 0, Block 0 of 3) ===

Each Block1D contains:
  - ConvRMSNorm (normalization)
  - Convlayer (depthwise conv)
  - FFN (feed-forward network)

‚ö†Ô∏è  This Block1D structure repeats in all stages

--- norm (ConvRMSNorm) ---
  norm.weight.shape = torch.Size([32])
  norm.weight.dtype = torch.bfloat16
  Tensor shape: [32] (1D tensor, channel_dim)
  Purpose: Normalizes [B, 32, T] before convolution

--- mixer.conv (Depthwise Conv1d) ---
  conv.weight.shape = torch.Size([32, 1, 7])
  conv.weight.dtype = torch.bfloat16
  Tensor shape: [out_ch=32, in_ch=1, kernel=7]
  Groups: 32 (depthwise - processes each channel separately)
  Purpose: Spatial feature extraction [B, 1, T] ‚Üí [B, 32, T]

--- ffn (Feed-Forward Network) ---
  ffn.linear1.weight.shape = torch.Size([128, 32]), dtype: torch.bfloat16
  ffn.linear2.weight.shape = torch.Size([32, 128]), dtype: torch.bfloat16
  Tensor shapes: [32] ‚Üí [128] ‚Üí [32]
  Purpose: Channel-wise transformation [B, T, 32] ‚Üí [B, T

In [36]:
# 4e. Acoustic Tokenizer Encoder - Head Layer (Final projection to 64D)
head = encoder.head.conv.conv
print(f"=== Encoder Head Layer (Final Layer) ===")
print(f"Layer: {encoder.head}")
print(f"\nConv1d:")
print(f"  head.weight.shape = {head.weight.shape}")
print(f"  head.weight.dtype = {head.weight.dtype}")
print(f"  Tensor shape: [out_ch={head.weight.shape[0]}, in_ch={head.weight.shape[1]}, kernel={head.weight.shape[2]}]")
print(f"  Purpose: Final projection to 64D acoustic latent space")
print(f"  Input: [B, {head.weight.shape[1]}, T] ‚Üí Output: [B, {head.weight.shape[0]}, T]")
print(f"  Note: '64D' means shape [B, 64, T] where B=batch, T=time steps")
print(f"\n‚úÖ This is the final layer of the encoder - outputs 64D acoustic latents")


=== Encoder Head Layer (Final Layer) ===
Layer: SConv1d(
  (conv): NormConv1d(
    (conv): Conv1d(2048, 64, kernel_size=(7,), stride=(1,))
    (norm): Identity()
  )
)

Conv1d:
  head.weight.shape = torch.Size([64, 2048, 7])
  head.weight.dtype = torch.bfloat16
  Tensor shape: [out_ch=64, in_ch=2048, kernel=7]
  Purpose: Final projection to 64D acoustic latent space
  Input: [B, 2048, T] ‚Üí Output: [B, 64, T]
  Note: '64D' means shape [B, 64, T] where B=batch, T=time steps

‚úÖ This is the final layer of the encoder - outputs 64D acoustic latents


## Part 4: Acoustic Tokenizer - Decoder (64D Latents ‚Üí Audio)

The decoder has **7 upsample stages** that reconstruct audio from latents.


In [37]:
# 5. Acoustic Tokenizer Decoder - Overview
decoder = model.model.acoustic_tokenizer.decoder
print("=== TokenizerDecoder ===")
print(f"Number of upsample stages: {len(decoder.upsample_layers)}")
print(f"Number of processing stages: {len(decoder.stages)}")
print(f"\nPurpose: Decodes 64D acoustic latents ‚Üí audio waveform")
print(f"Flow: 64D latents ‚Üí Upsample stages ‚Üí Processing blocks ‚Üí 1 channel audio")
print(f"\nThis is the reverse of the encoder - reconstructs audio from compressed representation")


=== TokenizerDecoder ===
Number of upsample stages: 7
Number of processing stages: 7

Purpose: Decodes 64D acoustic latents ‚Üí audio waveform
Flow: 64D latents ‚Üí Upsample stages ‚Üí Processing blocks ‚Üí 1 channel audio

This is the reverse of the encoder - reconstructs audio from compressed representation


In [38]:
# 5a. Acoustic Tokenizer Decoder - All Upsample Layers
print(f"=== All {len(decoder.upsample_layers)} Upsample Layers ===")
print(f"\nThese layers progressively upsample and decrease channels:")
print(f"\nStage | Out Channels | Type | Purpose")
print(f"------|--------------|------|---------")

for i, upsample_layer in enumerate(decoder.upsample_layers):
    if len(upsample_layer) > 0:
        layer = upsample_layer[0]
        if hasattr(layer, 'conv'):
            # Regular conv
            conv = layer.conv.conv
            out_ch = conv.weight.shape[0]
            in_ch = conv.weight.shape[1]
            print(f"  {i}    | {out_ch:12} | Conv1d | {in_ch}‚Üí{out_ch} channels")
        elif hasattr(layer, 'convtr'):
            # Transposed conv (upsampling)
            convtr = layer.convtr.convtr
            out_ch = convtr.weight.shape[0]
            in_ch = convtr.weight.shape[1]
            stride = convtr.stride[0] if hasattr(convtr, 'stride') else 1
            print(f"  {i}    | {out_ch:12} | ConvTranspose1d | {in_ch}‚Üí{out_ch}, stride={stride} (upsample)")

print(f"\n‚ö†Ô∏è  These {len(decoder.upsample_layers)} upsample layers progressively increase temporal resolution")
print(f"   and decrease channel depth from 2048 ‚Üí 1024 ‚Üí 512 ‚Üí 256 ‚Üí 128 ‚Üí 64 ‚Üí 32 ‚Üí 1")


=== All 7 Upsample Layers ===

These layers progressively upsample and decrease channels:

Stage | Out Channels | Type | Purpose
------|--------------|------|---------
  0    |         2048 | Conv1d | 64‚Üí2048 channels
  1    |         2048 | ConvTranspose1d | 1024‚Üí2048, stride=8 (upsample)
  2    |         1024 | ConvTranspose1d | 512‚Üí1024, stride=5 (upsample)
  3    |          512 | ConvTranspose1d | 256‚Üí512, stride=5 (upsample)
  4    |          256 | ConvTranspose1d | 128‚Üí256, stride=4 (upsample)
  5    |          128 | ConvTranspose1d | 64‚Üí128, stride=2 (upsample)
  6    |           64 | ConvTranspose1d | 32‚Üí64, stride=2 (upsample)

‚ö†Ô∏è  These 7 upsample layers progressively increase temporal resolution
   and decrease channel depth from 2048 ‚Üí 1024 ‚Üí 512 ‚Üí 256 ‚Üí 128 ‚Üí 64 ‚Üí 32 ‚Üí 1


In [39]:
# 5b. Acoustic Tokenizer Decoder - Head Layer (Final projection to 1 channel)
head = decoder.head.conv.conv
print(f"=== Decoder Head Layer (Final Layer) ===")
print(f"Layer: {decoder.head}")
print(f"\nConv1d:")
print(f"  head.weight.shape = {head.weight.shape}")
print(f"  head.weight.dtype = {head.weight.dtype}")
print(f"  Tensor shape: [out_ch={head.weight.shape[0]}, in_ch={head.weight.shape[1]}, kernel={head.weight.shape[2]}]")
print(f"  Purpose: Final projection to 1-channel audio waveform")
print(f"  Input: [B, {head.weight.shape[1]}, T] ‚Üí Output: [B, {head.weight.shape[0]}, T]")
print(f"  Where B=batch, T=time samples (audio waveform)")
print(f"\n‚úÖ This is the final layer of the decoder - outputs reconstructed audio")


=== Decoder Head Layer (Final Layer) ===
Layer: SConv1d(
  (conv): NormConv1d(
    (conv): Conv1d(32, 1, kernel_size=(7,), stride=(1,))
    (norm): Identity()
  )
)

Conv1d:
  head.weight.shape = torch.Size([1, 32, 7])
  head.weight.dtype = torch.bfloat16
  Tensor shape: [out_ch=1, in_ch=32, kernel=7]
  Purpose: Final projection to 1-channel audio waveform
  Input: [B, 32, T] ‚Üí Output: [B, 1, T]
  Where B=batch, T=time samples (audio waveform)

‚úÖ This is the final layer of the decoder - outputs reconstructed audio


In [40]:
# 6. Semantic Tokenizer Encoder - Overview
semantic_encoder = model.model.semantic_tokenizer.encoder
print("=== Semantic TokenizerEncoder ===")
print(f"Number of downsample stages: {len(semantic_encoder.downsample_layers)}")
print(f"Number of processing stages: {len(semantic_encoder.stages)}")
print(f"\nPurpose: Encodes audio waveform ‚Üí 128D semantic latents")
print(f"Key difference: Outputs 128D (semantic) vs 64D (acoustic)")
print(f"Captures: Linguistic content (what is said) not acoustic properties (how it sounds)")


=== Semantic TokenizerEncoder ===
Number of downsample stages: 7
Number of processing stages: 7

Purpose: Encodes audio waveform ‚Üí 128D semantic latents
Key difference: Outputs 128D (semantic) vs 64D (acoustic)
Captures: Linguistic content (what is said) not acoustic properties (how it sounds)


In [41]:
# 6a. Semantic Tokenizer Encoder - All Downsample Layers
print(f"=== All {len(semantic_encoder.downsample_layers)} Downsample Layers ===")
print(f"\nStage | Out Channels | In‚ÜíOut | Purpose")
print(f"------|--------------|--------|---------")

for i, downsample_layer in enumerate(semantic_encoder.downsample_layers):
    conv = downsample_layer[0].conv.conv
    out_ch = conv.weight.shape[0]
    in_ch = conv.weight.shape[1]
    print(f"  {i}    | {out_ch:12} | {in_ch:2}‚Üí{out_ch:2}  | Downsample + channel expansion")

print(f"\n‚ö†Ô∏è  Same structure as acoustic encoder: {len(semantic_encoder.downsample_layers)} downsample stages")
print(f"   Channel progression: 1 ‚Üí 32 ‚Üí 64 ‚Üí 128 ‚Üí 256 ‚Üí 512 ‚Üí 1024 ‚Üí 2048")


=== All 7 Downsample Layers ===

Stage | Out Channels | In‚ÜíOut | Purpose
------|--------------|--------|---------
  0    |           32 |  1‚Üí32  | Downsample + channel expansion
  1    |           64 | 32‚Üí64  | Downsample + channel expansion
  2    |          128 | 64‚Üí128  | Downsample + channel expansion
  3    |          256 | 128‚Üí256  | Downsample + channel expansion
  4    |          512 | 256‚Üí512  | Downsample + channel expansion
  5    |         1024 | 512‚Üí1024  | Downsample + channel expansion
  6    |         2048 | 1024‚Üí2048  | Downsample + channel expansion

‚ö†Ô∏è  Same structure as acoustic encoder: 7 downsample stages
   Channel progression: 1 ‚Üí 32 ‚Üí 64 ‚Üí 128 ‚Üí 256 ‚Üí 512 ‚Üí 1024 ‚Üí 2048


In [42]:
# 6b. Semantic Tokenizer Encoder - Head Layer (Final projection to 128D)
semantic_head = semantic_encoder.head.conv.conv
print(f"=== Semantic Encoder Head Layer (Final Layer) ===")
print(f"Layer: {semantic_encoder.head}")
print(f"\nConv1d:")
print(f"  Weight shape: {semantic_head.weight.shape}, dtype: {semantic_head.weight.dtype}")
print(f"  Dimensions: [0]={semantic_head.weight.shape[0]} (out_channels=128, semantic_latent_dim), [1]={semantic_head.weight.shape[1]} (in_channels), [2]={semantic_head.weight.shape[2]} (kernel)")
print(f"  Purpose: Final projection to 128D semantic latent space")
print(f"  Input: [batch, {semantic_head.weight.shape[1]}, time] ‚Üí Output: [batch, 128, time]")
print(f"\n‚úÖ Outputs 128D semantic latents (vs 64D for acoustic)")


=== Semantic Encoder Head Layer (Final Layer) ===
Layer: SConv1d(
  (conv): NormConv1d(
    (conv): Conv1d(2048, 128, kernel_size=(7,), stride=(1,))
    (norm): Identity()
  )
)

Conv1d:
  Weight shape: torch.Size([128, 2048, 7]), dtype: torch.bfloat16
  Dimensions: [0]=128 (out_channels=128, semantic_latent_dim), [1]=2048 (in_channels), [2]=7 (kernel)
  Purpose: Final projection to 128D semantic latent space
  Input: [batch, 2048, time] ‚Üí Output: [batch, 128, time]

‚úÖ Outputs 128D semantic latents (vs 64D for acoustic)


In [43]:
# 7. Acoustic Connector - Complete Structure
connector = model.model.acoustic_connector
print("=== SpeechConnector (Acoustic) ===")
print(connector)
print(f"\nPurpose: Maps 64D acoustic latents ‚Üí 1536D language model space")
print(f"\n--- Layer 1: fc1 ---")
fc1 = connector.fc1
print(f"  Weight shape: {fc1.weight.shape}, dtype: {fc1.weight.dtype}")
print(f"  Bias shape: {fc1.bias.shape if fc1.bias is not None else None}")
print(f"  Dimensions: [0]={fc1.weight.shape[0]} (output=1536, hidden_size), [1]={fc1.weight.shape[1]} (input=64, acoustic_latent_dim)")
print(f"  Purpose: Projects 64D acoustic latents to 1536D")

print(f"\n--- Layer 2: norm (RMSNorm) ---")
norm = connector.norm
print(f"  Weight shape: {norm.weight.shape}, dtype: {norm.weight.dtype}")
print(f"  Dimensions: [0]={norm.weight.shape[0]} (hidden_size=1536)")
print(f"  Purpose: Normalizes after first projection")

print(f"\n--- Layer 3: fc2 ---")
fc2 = connector.fc2
print(f"  Weight shape: {fc2.weight.shape}, dtype: {fc2.weight.dtype}")
print(f"  Bias shape: {fc2.bias.shape if fc2.bias is not None else None}")
print(f"  Dimensions: [0]={fc2.weight.shape[0]} (output=1536, hidden_size), [1]={fc2.weight.shape[1]} (input=1536, hidden_size)")
print(f"  Purpose: Second projection (1536‚Üí1536) for refinement")

print(f"\n‚úÖ Flow: 64D ‚Üí fc1 ‚Üí 1536D ‚Üí norm ‚Üí fc2 ‚Üí 1536D (ready for LM)")


=== SpeechConnector (Acoustic) ===
SpeechConnector(
  (fc1): Linear(in_features=64, out_features=1536, bias=True)
  (norm): LlamaRMSNorm((1536,), eps=1e-06)
  (fc2): Linear(in_features=1536, out_features=1536, bias=True)
)

Purpose: Maps 64D acoustic latents ‚Üí 1536D language model space

--- Layer 1: fc1 ---
  Weight shape: torch.Size([1536, 64]), dtype: torch.bfloat16
  Bias shape: torch.Size([1536])
  Dimensions: [0]=1536 (output=1536, hidden_size), [1]=64 (input=64, acoustic_latent_dim)
  Purpose: Projects 64D acoustic latents to 1536D

--- Layer 2: norm (RMSNorm) ---
  Weight shape: torch.Size([1536]), dtype: torch.bfloat16
  Dimensions: [0]=1536 (hidden_size=1536)
  Purpose: Normalizes after first projection

--- Layer 3: fc2 ---
  Weight shape: torch.Size([1536, 1536]), dtype: torch.bfloat16
  Bias shape: torch.Size([1536])
  Dimensions: [0]=1536 (output=1536, hidden_size), [1]=1536 (input=1536, hidden_size)
  Purpose: Second projection (1536‚Üí1536) for refinement

‚úÖ Flow: 6

In [44]:
# 8. Semantic Connector - Complete Structure
connector = model.model.semantic_connector
print("=== SpeechConnector (Semantic) ===")
print(connector)
print(f"\nPurpose: Maps 128D semantic latents ‚Üí 1536D language model space")
print(f"\n--- Layer 1: fc1 ---")
fc1 = connector.fc1
print(f"  Weight shape: {fc1.weight.shape}, dtype: {fc1.weight.dtype}")
print(f"  Bias shape: {fc1.bias.shape if fc1.bias is not None else None}")
print(f"  Dimensions: [0]={fc1.weight.shape[0]} (output=1536, hidden_size), [1]={fc1.weight.shape[1]} (input=128, semantic_latent_dim)")
print(f"  Purpose: Projects 128D semantic latents to 1536D")

print(f"\n--- Layer 2: norm (RMSNorm) ---")
norm = connector.norm
print(f"  Weight shape: {norm.weight.shape}, dtype: {norm.weight.dtype}")
print(f"  Dimensions: [0]={norm.weight.shape[0]} (hidden_size=1536)")
print(f"  Purpose: Normalizes after first projection")

print(f"\n--- Layer 3: fc2 ---")
fc2 = connector.fc2
print(f"  Weight shape: {fc2.weight.shape}, dtype: {fc2.weight.dtype}")
print(f"  Bias shape: {fc2.bias.shape if fc2.bias is not None else None}")
print(f"  Dimensions: [0]={fc2.weight.shape[0]} (output=1536, hidden_size), [1]={fc2.weight.shape[1]} (input=1536, hidden_size)")
print(f"  Purpose: Second projection (1536‚Üí1536) for refinement")

print(f"\n‚úÖ Flow: 128D ‚Üí fc1 ‚Üí 1536D ‚Üí norm ‚Üí fc2 ‚Üí 1536D (ready for LM)")
print(f"\n‚ö†Ô∏è  Note: Same structure as acoustic connector, but input is 128D instead of 64D")


=== SpeechConnector (Semantic) ===
SpeechConnector(
  (fc1): Linear(in_features=128, out_features=1536, bias=True)
  (norm): LlamaRMSNorm((1536,), eps=1e-06)
  (fc2): Linear(in_features=1536, out_features=1536, bias=True)
)

Purpose: Maps 128D semantic latents ‚Üí 1536D language model space

--- Layer 1: fc1 ---
  Weight shape: torch.Size([1536, 128]), dtype: torch.bfloat16
  Bias shape: torch.Size([1536])
  Dimensions: [0]=1536 (output=1536, hidden_size), [1]=128 (input=128, semantic_latent_dim)
  Purpose: Projects 128D semantic latents to 1536D

--- Layer 2: norm (RMSNorm) ---
  Weight shape: torch.Size([1536]), dtype: torch.bfloat16
  Dimensions: [0]=1536 (hidden_size=1536)
  Purpose: Normalizes after first projection

--- Layer 3: fc2 ---
  Weight shape: torch.Size([1536, 1536]), dtype: torch.bfloat16
  Bias shape: torch.Size([1536])
  Dimensions: [0]=1536 (output=1536, hidden_size), [1]=1536 (input=1536, hidden_size)
  Purpose: Second projection (1536‚Üí1536) for refinement

‚úÖ F

## Part 7: Prediction Head (Diffusion Head)

Generates acoustic latents using diffusion, conditioned on language model hidden states. Has **4 HeadLayer blocks** plus a final layer.


In [45]:
# 9. Prediction Head - Input Projections
head = model.model.prediction_head
print("=== VibeVoiceDiffusionHead ===")
print(f"Number of HeadLayer blocks: {len(head.layers)}")
print(f"\n--- Input Projection 1: noisy_images_proj ---")
noisy_proj = head.noisy_images_proj
print(f"  Weight shape: {noisy_proj.weight.shape}, dtype: {noisy_proj.weight.dtype}")
print(f"  Dimensions: [0]={noisy_proj.weight.shape[0]} (output=1536, hidden_size), [1]={noisy_proj.weight.shape[1]} (input=64, acoustic_latent_dim)")
print(f"  Purpose: Projects noisy acoustic latents to hidden_size for diffusion")

print(f"\n--- Input Projection 2: cond_proj ---")
cond_proj = head.cond_proj
print(f"  Weight shape: {cond_proj.weight.shape}, dtype: {cond_proj.weight.dtype}")
print(f"  Dimensions: [0]={cond_proj.weight.shape[0]} (output=1536, cond_dim), [1]={cond_proj.weight.shape[1]} (input=1536, hidden_size)")
print(f"  Purpose: Projects condition (LM hidden state) to cond_dim")

print(f"\n--- Timestep Embedder ---")
t_embedder = head.t_embedder
print(f"  {t_embedder}")
print(f"  Purpose: Embeds diffusion timestep for conditioning")
print(f"  Structure: MLP that processes timestep ‚Üí cond_dim")


=== VibeVoiceDiffusionHead ===
Number of HeadLayer blocks: 4

--- Input Projection 1: noisy_images_proj ---
  Weight shape: torch.Size([1536, 64]), dtype: torch.bfloat16
  Dimensions: [0]=1536 (output=1536, hidden_size), [1]=64 (input=64, acoustic_latent_dim)
  Purpose: Projects noisy acoustic latents to hidden_size for diffusion

--- Input Projection 2: cond_proj ---
  Weight shape: torch.Size([1536, 1536]), dtype: torch.bfloat16
  Dimensions: [0]=1536 (output=1536, cond_dim), [1]=1536 (input=1536, hidden_size)
  Purpose: Projects condition (LM hidden state) to cond_dim

--- Timestep Embedder ---
  TimestepEmbedder(
  (mlp): Sequential(
    (0): Linear(in_features=256, out_features=1536, bias=False)
    (1): SiLU()
    (2): Linear(in_features=1536, out_features=1536, bias=False)
  )
)
  Purpose: Embeds diffusion timestep for conditioning
  Structure: MLP that processes timestep ‚Üí cond_dim


In [46]:
# 10. Prediction Head - HeadLayer 0 (Processing Blocks)
prediction_head = model.model.prediction_head  # Use full path to avoid variable overwrite
head_layer_0 = prediction_head.layers[0]
print(f"=== HeadLayer 0 (of {len(prediction_head.layers)}) ===")
print(f"\n‚ö†Ô∏è  This layer pattern repeats for all {len(prediction_head.layers)} HeadLayer blocks")
print(f"\n--- FFN (Feed-Forward Network) ---")
ffn = head_layer_0.ffn
print(f"\ngate_proj:")
print(f"  ffn.gate_proj.weight.shape = {ffn.gate_proj.weight.shape}")
print(f"  ffn.gate_proj.weight.dtype = {ffn.gate_proj.weight.dtype}")
print(f"  Tensor shape: [out={ffn.gate_proj.weight.shape[0]}, in={ffn.gate_proj.weight.shape[1]}]")
print(f"\nup_proj:")
print(f"  ffn.up_proj.weight.shape = {ffn.up_proj.weight.shape}")
print(f"  ffn.up_proj.weight.dtype = {ffn.up_proj.weight.dtype}")
print(f"  Tensor shape: [out={ffn.up_proj.weight.shape[0]}, in={ffn.up_proj.weight.shape[1]}]")
print(f"\ndown_proj:")
print(f"  ffn.down_proj.weight.shape = {ffn.down_proj.weight.shape}")
print(f"  ffn.down_proj.weight.dtype = {ffn.down_proj.weight.dtype}")
print(f"  Tensor shape: [out={ffn.down_proj.weight.shape[0]}, in={ffn.down_proj.weight.shape[1]}]")

print(f"\n--- Norm ---")
norm = head_layer_0.norm
print(f"  norm.weight.shape = {norm.weight.shape}")
print(f"  norm.weight.dtype = {norm.weight.dtype}")
print(f"  Tensor shape: [{norm.weight.shape[0]}] (1D tensor)")

print(f"\n--- Adaptive Layer Norm Modulation (adaLN) ---")
adaLN = head_layer_0.adaLN_modulation
print(f"  {adaLN}")
adaLN_linear = adaLN[1]
print(f"  adaLN[1].weight.shape = {adaLN_linear.weight.shape}")
print(f"  adaLN[1].weight.dtype = {adaLN_linear.weight.dtype}")
print(f"  Tensor shape: [out={adaLN_linear.weight.shape[0]}, in={adaLN_linear.weight.shape[1]}]")
print(f"  Purpose: Modulates layer norm based on condition (timestep + text context)")

print(f"\n‚ö†Ô∏è  This HeadLayer structure repeats in all {len(prediction_head.layers)} layers")


=== HeadLayer 0 (of 4) ===

‚ö†Ô∏è  This layer pattern repeats for all 4 HeadLayer blocks

--- FFN (Feed-Forward Network) ---

gate_proj:
  ffn.gate_proj.weight.shape = torch.Size([4608, 1536])
  ffn.gate_proj.weight.dtype = torch.bfloat16
  Tensor shape: [out=4608, in=1536]

up_proj:
  ffn.up_proj.weight.shape = torch.Size([4608, 1536])
  ffn.up_proj.weight.dtype = torch.bfloat16
  Tensor shape: [out=4608, in=1536]

down_proj:
  ffn.down_proj.weight.shape = torch.Size([1536, 4608])
  ffn.down_proj.weight.dtype = torch.bfloat16
  Tensor shape: [out=1536, in=4608]

--- Norm ---
  norm.weight.shape = torch.Size([1536])
  norm.weight.dtype = torch.bfloat16
  Tensor shape: [1536] (1D tensor)

--- Adaptive Layer Norm Modulation (adaLN) ---
  Sequential(
  (0): SiLU()
  (1): Linear(in_features=1536, out_features=4608, bias=False)
)
  adaLN[1].weight.shape = torch.Size([4608, 1536])
  adaLN[1].weight.dtype = torch.bfloat16
  Tensor shape: [out=4608, in=1536]
  Purpose: Modulates layer norm ba

In [51]:
# 11. Prediction Head - Final Layer (Output Projection)
prediction_head = model.model.prediction_head  # Use full path to avoid variable overwrite
final_layer = prediction_head.final_layer
print("=== FinalLayer ===")
print(f"\n--- Final Norm ---")
norm_final = final_layer.norm_final
print(f"  norm_final.weight = {norm_final.weight if norm_final.weight is not None else 'None'}")
print(f"  Purpose: Final normalization (no learnable params)")

print(f"\n--- Output Projection (linear) ---")
linear = final_layer.linear
print(f"  linear.weight.shape = {linear.weight.shape}")
print(f"  linear.weight.dtype = {linear.weight.dtype}")
print(f"  Tensor shape: [out={linear.weight.shape[0]}, in={linear.weight.shape[1]}]")
print(f"  Purpose: Projects [B, {linear.weight.shape[1]}] ‚Üí [B, {linear.weight.shape[0]}]")
print(f"  Note: '64D' output means shape [B, 64] where B=batch size")

print(f"\n--- Final adaLN Modulation ---")
adaLN_final = final_layer.adaLN_modulation
adaLN_final_linear = adaLN_final[1]
print(f"  adaLN_final[1].weight.shape = {adaLN_final_linear.weight.shape}")
print(f"  adaLN_final[1].weight.dtype = {adaLN_final_linear.weight.dtype}")
print(f"  Tensor shape: [out={adaLN_final_linear.weight.shape[0]}, in={adaLN_final_linear.weight.shape[1]}]")
print(f"  Purpose: Final adaptive modulation before output")

print(f"\n‚úÖ This is the final layer - outputs 64D acoustic latents for diffusion")
print(f"   Flow: [B, 1536] ‚Üí linear ‚Üí [B, 64]")
print(f"   Where '64D' means tensor shape [B, 64]")


=== FinalLayer ===

--- Final Norm ---
  norm_final.weight = None
  Purpose: Final normalization (no learnable params)

--- Output Projection (linear) ---
  linear.weight.shape = torch.Size([64, 1536])
  linear.weight.dtype = torch.bfloat16
  Tensor shape: [out=64, in=1536]
  Purpose: Projects [B, 1536] ‚Üí [B, 64]
  Note: '64D' output means shape [B, 64] where B=batch size

--- Final adaLN Modulation ---
  adaLN_final[1].weight.shape = torch.Size([3072, 1536])
  adaLN_final[1].weight.dtype = torch.bfloat16
  Tensor shape: [out=3072, in=1536]
  Purpose: Final adaptive modulation before output

‚úÖ This is the final layer - outputs 64D acoustic latents for diffusion
   Flow: [B, 1536] ‚Üí linear ‚Üí [B, 64]
   Where '64D' means tensor shape [B, 64]
