In [None]:
import torch._inductor.config as config
import torch
import torchaudio
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import os
import json
import warnings

warnings.filterwarnings("ignore", category=UserWarning, module="torchaudio")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

torch.manual_seed(1337)
if torch.cuda.is_available():
    torch.cuda.manual_seed(1337)

if torch.cuda.is_available():
    major, minor = torch.cuda.get_device_capability()
    print(f"Compute Capability: {major}.{minor}")
    if major >= 8:
        print("✅ TF32 is supported (Ampere or newer).")
    else:
        print("❌ TF32 is not supported.")

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['TORCH_USE_CUDA_DSA'] = '1'  # For better debugging

config.triton.cudagraphs = False
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True

scaler = torch.amp.GradScaler('cuda', enabled=True)


In [None]:
class SegmentedTokensOnDisk(torch.utils.data.Dataset):
    """
    A PyTorch Dataset for loading audio tokens stored in sharded files on disk.
    
    This dataset efficiently loads pre-tokenized audio data from multiple shard files,
    enabling training on large audio datasets that don't fit in memory. Each sample
    consists of input tokens (x) and target tokens (y) for autoregressive training.
    
    Args:
        manifest_path (str): Path to JSON manifest file containing shard information
        root (str, optional): Root directory for shard files. Defaults to manifest directory
        cache_shards (bool): Whether to cache loaded shards in memory. Defaults to True
        
    Manifest JSON format:
        {
            "shards": ["shard_0.pt", "shard_1.pt", ...],
            "shard_sizes": [1000, 1000, ...],
            "num_samples": 2000,
            "seq_len": 32000,
            "stored_dtype": "long" or "uint8"
        }
    """
    def __init__(self, manifest_path, root=None, cache_shards=True):
        import json
        import torch
        from pathlib import Path
        mp = Path(manifest_path)
        man = json.loads(mp.read_text())
        self.root = mp.parent if root is None else Path(root)
        self.files = [self.root / s for s in man["shards"]]
        self.sizes = man["shard_sizes"]  # Number of samples per shard
        
        # Compute cumulative sizes for efficient sample location
        self.cum = []
        c = 0
        for s in self.sizes:
            c += s
            self.cum.append(c)
            
        self.N = man["num_samples"]  # Total number of samples
        self.T = man["seq_len"]      # Sequence length per sample
        self.cache = {}              # Cache for loaded shards
        self.stored_dtype = man.get("stored_dtype", "long")
        self.current_shard = -1      # Track currently loaded shard

    def __len__(self):
        """
        Returns the total number of samples in the dataset.
        
        Returns:
            int: Total number of samples across all shards
        """
        return self.N

    def _loc(self, idx):
        """
        Locate which shard contains the sample at given index and its offset within that shard.
        
        Args:
            idx (int): Global sample index
            
        Returns:
            tuple: (shard_index, offset_within_shard)
                - shard_index (int): Index of the shard containing the sample
                - offset_within_shard (int): Position of sample within the shard
        """
        import bisect
        s_idx = bisect.bisect_right(self.cum, idx)
        base = 0 if s_idx == 0 else self.cum[s_idx-1]
        off = idx - base
        return s_idx, off

    def get_shard(self, s_idx):
        """
        Load and cache a shard file containing tokenized audio data.
        
        Args:
            s_idx (int): Index of the shard to load
            
        Returns:
            dict: Loaded shard data containing:
                - 'x' or 'x_u8': Input token sequences, shape [shard_size, seq_len]
                - 'y' or 'y_u8': Target token sequences, shape [shard_size, seq_len]
        """
        # Clear cache if switching to different shard to save memory
        if self.current_shard != s_idx:
            self.cache = {}
            self.current_shard = -1

        self.current_shard = s_idx

        if s_idx in self.cache:
            shard = self.cache[s_idx]
        else:
            print(f"Loading shard {s_idx} from {self.files[s_idx]}")
            shard = torch.load(self.files[s_idx], map_location="cpu")
            self.cache[s_idx] = shard
        return shard

    def __getitem__(self, idx):
        """
        Get a sample (input, target) pair for autoregressive training.
        
        Args:
            idx (int): Sample index (0 to len(dataset)-1)
            
        Returns:
            tuple: (x, y) where:
                - x: Input tokens, shape [seq_len], dtype torch.long, values in [0, 255]
                - y: Target tokens, shape [seq_len], dtype torch.long, values in [0, 255]
                
        Note:
            Token values are clamped to [0, 255] range to ensure valid mu-law encoding.
        """
        s_idx, off = self._loc(idx)
        shard = self.get_shard(s_idx)
        
        if self.stored_dtype == "uint8":
            x = shard["x_u8"][off].to(torch.long)  # upcast once
            y = shard["y_u8"][off].to(torch.long)
        else:
            x = shard["x"][off]
            y = shard["y"][off]
            
        # Ensure tokens are in valid range for mu-law encoding
        if (x.gt(255).any() or y.gt(255).any()) or (x.lt(0).any() or y.lt(0).any()):
            x.clamp_(0, 255)
            y.clamp_(0, 255)
        return x, y


In [None]:
from IPython.display import Audio, display

dataset = SegmentedTokensOnDisk(
    "segmented_tokens/manifest.json", cache_shards=False)

# Print dataset info.
print(f"Number of samples in dataset: {len(dataset)}")


In [None]:
class MuLawEncoding:
    """
    Mu-law encoding and decoding for audio signal compression and quantization.
    
    Mu-law encoding compresses the dynamic range of audio signals by applying
    logarithmic quantization, which provides better perceptual quality for speech
    and audio. This is commonly used in telephony and audio codecs.
    
    Args:
        quantization_channels (int): Number of discrete quantization levels.
                                   Default is 256 (8-bit quantization).
    
    Attributes:
        Q (int): Number of quantization channels
        enc (torchaudio.transforms.MuLawEncoding): Encoder transform
        dec (torchaudio.transforms.MuLawDecoding): Decoder transform
    """
    def __init__(self, quantization_channels: int = 256):
        self.Q = quantization_channels
        self.enc = torchaudio.transforms.MuLawEncoding(self.Q)
        self.dec = torchaudio.transforms.MuLawDecoding(self.Q)

    @torch.no_grad()
    def mu_law_encode(self, x: torch.Tensor) -> torch.LongTensor:
        """
        Encode continuous audio waveform to discrete mu-law tokens.
        
        Converts continuous audio samples in [-1, 1] range to discrete integers
        in [0, Q-1] range using mu-law companding.
        
        Args:
            x (torch.Tensor): Input audio waveform, shape [..., T], 
                            dtype float, values in [-1, 1]
        
        Returns:
            torch.LongTensor: Encoded tokens, shape [..., T], 
                            dtype long, values in [0, Q-1]
        
        Example:
            >>> codec = MuLawEncoding(256)
            >>> audio = torch.randn(1, 16000)  # 1 second at 16kHz
            >>> tokens = codec.mu_law_encode(audio)  # shape: [1, 16000]
        """
        return self.enc(x)

    @torch.no_grad()
    def mu_law_decode(self, q: torch.Tensor) -> torch.Tensor:
        """
        Decode discrete mu-law tokens back to continuous audio waveform.
        
        Converts discrete tokens in [0, Q-1] range back to continuous audio
        samples in [-1, 1] range using inverse mu-law expansion.
        
        Args:
            q (torch.Tensor): Encoded tokens, shape [..., T], 
                            dtype long/int, values in [0, Q-1]
        
        Returns:
            torch.Tensor: Decoded audio waveform, shape [..., T], 
                        dtype float, values in [-1, 1]
        
        Example:
            >>> codec = MuLawEncoding(256)
            >>> tokens = torch.randint(0, 256, (1, 16000))
            >>> audio = codec.mu_law_decode(tokens)  # shape: [1, 16000]
        """
        return self.dec(q)


# Save codec once and use everywhere.
codec = MuLawEncoding(256)


In [None]:
config = {
    "batch_size": 4,
    "num_workers": 2,  # CHANGED: Use 0 for debugging, multiprocessing can cause issues
    "pin_memory": True,
    "mu": 256,
    "sr": 16000,
    "trim_silence_thresh": 1e-3,
    "window_size": 32001,  # ~2 seconds at 16kHz

    # Wavenet architecture related.
    "residual_channels": 64,
    "skip_channels": 256,
    "output_dim": 256,
    "n_layers": 10,
    "n_blocks": 5,
    "kernel_size": 2,
    'hop_size': 16000,
}


In [None]:
def display_waveform(audio, title=""):
    plt.figure(figsize=(14, 4))
    plt.plot(audio.t().numpy())
    plt.title(f"Audio: {title}")
    plt.tight_layout()
    plt.show()


def play_encoded_sample(tokens: torch.Tensor, sr: int = config['sr']) -> None:
    # tokens: [T] or [1,T] long
    if tokens.dim() == 1:
        tokens = tokens.unsqueeze(0)
    wav = codec.mu_law_decode(tokens).squeeze(0).cpu().numpy()  # in [-1, 1]
    display(Audio(wav, rate=config['sr']))


