In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass

@dataclass
class Wav2Vec2Config:
    # Feature encoder
    conv_layers: list = None
    dropout: float = 0.1
    layer_drop: float = 0.05
    
    # Transformer
    d_model: int = 768
    nhead: int = 8
    num_encoder_layers: int = 12
    dim_feedforward: int = 3072
    
    # Quantizer
    num_groups: int = 2
    num_vars: int = 320
    temp: float = 2.0
    min_temp: float = 0.5
    temp_decay: float = 0.999995
    
    # Masking
    mask_prob: float = 0.065
    mask_length: int = 10
    
    # Training
    learning_rate: float = 5e-4
    warmup_steps_pct: float = 0.08  # 8% warmup
    num_updates: int = 400_000  # BASE model updates
    l2_weight: float = 0.1  # L2 penalty for encoder activations
    encoder_grad_scale: float = 0.1  # Scale down encoder gradients
    contrastive_temperature: float = 0.1  # κ in the paper
    diversity_weight: float = 0.1  # α in the paper
    num_negatives: int = 100  # K distractors
    
    @classmethod
    def BASE(cls):
        return cls(
            conv_layers=[(512, 10, 5)] + [(512, 3, 2)] * 5 + [(512, 2, 2)],
            layer_drop=0.05,
            d_model=768,  # Ensures d/G = 384 for codebook entries
            min_temp=0.5,
            num_updates=400_000,
        )
    
    @classmethod
    def TINY(cls):
        """Tiny configuration for fast testing"""
        return cls(
            conv_layers=[(256, 10, 5)] + [(256, 3, 2)] * 3 + [(256, 2, 2)],  # Fewer layers, smaller channels
            d_model=256,  # Smaller model dimension
            nhead=4,      # Fewer attention heads
            num_encoder_layers=4,  # Fewer transformer layers
            dim_feedforward=1024,  # Smaller feedforward dimension
            layer_drop=0.05,
            num_groups=2,
            num_vars=160,  # Smaller codebook
            learning_rate=5e-4,  # Same as BASE but with faster schedule
            num_updates=50_000,  # Much fewer updates for testing
            warmup_steps_pct=0.1,  # Slightly faster warmup
            mask_prob=0.1,  # Slightly higher masking probability
            mask_length=5,  # Shorter mask length
            min_temp=0.5,  # Same as BASE
            temp_decay=0.9999,  # Faster temperature decay
            l2_weight=0.01,  # Reduced L2 penalty
            encoder_grad_scale=0.1,  # Same as BASE
        )
    
    @classmethod
    def LARGE(cls):
        return cls(
            conv_layers=[(512, 10, 5)] + [(512, 3, 2)] * 5 + [(512, 2, 2)],
            d_model=1024,
            nhead=16,
            num_encoder_layers=24,
            dim_feedforward=4096,
            layer_drop=0.2,
            min_temp=0.1,
            learning_rate=3e-4,
            num_updates=250_000,
        )

class FeatureEncoder(nn.Module):
    def __init__(self, conv_layers=[(512, 10, 5)] + [(512, 3, 2)] * 5 + [(512, 2, 2)]):
        super().__init__()
        
        layers = []
        in_channels = 1  # raw audio input
        
        # First layer without normalization (as per paper for Librispeech)
        layers.append(
            nn.Sequential(
                nn.Conv1d(in_channels, conv_layers[0][0], conv_layers[0][1], stride=conv_layers[0][2]),
                nn.GELU(),
                nn.Dropout(0.1),
            )
        )
        
        # Normalize output of first layer
        self.layer_norm = nn.LayerNorm(conv_layers[0][0])
        
        # Remaining layers
        in_channels = conv_layers[0][0]
        for out_channels, kernel_size, stride in conv_layers[1:]:
            layers.append(
                nn.Sequential(
                    nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride),
                    nn.GELU(),
                    nn.Dropout(0.1),
                )
            )
            in_channels = out_channels
        
        self.layers = nn.ModuleList(layers)
    
    def forward(self, x):
        # x shape: (batch_size, sequence_length)
        x = x.unsqueeze(1)  # Add channel dimension
        
        # First layer
        x = self.layers[0](x)
        
        # Normalize output of first layer (as per paper)
        x = x.transpose(1, 2)
        x = self.layer_norm(x)
        x = x.transpose(1, 2)
        
        # Remaining layers
        for layer in self.layers[1:]:
            x = layer(x)
            
        return x.transpose(1, 2)  # Return (batch_size, time_steps, channels)

