In [None]:
from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger

from tqdm.notebook import tqdm
import numpy as np
import pandas as pd
import glob
import pretty_midi
import random
import matplotlib.pyplot as plt

import wandb

### Building WaveNet model

#### Causal Convolution

In [2]:
class CausalConv(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int, dilation: int = 1, **kwargs):
        """
        Implementation of Causal Convolution 1d, computes 1d Convolution with mask so that values are only influenced by preceeding values

        :param int in_channels: input channels
        :param int out_channels: output channels
        :param int kernel_size: size of filter kernel
        :param int dilation: dilation of kernel, defaults to 1 <- no dilation
        """
        super().__init__()

        # calculating same padding based on kernel_size and dilation
        padding = dilation * (kernel_size-1) // 2

        # creating mask
        mask = torch.ones(kernel_size)
        mask[kernel_size//2+1:] = 0
        self.register_buffer("mask", mask[None])

        self.conv1d = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation)

    def forward(self, x: torch.Tensor):
        # applying mask to filter weights
        self.conv1d.weight.data *= self.mask

        return self.conv1d(x)

#### Gated Residual Conv Block

In [3]:
class GatedResidualConvBlock(nn.Module):
    def __init__(self, in_channels: int, hidden_channels: int, out_channels: int, kernel_size: int, dilation: int):
        """
        Implementation of Gated Residual Conv Block with Causal Convolution Layers 

        :param int in_channels: input channels
        :param int hidden_channels: intermediate channels between CasusalConv layer and Conv1x1
        :param int out_channels: output channels
        :param int kernel_size: size of filter kernel
        :param int dilation: dilation of kernel
        """
        super().__init__()

        self.dilated_conv = CausalConv(in_channels, 2*hidden_channels, kernel_size, dilation=dilation, bias=False)
        self.gn = nn.GroupNorm(num_groups=32, num_channels=2*hidden_channels)
        self.conv_1x1 = nn.Conv1d(hidden_channels, out_channels, kernel_size=1)

    def forward(self, x: torch.Tensor):
        # applying causal conv
        dilated = self.dilated_conv(x)

        dilated = self.gn(dilated)

        # spliting channel dim into value and gate
        value, gate = dilated.chunk(2, dim=1)

        # gate
        gated_value = torch.tanh(value) * torch.sigmoid(gate)

        # output conv
        output = self.conv_1x1(gated_value)
        # residual connection
        residual_output = output + x

        # output of residual connection and output of skip connection
        return residual_output, output

#### Gated Conv Stack

In [4]:
class GatedConvStack(nn.Module):
    def __init__(self, in_channels: int, hidden_channels: int, out_channels: int, kernel_size: int, num_residual_blocks: int):
        """
        Stack of Gated Residual Conv Blocks with dilation doubling at each step

        :param int in_channels: input channels
        :param int hidden_channels: intermediate channels
        :param int out_channels: output channels
        :param int kernel_size: size of filter kernels
        :param int num_residual_blocks: num of conv blocks in stack
        """
        super().__init__()

        # generating dilations -> 1, 2, 4, 8, 16, ...
        dilations = [2**i for i in range(num_residual_blocks)]

        self.conv_stack = nn.ModuleList(
            [GatedResidualConvBlock(in_channels, hidden_channels, out_channels, kernel_size, dilations[i]) for i in range(num_residual_blocks)]
        )

    def forward(self, x: torch.Tensor):
        skip_connections = []

        for layer in self.conv_stack:
            x, skip_connection = layer(x)

            skip_connections.append(skip_connection)

        # residual connection to next conv block, sum of skip connections from stack
        return x, torch.stack(skip_connections, dim=-1).sum(dim=-1)

#### WaveNet

In [5]:
class WaveNet(nn.Module):
    def __init__(self, in_channels: int, hidden_channels: int, out_channels: int, kernel_size: int, num_stacks: int, num_residual_blocks_in_stack: int):
        """
        Implementation of WaveNet Architecture

        :param int in_channels: input channels
        :param int hidden_channels: intermediate channels
        :param int out_channels: output channels
        :param int kernel_size: size of filter kernels
        :param int num_stacks: num of stacks
        :param int num_residual_blocks_in_stack: num of gated residual conv blocks in each stack
        """
        super().__init__()

        self.causal_conv = CausalConv(in_channels, hidden_channels, kernel_size)
        self.gated_conv_stacks = nn.ModuleList(
            [GatedConvStack(hidden_channels, hidden_channels, hidden_channels, kernel_size, num_residual_blocks_in_stack) for _ in range(num_stacks)]
        )
        self.output_block = nn.Sequential(
            nn.ReLU(),
            nn.Conv1d(hidden_channels, hidden_channels, 1, 1, 0),
            nn.ReLU(),
            nn.Conv1d(hidden_channels, out_channels, 1, 1, 0)
        )

    def forward(self, x: torch.Tensor):
        skip_connections = []
        
        x = self.causal_conv(x)

        for stack in self.gated_conv_stacks:
            x, skip_connection = stack(x)

            skip_connections.append(skip_connection)

        # sum of all skip connection outputs
        output = torch.stack(skip_connections, dim=-1).sum(dim=-1)

        return self.output_block(output)