def play_batch(batch):
    audio, _ = batch
    audio = audio[0]  # Get first item from batch
    play_encoded_sample(audio)


def plot_encoded_sample(tokens: torch.Tensor, sr: int, title=""):
    wav = codec.mu_law_decode(tokens if tokens.dim(
    ) == 2 else tokens.unsqueeze(0)).squeeze(0).cpu()
    # your plotting util expects [1,T]
    display_waveform(wav.unsqueeze(0), title)

In [None]:
class AudioProcessor:
    """
    Audio preprocessing utilities for WaveNet training and inference.
    
    Handles audio resampling, normalization, silence trimming, segmentation,
    and preparation of batched training data with mu-law encoding.
    
    Attributes:
        mu_law_encoding (MuLawEncoding): Codec for mu-law encoding/decoding
        resamplers (dict): Cache of resampler transforms for different sample rates
    """
    def __init__(self):
        self.mu_law_encoding = codec
        self.resamplers = {}  # Cache resamplers for efficiency

    def normalize(self, x):
        """
        Normalize audio to [-1, 1] range using peak and RMS normalization.
        
        Applies peak normalization to ensure maximum amplitude is 1.0, followed by
        RMS normalization to standardize the energy level across different audio files.
        
        Args:
            x (torch.Tensor): Input audio waveform, shape [channels, time] or [time],
                            any dtype (will be converted to float32)
        
        Returns:
            torch.Tensor: Normalized audio, shape same as input, dtype float32,
                        values in approximately [-1, 1] range
        
        Example:
            >>> processor = AudioProcessor()
            >>> audio = torch.randn(1, 16000) * 10  # Loud audio
            >>> normalized = processor.normalize(audio)  # Peak around ±1.0
        """
        # Ensure audio is float32 and normalize to [-1, 1]
        if x.dtype != torch.float32:
            x = x.float()

        # Peak normalization: scale to maximum absolute value of 1.0
        max_val = torch.max(torch.abs(x))
        if max_val > 0:  # Avoid division by zero
            x = x / max_val

        # RMS normalization: standardize energy level
        target_rms = 0.1

        def rms(x):
            return torch.sqrt(torch.mean(x**2) + 1e-8)
        current_rms = rms(x)
        if current_rms > 0:  # Avoid division by zero
            x = x * (target_rms / current_rms)

        return x

    def resample_audio(self, audio, orig_sr, target_sr):
        """
        Resample audio to target sample rate with caching for efficiency.
        
        Uses cached resampler transforms to avoid recreating them for the same
        sample rate pairs, which significantly improves performance when processing
        many audio files.
        
        Args:
            audio (torch.Tensor): Input audio waveform, shape [channels, time]
            orig_sr (int): Original sample rate in Hz
            target_sr (int): Target sample rate in Hz
        
        Returns:
            torch.Tensor: Resampled audio, shape [channels, new_time] where
                        new_time = time * (target_sr / orig_sr)
        
        Example:
            >>> processor = AudioProcessor()
            >>> audio_44k = torch.randn(1, 44100)  # 1 second at 44.1kHz
            >>> audio_16k = processor.resample_audio(audio_44k, 44100, 16000)
            >>> print(audio_16k.shape)  # torch.Size([1, 16000])
        """
        if orig_sr == target_sr:
            return audio

        # Use cached resampler for efficiency
        resampler_key = f"{orig_sr}_{target_sr}"
        if resampler_key not in self.resamplers:
            self.resamplers[resampler_key] = torchaudio.transforms.Resample(
                orig_freq=orig_sr,
                new_freq=target_sr
            )

        return self.resamplers[resampler_key](audio)

    def trim_silence(self, sig, thresh=config['trim_silence_thresh']):
        """
        Remove leading and trailing silence from audio signal.
        
        Detects silence by computing energy (absolute value) and trimming 
        samples below the threshold from the beginning and end of the signal.
        
        Args:
            sig (torch.Tensor): Input audio signal, shape [1, time]
            thresh (float): Energy threshold below which samples are considered silence.
                          Default from config['trim_silence_thresh']
        
        Returns:
            torch.Tensor: Trimmed audio signal, shape [1, trimmed_time] where
                        trimmed_time <= time
        
        Example:
            >>> processor = AudioProcessor()
            >>> # Audio with silence padding
            >>> audio = torch.cat([torch.zeros(1, 1000), torch.randn(1, 8000), 
            ...                    torch.zeros(1, 1000)], dim=1)
            >>> trimmed = processor.trim_silence(audio, thresh=0.01)
            >>> print(f"Original: {audio.shape[1]}, Trimmed: {trimmed.shape[1]}")
        """
        # Calculate energy (absolute value)
        energy = sig.abs().squeeze()
        # Find indices where energy is above threshold
        idx = torch.where(energy > thresh)[0]
        if len(idx) == 0:
            return sig  # Return original if no samples above threshold
        # Return trimmed signal
        return sig[:, idx[0].item():idx[-1].item() + 1]

    def segment_audio(self, audio, drop_last=True, hop_size=None):
        """
        Split audio into overlapping or non-overlapping segments for training.
        
        Divides long audio sequences into fixed-size windows that can be processed
        independently. Supports overlapping windows to increase training data.
        
        Args:
            audio (torch.Tensor): Input audio, shape [1, time]
            drop_last (bool): If True, drop incomplete segments at the end.
                            If False, pad the last segment. Default True.
            hop_size (int, optional): Step size between segments. If None,
                                    uses window_size (non-overlapping).
                                    If < window_size, creates overlapping segments.
        
        Returns:
            list[torch.Tensor]: List of audio segments, each with shape [1, window_size]
                              where window_size = config['window_size']
        
        Example:
            >>> processor = AudioProcessor()
            >>> audio = torch.randn(1, 100000)  # Long audio
            >>> segments = processor.segment_audio(audio, hop_size=16000)
            >>> print(f"Created {len(segments)} segments")
            >>> print(f"Each segment shape: {segments[0].shape}")  # [1, 32001]
        """
        window_size = config['window_size']
        hop = hop_size or window_size
        T = audio.shape[1]

        segments = []
        # Extract all full windows
        for start in range(0, T - window_size + 1, hop):
            segments.append(audio[:, start:start + window_size])

        # Handle incomplete tail segment
        if not drop_last and (T < window_size or (T - window_size) % hop != 0):
            last_start = max(0, T - window_size)
            tail = audio[:, last_start:]
            if tail.shape[1] < window_size:
                tail = F.pad(tail, (0, window_size - tail.shape[1]))
            segments.append(tail)

        return segments



In [None]:
from torch.utils.data import Subset

# Split dataset into train and test sets (90-10 split)
train_size = int(0.9 * len(dataset))
test_size = len(dataset) - train_size
train_indices = list(range(0, train_size))
test_indices = list(range(train_size, len(dataset)))
train_dataset = Subset(dataset, train_indices)
test_dataset = Subset(dataset, test_indices)

print(f"Training set size: {len(train_dataset)}")
print(f"Test set size: {len(test_dataset)}")

fake_train_dataset, fake_test_dataset = Subset(
    dataset, [0]), Subset(dataset, [0])

print(f"Fake training set size: {len(fake_train_dataset)}")
print(f"Fake test set size: {len(fake_test_dataset)}")


In [None]:
# Audio processor.
audio_processor = AudioProcessor()

# Create dataloaders
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=config['batch_size'],
    num_workers=0,
    persistent_workers=False,
)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=config['batch_size'],
    num_workers=0,
    persistent_workers=False,
)
print(f"Number of training batches: {len(train_loader)}")
print(f"Number of test batches: {len(test_loader)}")

# Fake dataloaders to test model training/overfitting on just 1 sample.
fake_train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=1,
    num_workers=0,
    persistent_workers=False,
)
fake_test_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=1,
    num_workers=0,
    persistent_workers=False,
)


In [None]:
for batch in train_loader:
    print(batch[0].shape)
    play_encoded_sample(batch[0])
    break


In [None]:
def sanity(x, y):
    print("x dtype:", x.dtype, "min:", int(x.min()), "max:", int(x.max()))
    print("y dtype:", y.dtype, "min:", int(y.min()), "max:", int(y.max()))
    assert x.dtype == torch.long and y.dtype == torch.long
    assert (x >= 0).all() and (x <= 255).all()
    assert (y >= 0).all() and (y <= 255).all()


for i, (x, y) in enumerate(train_loader):
    sanity(x, y)
    if i > 100:
        break


