# Notebook for experimenting with amp modeling

## Tasks
1. Basic Wavenet implementation
2. LSTM? Not sure how time series is handled but I saw it online
3. Audio as input
4. Audio as output
5. fixing any bugs with audio. probably will be volume issues and pops etc

## Goal for this notebook
Make a capture of an amp. Model vs Profile. Profiling is when you model a amp completely. This includes eq and gain knobs and the way it responds to input. Profiling is a lot harder. Capturing just takes a snapshot of the amp at a certain setting. For this project, I will be capturing the Dumble clone amp from Neural DSP with these settings.

In [1]:
print('hi')

hi


In [2]:
# import dependecies
import torch
import torch.nn as nn
import torch.nn.functional as F

  cpu = _conversion_method_template(device=torch.device("cpu"))


## Wavenet
A wavenet is a generative model

In [None]:
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# ----------------------
# mu-law for compression
# just a bunch of math that uses log compression to remove harsh sounds
# might not use dont want to compress dynamics
# ----------------------
MU = 255
def mu_law_encode(x, mu=MU):
    if isinstance(x, torch.Tensor):
        sign = torch.sign(x)
        mag = torch.log1p(mu * x.abs()) / math.log1p(mu)
        return ((sign * mag + 1) / 2 * mu).long()
    else:
        x = np.clip(x, -1, 1)
        mag = np.log1p(mu * np.abs(x)) / np.log1p(mu)
        encoded = np.sign(x) * mag
        return ((encoded + 1) / 2 * mu).astype(np.int64)


def mu_law_decode(encoded, mu=MU):
    if isinstance(encoded, torch.Tensor):
        enc = encoded.float()
        x = 2 * (enc / mu) - 1
        sign = torch.sign(x)
        mag = (1 / mu) * ((1 + mu) ** x.abs() - 1)
        return sign * mag
    else:
        x = 2 * (encoded.astype(np.float32) / mu) - 1
        sign = np.sign(x)
        mag = (1 / mu) * ((1 + mu) ** np.abs(x) - 1)
        return sign * mag


# ----------------------
# Causal Conv1d wrapper
# Just a conv layer thats causal meaning that it can see into the future
# ----------------------
class CausalConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1):
        super().__init__()
        self.pad = (kernel_size - 1) * dilation
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size,
                              padding=self.pad, dilation=dilation)

    def forward(self, x):
        # x: (batch, channel, timesteps)
        out = self.conv(x)
        if self.pad:
            return out[:, :, :-self.pad]  # remove future context
        return out

# resid block. you dont need to understand this.
# basically this is what makes wavenet a wave
class ResidualBlock(nn.Module):
    def __init__(self, residual_channels, skip_channels, kernel_size, dilation):
        super().__init__()
        self.filter_conv = CausalConv1d(residual_channels, residual_channels, kernel_size, dilation)
        self.gate_conv = CausalConv1d(residual_channels, residual_channels, kernel_size, dilation)
        self.res_conv = nn.Conv1d(residual_channels, residual_channels, kernel_size=1)
        self.skip_conv = nn.Conv1d(residual_channels, skip_channels, kernel_size=1)

    def forward(self, x):
        # x: (batch, res, timesteps)
        f = self.filter_conv(x)
        g = self.gate_conv(x)
        # gated activation unit
        out = torch.tanh(f) * torch.sigmoid(g)
        skip = self.skip_conv(out)
        res = self.res_conv(out)
        res = res + x  # residual connection
        return res, skip


