In [22]:
from torchaudio.transforms import MuLawEncoding
import torch
import torch.nn as nn
from net import *

In [23]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [24]:
def mu_law(audio: torch.Tensor):
    mu = MuLawEncoding(quantization_channels=256)
    quantized = mu(audio)  # shape: (1, T), dtype: long
    return quantized

In [25]:

def create_windows(seq: torch.Tensor, window_size=512):
    examples = []
    targets = []
    for i in range(len(seq) - window_size):
        x = seq[i: i + window_size]      # input
        y = seq[i + window_size]      # target
        examples.append(x)
        targets.append(y)
    return torch.stack(examples), torch.stack(targets)  # (N, T-1)

In [26]:
sines = generate_sine_wave()
mu_sines = mu_law(sines)
examples, targets = create_windows(mu_sines)
examples = examples.to(device)
targets = targets.to(device)

b = int(targets.shape[0] * 0.8)
train_examples = examples[b:]
test_examples = examples[:b]

train_targets = targets[b:]
test_targets = targets[:b]

In [27]:
model = WaveNet(
    emb_dim=64, 
    num_embs=256,
    channels=(64, 64, 64, 64),
    kernel_size=(2, 2),
    skip_size=64,
    out_size=(64, 128, 256),
    )
model.to(device)

WaveNet(
  (emb): Emb(
    (emb): Embedding(256, 64)
  )
  (blocks): ModuleList(
    (0): WaveNetBlock(
      (filter_conv): DilatedCausalConv1d(
        (conv): Conv1d(64, 64, kernel_size=(2,), stride=(1,))
      )
      (gated_conv): DilatedCausalConv1d(
        (conv): Conv1d(64, 64, kernel_size=(2,), stride=(1,))
      )
      (gated_act): GatedActivation(
        (gate_act): Sigmoid()
        (filter_act): Tanh()
      )
      (out_proj): Conv1x1(
        (conv): Conv1d(64, 64, kernel_size=(1,), stride=(1,))
      )
      (skip_proj): Conv1x1(
        (conv): Conv1d(64, 64, kernel_size=(1,), stride=(1,))
      )
    )
    (1): WaveNetBlock(
      (filter_conv): DilatedCausalConv1d(
        (conv): Conv1d(64, 64, kernel_size=(2,), stride=(1,), dilation=(2,))
      )
      (gated_conv): DilatedCausalConv1d(
        (conv): Conv1d(64, 64, kernel_size=(2,), stride=(1,), dilation=(2,))
      )
      (gated_act): GatedActivation(
        (gate_act): Sigmoid()
        (filter_act): Tanh(

In [28]:
def adam(model: nn.Module, loss_func: callable, xs: torch.Tensor, ys: torch.Tensor, lr=0.1, lambda_=1e-5, batch_size: int = 30, steps: int = 1000):
    """
    Perform adam for model
    """
    lossi = []
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=lambda_)
    for step in range(steps):
        idx = torch.randint(0, len(xs), (batch_size,))
        x_batch, y_batch = xs[idx], ys[idx]

        optimizer.zero_grad()
        preds = model(x_batch)
        loss = loss_func(preds, y_batch)

        loss.backward()
        optimizer.step()

        lossi.append(loss.item())

        if step % 100 == 0:
            print(f"Loss: {loss.item()} on step: {step + 1}")
    return lossi