In [None]:
class CausalDilatedConvolution(nn.Module):
    """
    Causal dilated 1D convolution for autoregressive sequence modeling.
    
    Implements dilated convolution with causal padding to ensure that the output
    at time t only depends on inputs at times <= t. This preserves the autoregressive
    property needed for WaveNet's generative modeling.
    
    Args:
        in_channels (int): Number of input channels
        out_channels (int): Number of output channels
        kernel_size (int): Size of the convolving kernel. Default: 2
        dilation (int): Dilation factor for dilated convolution. Default: 1
        
    Attributes:
        kernel_size (int): Stored kernel size
        dilation (int): Stored dilation factor
        conv1d (nn.Conv1d): The underlying convolution layer
    """
    def __init__(self, in_channels, out_channels, kernel_size=2, dilation=1):
        super().__init__()
        self.kernel_size = kernel_size
        self.dilation = dilation
        self.conv1d = nn.Conv1d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            dilation=dilation,
            padding=0,  # No automatic padding - we handle causality manually
            bias=True)

    def forward(self, x):
        """
        Apply causal dilated convolution to input sequence.
        
        Args:
            x (torch.Tensor): Input tensor, shape [batch_size, in_channels, time]
            
        Returns:
            torch.Tensor: Output tensor, shape [batch_size, out_channels, time]
                        Same temporal length as input due to causal padding
                        
        Note:
            Causal padding of size (kernel_size - 1) * dilation is applied to the left
            (past) side only, ensuring future information doesn't leak into past predictions.
        """
        if self.kernel_size > 1:
            # Apply causal padding: pad left side only to prevent future leakage
            pad_left = (self.kernel_size - 1) * self.dilation
            x = F.pad(x, (pad_left, 0), mode='constant', value=0)
        return self.conv1d(x)


In [None]:
class ResidualBlock(nn.Module):
    """
    WaveNet residual block with gated activation and skip connections.
    
    Core building block of WaveNet architecture. Applies dilated causal convolution
    followed by gated activation (tanh * sigmoid), then projects to residual and 
    skip connection outputs. The residual path enables deep networks while skip
    connections aggregate features across all layers.
    
    Args:
        C_res (int): Number of residual channels (internal feature dimension)
        C_skip (int): Number of skip connection channels
        dilation (int): Dilation factor for the causal convolution. Default: 1
        
    Attributes:
        dilated_conv (CausalDilatedConvolution): Causal dilated convolution layer
        skip_conv1x1 (nn.Conv1d): 1x1 conv for skip connection projection  
        res_conv1x1 (nn.Conv1d): 1x1 conv for residual connection projection
    """
    def __init__(self, C_res, C_skip, dilation=1):
        super().__init__()
        # Dilated conv outputs 2*C_res channels for gated activation
        self.dilated_conv = CausalDilatedConvolution(C_res, 2 * C_res, kernel_size=2, dilation=dilation)
        self.skip_conv1x1 = nn.Conv1d(C_res, C_skip, 1)
        self.res_conv1x1 = nn.Conv1d(C_res, C_res, 1)
        
    def forward(self, x):
        """
        Forward pass through residual block.
        
        Args:
            x (torch.Tensor): Input tensor, shape [batch_size, C_res, time]
            
        Returns:
            tuple: (residual_out, skip_out) where:
                - residual_out: torch.Tensor, shape [batch_size, C_res, time]
                              Residual output for next layer (x + processed_x)
                - skip_out: torch.Tensor, shape [batch_size, C_skip, time]
                          Skip connection output for final aggregation
                          
        Note:
            The gated activation is computed as: tanh(filter) * sigmoid(gate)
            where filter and gate are the two halves of the dilated conv output.
            This allows the network to learn what information to pass through.
        """
        # Apply dilated causal convolution
        output = self.dilated_conv(x)  # [B, 2*C_res, T]
        
        # Split into filter and gate components for gated activation
        filter_out, gate_out = torch.chunk(output, 2, dim=1)  # Each: [B, C_res, T]

        # Apply gated activation unit: tanh(filter) ⊙ sigmoid(gate)
        gated = torch.tanh(filter_out) * torch.sigmoid(gate_out)  # [B, C_res, T]

        # Project to residual and skip connection outputs via 1x1 convolutions
        residual = self.res_conv1x1(gated)  # [B, C_res, T]
        skip = self.skip_conv1x1(gated)     # [B, C_skip, T]

        # Return residual connection (input + processed) and skip output
        return residual + x, skip
   

In [None]:
class Wavenet(nn.Module):
    """
    WaveNet: A Generative Model for Raw Audio.
    
    Implements the WaveNet architecture for autoregressive audio generation.
    The model uses stacked dilated causal convolutions with residual and skip
    connections to efficiently model long-range dependencies in audio sequences.
    
    Architecture:
    - Input embedding layer: maps discrete tokens to continuous features
    - Stacked residual blocks: each block contains multiple layers with increasing dilations
    - Skip connections: aggregate features from all layers
    - Output head: projects skip features to output logits
    
    Args:
        config (dict): Configuration dictionary containing:
            - residual_channels (int): Number of channels in residual paths
            - skip_channels (int): Number of channels in skip connections  
            - output_dim (int): Vocabulary size (typically 256 for mu-law)
            - n_layers (int): Number of layers per residual block
            - n_blocks (int): Number of residual blocks
            - kernel_size (int): Convolution kernel size (typically 2)
            
    Attributes:
        config (dict): Stored configuration
        C_res (int): Residual channel dimension
        C_output (int): Output vocabulary size
        C_skip (int): Skip connection channel dimension
        kernel_size (int): Convolution kernel size
        _rf (int): Computed receptive field size
        embedding (nn.Embedding): Token embedding layer
        conv1d (nn.Conv1d): Initial feature processing layer
        residual_blocks (nn.ModuleList): Stack of residual blocks
        output_head (nn.Sequential): Final output projection layers
    """
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.C_res = config['residual_channels']
        self.C_output = config['output_dim']
        self.C_skip = config['skip_channels']
        self.kernel_size = config['kernel_size']
        self._rf = self._compute_receptive_field()

        # Token embedding: converts discrete tokens [0, 255] to continuous features
        self.embedding = nn.Embedding(
            self.C_output, self.C_res, dtype=torch.float32)

        # Initial 1x1 convolution for feature processing and stability
        self.conv1d = nn.Conv1d(self.C_res, self.C_res, kernel_size=1)

        # Stack of residual blocks with exponentially increasing dilations
        # Each block contains n_layers with dilations: 1, 2, 4, 8, ..., 2^(n_layers-1)
        self.residual_blocks = nn.ModuleList([
            self._create_residual_block() for i in range(config['n_blocks'])
        ])
        
        # Output head: processes aggregated skip connections to produce logits
        self.output_head = nn.Sequential(
            nn.ReLU(),
            nn.Conv1d(self.C_skip, self.C_skip, kernel_size=1),
            nn.ReLU(),
            nn.Conv1d(self.C_skip, self.C_output, kernel_size=1))

    def _create_residual_block(self):
        """
        Create a single residual block with exponentially increasing dilations.
        
        Each block contains n_layers ResidualBlock modules with dilation factors
        of 1, 2, 4, 8, ..., 2^(n_layers-1). This exponential growth allows the
        model to efficiently capture both short and long-range dependencies.
        
        Returns:
            nn.ModuleList: List of ResidualBlock modules with increasing dilations
            
        Example:
            For n_layers=5: dilations = [1, 2, 4, 8, 16]
            For n_layers=10: dilations = [1, 2, 4, ..., 512]
        """
        return nn.ModuleList([
            ResidualBlock(self.C_res, self.C_skip, dilation=2**i)
            for i in range(self.config['n_layers'])
        ])

    def forward(self, x):
        """
        Forward pass through WaveNet model.
        
        Args:
            x (torch.Tensor): Input token sequence, shape [batch_size, time],
                            dtype long, values in [0, output_dim-1]
        
        Returns:
            torch.Tensor: Output logits, shape [batch_size, time, output_dim],
                        dtype float. These can be used with cross-entropy loss
                        or softmax for sampling.
                        
        Processing flow:
            1. Embed tokens to continuous features: [B, T] -> [B, T, C_res]
            2. Transpose for convolution: [B, T, C_res] -> [B, C_res, T]
            3. Apply initial 1x1 convolution
            4. Process through residual blocks, accumulating skip connections
            5. Apply output head to aggregated skip connections
            6. Transpose back to sequence format: [B, C_output, T] -> [B, T, C_output]
        """
        # Step 1: Embed discrete tokens to continuous features
        x_embd = self.embedding(x)  # [B, T] -> [B, T, C_res]
        
        # Step 2: Transpose for convolution operations (PyTorch conv1d expects [B, C, T])
        x_embd = x_embd.permute(0, 2, 1)  # [B, T, C_res] -> [B, C_res, T]
        skip_output = None

        # Step 3: Apply initial feature processing
        x_embd = self.conv1d(x_embd)  # [B, C_res, T]

        # Step 4: Process through all residual blocks
        for residual_block in self.residual_blocks:
            for layer in residual_block:
                x_embd, x_skip = layer(x_embd)
                # Accumulate skip connections from all layers
                skip_output = x_skip if skip_output is None else skip_output + x_skip

        # Step 5: Generate output logits from aggregated skip connections
        output = self.output_head(skip_output)  # [B, C_skip, T] -> [B, C_output, T]

        # Step 6: Transpose back to sequence format for loss computation
        return output.permute(0, 2, 1)  # [B, C_output, T] -> [B, T, C_output]

    def _compute_receptive_field(self):
        """
        Compute the theoretical receptive field of the WaveNet model.
        
        The receptive field determines how many past time steps the model
        can observe when predicting the next token. For WaveNet with
        exponentially increasing dilations, this grows exponentially with depth.
        
        Returns:
            int: Receptive field size in time steps
            
        Formula:
            For kernel_size=2: RF = 1 + n_blocks * (2^n_layers - 1)
            General case: RF = 1 + n_blocks * sum_{i=0}^{n_layers-1} (k-1)*2^i
            
        Example:
            n_blocks=5, n_layers=10, kernel_size=2:
            RF = 1 + 5 * (2^10 - 1) = 1 + 5 * 1023 = 5116 time steps
        """
        if self.kernel_size != 2:
            # General formula for arbitrary kernel size
            dilation_sum = (self.kernel_size - 1) * (2**self.config['n_layers'] - 1)
            return 1 + self.config['n_blocks'] * dilation_sum
        # Optimized formula for kernel_size=2
        return 1 + self.config['n_blocks'] * (2**self.config['n_layers'] - 1)

    def get_receptive_field(self):
        """
        Get the receptive field size of the model.
        
        Returns:
            int: Number of past time steps the model can observe
        """
        return self._rf


