In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchaudio
import numpy as np
import seaborn; seaborn.set_style("darkgrid")
from tqdm.notebook import tqdm

# Control panel:

In [None]:
# Define hyper-parameters
batch_size = 10
learning_rate = 0.005
training_iterations = 200
dilation_depth = 6
sample_length = 100
generate_length = 2 * sample_length
bins = 20

# Model:

In [None]:
class Wavenet(nn.Module):
    def __init__(self, quantization_bins, channels, dilation_depth, blocks):
        super(Wavenet, self).__init__()

        """ Part 1: Define model parameters """""
        self.C = channels
        self.kernel_size = 2
        self.bins = quantization_bins
        self.dilations = [2 ** i for i in range(dilation_depth)] * blocks

        """ Part 2: Define model layers """
        self.pre_process_conv = nn.Conv1d(in_channels=self.bins, out_channels=self.C, kernel_size=1)
        self.causal_layers = nn.ModuleList()

        for d in self.dilations:
            self.causal_layers.append(ResidalLayer(in_channels=self.C, out_channels=self.C, dilation=d, kernel_size=self.kernel_size))

        self.post_process_conv1 = nn.Conv1d(self.C, self.C, kernel_size=1)
        self.post_process_conv2 = nn.Conv1d(self.C, self.bins, kernel_size=1)

    def forward(self, x):
        """ Function: Makes the forward pass/model prediction
            Input: Mu- and one-hot-encoded waveform. The shape of the input is (batch_size, quantization_bins, samples).
                   It is important that 'x' has at least the length of the models receptive field.
            Output: Distribution for prediction of next sample. Shape (batch_size, quantization_bins, what's left after
                    dilation, should be 1 at inference) """

        """ Part 1: Through pre-processing layer """
        x = self.pre_process_conv(x)

        """ Part 2: Through stack of dilated causal convolutions """
        skips, skip = [], None

        for layer in self.causal_layers:
            x, skip = layer(x)

            # Save skip connection results
            skips.append(skip)

        """ Part 3: Post processes (-softmax) """
        # Add skip connections together
        x = sum([s[:, :, -skip.size(2):] for s in skips])

        # Do the rest of the preprocessing
        x = F.relu(x)
        x = self.post_process_conv1(x)  # shape --> (batch_size, channels, samples)
        x = F.relu(x)
        x = self.post_process_conv2(x)  # shape --> (batch_size, quantization_bins, samples)

        return x

In [None]:
class ResidalLayer(nn.Module):
    """ Class description: This class is a sub-model of a residual layer """

    def __init__(self, in_channels:int, out_channels:int, kernel_size:int, dilation:int):
        super(ResidalLayer, self).__init__()

        """ Part 1: Define model parameters """
        self.dilation = dilation

        """ Part 2: Define model layers """
        # The original original WaveNet paper used a single shared 1x1 conv for both filter (f) and gate (g).
        # Instead we use one for each here i.e. conv_f and conv_g.
        self.conv_f = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, dilation=dilation)
        self.conv_g = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, dilation=dilation)

        # 1 shared 1x1 convolution
        self.conv_1x1 = nn.Conv1d(out_channels, out_channels, kernel_size=1)

    def forward(self, x):
        # Send through gate
        f = torch.tanh(self.conv_f(x))
        g = torch.sigmoid(self.conv_g(x))
        z = f * g

        # Save for skip connection
        skip = self.conv_1x1(z)

        # Save residual as input to next layer residual layer
        residual = x[:, :, self.dilation:] + skip

        return residual, skip

# Helper functions:

In [None]:
def one_hot(input, bins=256):
    """ Function: Makes a one-hot-encoding
        :param input: tensor (batch_size, samples)
        :param bins: quantization_bins
        :return: tensor of shape (batch_size, samples, quantization_bins). """

    as_int64 = input.long()
    one_hotted = F.one_hot(as_int64, num_classes=bins)
    correct_dims = one_hotted.transpose(1,2)
    as_float32 = correct_dims.float()

    return as_float32

def make_data(amount=128):
    """ Function: Create randomly phase-shifted sinusoids """

    X = torch.linspace(0, 2 * np.pi, amount)

    return torch.sin(X + 2*np.pi * torch.rand(1)).unsqueeze(0)

# Generator:

In [None]:
def generate(seed: torch.tensor, amount: int, model, bins: int):
    """ Function: Inefficinet - but intutitive - implemenetation of sample generation
        :param seed: start the generation process, must be at least the size of the model's receptive field. Shape (1, samples)
        :param amount: number of samples to generate
        :return: list with [seed + generated_samples] """

    model.eval()
    temp = seed[0].tolist()
    next_point = -sum(model.dilations) - 1

    for n in range(amount):
        input = torch.tensor(temp[next_point:]).long()
        input = one_hot(input.unsqueeze(0), bins)
        predictions = model(input)
        predictions = torch.softmax(predictions, dim=1)

        max_index = torch.multinomial(predictions[0, :, 0], 1).squeeze()
        temp.append(max_index.item())

    return temp

# Train and test:

In [None]:
# Create model
model = Wavenet(quantization_bins=bins, channels=32, dilation_depth=dilation_depth, blocks=1)
model.train()

# Mu law encoding
mu = torchaudio.transforms.MuLawEncoding(quantization_channels=bins)

# Optimizer and loss
optim = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = torch.nn.CrossEntropyLoss()
losses = []

# Training loop
for i in tqdm(range(training_iterations)):
    # Create training data
    raw_data = [make_data(sample_length) for _ in range(batch_size)]
    raw_data = torch.cat(raw_data, dim=0)

    # Mu-law- and one-hot-encode data
    y_true = mu(raw_data)
    x = one_hot(y_true, bins=bins)

    # Prediction and loss
    y_preds = model(x)
    loss = criterion(y_preds[:, :, :-1], y_true[:, -y_preds.size(2)+1:])

    # Updates
    losses.append(loss.item())
    loss.backward()
    optim.step()
    optim.zero_grad()

# Plot results:

In [1]:
_, (ax1, ax2) = plt.subplots(2,1, figsize=(12,8))

ax1.plot(losses, label="Cross entropy loss")
ax1.legend()

seed = mu(make_data(sample_length))
gen = generate(seed=seed, amount=generate_length , model=model, bins=bins)
ax2.plot(seed.squeeze(), '.-', label="Seed")
ax2.plot(np.arange(sample_length, len(gen)), gen[sample_length:], '.-', label="Generated")
ax2.legend()
plt.show()

NameError: name 'plt' is not defined