# Load dataset
#### Generated from preprocess.py

In [1]:
import pickle as pkl
with open("dataset.pkl", "rb") as fr:
    dataset = pkl.load(fr)
print(f"len(dataset) = {len(dataset):,}")

len(dataset) = 3,897


# model.py

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self, lstm_first, lstm_second, latent_Wmu, latent_Wsig):
        super().__init__()
        self.lstm_first = nn.LSTM(**lstm_first)
        self.lstm_second = nn.LSTM(**lstm_second)
        self.latent_Wmu = nn.Linear(**latent_Wmu)
        self.latent_Wsig = nn.Linear(**latent_Wsig)
    
    def forward(self, x):
        B = x.shape[0]
        o, (h_1, c) = self.lstm_first(x)
        o, (h, c) = self.lstm_second(o)
        h = h.permute(1, 0, 2).reshape(B, -1) # 2 B H > B 2 H > B (2H)
        mu = self.latent_Wmu(h)
        sig = F.softplus(self.latent_Wsig(h), threshold=5) # (7)

        return mu, sig

class Decoder(nn.Module):
    def __init__(self, 
                 W_cond_init_state, lstm_conductor, 
                 W_dec_init_state, lstm_decoder,
                 num_conductor_out, num_decoder_out,
                 token_embedding, W_dist_out,
                 sampling_strategy = "argmax",
                 training_strategy = "scheduled"):
        super().__init__()
        self.W_cond_init_state = nn.Sequential(nn.Linear(**W_cond_init_state), nn.Tanh())
        self.lstm_conductor = nn.LSTM(**lstm_conductor)
        self.W_dec_init_state = nn.Sequential(nn.Linear(**W_dec_init_state), nn.Tanh())
        self.lstm_decoder = nn.LSTM(**lstm_decoder)
        self.embedding = nn.Embedding(**token_embedding)
        self.dist_out_layer = nn.Sequential(nn.Linear(**W_dist_out), nn.Softmax())

        self.conductor_hidden_size = lstm_conductor["hidden_size"]
        self.num_conductor_out = num_conductor_out
        self.num_decoder_out = num_decoder_out
        self.sampling_strategy = sampling_strategy
        self.training_strategy = training_strategy

    def forward(self, z, x=None, eps=0.):
        conductor_in = self.W_cond_init_state(z)
        conductor_out = self.get_conductor_out(conductor_in)
        states = self.W_dec_init_state(conductor_out)
        if x is not None: # teacher forcing
            if self.training_strategy == "scheduled":
                out = self.decoder_rnn_ss(states, x, eps) # distributions
            else: # teacher forcing
                out = self.decoder_rnn_tf(states, x) # distributions
        else: # feedback generation
            out = self.decoder_rnn_fg(states) # indices
        return out

    def get_conductor_out(self, conductor_in):
        B, D = conductor_in.shape
        h = conductor_in[:, :D//2]
        h1, h2 = h[:, :D//4], h[:, D//4:]
        h = torch.stack((h1, h2), 0)
        c = conductor_in[:, D//2:]
        c1, c2 = c[:, :D//4], c[:, D//4:]
        c = torch.stack((c1, c2), 0)
        outs = []
        start = torch.zeros((B, 1, D//4), 
                            dtype=conductor_in.dtype,
                            device=conductor_in.device)
        o = start
        for i in range(self.num_conductor_out):
            o, (h, c) = self.lstm_conductor(o, (h, c))
            outs.append(o)
        out = torch.cat(outs, 1)
        return out
    
    def decoder_rnn_tf(self, decoder_in, x):
        B, U, D = decoder_in.shape # U is the number of bars
        Bx, L = x.shape 
        assert D % 4 == 0
        assert L % U == 0
        decoder_in = decoder_in.view(-1, D) # B U D > (BU) D
        x = x.view(Bx*U, L//U) # B L > (BU) (L/U) # (B*4) (64/4)
        h = decoder_in[:, :D//2]
        h1, h2 = h[:, :D//4], h[:, D//4:]
        h = torch.stack((h1, h2), 0)
        c = decoder_in[:, D//2:]
        c1, c2 = c[:, :D//4], c[:, D//4:]
        c = torch.stack((c1, c2), 0)
        outs = []
        emb_init = torch.zeros([Bx*U, 1, D//4],
                               dtype=decoder_in.dtype,
                               device=decoder_in.device)
        emb = self.embedding(x)
        emb = torch.cat((emb_init, emb[:, :-1]), 1)
        h_init = h.clone().sum(0)[:, None] # 2 (BU) D > (BU) 1 D # no info about sum
        emb = torch.cat((emb, h_init.repeat(1, emb.shape[1], 1)), -1)
        o, (h, c) = self.lstm_decoder(emb, (h, c))
        o = self.dist_out_layer(o)

        o_unfolded = o.reshape(B, L, -1)
        
        return o_unfolded

    def decoder_rnn_fg(self, decoder_in):
        B, U, D = decoder_in.shape # U is the number of bars
        assert D % 4 == 0
        decoder_in = decoder_in.view(-1, D) # B U D > (BU) D
        h = decoder_in[:, :D//2]
        h1, h2 = h[:, :D//4], h[:, D//4:]
        h = torch.stack((h1, h2), 0)
        c = decoder_in[:, D//2:]
        c1, c2 = c[:, :D//4], c[:, D//4:]
        c = torch.stack((c1, c2), 0)
        outs = []
        emb = torch.zeros([B*U, 1, D//4], 
                          dtype=decoder_in.dtype,
                          device=decoder_in.device)
        h_init = h.clone().sum(0)[:, None] # 2 (BU) D > (BU) 1 D # no info about sum
        for i in range(self.num_decoder_out):
            o_in = torch.cat((emb, h_init), -1)
            o, (h, c) = self.lstm_decoder(o_in, (h, c))
            o = self.dist_out_layer(o)
            emb, indices = self.sample_from_dist(o)
            outs.append(indices)
        out = torch.cat(outs, 1)
        out_unfolded = out.view(B, U*self.num_decoder_out, -1)
        return out_unfolded

    def decoder_rnn_ss(self, decoder_in, x, eps):
        # 1) with torch.no_grad(), get x_hat indices
        # 2) random mixing
        # 3) decoder_rnn_tf with mixedsample
        with torch.no_grad():
            hat_dist = self.decoder_rnn_tf(decoder_in, x)
            _, hat_indices = self.sample_from_dist(hat_dist)
            mixed_x = self.scheduled_sampling(x, hat_indices, eps)

        out = self.decoder_rnn_tf(decoder_in, mixed_x)
        return out

    def scheduled_sampling(self, x, x_hat, eps):
        B, L = x.shape
        device = x.device
        candidate = torch.cat((x, x_hat), -1) # 0: True, 1: hat
        coins = torch.empty(B, L, device=device).fill_(eps) # eps==0: all-True, eps==1: all-hat
        result = torch.bernoulli(coins).to(int)
        
        sampled = torch.gather(candidate, -1, result)
        return sampled

    def sample_from_dist(self, dist):
        if self.sampling_strategy == "argmax":
            indices = torch.argmax(dist, dim=-1)
            emb = self.embedding(indices)
            return emb, indices
        elif self.sampling_strategy == "multinomial":
            B, L, D = dist.shape
            dist_folded = dist.view(B*L, D)
            indices = torch.multinomial(dist_folded, 1) # (BL) 1
            indices = indices.view(B, L)
            emb = self.embdding(indices)
            return emb, indices


class MusicVAE(nn.Module):
    def __init__(self, embedding, encoder, decoder):
        super().__init__()
        self.emb = nn.Embedding(**embedding)
        self.enc = Encoder(**encoder)
        self.dec = Decoder(**decoder)
    
    def forward(self, x=None, z=None, mode="train", eps=0.5):
        if mode == "train":
            embed_out = self.emb(x)
            mu, sig = self.enc(embed_out)
            z_reparam = self.z_reparam(mu, sig)
            x_hat = self.dec(z_reparam, x=x) # x_hat: dist
        elif mode == "generate":
            x_hat = self.dec(z) # x_hat: indices(sampled from dist)
        else: raise TypeError("Unknown mode")
        
        out = x_hat if mode == "generate" else (x_hat, mu, sig)
        return out

    def z_reparam(self, mu, sig):
        eps = torch.randn_like(mu, requires_grad=False)
        
        z_reparam = mu + sig*eps
        return z_reparam

# For Test/Debug

* If you do not want to test model, do "test_for_debug := False"

In [3]:
if test_for_debug := False:
    device = torch.device("cuda") if torch.cuda.is_available else torch.device("cpu")
    test_x = torch.randint(128, [10, 256], device=device)
    test_z = torch.randn([10, 512], device=device)
    %load_ext autoreload
    %autoreload 2
    from config import model_config
    model = MusicVAE(**model_config).to(device)
    out_tf, mu, sig = model(x = test_x, eps=0.5)
    out_fg = model(z = test_z, mode="generate")

# Dataloader, Model, Optimizer, LRscheduler

In [4]:
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import ExponentialLR
%load_ext autoreload
%autoreload 2
from config import model_config

def col_fn(batch_list):
    batch = torch.stack([torch.tensor(item) for item in batch_list])
    return batch

device = torch.device("cuda") if torch.cuda.is_available else torch.device("cpu")
model = MusicVAE(**model_config)
dl = DataLoader(dataset, batch_size=64, shuffle=True, collate_fn=col_fn) # b 512 but
opt = Adam(model.parameters(), lr=1e-3)
lr_sch = ExponentialLR(opt, gamma=0.9999) 
lr_sch.last_epoch = 46049 # 5.1. 3. 1 (0.9999**n==0.01)



# Loss function

In [5]:
from torch.distributions.normal import Normal
from torch.distributions.kl import kl_divergence

def recon_loss(x_hat, x):
    return F.nll_loss(torch.log(x_hat).permute(0,2,1), x)

def kl_loss(mu, sigma, beta, free_bits): # 5.2. 1. 2
    device = mu.device
    mu_p  = torch.tensor([0.], device=device)
    sig_p = torch.tensor([1.], device=device)
    p_dist = Normal(mu_p, sig_p)
    q_dist = Normal(mu, sigma)

    kl_div = kl_divergence(q_dist, p_dist)
    free_bits_tensor = torch.tensor([free_bits], device=device)
    zero = torch.tensor([0.], device=device)
    kl_loss = - beta * torch.max(torch.mean(kl_div)-free_bits, zero)

    return kl_loss

def loss_fn(x_hat, x, mu, sigma, beta, free_bits=48.):
    reconstruction_loss = recon_loss(x_hat, x)
    kl_div_loss = kl_loss(mu, sigma, beta, free_bits)

    return reconstruction_loss + kl_div_loss




In [6]:
import math
# beta for beta-VAE 
def get_beta(idx):
    return 2. - 2.0 * 0.99999 ** idx # 5.2. 1. 2

# epsilon for scheduled sampling coefficient
def get_epsilon(idx, k_rate=2000):
    # inverse sigmoid
    eps = k_rate / (k_rate+math.exp(idx/k_rate))
    return eps

In [7]:
if is_train := False:
    model = model.to(device)
    idx = 0
    for e in range(10):
        model.train()
        for x in dl:
            opt.zero_grad()
            x = x.to(device)
            eps = get_epsilon(idx)
            x_hat, mu, sigma = model(x, mode="train", eps=eps)

            beta = get_beta(idx)
            loss = loss_fn(x_hat, x, mu, sigma, beta)
            loss.backward()
            opt.step()
            lr_sch.step()
            idx += 1
            print(f"\riter: {idx:,} / loss = {loss.item():.3f}", end='')
    torch.save(model.state_dict(), "checkpoint.ckpt")

In [8]:
if is_generation := True:
    ckpt_path = "drive/MyDrive/Colab Notebooks/1850.ckpt"
    model.load_state_dict(torch.load(ckpt_path))
    model = model.to(device)
    model.eval()
    with torch.no_grad():
        z = torch.randn([1, 512], device=device)
        generated = model(z=z, mode="generate")

  input = module(input)


In [9]:
generated = generated.detach().cpu().numpy()

In [10]:
print(generated[0, :, 0].tolist())

[384, 192, 384, 192, 192, 384, 192, 128, 80, 272, 192, 384, 80, 384, 192, 128, 328, 17, 336, 273, 465, 273, 65, 273, 273, 81, 273, 81, 17, 17, 273, 273, 325, 256, 5, 69, 161, 64, 161, 69, 225, 261, 352, 481, 0, 385, 0, 65, 6, 98, 2, 102, 100, 274, 98, 98, 98, 98, 98, 98, 98, 98, 2, 100]


In [11]:
!pip install pretty_midi
!pip install pyfluidsynth
from to_midi import get_midi_from_profile

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [12]:
midi = get_midi_from_profile(generated[0, :, 0].tolist())
midi.write("gen.mid")