In [None]:
# Figure out dataloaders, etc. based on training requirements.
# Fake dataset setup to test if training is working and is able to overfit on 1 sample.
USE_FAKE = False
if USE_FAKE:
    config['batch_size'] = 1
    config['num_workers'] = 0

checkpoint_dir = './wavenet_checkpoints_fake' if USE_FAKE else './wavenet_checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

train_loader = fake_train_loader if USE_FAKE else train_loader
test_loader = fake_test_loader if USE_FAKE else test_loader


In [None]:
class Model:
    """
    Base class for WaveNet model wrappers with checkpoint management and compilation.
    
    Provides common functionality for model checkpoint handling, path management,
    and PyTorch compilation for both training and evaluation modes.
    
    Args:
        config (dict): Model configuration dictionary
        base_model (Wavenet): The underlying WaveNet model
        checkpoint_dir (str): Directory for saving/loading checkpoints
        
    Attributes:
        config (dict): Model configuration
        base_model (Wavenet): The underlying WaveNet model
        model_train (torch.fx.GraphModule): Compiled model for training (None initially)
        model_eval (torch.fx.GraphModule): Compiled model for evaluation (None initially)
        checkpoint_dir (str): Directory for checkpoints
    """
    def __init__(self, config, base_model, checkpoint_dir):
        self.config = config
        self.base_model = base_model
        self.model_train = None
        self.model_eval = None
        self.checkpoint_dir = checkpoint_dir

    def _get_checkpoint_paths(self):
        """
        Get paths to latest and best model checkpoints if they exist.
        
        Returns:
            tuple: (latest_path, best_path) where each is either a valid
                   file path string or None if the checkpoint doesn't exist
        """
        latest_path = self._get_latest_checkpoint_path()
        best_path = self._get_best_checkpoint_path()
        latest_path = latest_path if os.path.exists(latest_path) else None
        best_path = best_path if os.path.exists(best_path) else None
        return latest_path, best_path

    def _get_best_checkpoint_path(self):
        """
        Get path to the best model checkpoint (lowest validation loss).
        
        Returns:
            str: Path to best model checkpoint file
        """
        return os.path.join(self.checkpoint_dir, 'best_model.pth')

    def _get_latest_checkpoint_path(self):
        """
        Get path to the latest model checkpoint (most recent training state).
        
        Returns:
            str: Path to latest checkpoint file
        """
        return os.path.join(self.checkpoint_dir, 'latest_checkpoint.pth')

    def _compile_for_training(self):
        """
        Compile the model for optimized training performance.
        
        Sets model to training mode and applies PyTorch compilation with
        "reduce-overhead" mode for faster training.
        
        Returns:
            torch.fx.GraphModule: Compiled model optimized for training
        """
        self.base_model.train()
        return torch.compile(self.base_model, mode="reduce-overhead")

    def _compile_for_eval(self):
        """
        Compile the model for optimized evaluation/inference performance.
        
        Sets model to evaluation mode and applies PyTorch compilation with
        "reduce-overhead" mode for faster inference.
        
        Returns:
            torch.fx.GraphModule: Compiled model optimized for evaluation
        """
        self.base_model.eval()
        return torch.compile(self.base_model, mode="reduce-overhead")

    def get_receptive_field(self):
        """
        Get the receptive field size of the underlying model.
        
        Returns:
            int: Number of past time steps the model can observe
        """
        return self.base_model.get_receptive_field()


class TrainableModel(Model):
    """
    Trainable WaveNet model wrapper with training state management and checkpointing.
    
    Extends the base Model class to handle training-specific functionality including
    optimizer state, learning rate scheduling, training metrics tracking, and
    checkpoint loading/saving for resumable training.
    
    Args:
        config (dict): Model configuration dictionary
        checkpoint_dir (str): Directory for saving/loading checkpoints
        base_model (Wavenet): The underlying WaveNet model
        optimizer (torch.optim.Optimizer): Optimizer for training
        scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler
        load_from_checkpoint (bool): Whether to load from existing checkpoint. Default: False
        
    Attributes:
        base_model (Wavenet): The underlying WaveNet model
        optimizer (torch.optim.Optimizer): Training optimizer
        scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler
        training_stats (dict): Training metrics history containing:
            - train_losses (list): Training loss per epoch
            - val_losses (list): Validation loss per epoch  
            - train_accuracies (list): Training accuracy per epoch
            - val_accuracies (list): Validation accuracy per epoch
            - best_val_loss (float): Best validation loss achieved
        trained_till_epoch_index (int): Last completed epoch (-1 if untrained)
        compiled_model (torch.fx.GraphModule): Compiled model for training
    """
    def __init__(self, config, checkpoint_dir, base_model, optimizer, scheduler, load_from_checkpoint=False):
        super().__init__(config, base_model, checkpoint_dir)

        self.base_model = base_model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.training_stats = {
            'train_losses': [],
            'val_losses': [],
            'train_accuracies': [],
            'val_accuracies': [],
            'best_val_loss': float('inf'),
        }
        # Track training progress: -1 = untrained, 0+ = last completed epoch index
        self.trained_till_epoch_index = -1
        if load_from_checkpoint:
            self._load_from_checkpoint()

        # Compile model specifically for training optimization
        self.compiled_model = self._compile_for_training()

    def save(self, epoch, training_stats, learning_rate, is_best=False):
        """
        Save complete training checkpoint with model and optimizer state.
        
        Saves all necessary information to resume training including model weights,
        optimizer state, scheduler state, training metrics, and configuration.
        
        Args:
            epoch (int): Current epoch number (0-indexed)
            training_stats (dict): Dictionary containing training metrics:
                - train_losses, val_losses, train_accuracies, val_accuracies, best_val_loss
            learning_rate (float): Current learning rate
            is_best (bool): Whether this is the best model so far (lowest val loss)
            
        Saves:
            - Latest checkpoint: Always saved for training resumption
            - Best checkpoint: Only saved when is_best=True for evaluation
        """
        checkpoint = {
            'config': self.config,
            'model_state_dict': self.base_model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),

            'train_losses': training_stats['train_losses'],
            'val_losses': training_stats['val_losses'],
            'train_accuracies': training_stats['train_accuracies'],
            'val_accuracies': training_stats['val_accuracies'],
            'best_val_loss': training_stats['best_val_loss'],

            'trained_till_epoch_index': epoch,
            'learning_rate': learning_rate
        }

        # Always save latest checkpoint for resumable training
        torch.save(checkpoint, self._get_latest_checkpoint_path())

        # Save best model checkpoint for evaluation/inference
        if is_best:
            torch.save(checkpoint, self._get_best_checkpoint_path())
            print(f"💾 New best model saved! Val Loss: {
                  training_stats['best_val_loss']:.4f}")

    def _load_from_checkpoint(self):
        """
        Load training state from the latest checkpoint for resumable training.
        
        Restores model weights, optimizer state, scheduler state, training metrics,
        and epoch counter from the most recent checkpoint. If no checkpoint exists,
        training will start from scratch.
        
        Updates:
            - base_model: Loads saved model weights
            - optimizer: Restores optimizer state (momentum, learning rates, etc.)
            - scheduler: Restores scheduler state  
            - training_stats: Loads training history metrics
            - trained_till_epoch_index: Sets last completed epoch
        """
        latest_checkpoint_path, _ = self._get_checkpoint_paths()

        if not latest_checkpoint_path:
            print("🆕 No existing checkpoints found. Starting training from scratch.")
            return

        print(f"🔄 Loading training checkpoint from: {latest_checkpoint_path}")
        checkpoint = torch.load(latest_checkpoint_path, map_location=device)

        # Restore model weights
        self.base_model.load_state_dict(checkpoint['model_state_dict'])
        
        # Restore optimizer state (important for momentum, learning rates, etc.)
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        # Restore scheduler state if available
        if 'scheduler_state_dict' in checkpoint:
            self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            
        # Restore training progress
        self.trained_till_epoch_index = checkpoint.get('trained_till_epoch_index', -1)

        # Restore training metrics history
        self.training_stats = {
            'train_losses': checkpoint.get('train_losses', []),
            'val_losses': checkpoint.get('val_losses', []),
            'train_accuracies': checkpoint.get('train_accuracies', []),
            'val_accuracies': checkpoint.get('val_accuracies', []),
            'best_val_loss': checkpoint.get('best_val_loss', float('inf')),
        }
        print(f"✅ Training checkpoint loaded successfully. Trained for {
              self.trained_till_epoch_index + 1} epochs!")