class ProductQuantizer(nn.Module):
    def __init__(self, input_dim, num_groups=2, num_vars=320, temp=2.0, min_temp=0.5, temp_decay=0.999995):
        super().__init__()
        self.num_groups = num_groups
        self.num_vars = num_vars
        self.temp = temp
        self.min_temp = min_temp
        self.temp_decay = temp_decay
        
        self.vars = nn.Parameter(torch.FloatTensor(num_groups * num_vars, input_dim // num_groups))
        nn.init.uniform_(self.vars)
        
        self.weight_proj = nn.Linear(input_dim, num_groups * num_vars)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, x):
        x = self.dropout(x)
        bsz, tsz, fsz = x.shape
        
        # Project to G x V logits
        x = self.weight_proj(x)
        x = x.view(bsz * tsz, self.num_groups, self.num_vars)
        
        if self.training:
            # Gumbel noise
            uniform_noise = torch.rand_like(x)
            gumbel = -torch.log(-torch.log(uniform_noise + 1e-10) + 1e-10)
            
            # Apply formula: exp((l_{g,v} + n_v)/τ) / sum_k(exp((l_{g,k} + n_k)/τ))
            logits_with_noise = (x + gumbel) / self.temp
            numerator = torch.exp(logits_with_noise)
            denominator = numerator.sum(dim=-1, keepdim=True)
            x = numerator / denominator
            
            # Update temperature
            self.temp = max(self.temp * self.temp_decay, self.min_temp)
        else:
            # During inference, use straight-through estimator
            logits = x / self.temp
            x = F.softmax(logits, dim=-1)
            
        # Straight-through Gumbel-Softmax
        indices = x.max(dim=-1)[1]
        x_hard = torch.zeros_like(x).scatter_(-1, indices.unsqueeze(-1), 1.0)
        x = (x_hard - x).detach() + x
        
        return x.view(bsz, tsz, -1)

class Wav2Vec2(nn.Module):
    def __init__(self, config: Wav2Vec2Config):
        super().__init__()
        
        self.config = config
        # Feature encoder with layer norm and GELU
        self.feature_encoder = FeatureEncoder(config.conv_layers)
        
        # Calculate the encoder output dimension based on the last conv layer
        last_conv_channels = config.conv_layers[-1][0]
        
        # Add projection layer to match transformer dimensions
        self.proj = nn.Linear(last_conv_channels, config.d_model)
        
        # Add projection for quantized vectors
        self.quantizer_proj = nn.Linear(config.num_groups * config.num_vars, config.d_model)
        
        # Layer normalization
        self.layer_norm = nn.LayerNorm(config.d_model)
        
        # Context network components
        # 1. Convolutional layer for relative positional embedding
        kernel_size = 128
        # Calculate padding to maintain sequence length
        padding = kernel_size  # Full padding on both sides
        self.context_pos_conv = nn.Sequential(
            nn.Conv1d(
                config.d_model,
                config.d_model,
                kernel_size=kernel_size,
                padding=padding,
                groups=16,
                padding_mode='replicate'  # Use replicate padding to avoid edge effects
            ),
            nn.GELU()
        )
        
        # 2. Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config.d_model,
            nhead=config.nhead,
            dim_feedforward=config.dim_feedforward,
            dropout=config.dropout,
            activation="gelu",
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, config.num_encoder_layers)
        
        # Quantizer
        self.quantizer = ProductQuantizer(
            config.d_model,
            num_groups=config.num_groups,
            num_vars=config.num_vars,
            temp=config.temp,
            min_temp=config.min_temp,
            temp_decay=config.temp_decay
        )
        
        self.mask_emb = nn.Parameter(torch.FloatTensor(config.d_model).uniform_())
        
    def apply_mask(self, x, mask_prob=0.065, mask_length=10):
        B, T, C = x.shape
        
        # Calculate how many starting indices to sample
        num_mask = int(T * mask_prob)
        
        # Sample starting indices
        mask_starts = torch.randperm(T)[:num_mask]
        
        # Create mask tensor
        mask = torch.zeros(B, T, dtype=torch.bool, device=x.device)
        
        # For each starting index, mask the subsequent M time steps
        for start in mask_starts:
            end = min(start + mask_length, T)
            mask[:, start:end] = True
            
        return mask
        
    def forward(self, x, mask=True):
        # Debug: Print input shape
        # print(f"\nShape tracking:")
        # print(f"Raw input: {x.shape}")
        
        # 1. Feature encoder (provides initial relative positional information)
        x = self.feature_encoder(x)
        #print(f"After feature encoder: {x.shape}")
        
        # 2. Project to transformer dimension
        x = self.proj(x)
        #print(f"After projection: {x.shape}")

        q = self.quantizer(x)
        q = self.quantizer_proj(q)
        
        # 3. Initialize mask_indices
        mask_indices = None
        
        # 4. Apply masking if requested
        if mask:
            mask_indices = self.apply_mask(
                x,
                mask_prob=self.config.mask_prob,
                mask_length=self.config.mask_length
            )
            x = torch.where(
                mask_indices.unsqueeze(-1),
                self.mask_emb.view(1, 1, -1).expand(x.shape[0], -1, -1),
                x
            )
        #print(f"After masking: {x.shape}")
        
        # 5. Layer normalization
        x = self.layer_norm(x)
        #print(f"After layer norm: {x.shape}")
        
        # 6. Context network processing
        # Add relative positional information through convolution
        x_t = x.transpose(1, 2)  # [B, T, C] -> [B, C, T]
        #print(f"Before context conv: {x_t.shape}")
        
        # Save original sequence length
        orig_len = x_t.size(2)
        
        # Apply convolution and ensure output length matches input
        pos_embedding = self.context_pos_conv(x_t)  # Apply conv and GELU
        
        # Ensure we get exactly the sequence length we want
        if pos_embedding.size(2) > orig_len:
            # If too long, trim from both ends equally
            excess = pos_embedding.size(2) - orig_len
            start = excess // 2
            pos_embedding = pos_embedding[:, :, start:start + orig_len]
        elif pos_embedding.size(2) < orig_len:
            # If too short, pad both ends equally
            pad_size = orig_len - pos_embedding.size(2)
            pad_left = pad_size // 2
            pad_right = pad_size - pad_left
            pos_embedding = F.pad(pos_embedding, (pad_left, pad_right), mode='replicate')
            
        #print(f"After context conv: {pos_embedding.shape}")
        
        pos_embedding = pos_embedding.transpose(1, 2)  # [B, C, T] -> [B, T, C]
        #print(f"After transpose: {pos_embedding.shape}")
        
        # Verify shapes match before adding
        assert x.shape == pos_embedding.shape, f"Shape mismatch: x={x.shape}, pos_embedding={pos_embedding.shape}"
        x = x + pos_embedding
        
        # 7. Transformer processing
        c = self.transformer(x)
        #print(f"After transformer: {c.shape}")
        
        #print(f"Final quantized: {q.shape}")
        
        if not mask:
            mask_indices = torch.zeros(x.shape[0], x.shape[1], dtype=torch.bool, device=x.device)
        
        return c, q, mask_indices
        
    def compute_loss(self, c, q, mask_indices, num_negatives=100, temperature=0.1, eps=1e-7):
        """
        Compute contrastive loss:
        L_m = -log(exp(sim(c_t, q_t)/κ) / sum_k(exp(sim(c_t, q_k)/κ)))
        where:
        - c_t is the context network output at masked position t
        - q_t is the correct quantized representation at position t
        - q_k are the distractors (including q_t)
        - κ is the temperature (set to 0.1)
        - sim(a,b) is the cosine similarity between a and b
        """
        # Check if we have any masked indices
        if mask_indices.sum() == 0:
            return torch.tensor(0.0, device=c.device, requires_grad=True)
        
        # Get masked indices in flattened form
        flat_mask = mask_indices.view(-1)
        masked_indices = torch.nonzero(flat_mask).squeeze(-1)
        
        if len(masked_indices) == 0:  # No masked positions
            return torch.tensor(0.0, device=c.device, requires_grad=True)
        
        # Get positive samples (c_t and q_t pairs)
        c_masked = c.view(-1, c.size(-1))[masked_indices]  # c_t
        q_masked = q.view(-1, q.size(-1))[masked_indices]  # q_t
        
        # Sample negative indices for each positive
        with torch.no_grad():
            neg_indices = self._sample_negatives(masked_indices, len(flat_mask), num_negatives)
            negatives = q.view(-1, q.size(-1))[neg_indices]  # q_k distractors
        
        # Compute cosine similarity with numerical stability
        c_masked = F.normalize(c_masked + eps, dim=-1)
        q_masked = F.normalize(q_masked + eps, dim=-1)
        negatives = F.normalize(negatives + eps, dim=-1)
        
        # Compute sim(c_t, q_t) for positives
        pos_logits = torch.sum(c_masked * q_masked, dim=-1, keepdim=True)  # [num_masked, 1]
        
        # Compute sim(c_t, q_k) for negatives
        neg_logits = torch.bmm(c_masked.unsqueeze(1), negatives.transpose(1, 2)).squeeze(1)  # [num_masked, num_negatives]
        
        # Concatenate positive and negative logits
        logits = torch.cat([pos_logits, neg_logits], dim=1)  # [num_masked, 1 + num_negatives]
        
        # Scale by temperature κ
        logits = logits / temperature
        
        # Targets are zeros (positive pair should be selected)
        targets = torch.zeros(logits.size(0), dtype=torch.long, device=logits.device)
        
        # Compute contrastive loss using cross entropy (equivalent to -log(exp(pos)/sum(exp(all))))
        contrastive_loss = F.cross_entropy(logits, targets)
        
        # Compute diversity loss (weight α = 0.1 as per paper)
        try:
            prob_perplexity = self.compute_prob_perplexity()
            diversity_loss = -torch.log(prob_perplexity + eps) * 0.1  # α = 0.1
            diversity_loss = torch.clamp(diversity_loss, min=-10, max=10)
        except Exception as e:
            print(f"Warning: Error computing diversity loss: {e}")
            diversity_loss = torch.tensor(0.0, device=c.device, requires_grad=True)
        
        # Total loss
        loss = contrastive_loss + diversity_loss
        
        # Print loss components only if they're valid
        # if self.training and not torch.isnan(loss) and not torch.isinf(loss):
        #     print(f"\nLoss components:")
        #     print(f"Contrastive loss: {contrastive_loss.item():.4f}")
        #     print(f"Diversity loss: {diversity_loss.item():.4f}")
        #     print(f"Total loss: {loss.item():.4f}")
        #     print(f"Prob perplexity: {prob_perplexity.item():.2f}")
        #     print(f"Number of masked positions: {len(masked_indices)}")
        #     print(f"Average positive logit: {pos_logits.mean().item():.4f}")
        #     print(f"Average negative logit: {neg_logits.mean().item():.4f}")
        
        return loss
        
    def compute_prob_perplexity(self, eps=1e-7):
        """
        Compute the perplexity of the averaged softmax probability over codebook entries
        This helps ensure even usage of the codebook vectors
        """
        # Get the weight matrix from the quantizer projection
        logits = self.quantizer.weight_proj.weight
        
        # Reshape to (num_groups, num_vars, -1)
        logits = logits.view(
            self.config.num_groups,
            self.config.num_vars,
            -1
        )
        
        # Compute softmax probabilities with numerical stability
        logits = torch.clamp(logits, min=-100, max=100)
        probs = F.softmax(logits, dim=1)  # Along codebook dimension
        
        # Average over feature dimension
        avg_probs = probs.mean(dim=-1)
        
        # Add small epsilon to avoid log(0)
        avg_probs = avg_probs + eps
        
        # Compute perplexity for each group
        perplexities = []
        for g in range(self.config.num_groups):
            p = avg_probs[g]
            # Normalize probabilities to sum to 1
            p = p / p.sum()
            perplexity = torch.exp(-torch.sum(p * torch.log(p)))
            perplexities.append(perplexity)
        
        # Average perplexity across groups
        avg_perplexity = torch.stack(perplexities).mean()
        
        return avg_perplexity
        
    def _sample_negatives(self, pos_indices, num_masked, num_negatives):
        """Sample negative indices from other masked positions."""
        with torch.no_grad():
            # Create a range of all masked indices
            all_indices = torch.arange(num_masked, device=pos_indices.device)
            
            # For each positive, sample K distractors from other masked positions
            neg_indices = []
            for i in range(len(pos_indices)):
                # Exclude the current positive index
                valid_indices = torch.cat([all_indices[:i], all_indices[i+1:]])
                # Sample K indices
                sampled = valid_indices[torch.randperm(len(valid_indices))[:num_negatives]]
                neg_indices.append(sampled)
            
            return torch.stack(neg_indices) 

