## Setup

In [None]:
import sys
import torch
from torch.utils.data import DataLoader

# Import custom modules
from loaders.audio_loader import AudioLoader
from preprocessing.audio_preprocessor import AudioPreprocessor
from models.audio_encoder import AudioEncoder

In [None]:
# Configuration
DATA_DIR = "D:/florida_coursework/third_sem/multimedia_expert_systems/multimedia_prototype/data/audio/for-norm/for-norm/"
MODEL_CHECKPOINT = "facebook/wav2vec2-base"
MAX_DURATION = 5.0  # seconds
PROJECTION_DIM = 512  # Embedding dim for fusion
BATCH_SIZE = 4  # For 8GB GPU
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Device: {DEVICE}")

## Step 1: Load Raw Audio Data

In [None]:
# Initialize loader
loader = AudioLoader(data_dir=DATA_DIR, target_sr=16000)

# Load dataset (returns train/validation/test splits)
dataset = loader.load()

print(f"\nDataset splits: {loader.get_splits()}")
print(f"Label mapping: {loader.get_label_mapping()}")
print(f"\nTrain samples: {len(dataset['training'])}")
print(f"Val samples: {len(dataset['validation'])}")
print(f"Test samples: {len(dataset['testing'])}")

In [None]:
# Inspect a sample
sample = dataset['training'][0]
print(f"Sample keys: {sample.keys()}")
print(f"Audio array shape: {sample['audio']['array'].shape}")
print(f"Sampling rate: {sample['audio']['sampling_rate']}")
print(f"Label: {sample['label']}")

## Step 2: Preprocess Audio

In [None]:
# Initialize preprocessor
preprocessor = AudioPreprocessor(
    model_checkpoint=MODEL_CHECKPOINT,
    max_duration=MAX_DURATION
)

print(preprocessor)

In [None]:
# Apply preprocessing to datasets
encoded_train = dataset['training'].map(
    preprocessor.preprocess_batch,
    remove_columns=['audio'],
    batched=True
)

encoded_val = dataset['validation'].map(
    preprocessor.preprocess_batch,
    remove_columns=['audio'],
    batched=True
)

encoded_test = dataset['testing'].map(
    preprocessor.preprocess_batch,
    remove_columns=['audio'],
    batched=True
)

# Set format to PyTorch
encoded_train.set_format(type='torch')
encoded_val.set_format(type='torch')
encoded_test.set_format(type='torch')

print(f"Encoded train sample keys: {encoded_train[0].keys()}")
print(f"Input values shape: {encoded_train[0]['input_values'].shape}")

## Step 3: Create DataLoaders with Custom Collate

In [None]:
# Create DataLoaders with custom collate function for padding
train_loader = DataLoader(
    encoded_train,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=preprocessor.collate_fn,
    num_workers=0  # Windows compatibility
)

val_loader = DataLoader(
    encoded_val,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=preprocessor.collate_fn,
    num_workers=0
)

test_loader = DataLoader(
    encoded_test,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=preprocessor.collate_fn,
    num_workers=0
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

In [None]:
# Test loading a batch
batch = next(iter(train_loader))
print(f"\nBatch keys: {batch.keys()}")
print(f"Input values shape: {batch['input_values'].shape}")
print(f"Attention mask shape: {batch['attention_mask'].shape}")
print(f"Labels shape: {batch['labels'].shape}")
print(f"Labels: {batch['labels']}")

## Step 4: Initialize Audio Encoder

In [None]:
# Create encoder for multimodal fusion
audio_encoder = AudioEncoder(
    model_checkpoint=MODEL_CHECKPOINT,
    projection_dim=PROJECTION_DIM,  # Project to 512-dim for fusion
    freeze_feature_extractor=True,  # Save memory on 8GB GPU
    freeze_encoder=False  # Fine-tune encoder
).to(DEVICE)

print(audio_encoder)
print(f"\nOutput embedding dimension: {audio_encoder.get_embedding_dim()}")

In [None]:
# Count parameters
total_params = sum(p.numel() for p in audio_encoder.parameters())
trainable_params = sum(p.numel() for p in audio_encoder.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Frozen parameters: {total_params - trainable_params:,}")

## Step 5: Generate Embeddings

In [None]:
# Forward pass through encoder
audio_encoder.eval()
with torch.no_grad():
    # Move batch to device
    input_values = batch['input_values'].to(DEVICE)
    attention_mask = batch['attention_mask'].to(DEVICE)
    
    # Generate embeddings
    embeddings = audio_encoder(input_values, attention_mask)
    
    print(f"Audio embeddings shape: {embeddings.shape}")  # (B, 512)
    print(f"Embeddings sample:\n{embeddings[0, :10]}")

## Step 6: Ready for Multimodal Fusion!

These audio embeddings (B, 512) can now be:
- Concatenated with image embeddings: `torch.cat([audio_emb, image_emb], dim=-1)`
- Fused with video embeddings using attention mechanisms
- Fed into a fusion network for final classification

### Next Steps:
1. Load image/video encoders
2. Create fusion model
3. Train end-to-end or with frozen audio encoder

## Memory Usage Tips for 8GB GPU:

```python
# Option 1: Freeze audio encoder completely (only train fusion layer)
audio_encoder = AudioEncoder(
    projection_dim=512,
    freeze_encoder=True  # Freeze everything
)

# Option 2: Use smaller projection dimension
audio_encoder = AudioEncoder(
    projection_dim=256,  # Smaller embeddings
    freeze_feature_extractor=True
)

# Option 3: Gradient checkpointing (if needed)
audio_encoder.model.gradient_checkpointing_enable()
```