#### MelodyWaveNet

In [6]:
class OutputBlock(nn.Module):
    def __init__(self, channels: int, num_classes: int):
        """
        Output Block for Melody WaveNet

        :param int channels: input channels
        :param int num_classes: number of output classes
        """
        super().__init__()

        self.output = nn.Sequential(
            CausalConv(channels, channels, kernel_size=5, dilation=1, bias=False),
            nn.GroupNorm(32, channels),
            nn.GELU(),
            CausalConv(channels, channels, kernel_size=5, dilation=2, bias=False),
            nn.GroupNorm(32, channels),
            nn.GELU(),
            nn.Conv1d(channels, num_classes, kernel_size=1)
        )

    def forward(self, x: torch.Tensor):
        return self.output(x)

class MelodyWaveNet(nn.Module):
    def __init__(self, embedding_dims: int, channels: int, kernel_size: int, num_stacks: int, num_residual_blocks_in_stack: int):
        """
        Implementation of WaveNet for generating midi sequences

        :param int embedding_dims: embedding dimension for each feature
        :param int channels: intermediate channels
        :param int out_channels: output channels
        :param int kernel_size: size of filter kernels
        :param int num_stacks: num of stacks
        :param int num_residual_blocks_in_stack: num of gated residual conv blocks in each stack
        """
        super().__init__()

        # embeddings for features, pitch and velocity have values between 0-128
        # duration and step between 0-7 corresponding to different note lengths -> 0 - 0 length, 1 - 32th note, 2 - 16th note, 3 - 8th note, etc.
        # for simplicity dotted notes and triplets are omitted
        self.embedding_pitch = nn.Embedding(num_embeddings=129, embedding_dim=embedding_dims, padding_idx=0)
        self.embedding_velocity = nn.Embedding(num_embeddings=129, embedding_dim=embedding_dims, padding_idx=0)
        self.embedding_duration = nn.Embedding(num_embeddings=8, embedding_dim=embedding_dims, padding_idx=0)
        self.embedding_step = nn.Embedding(num_embeddings=8, embedding_dim=embedding_dims, padding_idx=0)

        self.wavenet = WaveNet(4*embedding_dims, channels, channels, kernel_size, num_stacks, num_residual_blocks_in_stack)

        # output shape [batch_size, num_classes, seq_len]
        self.output_pitch = OutputBlock(channels, num_classes=129)
        self.output_velocity = OutputBlock(channels, num_classes=129)
        self.output_duration = OutputBlock(channels, num_classes=8)
        self.output_step = OutputBlock(channels, num_classes=8)

    def forward(self, pitch: torch.Tensor, velocity: torch.Tensor, duration: torch.Tensor, step: torch.Tensor):
        # pitch, velocity, duration and step have shapes of [batch_size, seq_len]

        # embedding each feature, shapes after embedding -> [batch_size, seq_len, embedding_dim]
        pitch_embed = self.embedding_pitch(pitch)
        velocity_embed = self.embedding_velocity(velocity)
        duration_embed = self.embedding_duration(duration)
        step_embed = self.embedding_step(step)

        # concatenating features on embedding dim -> [batch_size, seq_len, 4*embedding_dim]
        # permuting so embedding is counted as channels -> [batch_size, 4*embedding_dim, seq_len]
        features = torch.cat([pitch_embed, velocity_embed, duration_embed, step_embed], dim=-1).permute(0, 2, 1)

        # passing through WaveNet
        x = self.wavenet(features)

        # shapes after output layers -> [batch_size, num_classes, seq_len]
        pitch = self.output_pitch(x)
        velocity = self.output_velocity(x)
        duration = self.output_duration(x)
        step = self.output_step(x)

        return pitch, velocity, duration, step

### Lightning Wrapper for model