# ----------------------
# WaveNet model
# ----------------------
class WaveNet(nn.Module):
    def __init__(self, n_quantize=MU + 1, residual_channels=32, skip_channels=64,
                 kernel_size=2, dilations=None):
        super().__init__()
        if dilations is None:
            # this is the wavenet patter youll see in the imgaes
            dilations = [1, 2, 4, 8, 16, 32] * 2

        self.n_quantize = n_quantize
        self.residual_channels = residual_channels
        self.skip_channels = skip_channels

        # embed input
        self.embedding = nn.Embedding(n_quantize, residual_channels)
        self.causal_in = CausalConv1d(residual_channels, residual_channels, kernel_size=1)

        self.res_blocks = nn.ModuleList([
            ResidualBlock(residual_channels, skip_channels, kernel_size, d)
            for d in dilations
        ])

        # post-processing
        # just stacking conv and relu layers
        self.relu = nn.ReLU()
        self.post1 = nn.Conv1d(skip_channels, skip_channels, kernel_size=1)
        self.post2 = nn.Conv1d(skip_channels, n_quantize, kernel_size=1)

    def forward(self, x):
        # x: (batch, time)
        x = self.embedding(x).permute(0, 2, 1).contiguous()
        x = self.causal_in(x)

        skip_sum = 0
        for block in self.res_blocks:
            x, skip = block(x)
            skip_sum = skip_sum + skip if isinstance(skip_sum, torch.Tensor) else skip

        out = self.relu(skip_sum)
        out = self.post1(out)
        out = self.relu(out)
        out = self.post2(out)
        return out


# ----------------------
# Dataset helper (takes raw waveform arrays already quantized)
# ----------------------
class WaveDataset(torch.utils.data.Dataset):
    def __init__(self, quantized_wave, seq_len):
        if isinstance(quantized_wave, torch.Tensor):
            self.data = quantized_wave.long()
        elif isinstance(quantized_wave, np.ndarray):
            self.data = torch.tensor(quantized_wave, dtype=torch.long)
        else:
            self.data = torch.tensor(np.array(quantized_wave), dtype=torch.long)

        self.seq_len = seq_len

    def __len__(self):
        return len(self.data) - self.seq_len

    def __getitem__(self, idx):
        x = self.data[idx:idx + self.seq_len]
        y = self.data[idx + 1:idx + self.seq_len + 1]
        return x, y

# ----------------------
# Simple autoregressive generation (slow, sample-by-sample)
# ----------------------
@torch.no_grad()
def generate(model, device, initial_sequence, gen_len=16000, temperature=1.0):
    """
    initial_sequence: 1D tensor of quantized values to prime the model (length L)
    returns quantized waveform of length L + gen_len (torch.LongTensor)
    """
    model.eval()
    generated = initial_sequence.clone().long().to(device)
    for _ in range(gen_len):
        # feed last chunk (the model is fully convolutional; feeding whole sequence is fine but grows)
        x = generated[-1024:].unsqueeze(0)  # (1, T)
        logits = model(x)  # (1, n_quantize, T)
        logits_last = logits[:, :, -1].squeeze(0)  # (n_quantize)
        probs = F.softmax(logits_last / max(1e-8, temperature), dim=-1)
        sample = torch.multinomial(probs, 1)
        generated = torch.cat([generated, sample.squeeze(0)])
    return generated.cpu()


In [None]:
sr = 8000
sec = 5
t = np.linspace(0, sec, sr * sec, endpoint=False)
sine = 0.6 * np.sin(2 * np.pi * 220 * t)  # 220 Hz tone
quant = mu_law_encode(sine)  # ints 0..255


seq_len = 512
dataset = WaveDataset(quant, seq_len=seq_len)
dl = DataLoader(dataset, batch_size=8, shuffle=True, drop_last=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = WaveNet(n_quantize=MU + 1, residual_channels=32, skip_channels=64).to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)


In [None]:
# Training loop
for epoch in range(5):
    model.train()
    total_loss = 0.0
    criterion = nn.CrossEntropyLoss()
    for x, y in dl:
        x = x.to(device)
        y = y.to(device)
        logits = model(x)
        B, C, T = logits.shape
        loss = criterion(logits.permute(0, 2, 1).reshape(B * T, C),
                         y.reshape(B * T))
        opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()
        total_loss += loss.item() * B
    loss =  total_loss / len(dl.dataset)

    print(f"Epoch {epoch + 1} loss: {loss:.4f}")

KeyboardInterrupt: 

In [None]:
# generate 1s of audio primed with a short seed
seed = torch.from_numpy(quant[:256]).long()
out_quant = generate(model, device, seed, gen_len=sr)
out_wave = mu_law_decode(out_quant.numpy())
# save or listen using soundfile or scipy (not included here)
print("Generated waveform (first 10 samples):", out_wave[:10])