<a href="https://colab.research.google.com/github/AndrewKruszka/NeuralMachineLearning/blob/main/wavenet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#WaveNet Audio Generation with PyTorch  
### Andrew Kruszka

This notebook implements a WaveNet-style autoregressive model using PyTorch.  
It loads audio, trains a model to learn its waveform distribution using mu-law quantization,
and generates new audio samples sample-by-sample

---

### 🧠 Highlights:
- Raw audio waveform modeling with 8-bit mu-law encoding
- Dilated causal convolutions (WaveNet architecture)
- Autoregressive sampling with temperature scaling
- Real-time waveform synthesis (playable directly in notebook)

---

#Step 1: Environment Setup

This step prepares the notebook environment for training and generating audio with WaveNet.

-Clones the official [pytorch-wavenet](https://github.com/Vichoko/pytorch-wavenet) repository
- Installs required Python packages:
  - `librosa` for audio processing
  - `soundfile` for saving/loading audio files
  - `einops` for tensor reshaping (used by some model variants)
- Adds the cloned repo to Python's import path
- Verifies the directory structure and loads core libraries:
  - PyTorch, NumPy, librosa, matplotlib
  - Loads the WaveNet model architecture


In [2]:
 # STEP 1: Setup
!git clone https://github.com/Vichoko/pytorch-wavenet.git
%cd pytorch-wavenet
!pip install librosa soundfile
!pip install einops
import einops
import sys
sys.path.append('/content/pytorch-wavenet')

!ls /content/pytorch-wavenet

import torch
import torch.nn as nn
import torch.optim as optim
import librosa
import soundfile as sf
import numpy as np
import os
from wavenet_model import WaveNetModel
from IPython.display import Audio, display, clear_output
import matplotlib.pyplot as plt
import glob
import torch.nn.functional as F

Cloning into 'pytorch-wavenet'...
remote: Enumerating objects: 1168, done.[K
remote: Counting objects: 100% (10/10), done.[K
remote: Compressing objects: 100% (7/7), done.[K
remote: Total 1168 (delta 3), reused 6 (delta 3), pack-reused 1158 (from 1)[K
Receiving objects: 100% (1168/1168), 268.95 MiB | 29.49 MiB/s, done.
Resolving deltas: 100% (720/720), done.
/content/pytorch-wavenet
audio_data.py		 model_logging.py  snapshots	    WaveNet_demo.ipynb
demo.ipynb		 notebooks	   tests	    wavenet_model.py
generated_samples	 optimizers.py	   test_script.py   wavenet_modules.py
Generated_Samples.ipynb  profiling.ipynb   train_samples    wavenet_training.py
generate_script.py	 __pycache__	   train_script.py
LICENSE			 README.md	   visualize.py


## Step 2: Download Sample Audio

This step creates a `data/` directory (if it doesn't already exist) and downloads sample `.wav` files from the Free Spoken Digit Dataset (FSDD).  
The audio file will be used to train and test the WaveNet model.

In [2]:
# STEP 2: Download Sample Audio
import os

os.makedirs('data', exist_ok=True)

# Expanded list of .wav files from FSDD (2 samples per digit, george)
sample_urls = [
    f"https://github.com/Jakobovski/free-spoken-digit-dataset/raw/master/recordings/{digit}_george_{take}.wav"
    for digit in range(10)
    for take in range(2)
]

# Download each file and print status
for i, url in enumerate(sample_urls):
    filename = f"data/sample_{i}.wav"
    print(f"Downloading to {filename}...")
    os.system(f"wget -q -O {filename} {url}")

print("All files downloaded for training.")


Downloading to data/sample_0.wav...
Downloading to data/sample_1.wav...
Downloading to data/sample_2.wav...
Downloading to data/sample_3.wav...
Downloading to data/sample_4.wav...
Downloading to data/sample_5.wav...
Downloading to data/sample_6.wav...
Downloading to data/sample_7.wav...
Downloading to data/sample_8.wav...
Downloading to data/sample_9.wav...
Downloading to data/sample_10.wav...
Downloading to data/sample_11.wav...
Downloading to data/sample_12.wav...
Downloading to data/sample_13.wav...
Downloading to data/sample_14.wav...
Downloading to data/sample_15.wav...
Downloading to data/sample_16.wav...
Downloading to data/sample_17.wav...
Downloading to data/sample_18.wav...
Downloading to data/sample_19.wav...
All files downloaded for training.


## Step 3: Load and Normalize Audio Samples

This step loads the audio samples from the `data/` directory and prepares them for training:

- Loads each `.wav` file at a 16kHz sample rate
- Adds a small sine wave if the audio is nearly silent to prevent collapse during encoding
- Normalizes each waveform to the range [-1, 1]
- Pads or trims each waveform to exactly 1 second (16,000 samples)
- Converts each waveform into a PyTorch tensor and stacks them into a single batch

The final output is a tensor of shape `(N, 16000)`, where `N` is the number of audio samples.


In [4]:
sample_rate = 16000
max_length = sample_rate   # 1 second = 16 000 samples
waveforms = []

for file_path in sorted(glob.glob("data/sample_*.wav")):
    waveform, sr = librosa.load(file_path, sr=sample_rate)

    # Add tiny sine wave or noise to prevent flat zero
    if np.max(np.abs(waveform)) < 1e-4:
        t = np.linspace(0, 1, len(waveform))
        waveform += 0.05 * np.sin(2 * np.pi * 220 * t)

    # Normalize to [-1, 1]
    max_amp = np.max(np.abs(waveform))
    if max_amp > 0:
        waveform = waveform / max_amp

    # Pad or trim to exactly 1 second
    if len(waveform) > max_length:
        waveform = waveform[:max_length]
    else:
        waveform = np.pad(waveform, (0, max_length - len(waveform)))

    waveforms.append(torch.tensor(waveform).unsqueeze(0))  # shape: (1, T)

# Stack into (N, T)
waveforms = torch.cat(waveforms, dim=0)
print("Waveforms shape:", waveforms.shape)

Waveforms shape: torch.Size([20, 16000])


## Step 4: Encode Input and Define the Model

This step prepares the input data and defines the WaveNet model for training:

- **Mu-law encoding**: Compresses each waveform into 256 discrete values to better represent audio dynamics.
- **One-hot encoding**: Converts the mu-law encoded data into one-hot vectors with shape `(N, 256, T)`, where `N` is the number of samples and `T` is the number of time steps.
- Prints the input and target tensor shapes, value ranges, and unique classes to verify correct encoding.

Next, a **WaveNet model** is initialized with:
- 10 layers per block
- 2 blocks
- 32 channels for dilation and residual connections
- 64 channels for skip and end connections
- 256 output classes (matching mu-law encoding)

The model's **receptive field** (i.e., the number of past time steps it can see when predicting) is printed to verify it fits within the length of the audio samples.


In [5]:
# STEP 4: Encode input and define the model

# Convert to 8-bit mu-law encoding for training
def mu_law_encode(audio, quantization_channels=256):
    mu = quantization_channels - 1
    safe_audio = torch.clamp(audio, -1.0, 1.0)
    magnitude = torch.log1p(mu * torch.abs(safe_audio)) / torch.log1p(torch.tensor(mu, dtype=torch.float32))
    signal = torch.sign(safe_audio) * magnitude
    encoded = ((signal + 1) / 2 * mu + 0.5).long()
    return encoded

# One-hot encoder for batch input
def one_hot_encode(indices, num_classes=256):
    # (N, T) → (N, 256, T)
    return torch.nn.functional.one_hot(indices, num_classes).float().permute(0, 2, 1)

# Apply encoding to batch
target = mu_law_encode(waveforms)  # shape: (N, T)
input_audio = one_hot_encode(target, num_classes=256)  # shape: (N, 256, T)

print("input_audio shape:", input_audio.shape)
print("target shape:", target.shape)
print("Target min/max (mu-law encoded):", target.min().item(), target.max().item())
print("Unique target classes:", torch.unique(target))


model = WaveNetModel(
    layers=10,
    blocks=2,
    dilation_channels=32,
    residual_channels=32,
    skip_channels=64,
    end_channels=64,
    classes=256
)

print("Receptive field:", model.receptive_field)  # should now be 2047



input_audio shape: torch.Size([20, 256, 16000])
target shape: torch.Size([20, 16000])
Target min/max (mu-law encoded): 0 255
Unique target classes: tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
        140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152,

## Step 5: Train the WaveNet Model

This step trains the WaveNet model using the prepared audio data:

- **Optimizer**: Adam optimizer is used with a learning rate of `5e-4`.
- **Crop size**: A random crop of each audio sample is taken, with a minimum size slightly larger than the model's receptive field (receptive field + 512 samples).
- **Random cropping**: For each input waveform, a random non-silent crop is selected to ensure the model learns meaningful patterns rather than silence.
- **Training loop**:
  - For each epoch, each sample is cropped, passed through the model, and the loss is calculated.
  - Cross-entropy loss is used between the model's predicted output and the true mu-law encoded labels.
  - Gradients are computed and optimizer steps are taken.
  - The average loss per epoch is printed along with the number of valid (non-silent) crops used.

The training ensures the model gradually learns to predict the next audio samples based on past context from the receptive field.


In [7]:
model.train()
optimizer = optim.Adam(model.parameters(), lr=5e-4)

# Smaller receptive field model
receptive_field = model.receptive_field  # should be 2047 now
max_available = input_audio.shape[-1]
crop_size = min(receptive_field + 512, max_available)  # e.g. 2047 + 512 = 2559

print(f"Using receptive field = {receptive_field}, crop size = {crop_size}")

for epoch in range(20):
    total_loss = 0.0
    valid = 0

    for i in range(input_audio.size(0)):
        x = input_audio[i:i+1]   # (1, 256, 16000)
        y = target[i:i+1]        # (1, 16000)

        # Skip if too short
        if x.shape[-1] < crop_size:
            continue

        # Random crop with non-silent check
        max_start = x.shape[-1] - crop_size
        for _ in range(10):
            start = np.random.randint(0, max_start + 1)
            x_crop = x[:, :, start : start + crop_size]
            y_crop = y[:, start : start + crop_size]  # still integer
            if torch.std(y_crop.float()) > 1.0:       # cast to float here
                break
        else:
            continue  # no valid crop found

        optimizer.zero_grad()
        out = model(x_crop)          # (T_out, 256)
        T_out = out.shape[0]

        y_crop = y_crop[:, -T_out:].squeeze(0)  # (T_out,)
        loss = nn.CrossEntropyLoss()(out, y_crop)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        valid += 1

    avg_loss = total_loss/valid if valid else float('nan')
    print(f"Epoch {epoch+1:2d} — Avg Loss: {avg_loss:.4f} (on {valid} crops)")

Using receptive field = 2047, crop size = 2559
Epoch  1 — Avg Loss: 5.5411 (on 20 crops)
Epoch  2 — Avg Loss: 5.3864 (on 20 crops)
Epoch  3 — Avg Loss: 4.0153 (on 20 crops)
Epoch  4 — Avg Loss: 3.0511 (on 20 crops)
Epoch  5 — Avg Loss: 4.6724 (on 20 crops)
Epoch  6 — Avg Loss: 3.5333 (on 20 crops)
Epoch  7 — Avg Loss: 3.8498 (on 20 crops)
Epoch  8 — Avg Loss: 4.6657 (on 20 crops)
Epoch  9 — Avg Loss: 3.8172 (on 20 crops)
Epoch 10 — Avg Loss: 3.7017 (on 20 crops)
Epoch 11 — Avg Loss: 3.6322 (on 20 crops)
Epoch 12 — Avg Loss: 3.1660 (on 20 crops)
Epoch 13 — Avg Loss: 3.8605 (on 20 crops)
Epoch 14 — Avg Loss: 4.5765 (on 20 crops)
Epoch 15 — Avg Loss: 3.9234 (on 20 crops)
Epoch 16 — Avg Loss: 2.6993 (on 20 crops)
Epoch 17 — Avg Loss: 3.6316 (on 20 crops)
Epoch 18 — Avg Loss: 2.8498 (on 20 crops)
Epoch 19 — Avg Loss: 3.5250 (on 20 crops)
Epoch 20 — Avg Loss: 2.7145 (on 20 crops)


## Step 6: Generate a 1-Second Audio Sample from the Trained Model

This step uses the trained WaveNet model to generate new audio:

- **Seed selection**: A short initial seed (1024 samples) is taken from one training example to start the generation.
- **Autoregressive generation**:
  - The model repeatedly predicts the next audio sample based on the most recent receptive field of context.
  - Each prediction is sampled from the model’s probability distribution over 256 mu-law classes.
  - The sampled class is one-hot encoded and appended to the generated audio.
- **Mu-law decoding**:
  - After generation, the mu-law encoded sequence is converted back into a real-valued waveform using inverse mu-law decoding.
- **Playback**:
  - The final waveform (1 second long) is played back using IPython’s Audio display.

This process generates a realistic-sounding waveform based on patterns the model learned during training.


In [12]:
# STEP 6: Generate a 1-second sample from the trained model

# 1) Select one example from your batch (e.g. index 0)
idx = 0
seed = input_audio[idx : idx+1, :, :1024]  # shape: (1, 256, 1024)
generated = seed.clone()

model.eval()
with torch.no_grad():
    # We want to generate 16000 total samples, we already have 1024
    to_generate = 16000 - generated.shape[-1]
    for _ in range(to_generate):
        # a) extract last receptive_field samples, pad if needed
        rf = model.receptive_field
        context = generated[:, :, -rf+1 :]
        padded = F.pad(context, (rf-1, 0))  # now (1,256,rf)

        # b) run through model
        out = model(padded)

        # c) pick logits for the next sample
        if out.ndim == 3:        # out shape: (1,256,T)
            logits = out[0, :, -1]   # shape (256,)
        elif out.ndim == 2:      # out shape: (T,256)
            logits = out[-1, :]     # shape (256,)
        else:
            raise RuntimeError(f"Unexpected out shape {out.shape}")

        # d) sample from the 256-way distribution
        probs = F.softmax(logits, dim=0)
        idx_sample = torch.multinomial(probs, num_samples=1).item()  # scalar in [0..255]

        # e) convert to one-hot and append
        one_hot = F.one_hot(torch.tensor([idx_sample]), num_classes=256).float()
        one_hot = one_hot.view(1, 256, 1)  # shape (1,256,1)
        generated = torch.cat([generated, one_hot], dim=2)

# 2) Decode mu-law back to waveform
def mu_law_decode(encoded, quantization_channels=256):
    mu = quantization_channels - 1
    signal = 2 * (encoded.float() / mu) - 1
    magnitude = (1.0 / mu) * ((1.0 + mu)**signal.abs() - 1)
    return signal.sign() * magnitude

# pick the batch and channel dims away, leaving a (16000,) tensor
gen_indices = generated.argmax(dim=1).squeeze(0)   # shape: (16000,)
waveform_out = mu_law_decode(gen_indices)

# 3) Play it!
Audio(waveform_out.cpu().numpy(), rate=16000)


## Step 7: Decode and Play the Generated Audio

This step processes the generated one-hot encoded audio and plays it:

- **One-hot decoding**:
  - Converts the generated tensor from one-hot format `(1, 256, T)` back into mu-law class indices `(1, T)`.
- **Mu-law decoding**:
  - Transforms the discrete mu-law indices back into a continuous audio waveform in the range [-1, 1] using the inverse mu-law formula.
- **Playback**:
  - Uses IPython’s `Audio` class to play the resulting waveform at a 16kHz sample rate.

This final step allows listening to the 1-second audio sample produced by the WaveNet model.


In [3]:
# STEP 7: Play the generated audio
# Convert one-hot (1, 256, T) → (1, T) mu-law indices
generated_indices = torch.argmax(generated, dim=1)  # (1, T)

# Mu-law decode
def mu_law_decode(encoded, quantization_channels=256):
    mu = quantization_channels - 1
    signal = 2 * (encoded.float() / mu) - 1
    magnitude = (1.0 / mu) * ((1.0 + mu)**torch.abs(signal) - 1)
    return torch.sign(signal) * magnitude

waveform_out = mu_law_decode(generated_indices.squeeze())  # (T,)

# === STEP 7: Play the audio ===
from IPython.display import Audio
Audio(waveform_out.cpu().numpy(), rate=16000)

NameError: name 'generated' is not defined

### Suggestions to Improve Audio Quality:

- **Train with a larger dataset**: Use more spoken samples, multiple speakers, and more variation in pronunciation to allow the model to generalize better.
- **Train for more epochs**: Extending training to 100+ epochs with early stopping can help the model capture finer structures without overfitting.
- **Use data augmentation**: Introduce pitch shifting, time stretching, and noise injection during training to simulate more variability and prevent overfitting.
- **Increase model capacity**: Gradually increasing the number of layers, dilation channels, and blocks can give the model more representational power.
- **Use a two-stage generation**: First predict coarse features (like phonemes or spectrograms) and then use a second model to refine into waveform details (common in modern WaveNet-based systems).
- **Fine-tune learning rate schedules**: Using learning rate decay or schedulers can allow finer adjustments toward convergence in later stages of training.

Overall, the generation proves the model has learned meaningful structure, but larger datasets, longer training, and model refinements would lead to much clearer and more natural audio output.