In [7]:
class MelodyWaveNetLit(pl.LightningModule):
    def __init__(
                self, embedding_dims: int, channels: int, kernel_size: int, 
                num_stacks: int, num_residual_blocks_in_stack: int, lr: float, loss_lambdas: List[int], l2: float
                ):
        super().__init__()

        self.save_hyperparameters()

        # model
        self.melody_wavenet = MelodyWaveNet(embedding_dims, channels, kernel_size, num_stacks, num_residual_blocks_in_stack)

    def forward(self, pitch: torch.Tensor, velocity: torch.Tensor, duration: torch.Tensor, step: torch.Tensor):
        return self.melody_wavenet(pitch, velocity, duration, step)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.melody_wavenet.parameters(), lr=self.hparams.lr, betas=(0.9, 0.999), weight_decay = self.hparams.l2)
        lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 30, gamma=0.1)

        return optimizer, lr_scheduler

    def training_step(self, batch, batch_idx):
        # split features
        # shapes [N, seq_len]
        pitch, velocity, duration, step = batch

        # shapes [N, num_classes, seq_len]
        pitch_pred, velocity_pred, duration_pred, step_pred = self.forward(pitch[:,:,:-1], velocity[:,:,:-1], duration[:,:,:-1], step[:,:,:-1])

        nll_pitch = F.cross_entropy(pitch_pred, pitch[:,:,1:])
        nll_velocity = F.cross_entropy(velocity_pred, velocity[:,:,1:])
        nll_duration = F.cross_entropy(duration_pred, duration[:,:,1:])
        nll_step = F.cross_entropy(step_pred, step[:,:,1:])

        loss = (
             self.hparams.loss_lambdas[0] * nll_pitch + 
             self.hparams.loss_lambdas[1] * nll_velocity + 
             self.hparams.loss_lambdas[2] * nll_duration + 
             self.hparams.loss_lambdas[3] * nll_step
        )

        self.log_dict(
            {
                "val/loss_pitch": nll_pitch,
                "val/loss_velocity": nll_velocity,
                "val/loss_duration": nll_duration,
                "val/loss_step": nll_step,
                "val/total_loss": loss
            }
        )

        return loss

    def validation_step(self, batch, batch_idx):
        pitch, velocity, duration, step = batch

        # shapes [N, num_classes, seq_len]
        pitch_pred, velocity_pred, duration_pred, step_pred = self.forward(pitch[:,:,:-1], velocity[:,:,:-1], duration[:,:,:-1], step[:,:,:-1])

        nll_pitch = F.cross_entropy(pitch_pred, pitch[:,:,1:])
        nll_velocity = F.cross_entropy(velocity_pred, velocity[:,:,1:])
        nll_duration = F.cross_entropy(duration_pred, duration[:,:,1:])
        nll_step = F.cross_entropy(step_pred, step[:,:,1:])

        loss = (
             self.hparams.loss_lambdas[0] * nll_pitch + 
             self.hparams.loss_lambdas[1] * nll_velocity + 
             self.hparams.loss_lambdas[2] * nll_duration + 
             self.hparams.loss_lambdas[3] * nll_step
        )

        self.log_dict(
            {
                "val/loss_pitch": nll_pitch,
                "val/loss_velocity": nll_velocity,
                "val/loss_duration": nll_duration,
                "val/loss_step": nll_step,
                "val/total_loss": loss
            }
        )

### Dataset

In [8]:
class MidiDataset(Dataset):
    def __init__(self, midi_file_list: list, seq_len: int):
        self.midi_file_list = midi_file_list
        self.seq_len = seq_len

    def __len__(self):
        return len(self.midi_file_list)
    
    def __getitem__(self, index):
        # loading data
        data = pd.read_csv(self.midi_file_list[index]).to_numpy(dtype=np.int64)

        # grab random starting index
        start_idx = random.randint(0, len(data)-self.seq_len-1)
        # get slice of data
        data = data[start_idx:start_idx+self.seq_len]

        # padding values
        data_len = len(data)
        if data_len < self.seq_len:
            data = np.pad(data, ((0, self.seq_len-data_len), (0, 0)))

        data_torch = torch.from_numpy(data)

        # returning pitch, velocity, duration and step
        return data_torch[:, 0], data_torch[:, 1], data_torch[:, 2], data_torch[:, 3]

#### Hyperparameters

In [17]:
config = dict(
    epochs=100,
    batch_size=32,
    lr=3e-4,
    l2=0.01,
    loss_lambdas=[2.0, 1.0, 1.0, 1.0],
    seq_len = 512
)

#### Initializing model

In [19]:
# model
model = MelodyWaveNetLit(
    embedding_dims=32, channels=128, kernel_size=5, num_stacks=3, num_residual_blocks_in_stack=4, 
    lr=config["lr"], loss_lambdas=config["loss_lambdas"], l2=config["l2"]
)

In [None]:
def init(model: nn.Module):
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.xavier_uniform_(m.weight.data)
            # nn.init.xavier_uniform_(m.bias.data)
        if isinstance(m, nn.Embedding):
            nn.init.xavier_uniform_(m.weight.data)

#### Training pipeline