In [3]:
import os
import numpy as np
import torch
import torchaudio

class LibriSpeech(torch.utils.data.Dataset):
    def __init__(self, split="test-clean", target_length=480000, device='cpu'):
        self.dataset = torchaudio.datasets.LIBRISPEECH(
            root=os.path.expanduser("~/.cache"),
            url=split,
            download=True,
        )
        self.device = device
        self.target_length = target_length
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self, item):
        audio, sample_rate, text, _, _, _ = self.dataset[item]
        assert sample_rate == 16000
        audio = audio.flatten().numpy()
        audio_length = len(audio)
        if audio_length < self.target_length:
            padding = np.zeros(self.target_length - audio_length)
            audio = np.concatenate((audio, padding))
        elif audio_length > self.target_length:
            audio = audio[:self.target_length]
        audio = torch.tensor(audio, dtype=torch.float32)
        return audio, text

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
import math
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
from datetime import datetime
import json

class WarmupLinearSchedule(optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, warmup_steps, total_steps, last_epoch=-1):
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        super(WarmupLinearSchedule, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        step = self.last_epoch
        if step < self.warmup_steps:
            # Linear warmup
            return [base_lr * step / self.warmup_steps for base_lr in self.base_lrs]
        else:
            # Linear decay
            return [base_lr * (self.total_steps - step) / (self.total_steps - self.warmup_steps)
                   for base_lr in self.base_lrs]

class Trainer:
    def __init__(
        self,
        model: Wav2Vec2,
        train_dataset: LibriSpeech,
        val_dataset: LibriSpeech,
        config: Wav2Vec2Config,
        device: torch.device,
        is_librispeech: bool = True,
        patience: int = 10,
        log_dir: str = "runs",
        batch_size: int = 8,
        checkpoint_dir: str = "checkpoints",
        loader_kwargs: dict = None
    ):
        self.model = model.to(device)
        self.device = device
        self.config = config
        self.is_librispeech = is_librispeech
        self.patience = patience
        self.log_dir = log_dir
        self.checkpoint_dir = checkpoint_dir
        
        # Create log directory
        self.run_dir = os.path.join(log_dir, datetime.now().strftime("%Y%m%d_%H%M%S"))
        os.makedirs(self.run_dir, exist_ok=True)
        
        # Initialize metrics tracking
        self.train_losses = []
        self.val_losses = []
        self.learning_rates = []
        self.current_epoch = 0
        
        # Early stopping variables
        self.patience_counter = 0
        self.best_val_loss = float('inf')
        self.early_stop = False
        
        # Librispeech specific regularization
        if is_librispeech:
            self.encoder_grad_scale = 0.1
            self.l2_regularization = True
        else:
            self.encoder_grad_scale = 1.0
            self.l2_regularization = False
            
        # Use provided loader kwargs or default
        if loader_kwargs is None:
            loader_kwargs = {
                'batch_size': batch_size,
                'num_workers': 2 if device.type == 'cuda' else 0,
                'pin_memory': device.type == 'cuda',
            }
            
        # Create data loaders
        self.train_loader = DataLoader(
            train_dataset,
            shuffle=True,
            **loader_kwargs
        )
        
        self.val_loader = DataLoader(
            val_dataset,
            shuffle=False,
            **loader_kwargs
        )
        
        # Setup optimizer
        self.optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
        
        # Setup learning rate scheduler
        total_steps = 400000 if config.d_model == 768 else 250000  # BASE vs LARGE
        warmup_steps = int(0.08 * total_steps)  # 8% warmup
        self.scheduler = WarmupLinearSchedule(
            self.optimizer,
            warmup_steps=warmup_steps,
            total_steps=total_steps
        )
    
    def save_checkpoint(self, is_best=False):
        checkpoint = {
            'epoch': self.current_epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'best_val_loss': self.best_val_loss,
            'train_losses': self.train_losses,
            'val_losses': self.val_losses,
            'learning_rates': self.learning_rates,
            'config': self.config,
        }
        
        # Save latest checkpoint
        latest_path = os.path.join(self.checkpoint_dir, "latest_checkpoint.pt")
        torch.save(checkpoint, latest_path)
        
        # Save best model if needed
        if is_best:
            best_path = os.path.join(self.checkpoint_dir, "best_model.pt")
            torch.save(checkpoint, best_path)
            print(f"Saved best model with validation loss: {self.best_val_loss:.4f}")
    
    def load_checkpoint(self, checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.current_epoch = checkpoint['epoch']
        self.best_val_loss = checkpoint['best_val_loss']
        self.train_losses = checkpoint['train_losses']
        self.val_losses = checkpoint['val_losses']
        self.learning_rates = checkpoint['learning_rates']
        print(f"Loaded checkpoint from epoch {self.current_epoch}")
    
    def save_metrics(self):
        metrics = {
            'train_losses': self.train_losses,
            'val_losses': self.val_losses,
            'learning_rates': self.learning_rates
        }
        with open(os.path.join(self.run_dir, 'metrics.json'), 'w') as f:
            json.dump(metrics, f)
            
    def plot_metrics(self):
        plt.figure(figsize=(12, 8))
        
        # Plot losses
        plt.subplot(2, 1, 1)
        plt.plot(self.train_losses, label='Train Loss')
        plt.plot(self.val_losses, label='Validation Loss')
        plt.title('Training and Validation Losses')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        
        # Plot learning rate
        plt.subplot(2, 1, 2)
        plt.plot(self.learning_rates, label='Learning Rate')
        plt.title('Learning Rate Schedule')
        plt.xlabel('Step')
        plt.ylabel('Learning Rate')
        plt.legend()
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.run_dir, 'training_metrics.png'))
        plt.close()
        
    def train_epoch(self):
        self.model.train()
        total_loss = 0
        
        for batch_idx, batch in enumerate(tqdm(self.train_loader, desc="Training")):
            audio = batch[0].to(self.device)
            
            # Print shapes for debugging (only first batch)
            if batch_idx == 0:
                print(f"\nInput audio shape: {audio.shape}")
            
            try:
                # Forward pass
                c, q, mask_indices = self.model(audio)
                
                # Print shapes for debugging (only first batch)
                if batch_idx == 0:
                    print(f"Context output shape: {c.shape}")
                    print(f"Quantized output shape: {q.shape}")
                    print(f"Mask indices shape: {mask_indices.shape}")
                    print(f"Number of masked positions: {mask_indices.sum().item()}")
                
                # Compute loss
                loss = self.model.compute_loss(c, q, mask_indices)
                
                if batch_idx == 0:
                    print(f"Initial loss: {loss.item():.4f}")
                
                # Add L2 regularization for Librispeech
                if self.l2_regularization:
                    l2_loss = 0.0
                    for name, param in self.model.feature_encoder.named_parameters():
                        if 'weight' in name:
                            l2_loss += torch.norm(param)
                    loss += 0.01 * l2_loss
                    
                    if batch_idx == 0:
                        print(f"L2 loss: {l2_loss.item():.4f}")
                        print(f"Total loss: {loss.item():.4f}")
                
                # Backward pass
                self.optimizer.zero_grad()
                loss.backward()
                
                # Scale gradients for feature encoder if using Librispeech
                if self.is_librispeech:
                    for param in self.model.feature_encoder.parameters():
                        param.grad *= self.encoder_grad_scale
                
                # Clip gradients
                grad_norm = clip_grad_norm_(self.model.parameters(), max_norm=5.0)
                if batch_idx == 0:
                    print(f"Gradient norm: {grad_norm:.4f}")
                
                self.optimizer.step()
                self.scheduler.step()
                
                # Track learning rate
                self.learning_rates.append(self.scheduler.get_last_lr()[0])
                
                total_loss += loss.item()
                
                # Print batch statistics (only first few batches)
                if batch_idx < 5:
                    print(f"Batch {batch_idx}, Loss: {loss.item():.4f}")
                
            except RuntimeError as e:
                print(f"\nError in batch {batch_idx}:")
                print(f"Input shape: {audio.shape}")
                raise e
            
        return total_loss / len(self.train_loader)
    
    def validate(self):
        self.model.eval()
        total_loss = 0
        
        with torch.no_grad():
            for batch_idx, batch in enumerate(tqdm(self.val_loader, desc="Validation")):
                audio = batch[0].to(self.device)
                
                try:
                    # Forward pass with masking enabled (same as training)
                    c, q, mask_indices = self.model(audio, mask=True)
                    
                    # Compute loss
                    loss = self.model.compute_loss(c, q, mask_indices)
                    
                    # Add L2 regularization if using Librispeech (same as training)
                    if self.l2_regularization:
                        l2_loss = 0.0
                        for name, param in self.model.feature_encoder.named_parameters():
                            if 'weight' in name:
                                l2_loss += torch.norm(param)
                        loss += 0.01 * l2_loss
                    
                    total_loss += loss.item()
                    
                    # Print validation statistics (only first batch)
                    if batch_idx == 0:
                        print(f"\nValidation batch statistics:")
                        print(f"Loss: {loss.item():.4f}")
                        print(f"Number of masked positions: {mask_indices.sum().item()}")
                        
                except RuntimeError as e:
                    print(f"\nError in validation batch {batch_idx}:")
                    print(f"Input shape: {audio.shape}")
                    raise e
                
        return total_loss / len(self.val_loader)
        
    def train(self, num_epochs):
        for epoch in range(self.current_epoch, num_epochs):
            print(f"\nEpoch {epoch + 1}/{num_epochs}")
            
            # Training
            train_loss = self.train_epoch()
            self.train_losses.append(train_loss)
            
            # Validation
            val_loss = self.validate()
            self.val_losses.append(val_loss)
            
            print(f"\nEpoch {epoch + 1} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
            
            # Check for improvement
            is_best = val_loss < self.best_val_loss
            if is_best:
                self.best_val_loss = val_loss
                self.patience_counter = 0
            else:
                self.patience_counter += 1
            
            # Save checkpoint
            self.save_checkpoint(is_best)
            
            # Save and plot metrics
            self.save_metrics()
            self.plot_metrics()
            
            # Early stopping
            if self.patience_counter >= self.patience:
                print(f"\nEarly stopping triggered after {epoch + 1} epochs")
                break
            
            self.current_epoch = epoch + 1


# Wav2Vec 2.0 Training

This notebook provides an interface to train the Wav2Vec 2.0 model on LibriSpeech dataset.

In [5]:
import torch
import torch.nn as nn
import os
import matplotlib.pyplot as plt
%matplotlib inline

## 1. Setup Device and Configuration

In [6]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory Available: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

# Set batch size based on device
batch_size = 32 if torch.cuda.is_available() else 2

Using device: cuda
GPU: NVIDIA RTX A6000
Memory Available: 50.93 GB


## 2. Create Datasets

In [7]:
# Create training and validation datasets
train_dataset = LibriSpeech(split="train-clean-100", target_length=48000)
val_dataset = LibriSpeech(split="dev-clean", target_length=48000)

print(f"Training dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")

Training dataset size: 28539
Validation dataset size: 2703


## 3. Create and Configure Model

In [8]:
# Configure dataset size and model size for fast testing
FAST_DEV = False  # Set to False for full training

if FAST_DEV:
    # Use smaller model
    config = Wav2Vec2Config.TINY()

    # Use smaller subset of data
    train_subset_size = 1000  # Adjust this number as needed
    val_subset_size = 100

    # Create subset indices
    train_indices = torch.randperm(len(train_dataset))[:train_subset_size]
    val_indices = torch.randperm(len(val_dataset))[:val_subset_size]

    # Create subset datasets
    from torch.utils.data import Subset
    train_dataset = Subset(train_dataset, train_indices)
    val_dataset = Subset(val_dataset, val_indices)

    # Use smaller batch size
    batch_size = 4

    print(f"Fast dev mode enabled:")
    print(f"Training on {len(train_dataset)} examples")
    print(f"Validating on {len(val_dataset)} examples")
else:
    # Use original BASE configuration
    config = Wav2Vec2Config.BASE()

# Reduce model size if using CPU
if device.type == 'cpu':
    config.d_model = 256
    config.dim_feedforward = 1024
    config.num_encoder_layers = 4

# Create model and move to device
model = Wav2Vec2(config)
model = model.to(device)

# Enable multi-GPU if available
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    model = nn.DataParallel(model)

print(f"Model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")

Model parameters: 95.87M


## 4. Setup Training

In [9]:
# Create checkpoint directory
checkpoint_dir = "checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

# Create trainer
trainer = Trainer(
    model=model,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    config=config,
    device=device,
    is_librispeech=True,
    patience=10,
    log_dir="wav2vec_runs",
    batch_size=batch_size,
    checkpoint_dir=checkpoint_dir
)

## 5. Load Checkpoint (if exists)

In [10]:
# Load checkpoint if it exists
checkpoint_path = os.path.join(checkpoint_dir, "latest_checkpoint.pt")
if os.path.exists(checkpoint_path):
    print(f"Loading checkpoint from {checkpoint_path}")
    trainer.load_checkpoint(checkpoint_path)

Loading checkpoint from checkpoints/latest_checkpoint.pt


  checkpoint = torch.load(checkpoint_path, map_location=self.device)


Loaded checkpoint from epoch 6


## 6. Start Training

The training will automatically:
- Save checkpoints
- Plot training metrics
- Implement early stopping
- Handle GPU memory efficiently

In [11]:
# Start training
trainer.train(num_epochs=25)  # Will stop early if no improvement


Epoch 7/25


Training:   0%|          | 0/892 [00:00<?, ?it/s]


Input audio shape: torch.Size([32, 48000])
Raw input: torch.Size([32, 48000])
Context output shape: torch.Size([32, 149, 768])
Quantized output shape: torch.Size([32, 149, 768])
Mask indices shape: torch.Size([32, 149])
Number of masked positions: 2240
Initial loss: 4.0382
L2 loss: 11.0398
Total loss: 4.1486


Training:   0%|          | 1/892 [00:05<1:24:40,  5.70s/it]

Gradient norm: 0.3450
Batch 0, Loss: 4.1486
Raw input: torch.Size([32, 48000])


Training:   0%|          | 2/892 [00:06<43:09,  2.91s/it]  

Batch 1, Loss: 4.1488
Raw input: torch.Size([32, 48000])


Training:   0%|          | 3/892 [00:07<28:08,  1.90s/it]

Batch 2, Loss: 4.1485
Raw input: torch.Size([32, 48000])


Training:   0%|          | 4/892 [00:08<23:16,  1.57s/it]

Batch 3, Loss: 4.1489
Raw input: torch.Size([32, 48000])


Training:   1%|          | 5/892 [00:09<19:01,  1.29s/it]

Batch 4, Loss: 4.1485
Raw input: torch.Size([32, 48000])


Training:   1%|          | 6/892 [00:10<24:38,  1.67s/it]


Raw input: torch.Size([32, 48000])


KeyboardInterrupt: 

## 7. Visualize Training Results

In [None]:
# Plot final training metrics
plt.figure(figsize=(12, 8))

# Plot losses
plt.subplot(2, 1, 1)
plt.plot(trainer.train_losses, label='Train Loss')
plt.plot(trainer.val_losses, label='Validation Loss')
plt.title('Training and Validation Losses')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# Plot learning rate
plt.subplot(2, 1, 2)
plt.plot(trainer.learning_rates, label='Learning Rate')
plt.title('Learning Rate Schedule')
plt.xlabel('Step')
plt.ylabel('Learning Rate')
plt.legend()

plt.tight_layout()
plt.show()