class EvalModel(Model):
    """
    Evaluation-only WaveNet model wrapper for inference and generation.
    
    Loads the best saved model checkpoint (lowest validation loss) and compiles
    it for optimized evaluation performance. Used for inference, generation,
    and model evaluation after training.
    
    Args:
        config (dict): Model configuration dictionary
        checkpoint_dir (str): Directory containing saved checkpoints
        base_model (Wavenet): The underlying WaveNet model
        
    Attributes:
        base_model (Wavenet): The underlying WaveNet model
        compiled_model (torch.fx.GraphModule): Compiled model optimized for evaluation
    """
    def __init__(self, config, checkpoint_dir, base_model):
        super().__init__(config, base_model, checkpoint_dir)
        self.base_model = base_model
        self._load_from_checkpoint()

        # Compile model specifically for evaluation/inference optimization
        self.compiled_model = self._compile_for_eval()

    def _load_from_checkpoint(self):
        """
        Load the best model checkpoint for evaluation.
        
        Loads model weights from the best checkpoint (lowest validation loss)
        rather than the latest checkpoint. This ensures optimal performance
        for inference and generation tasks.
        
        Updates:
            - base_model: Loads best model weights for evaluation
        """
        _, best_checkpoint_path = self._get_checkpoint_paths()

        if not best_checkpoint_path:
            print("🆕 No existing best checkpoint found for evaluation")
            return

        print(f"🔄 Loading eval model checkpoint from: {best_checkpoint_path}")
        checkpoint = torch.load(best_checkpoint_path, map_location=device)
        
        # Load only model weights (no optimizer state needed for evaluation)
        self.base_model.load_state_dict(checkpoint['model_state_dict'])

        epoch = checkpoint.get('trained_till_epoch_index', 0)
        print(f"✅ Eval model checkpoint loaded successfully. Trained till epoch {
              epoch + 1}!")