In [None]:
def training_pipeline(model, config, train_filepaths, val_filepaths):
    # logger
    logger = WandbLogger(name="melody_wavenet", project="melody_wavenet")

    # setting up trainer
    trainer = pl.Trainer(logger=logger, log_every_n_steps=50, accelerator='gpu', devices=-1, max_epochs=config["epochs"], precision=16)

    # datasets
    train_set = MidiDataset(train_filepaths, seq_len=config["seq_len"])
    val_set = MidiDataset(val_filepaths, seq_len=config["seq_len"])

    # dataloaders
    train_loader = DataLoader(train_set, batch_size=config["batch_size"], shuffle=True)
    val_loader = DataLoader(val_set, batch_size=config["batch_size"], shuffle=True)

    # training
    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

    torch.save(model.melody_wavenet, "melody_generator.pt")

    wandb.finish()

#### Spliting data

In [None]:
# randomly shuffling data
data_filepaths = glob.glob("../../wavenet/extracted_data/*")
random.shuffle(data_filepaths)

In [3]:
# there are about 1200 files, so 1000 go to train and the rest goes to validation
train_file_list = data_filepaths[:1000]
val_file_list = data_filepaths[1000:]

#### Running training

In [None]:
training_pipeline(model, config, train_file_list, val_file_list)

#### Melody Generator

In [7]:
class MelodyGenerator:
    def __init__(self, model: nn.Module, temperature: list):
        self.model = model.eval()
        self.temperature = temperature

    @torch.no_grad()
    def generate_sequence(self, sequence_len: int, seq: np.ndarray = None):
        # sequence shape: [1, seq_len]
        if seq is None:
            pitch, velocity, duration, step = (torch.zeros((1, sequence_len), dtype=torch.long) for _ in range(4))
        else:
            # generate tensors with conditioned values and padded to seq_len
            pitch, velocity, duration, step = (torch.LongTensor(np.pad(seq[i], (0, sequence_len-len(seq[i])))).unsqueeze(0) for i in range(4))

        start_idx = 0 if seq is None else len(seq[0])

        for i in tqdm(range(start_idx, sequence_len)):
            
            # passing only previous values to speed up computation, out shape [batch_size, classes, seq_len]
            pitch_pred, velocity_pred, duration_pred, step_pred = self.model(pitch[:,:i], velocity[:,:i], duration[:,:i], step[:,:i])

            pitch[:,i] = self._sample_with_temperature(pitch_pred[:,:,-1], self.temperature[0])
            velocity[:,i] = self._sample_with_temperature(velocity_pred[:,:,-1], self.temperature[1])
            duration[:,i] = self._sample_with_temperature(duration_pred[:,:,-1], self.temperature[2])
            step[:,i] = self._sample_with_temperature(step_pred[:,:,-1], self.temperature[3])

        return torch.cat([pitch, velocity, duration, step], dim=0)
        
    def _sample_with_temperature(self, values: torch.Tensor, temperature: float):
        predictions = values / temperature
        predictions[:, 0] = -9999

        probabilities = torch.softmax(predictions, dim=1)

        return torch.multinomial(probabilities, num_samples=1)

#### Encoding generated data to MIDI

In [8]:
class Notes2Midi:
    def __init__(self):
        # mapping to convert idx to note length
        self.mapping = {
            1: 0.0,
            2: 0.125,
            3: 0.25,
            4: 0.5,
            5: 1.0,
            6: 2.0,
            7: 4.0
        }

    def save_sequence_as_midi(self, sequence: np.ndarray, tempo: float, save_path: str):
        pm = pretty_midi.PrettyMIDI()
        instrument = pretty_midi.Instrument(program=0)

        # calculate quarter note length based on tempo
        quarter_note_len = 60.0 / tempo

        previous_start = 0.0

        pitch, velocity, duration, step = sequence[0], sequence[1], sequence[2], sequence[3]

        for p, v, d, s in zip(pitch, velocity, duration, step):

            # calulate note start and note end
            start = previous_start + self.mapping[s] * quarter_note_len
            end = start + self.mapping[d] * quarter_note_len

            note = pretty_midi.Note(velocity=v-1, pitch=p-1, start=start, end=end)
            instrument.notes.append(note)

            previous_start = start

        pm.instruments.append(instrument)
        pm.write(save_path)

In [12]:
model = torch.load("melody_generator_3.pt").to("cpu")

In [13]:
mg = MelodyGenerator(model, temperature=[1,1,1,1])
n2m = Notes2Midi()

In [17]:
gen_seq = mg.generate_sequence(512).numpy()

  0%|          | 0/480 [00:00<?, ?it/s]

In [18]:
n2m.save_sequence_as_midi(gen_seq, 120, "generated/14.midi")