In [4]:
from multimodal_alignment_perceiver import MultimodalAlignmentConfig, MultimodalAlignmentModel, count_parameters
from multimodal_alignment_perceiver import contrastive_loss, matryoshka_loss

In [5]:
import torch

In [6]:
print("="*70)
print("MULTIMODAL ALIGNMENT WITH PERCEIVER RESAMPLER")
print("="*70)

# Configuration
config = MultimodalAlignmentConfig(
    perceiver_dim=512,
    num_latents=64,
    num_perceiver_layers=4,
    d_align=512,
)


MULTIMODAL ALIGNMENT WITH PERCEIVER RESAMPLER


In [7]:
# Create model
model = MultimodalAlignmentModel(config)

In [8]:
# Print architecture
print("\nüìê Architecture:")
params = count_parameters(model)
print(f"   Total parameters: {params['total']:,}")
print(f"   Trainable: {params['trainable']:,}")

# Test forward pass
print("\nüß™ Testing forward pass...")



üìê Architecture:
   Total parameters: 21,621,760
   Trainable: 21,621,760

üß™ Testing forward pass...


In [9]:

# Simulate encoder outputs
batch_size = 4
vision_feats = torch.randn(batch_size, 50, config.d_vision)   # CLIP: 50 patches
audio_feats = torch.randn(batch_size, 1500, config.d_audio)   # Whisper: ~1500 frames
text_feats = torch.randn(batch_size, 32, config.d_text)       # Text: 32 tokens



In [10]:
# Encode each modality
z_vision = model.encode_vision(vision_feats)
z_audio = model.encode_audio(audio_feats)
z_text = model.encode_text(text_feats)

print(f"   Vision embedding: {z_vision.shape}")  # (4, 512)
print(f"   Audio embedding: {z_audio.shape}")    # (4, 512)
print(f"   Text embedding: {z_text.shape}")      # (4, 512)



   Vision embedding: torch.Size([4, 512])
   Audio embedding: torch.Size([4, 512])
   Text embedding: torch.Size([4, 512])


In [11]:
# Test LLM projection
llm_prefix = model.project_to_llm(vision_feats, 'vision')
print(f"   LLM prefix: {llm_prefix.shape}")      # (4, 64, 1536)

# Test loss computation
print("\nüìâ Testing loss computation...")


   LLM prefix: torch.Size([4, 64, 1536])

üìâ Testing loss computation...


In [12]:

loss_clip = contrastive_loss(z_vision, z_text)
loss_mrl = matryoshka_loss(z_vision, z_text, dims=config.mrl_dims)

print(f"   CLIP loss: {loss_clip.item():.4f}")
print(f"   MRL loss: {loss_mrl.item():.4f}")

print("\n‚úÖ All tests passed!")
print("="*70)


   CLIP loss: 1.4866
   MRL loss: 1.5836

‚úÖ All tests passed!