In [None]:
class Trainer:
    """
    Comprehensive trainer for WaveNet autoregressive audio modeling.
    
    Handles the complete training pipeline including forward passes, loss computation,
    validation, checkpointing, progress tracking, and visualization. Supports
    mixed precision training, gradient clipping, and learning rate scheduling.
    
    Args:
        base_model (Wavenet): The underlying WaveNet model
        trainable_model (TrainableModel): Wrapper with training state management
        config (dict): Training configuration dictionary
        learning_rate (float): Initial learning rate
        optimizer (torch.optim.Optimizer): Optimizer for training
        scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler
        checkpoint_dir (str): Directory for saving checkpoints and plots
        train_loader (torch.utils.data.DataLoader): Training data loader
        val_loader (torch.utils.data.DataLoader): Validation data loader
        device (torch.device): Device for training (cuda/cpu)
        
    Attributes:
        device (torch.device): Training device
        base_model (Wavenet): Base WaveNet model
        trainable_model (TrainableModel): Training wrapper
        compiled_model (torch.fx.GraphModule): Compiled model for training
        learning_rate (float): Learning rate
        optimizer (torch.optim.Optimizer): Training optimizer
        scheduler (torch.optim.lr_scheduler._LRScheduler): LR scheduler
        checkpoint_dir (str): Checkpoint directory
        train_loader (torch.utils.data.DataLoader): Training data
        val_loader (torch.utils.data.DataLoader): Validation data
        training_stats (dict): Training metrics history
        start_epoch (int): Epoch to start/resume training from
    """
    def __init__(self,
                 base_model,
                 trainable_model,
                 config,
                 learning_rate,
                 optimizer,
                 scheduler,
                 checkpoint_dir,
                 train_loader,
                 val_loader,
                 device):
        self.device = device
        self.base_model = base_model
        self.trainable_model = trainable_model
        self.compiled_model = self.trainable_model.compiled_model

        self.learning_rate = learning_rate
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.checkpoint_dir = checkpoint_dir
        self.train_loader = train_loader
        self.val_loader = val_loader

        # Reference to training metrics (shared with trainable_model)
        self.training_stats = self.trainable_model.training_stats
        # Determine starting epoch (resume from checkpoint or start fresh)
        self.start_epoch = self.trainable_model.trained_till_epoch_index + 1

    def prepare_batch(self, batch, device):
        """
        Move batch tensors to the specified device for training.
        
        Args:
            batch (tuple): Batch from DataLoader containing (audio_x, audio_y)
                - audio_x: Input token sequences, shape [batch_size, seq_len]
                - audio_y: Target token sequences, shape [batch_size, seq_len]  
            device (torch.device): Target device (cuda/cpu)
            
        Returns:
            tuple: (audio_x, audio_y) moved to device
                - audio_x: torch.LongTensor, shape [batch_size, seq_len], device=device
                - audio_y: torch.LongTensor, shape [batch_size, seq_len], device=device
        """
        audio_x, audio_y = batch
        return audio_x.to(device), audio_y.to(device)

    def calculate_accuracy(self, output, target):
        """
        Calculate token-level prediction accuracy for monitoring training progress.
        
        Computes the fraction of tokens where the model's prediction (argmax of logits)
        matches the ground truth target token.
        
        Args:
            output (torch.Tensor): Model output logits, shape [batch_size * seq_len, vocab_size]
            target (torch.Tensor): Target tokens, shape [batch_size * seq_len]
            
        Returns:
            float: Accuracy as fraction of correct predictions in [0, 1]
        """
        pred = torch.argmax(output, dim=-1)  # Get predicted token indices
        correct = (pred == target).sum().item()  # Count correct predictions
        total = target.numel()  # Total number of tokens
        return correct / total

    def train_epoch(self):
        """
        Train the model for one complete epoch with mixed precision and gradient clipping.
        
        Processes all training batches, computing forward passes, losses, and gradients.
        Uses automatic mixed precision (AMP) with FP16 for memory efficiency and speed,
        gradient clipping for stability, and progress tracking.
        
        Returns:
            tuple: (average_loss, average_accuracy) for the epoch
                - average_loss (float): Mean cross-entropy loss across all batches
                - average_accuracy (float): Mean token accuracy across all batches
                
        Training pipeline per batch:
            1. Forward pass with mixed precision (FP16)
            2. Compute cross-entropy loss on flattened sequences
            3. Backward pass with gradient scaling
            4. Gradient clipping (norm=1.0) for stability
            5. Optimizer step with gradient unscaling
        """
        self.base_model.train()
        self.compiled_model.train()

        total_loss = 0.0
        total_accuracy = 0.0
        num_batches = 0

        pbar = tqdm(self.train_loader, desc="Training", leave=False)

        for batch in pbar:
            try:
                # Move batch to device
                audio_x, audio_y = self.prepare_batch(batch, self.device)
                self.optimizer.zero_grad(set_to_none=True)

                # Forward pass with automatic mixed precision (FP16)
                with torch.amp.autocast("cuda", dtype=torch.float16):
                    output = self.compiled_model(audio_x)  # [B, T, vocab_size]
                    B, T, C = output.shape
                    
                    # Flatten for cross-entropy loss computation
                    output_flat = output.reshape(-1, C)      # [B*T, vocab_size]
                    target_flat = audio_y.reshape(-1)        # [B*T]
                    
                    # Compute cross-entropy loss (supports FP16)
                    loss = F.cross_entropy(output_flat, target_flat)

                # Calculate accuracy for monitoring
                accuracy = self.calculate_accuracy(output_flat, target_flat)

                # Backward pass with gradient scaling for mixed precision
                scaler.scale(loss).backward()
                scaler.unscale_(self.optimizer)
                
                # Gradient clipping for training stability
                torch.nn.utils.clip_grad_norm_(self.base_model.parameters(), 1.0)
                
                # Optimizer step with automatic scaling
                scaler.step(self.optimizer)
                scaler.update()

                # Accumulate metrics
                total_loss += loss.item()
                total_accuracy += accuracy
                num_batches += 1

                # Update progress bar
                pbar.set_postfix({
                    'Loss': f'{loss.item():.4f}',
                    'Acc': f'{accuracy:.3f}'
                })

            except Exception as e:
                print(f"Error in training batch: {e}")
                continue

        return total_loss / max(num_batches, 1), total_accuracy / max(num_batches, 1)

    def validate_epoch(self):
        """
        Validate the model for one complete epoch without gradient computation.
        
        Evaluates model performance on validation data using the same loss and
        accuracy metrics as training. Uses mixed precision for consistency and
        efficiency, but disables gradient computation for speed and memory savings.
        
        Returns:
            tuple: (average_loss, average_accuracy) for validation epoch
                - average_loss (float): Mean cross-entropy loss across validation batches
                - average_accuracy (float): Mean token accuracy across validation batches
                
        Validation pipeline per batch:
            1. Forward pass with mixed precision (FP16) and no gradients
            2. Compute cross-entropy loss on flattened sequences
            3. Calculate accuracy metrics for monitoring
        """
        self.base_model.eval()

        total_loss = 0.0
        total_accuracy = 0.0
        num_batches = 0

        with torch.no_grad():  # Disable gradient computation for efficiency
            pbar = tqdm(self.val_loader, desc="Validation", leave=False)

            for batch in pbar:
                try:
                    # Move batch to device
                    audio_x, audio_y = self.prepare_batch(batch, self.device)
                    
                    # Forward pass with mixed precision (no gradients)
                    with torch.amp.autocast('cuda', dtype=torch.float16):
                        output = self.base_model(audio_x)  # [B, T, vocab_size]
                        B, T, C = output.shape
                        
                        # Flatten for loss computation
                        output_flat = output.reshape(-1, C)      # [B*T, vocab_size]
                        target_flat = audio_y.reshape(-1)        # [B*T]
                        
                        # Compute validation loss
                        loss = F.cross_entropy(output_flat, target_flat)

                    # Calculate accuracy for monitoring
                    accuracy = self.calculate_accuracy(output_flat, target_flat)

                    # Accumulate metrics
                    total_loss += loss.item()
                    total_accuracy += accuracy
                    num_batches += 1

                    # Update progress bar
                    pbar.set_postfix({
                        'Val Loss': f'{loss.item():.4f}',
                        'Val Acc': f'{accuracy:.3f}'
                    })

                except Exception as e:
                    print(f"Error in validation batch: {e}")
                    continue

        return total_loss / max(num_batches, 1), total_accuracy / max(num_batches, 1)

    def plot_training_progress(self, epoch):
        """
        Create and save comprehensive training progress visualization.
        
        Generates a multi-panel plot showing training/validation loss curves,
        accuracy curves, and overfitting indicator. Saves plots to checkpoint
        directory for monitoring training health and performance.
        
        Args:
            epoch (int): Current epoch number (0-indexed) for plot title and filename
            
        Plots created:
            1. Loss curves: Training and validation loss over epochs
            2. Accuracy curves: Training and validation accuracy over epochs  
            3. Overfitting indicator: Difference between validation and training loss
            
        Saves:
            - PNG file: progress_epoch_{epoch+1}.png in checkpoint directory
            - Displays plot in notebook if running interactively
        """
        train_losses = self.training_stats['train_losses']
        val_losses = self.training_stats['val_losses']
        train_accuracies = self.training_stats['train_accuracies']
        val_accuracies = self.training_stats['val_accuracies']

        if len(train_losses) == 0:
            return

        plt.figure(figsize=(15, 5))

        # Loss plot
        plt.subplot(1, 3, 1)
        epochs_range = range(1, len(train_losses) + 1)
        plt.plot(epochs_range, train_losses, 'b-',
                 label='Training Loss', linewidth=2)
        plt.plot(epochs_range, val_losses, 'r-',
                 label='Validation Loss', linewidth=2)
        plt.title(f'Training Progress (Epoch {epoch+1})')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True, alpha=0.3)

        # Accuracy plot
        plt.subplot(1, 3, 2)
        plt.plot(epochs_range, train_accuracies, 'b-',
                 label='Training Accuracy', linewidth=2)
        plt.plot(epochs_range, val_accuracies, 'r-',
                 label='Validation Accuracy', linewidth=2)
        plt.title(f'Accuracy Progress (Epoch {epoch+1})')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.legend()
        plt.grid(True, alpha=0.3)

        # Loss difference (overfitting indicator)
        plt.subplot(1, 3, 3)
        loss_diff = [v - t for t, v in zip(train_losses, val_losses)]
        plt.plot(epochs_range, loss_diff, 'g-', linewidth=2)
        plt.title('Overfitting Indicator (Val - Train Loss)')
        plt.xlabel('Epoch')
        plt.ylabel('Loss Difference')
        plt.axhline(y=0, color='k', linestyle='--', alpha=0.5)
        plt.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.savefig(os.path.join(checkpoint_dir,
                    f'progress_epoch_{epoch+1}.png'), dpi=150)
        plt.show()

    def train(self, num_epochs):
        """
        Execute the complete training loop with checkpointing and progress tracking.
        
        Runs training for the specified number of epochs, handling resumption from
        checkpoints, validation, learning rate scheduling, progress visualization,
        and automatic saving of best models.
        
        Args:
            num_epochs (int): Total number of epochs to train for. If model has
                            already been trained beyond this, training is skipped.
                            
        Training flow per epoch:
            1. Training phase: Process all training batches
            2. Validation phase: Evaluate on validation data
            3. Learning rate scheduling: Update LR based on validation loss
            4. Checkpointing: Save latest model and best model if improved
            5. Progress visualization: Plot metrics every 5 epochs
            6. Training summary: Save final results to JSON
            
        Outputs:
            - Checkpoint files: latest_checkpoint.pth, best_model.pth
            - Progress plots: progress_epoch_*.png files
            - Training summary: training_summary.json with final metrics
        """
        # Check if training is already complete
        if self.start_epoch >= num_epochs:
            print(f"Model is already trained to {num_epochs} epochs.")
            return

        print(f"\n{'='*60}")
        print(f"🚀 Training...")
        print(f"📊 Epochs: {self.start_epoch} → {num_epochs}")
        print(f"{'='*60}")

        for epoch in range(self.start_epoch, num_epochs):
            print(f"\nEpoch {epoch+1}/{num_epochs}")

            # Training
            train_loss, train_acc = self.train_epoch()

            # Validation
            val_loss, val_acc = self.validate_epoch()

            # Update scheduler
            self.scheduler.step(val_loss)

            # Save metrics
            self.training_stats['train_losses'].append(train_loss)
            self.training_stats['val_losses'].append(val_loss)
            self.training_stats['train_accuracies'].append(train_acc)
            self.training_stats['val_accuracies'].append(val_acc)

            # Check for best model
            is_best = val_loss < self.training_stats['best_val_loss']
            if is_best:
                self.training_stats['best_val_loss'] = val_loss

            # Save checkpoint every epoch.
            self.trainable_model.save(
                epoch, self.training_stats, self.learning_rate, is_best)

            # Print epoch summary
            improvement = "🔥" if is_best else ""
            print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.3f}")
            print(f"Val Loss: {val_loss:.4f} | Val Acc: {
                  val_acc:.3f} {improvement}")
            print(f"LR: {self.optimizer.param_groups[0]['lr']:.6f}")

            # Plot every 5 epochs or if it's the first resumed epoch
            if (epoch + 1) % 5 == 0 or epoch == self.start_epoch:
                self.plot_training_progress(epoch)

        print(f"\n🎉 Training completed!")
        print(f"Best validation loss: {
              self.training_stats['best_val_loss']:.4f}")
        print(f"Total epochs trained: {
              len(self.training_stats['train_losses'])}")
        print(f"Checkpoints saved in: {self.checkpoint_dir}")

        # Final comprehensive plot.
        self.plot_training_progress(
            len(self.training_stats['train_losses']) - 1)

        self._save_training_summary()
        print(f"Training completed!")

    def _save_training_summary(self):
        """
        Save comprehensive training summary to JSON file.
        
        Creates a summary of the complete training session including final metrics,
        best performance achieved, and total training duration. Useful for comparing
        different training runs and tracking model performance.
        
        Saves:
            training_summary.json containing:
                - total_epochs: Number of epochs trained
                - best_val_loss: Best validation loss achieved  
                - final_train_loss: Training loss at end of training
                - final_val_loss: Validation loss at end of training
                - final_train_acc: Training accuracy at end of training
                - final_val_acc: Validation accuracy at end of training
        """
        final_results = {
            'total_epochs': len(self.training_stats['train_losses']),
            'best_val_loss': self.training_stats['best_val_loss'],
            'final_train_loss': self.training_stats['train_losses'][-1] if self.training_stats['train_losses'] else None,
            'final_val_loss': self.training_stats['val_losses'][-1] if self.training_stats['val_losses'] else None,
            'final_train_acc': self.training_stats['train_accuracies'][-1] if self.training_stats['train_accuracies'] else None,
            'final_val_acc': self.training_stats['val_accuracies'][-1] if self.training_stats['val_accuracies'] else None
        }
        with open(os.path.join(self.checkpoint_dir, 'training_summary.json'), 'w') as f:
            json.dump(final_results, f, indent=2)
        print(f"📋 Training summary saved to: {
              os.path.join(self.checkpoint_dir, 'training_summary.json')}")


In [None]:
# Number of epochs to train for.
# If the saved checkpoint is < num_epochs then training will be done.
# Otherwise model will be loaded and put in eval mode and compiled.
# 180 is enough to overfit completely.
num_epochs = 50 if not USE_FAKE else 180

# Initialize base model and optimizer.
base_model = Wavenet(config).to(device)
learning_rate = 1e-3
optimizer = torch.optim.AdamW(base_model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=1, threshold=1e-3)

# Load the latest trainable model.
trainable_model = TrainableModel(config, checkpoint_dir, base_model, optimizer, scheduler,
                                 load_from_checkpoint=True)
print(f"Model parameters: {sum(p.numel()
      for p in trainable_model.base_model.parameters()):,}")

if trainable_model.trained_till_epoch_index + 1 < num_epochs:
    print(f"🚀 {'Resuming' if trainable_model.trained_till_epoch_index > -
          1 else 'Starting'} WaveNet Training")
    # Create the trainer.
    trainer = Trainer(base_model, trainable_model, config, learning_rate,
                      optimizer, scheduler, checkpoint_dir, train_loader, test_loader, device)
    trainer.train(num_epochs)
else:
    print(f"Model is already trained to {num_epochs} epochs.")


In [None]:
# Reload model for evals and generation.
base_model = Wavenet(config).to(device)
model = EvalModel(config, checkpoint_dir, base_model)


In [None]:
import torch
import torch.nn.functional as F
from math import prod


@torch.no_grad()
def generate_continuation(
    model,
    seed_tokens: torch.LongTensor,   # [B, T_seed], class IDs in [0..255]
    n_steps: int,                    # how many new samples to generate
    temperature: float = 1.0,
    top_k: int | None = None,
    device: str | torch.device = "cuda",
    rf_override: int | None = None,  # optionally pass model RF
):
    """
    Generate audio continuation using autoregressive sampling from WaveNet.
    
    Performs autoregressive generation by iteratively predicting the next token
    based on the current sequence context. Uses temperature and top-k sampling
    for controllable generation quality vs diversity trade-off.
    
    Args:
        model: WaveNet model with get_receptive_field() method
        seed_tokens (torch.LongTensor): Initial sequence, shape [batch_size, seed_length],
                                      dtype long, values in [0, 255]
        n_steps (int): Number of new tokens to generate
        temperature (float): Sampling temperature controlling randomness.
                           1.0 = normal, < 1.0 = more deterministic, > 1.0 = more random
        top_k (int, optional): If specified, only sample from top-k most likely tokens
        device (str | torch.device): Device for computation
        rf_override (int, optional): Override model's receptive field size
        
    Returns:
        torch.LongTensor: Extended sequence, shape [batch_size, seed_length + n_steps],
                        dtype long, values in [0, 255]
        
    Note:
        - Uses mixed precision (FP16) for forward passes but FP32 for sampling stability
        - Crops context to receptive field size for efficiency
        - Supports only batch_size=1 for autoregressive generation
        
    Example:
        >>> seed = torch.randint(0, 256, (1, 1000))  # 1000 token seed
        >>> generated = generate_continuation(model, seed, n_steps=5000, temperature=0.8)
        >>> print(generated.shape)  # torch.Size([1, 6000])
    """
    model.eval()
    seq = seed_tokens.to(device)  # [B, T_seed]
    B = seq.size(0)
    assert B == 1, "start with batch size 1 for autoregressive generation"

    # try to infer RF from model config if not provided
    if rf_override is not None:
        rf = rf_override
    else:
        rf = model.get_receptive_field()

    for _ in tqdm(range(n_steps), desc='Generating'):
        # crop to receptive field context to save compute
        ctx = seq[:, -rf:] if seq.size(1) > rf else seq

        # Generate in fp32 for stability, torch.amp.autocast("cuda", dtype=torch.float16):
        with torch.inference_mode():
            logits = model(ctx).float()          # [1, T_ctx, 256]
            logits = logits[:, -1, :]    # [1, 256] last step

        # temperature / top-k sampling (avoid argmax for better naturalness)
        logits = logits / max(temperature, 1e-6)
        if top_k is not None and top_k > 0:
            topk_vals, topk_idx = torch.topk(logits, k=top_k, dim=-1)
            probs = torch.zeros_like(
                logits).scatter(-1, topk_idx, F.softmax(topk_vals, dim=-1))
        else:
            probs = F.softmax(logits, dim=-1)

        next_tok = torch.multinomial(probs, num_samples=1)  # [1,1] Long
        seq = torch.cat([seq, next_tok], dim=1)             # append token

    return seq  # [1, T_seed + n_steps] (tokens)


@torch.no_grad()
def generate_greedy(model, seed, n_steps, device="cuda", rf=None):
    """
    Generate audio continuation using deterministic greedy decoding.
    
    Performs autoregressive generation by always selecting the most likely
    next token (argmax). This produces deterministic, high-quality output
    but may lack diversity compared to sampling methods.
    
    Args:
        model: WaveNet model with get_receptive_field() method
        seed (torch.LongTensor): Initial sequence, shape [batch_size, seed_length],
                               dtype long, values in [0, 255]
        n_steps (int): Number of new tokens to generate
        device (str): Device for computation, default "cuda"
        rf (int, optional): Receptive field size, will query model if None
        
    Returns:
        torch.LongTensor: Extended sequence, shape [batch_size, seed_length + n_steps],
                        dtype long, values in [0, 255]
        
    Note:
        - Uses FP32 for argmax stability
        - Crops context to receptive field for efficiency
        - Deterministic output (no randomness)
        
    Example:
        >>> seed = torch.randint(0, 256, (1, 1000))
        >>> generated = generate_greedy(model, seed, n_steps=2000)
        >>> print(generated.shape)  # torch.Size([1, 3000])
    """
    model.eval()
    seq = seed.to(device)
    rf = rf or model.get_receptive_field()
    for _ in tqdm(range(n_steps), desc='Generating'):
        ctx = seq[:, -rf:] if seq.size(1) > rf else seq
        # Force FP32 for numerical stability in argmax
        logits = model(ctx)[:, -1, :].float()
        nxt = logits.argmax(dim=-1, keepdim=True)  # Greedy selection
        seq = torch.cat([seq, nxt], 1)
    return seq


In [None]:
import torch
import torch.nn.functional as F
from tqdm import tqdm


@torch.no_grad()
def generate_continuation_fixed(
    model,
    seed_tokens: torch.LongTensor,   # [B, T_seed]
    n_steps: int,
    temperature: float = 1.0,
    top_k: int | None = None,
    device: str | torch.device = "cuda",
    use_fp32: bool = True,  # Use FP32 for generation stability
):
    """
    Enhanced generation function with improved numerical stability and debugging.
    
    An improved version of generate_continuation with better FP32 handling,
    detailed progress logging, and numerical stability improvements for
    high-quality audio generation.
    
    Args:
        model: WaveNet model with get_receptive_field() method
        seed_tokens (torch.LongTensor): Initial sequence, shape [batch_size, seed_length],
                                      dtype long, values in [0, 255]
        n_steps (int): Number of new tokens to generate
        temperature (float): Sampling temperature. Default: 1.0
        top_k (int, optional): Top-k filtering for sampling. Default: None (no filtering)
        device (str | torch.device): Computation device. Default: "cuda"
        use_fp32 (bool): Whether to use FP32 for generation stability. Default: True
        
    Returns:
        torch.LongTensor: Extended sequence, shape [batch_size, seed_length + n_steps],
                        dtype long, values in [0, 255]
        
    Features:
        - Improved numerical stability with proper FP32 handling
        - Progress logging every 1000 steps with entropy and probability metrics
        - Robust top-k filtering implementation
        - Better error handling and debugging output
        
    Example:
        >>> seed = torch.randint(0, 256, (1, 2000))
        >>> generated = generate_continuation_fixed(
        ...     model, seed, n_steps=8000, temperature=0.9, top_k=100
        ... )
        >>> print(generated.shape)  # torch.Size([1, 10000])
    """
    model.eval()
    seq = seed_tokens.to(device)
    B = seq.size(0)
    assert B == 1, "Use batch size 1 for generation"

    rf = model.get_receptive_field()
    print(f"Using receptive field: {rf}")

    for step in tqdm(range(n_steps), desc='Generating'):
        # Crop to receptive field
        ctx = seq[:, -rf:] if seq.size(1) > rf else seq

        if use_fp32:
            # Use FP32 for stability
            with torch.inference_mode():
                logits = model(ctx).float()  # Force FP32
        else:
            # Use mixed precision
            with torch.inference_mode(), torch.amp.autocast("cuda", dtype=torch.float16):
                logits = model(ctx)
                logits = logits.float()  # Convert to FP32 for sampling

        logits = logits[:, -1, :]  # [1, 256] - last timestep

        # Apply temperature
        if temperature != 1.0:
            logits = logits / max(temperature, 1e-6)

        # Apply top-k filtering if specified
        if top_k is not None and top_k > 0:
            topk_vals, topk_idx = torch.topk(
                logits, k=min(top_k, logits.size(-1)), dim=-1)
            # Zero out non-top-k logits
            logits_filtered = torch.full_like(logits, float('-inf'))
            logits_filtered.scatter_(-1, topk_idx, topk_vals)
            logits = logits_filtered

        # Sample next token
        probs = F.softmax(logits, dim=-1)
        next_tok = torch.multinomial(probs, num_samples=1)  # [1, 1]

        # Append to sequence
        seq = torch.cat([seq, next_tok], dim=1)

        # Debug: Print some info every 1000 steps
        if step % 1000 == 0:
            prob_max = probs.max().item()
            entropy = -(probs * torch.log(probs + 1e-8)).sum().item()
            print(f"Step {step}: max_prob={prob_max:.3f}, entropy={
                  entropy:.3f}, token={next_tok.item()}")

    return seq


@torch.no_grad()
def generate_with_teacher_forcing_test(model, seed_tokens, target_tokens, device="cuda"):
    """
    Test generation quality by comparing autoregressive vs teacher-forced predictions.
    
    Validates model consistency by comparing autoregressive generation with
    teacher-forced predictions on the same sequence. High agreement indicates
    good model training and generation stability.
    
    Args:
        model: Trained WaveNet model
        seed_tokens (torch.LongTensor): Initial sequence, shape [1, seed_length]
        target_tokens (torch.LongTensor): Full target sequence for comparison,
                                        shape [1, target_length]
        device (str): Computation device. Default: "cuda"
        
    Returns:
        torch.LongTensor: Generated sequence from autoregressive sampling,
                        shape [1, generated_length]
        
    Prints:
        - Match percentage between teacher forcing and autoregressive generation
        - Quality assessment based on agreement level
        
    Example:
        >>> seed = target[:, :1000]  # First 1000 tokens as seed
        >>> generated = generate_with_teacher_forcing_test(model, seed, target)
        Teacher forcing vs autoregressive match: 0.847
        ✅ Good match between teacher forcing and autoregressive generation
    """
    model.eval()

    # Teacher forcing (what the model should predict)
    with torch.inference_mode():
        teacher_logits = model(target_tokens.to(device))  # [1, T, 256]

    # Autoregressive generation
    generated = seed_tokens.clone()
    rf = model.get_receptive_field()

    for i in tqdm(range(target_tokens.size(1) - seed_tokens.size(1))):
        ctx = generated[:, -rf:] if generated.size(1) > rf else generated

        with torch.inference_mode():
            logits = model(ctx.to(device)).float()
            logits = logits[:, -1, :]  # Last timestep

        # Use argmax for deterministic comparison
        next_tok = logits.argmax(dim=-1, keepdim=True)
        generated = torch.cat([generated, next_tok.cpu()], dim=1)

    # Compare predictions
    teacher_preds = teacher_logits.argmax(dim=-1)  # [1, T]

    # Compare the overlapping region
    overlap_start = seed_tokens.size(1)
    overlap_end = min(generated.size(1), teacher_preds.size(1))

    if overlap_end > overlap_start:
        generated_overlap = generated[:, overlap_start:overlap_end]
        teacher_overlap = teacher_preds[:, overlap_start:overlap_end]

        matches = (generated_overlap == teacher_overlap.cpu()).float().mean()
        print(f"Teacher forcing vs autoregressive match: {matches:.3f}")

        if matches < 0.5:
            print("⚠️  Poor match between teacher forcing and autoregressive generation")
        else:
            print("✅ Good match between teacher forcing and autoregressive generation")

    return generated

# Usage example:


def test_generation_quality(model, dataset, audio_processor, device):
    """
    Comprehensive test suite for evaluating generation quality and model behavior.
    
    Performs multiple tests to assess generation quality including teacher forcing
    consistency, temperature sensitivity, and mode collapse detection. Provides
    detailed diagnostics for model debugging and validation.
    
    Args:
        model: Trained WaveNet model for testing
        dataset: Dataset to sample test audio from
        audio_processor (AudioProcessor): Audio preprocessing utilities
        device (torch.device): Computation device
        
    Returns:
        torch.Tensor: Sample generation from the test suite
        
    Tests performed:
        1. Teacher forcing consistency: Compare autoregressive vs teacher-forced predictions
        2. Temperature sensitivity: Test generation diversity across different temperatures
        3. Mode collapse detection: Check for identical outputs across multiple runs
        
    Example:
        >>> test_result = test_generation_quality(model, dataset, audio_processor, device)
        === Testing Generation Quality ===
        1. Testing teacher forcing consistency...
        2. Testing different temperatures...
        3. Testing for mode collapse...
        ✅ Generations show diversity
    """
    print("=== Testing Generation Quality ===")

    # Get a test sample
    audio, sr, _, _ = dataset[0]
    audio = audio_processor.resample_audio(audio, sr, 16000)
    audio = audio_processor.normalize(audio)

    # Encode to tokens
    mu_law = codec
    tokens = mu_law.mu_law_encode(audio.squeeze()).unsqueeze(0)  # [1, T]

    # Split into seed and target
    split_point = min(8000, tokens.size(1) // 2)  # Use first half as seed
    seed = tokens[:, :split_point]
    target = tokens[:, :split_point + 1000]  # Small target for testing

    print(f"Seed length: {seed.size(1)}")
    print(f"Target length: {target.size(1)}")

    # Test 1: Teacher forcing vs autoregressive
    print("\n1. Testing teacher forcing consistency...")
    generated = generate_with_teacher_forcing_test(model, seed, target, device)

    # Test 2: Different temperatures
    print("\n2. Testing different temperatures...")
    for temp in [0.1, 0.8, 1.0, 1.2]:
        print(f"Temperature {temp}:")
        gen_seq = generate_continuation_fixed(
            model, seed, n_steps=100, temperature=temp, device=device
        )

        # Check diversity
        unique_tokens = gen_seq[:, seed.size(1):].unique().numel()
        print(f"  Unique tokens in 100 steps: {unique_tokens}")

        if temp > 1.0 and unique_tokens < 10:
            print(f"  ⚠️  Low diversity at temperature {temp}")

    # Test 3: Check for mode collapse
    print("\n3. Testing for mode collapse...")
    generations = []
    for i in range(5):
        gen_seq = generate_continuation_fixed(
            model, seed, n_steps=200, temperature=1.0, device=device
        )
        generations.append(gen_seq[:, seed.size(1):])

    # Check if all generations are identical
    all_same = all(torch.equal(generations[0], gen) for gen in generations[1:])
    if all_same:
        print("⚠️  Mode collapse detected - all generations identical")
    else:
        print("✅ Generations show diversity")

    return generations[0]


In [None]:
def gen(i):
    # Get one batch from test_loader;
    audio, _ = next(iter(test_loader))     # x: [B, T-1] tokens (Long)
    audio = audio[i, :].unsqueeze(0).to(device)

    print(audio.shape)

    play_encoded_sample(audio.cpu())
    display_waveform(audio.cpu(), "Original")

    # Seconds to seed the generation with.
    seconds_to_seed = 0.3
    length_to_seed = int(16000*seconds_to_seed - 1)
    # [1, seed_len]; ensure <= available length
    seed = audio[:1, : length_to_seed]

    # Seconds to generate.
    sec_to_generate = 1.7
    n_new = int(16000 * sec_to_generate)

    play_encoded_sample(seed.cpu())

    display_waveform(seed.cpu(), "Seed")

    audio_full = generate_continuation(
        model.compiled_model,
        seed_tokens=seed,
        n_steps=n_new,
        temperature=1.0,      # tweak 0.7–1.2
        top_k=100,            # optional; try None or 50–200
        device=device,
    )
    # audio_full_greedy = generate_greedy(model.base_model, seed, n_new, device=device)

    play_encoded_sample(audio_full.cpu())
    display_waveform(audio_full.cpu(), "audio_full")

    # play_encoded_sample(audio_full_greedy.cpu())
    # display_waveform(audio_full_greedy.cpu(), "audio_full_greedy")


gen(0)
gen(1)
gen(2)
gen(3)


In [None]:
rf = model.get_receptive_field()

# use the training loader for the overfit window
x_batch, y_batch = next(iter(train_loader))
i = 0  # or pick the exact index you overfit on
# Optionally pick a prefix length N to avoid pads; else use full length
# example: 0.5s @16k if tokens==samples
N = min(8000, x_batch.size(1))
q = x_batch[i:i+1, :N].to(device)             # [1, N]
y = y_batch[i:i+1, :N].to(device)             # [1, N]

# q: [1, W] tokens of the SAME training window you overfit
S = max(rf, (q.size(1)-1)//2)            # seed ≥ RF, inside the same window
seq = q[:, :S].to(device)

with torch.no_grad():
    for _ in tqdm(range((q.size(1)-1) - S)):
        ctx = seq[:, -rf:] if seq.size(1) > rf else seq
        logits = model.base_model(ctx)[:, -1, :].float()  # cast logits to fp32
        nxt = logits.argmax(-1, keepdim=True)  # greedy
        seq = torch.cat([seq, nxt], 1)

pred = seq[:, S:]                # [1, W-1-S]
tgt = q[:, S+1:].to(device)
fr_acc = (pred == tgt).float().mean().item()
print("FR acc within-window:", fr_acc)
