# Gaussian Mixture VAE

Many utilities borrowed from https://github.com/jariasf/GMVAE/blob/master/pytorch/networks/Layers.py. 

## Neural Network Implementation

In [None]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
def get_activation_function(id):
  if id == 'relu':
    return nn.ReLU()
  
  if id == 'sigmoid':
    return nn.Sigmoid()

  if id == 'tanh':
    return nn.Tanh()

  if id == 'none':
    return nn.Identity()

  else:
    raise ValueError

def get_batch_norm(bool, size):
  if bool:
    return nn.BatchNorm1d(size)

  else:
    return nn.Identity()

def get_rnn_cell(id):
  # returns constructor
  if id == 'gru':
    return nn.GRU

  if id == 'lstm':
    return nn.LSTM  

  if id == 'basic':
    return nn.RNN # why use this


In [None]:
class ff(nn.Module):
    def __init__(self, argdict):
      """
      argdict: contains all arguments 
      """
      super(ff, self).__init__()

      self.argdict = argdict

      input_dim = self.argdict["input_dim"]
      output_dim = self.argdict["output_dim"]
      layer_params = self.argdict["layer_params"]
      n_layers = len(layer_params)

      
      self.layers = []
      
      if n_layers != 0:
        self.layers.append(
          nn.Sequential(
            nn.Linear(input_dim, layer_params[0]["size"]),
            get_batch_norm(self.argdict[0]["batch_norm"], layer_params[0]["size"]),
            get_activation_function(layer_params[0]["activation_fn"])
          )
        )
      
        for i in range(n_layers-1):
          self.layers.append(
            nn.Sequential(
              nn.Linear(layer_params[i]["size"], layer_params[i+1]["size"]),
              get_batch_norm(self.argdict[i+1]["batch_norm"], layer_params[i+1]["size"]),
              get_activation_function(layer_params[i+1]["activation_fn"])
            )
        )
        
        self.layers.append(
            nn.Sequential(
              nn.Linear(layer_params[-1]["size"], output_dim),
              get_activation_function(self.argdict["output_activation_fn"])
            )
        )

      # 0 layer case, just pipe to output
      else:
        self.layers.append(
            nn.Sequential(
              nn.Linear(input_dim, output_dim),
              get_activation_function(self.argdict["output_activation_fn"])
            )
        )

    def forward(self, x):
      for layer in self.layers:
        x = layer(x)

      return x

class rnn(nn.Module):
    def __init__(self, argdict):
      """
      """
      super(rnn, self).__init__()

      self.argdict = argdict

      # no embedding? assume already embedded

      input_dim = self.argdict["input_dim"]
      output_dim = self.argdict["output_dim"]

      rnn_cell_constructor = get_rnn_cell(self.argdict["rnn_cell"])

      self.rnn_layer = rnn_cell_constructor(input_dim, output_dim, batch_first = True)

    def forward(self, x):
      return self.rnn_layer(x)

In [None]:
# sampling
def gumbel_sampler(x, temperature)
    # softmax but with noise
    sampled = torch.rand(x.size())
    eps = 1e-10 # stability
    if x.is_cuda:
      sampled = sampled.cuda()
    noise = torch.log(-torch.log(sampled + eps) + eps) # loglog
    return F.softmax((x - noise) / temperature, dim=-1)

def gaussian_sampler(m, v):
    std = torch.sqrt(v + 1e-10)
    eps = torch.randn_like(std)
    z = m + eps * std
    return z

# losses
def cross_entropy(logits, labels):
    return F.cross_entropy(logits, labels)

def mse(pred, labels):
    loss = (pred - labels).pow(2)
    return loss.sum(-1).mean()

def entropy(logits, labels):
    # wrt logits
    log_q = F.log_softmax(logits, dim=-1)
    return -torch.mean(torch.sum(labels * log_q, dim=-1))

def log_normal(z, m, v):
    v_stable = v + 1e-10
    return -0.5 * torch.sum(torch.pow(z - m, 2)/v + torch.log(var), dim=-1) # ignore constant 

def gaussian_kl(sample, mu, var, mu_prior, var_prior):
    loss = log_normal(sample, mu, var) - log_normal(sample, mu_prior, var_prior)
    return loss.mean()

In [None]:
class softmax_with_gumbel(nn.Module):

  def __init__(self, argdict):
    super(softmax_with_gumbel, self).__init__()
    
    self.argdict = argdict

    self.input_dim = argdict["input_dim"]
    self.output_dim = argdict["output_dim"]
    self.layer = nn.Linear(input_dim, output_dim)
    self.activation = nn.Softmax(dim = -1)
  
  def forward(self, x, temperature = 1.0):
    x = self.layer(x)
    y = gumbel_sampler(x, temperature)
    return self.activation(x), y # logits, y

class gaussian(nn.Module):
  def __init__(self, argdict):
    super(gaussian, self).__init__()

    self.argdict = argdict
    self.input_dim = argdict["input_dim"]
    self.output_dim = argdict["output_dim"]

    self.mu_layer = nn.Linear(self.input_dim, self.output_dim)
    self.var_layer = nn.Sequential(
        nn.Linear(self.input_dim, self.output_dim),
        nn.Softplus() # need softplus
    )

  def forward(self, x):
    mu = self.mu_layer(x)
    var = self.var(x)
    z = gaussian_sampler(mu, var)
    return mu, var, z    

In [None]:
class encoder(nn.Module):
  def __init__(self, argdict):
    super(encoder, self).__init__()

    self.argdict = argdict
    
    # q(y|x)
    self.q_y_network = torch.nn.Sequential(
      rnn(argdict["q_y_rnn"]), # rnn component
      ff(argdict["q_y_linear"])                                    
    ) # make sure to constrain arguments

    self.q_y = softmax_with_gumbel(argdict["q_y_gumbel"]) # separate for temperature parameter
    
    # q(z|y,x)
    self.q_z = torch.nn.Sequential(
      rnn(argdict["q_z_rnn"]), # rnn?? can remove if buggy
      ff(argdict["q_z_linear"]),
      gaussian(argdict["q_z_gaussian"])                                          
    ) # make sure to constrain arguments
  
  def forward_fixed_y(self, x, y_fixed):
    # for style transfer
    
    mu, var, z = self.q_z(torch.cat((x, y_fixed), dim=1))
    return_dict = {'mu': mu, 'var': var, 'z': z}
    return return_dict

  def forward(self, x, temperature = 1.0):
    pi, y = self.q_y(self.q_y_network(x), temperature = temperature)
    
    mu, var, z = self.q_z(torch.cat((x, y), dim=-1))
    
    return_dict = {'pi': pi, 'y': y, 'mu': mu, 'var': var, 'z': z}
    return return_dict

class Decoder(nn.Module):
  def __init__(self, argdict):
    super(decoder, self).__init__()

    self.argdict = argdict
    input_dim = self.argdict["input_dim"] # y_dim
    output_dim = self.argdict["output_dim"] # z_dim
    # make sure dims match when constructing args

    self.p_z_mu_nn = nn.Linear(input_dim, output_dim)
    self.p_z_var_nn = nn.Sequential(
      nn.Linear(input_dim, output_dim),
      nn.Softplus()
    )

    self.p_x = torch.nn.Sequential(
      ff(argdict["p_x_linear"]),
      rnn(argdict["p_x_rnn"]),
    ) # apply activations on output


  def forward(self, z, y):
    z_mu = self.p_z_mu_nn(y)
    z_var = self.p_z_var_nn(y)
    
    x = self.p_x(z)

    return_dict = {'mu': y_mu, 'var': y_var, 'x': x}
    return return_dict

class GMVAE(nn.Module):
  def __init__(self, argdict):
    super(GMVAE, self).__init__()

    self.encoder = encoder(argdict["encoder"])
    self.decoder = decoder(argdict["decoder"])

    # weight initialization
    for m in self.modules():
      if type(m) == nn.Linear:
        torch.nn.init.xavier_normal_(m.weight)
      elif type(m) == nn.RNN or type(m) == nn.GRU or type(m) == nn.LSTM:
        torch.nn.init.he_normal(m.weight) # RNN weighting

  def style_transfer(self, x, y_fixed):
    # y_fixed should be binary vector with the target style

    encoder_returns = self.encoder.forward_fixed_y(x, y_fixed)
    z = encoder_returns['z']
    decoder_returns = self.decoder(z, y_fixed)
    
    return_dict = {"encoder": encoder_returns, "decoder": decoder_returns}
    return return_dict

  def forward(self, x, temperature=1.0):
    # standard

    encoder_returns = self.encoder(x, temperature = temperature)
    z, y = encoder_returns['z'], encoder_returns['y']
    decoder_returns = self.decoder(z, y)
    
    return_dict = {"encoder": encoder_returns, "decoder": decoder_returns}
    return return_dict

In [None]:
class Model():
  def __init__(self, argdict):
    # unpacking 50000000 args
    # lr control

    # lr variable
    self.learning_rate = argdict["learning_rate"]
    
    # decay parameters
    self.decay_epoch = argdict["decay_epoch"]
    self.lr_decay = argdict["lr_decay"]

    # weighting for loss
    self.weight_style = argdict["weight_style"]
    self.weight_entropy = argdict["weight_entropy"]
    self.weight_sampling = argdict["weight_sampling"]

    # mix different audio??   
    self.weight_pitch = argdict["weight_pitch"]
    self.weight_instrument = argdict["weight_instrument"]
    self.weight_velocity = argdict["weight_velocity"]

    # sizes, make sure it matches
    self.pitch_size = argdict["pitch_size"]
    self.instrument_size = argdict["instrument_size"]
    self.velocity_size = argdict["velocity_size"]

    # temperature for sampling for GMM, very annoying
    self.init_temp = argdict["init_temp"]
    self.decay_temp = argdict["decay_temp"]
    self.min_temp = argdict["min_temp"]
    self.decay_temp_rate = argdict["decay_temp_rate"]

    # temperature variable
    self.gumbel_temp = self.init_temp

    self.model = GMVAE(argdict)
    if argdict["cuda"]:
      self.model = self.model.cuda()

  def _elbo(self, pitch, instrument, velocity, style_label):
    x = torch.cat(pitch, instrument, velocity, dim=-1) # should be dim = 3
    return_dict = self.model(x, temperature = self.gumbel_temp)

    x_pred = return_dict["decoder"]["x"]
    pitch_pred, instrument_pred, velocity_pred = \
      torch.split(x_pred, [self.pitch_size, self.instrument_size, self.velocity_size], dim=-1)

    # renormalizing?
    pitch_pred = (pitch_pred + 1)/2
    instrument_pred = (instrument_pred + 1)/2
    velocity_pred = (velocity_pred + 1)/2 + 0.5 # is this correct?

    loss_pitch = cross_entropy(pitch_pred, torch.argmax(pitch, dim = -1))
    loss_instrument = cross_entropy(instrument_pred, torch.argmax(instrument, dim = -1))
    loss_velocity = mse(velocity, velocity_pred)

    style_logits = return_dict["encoder"]["pi"]
    y_pred = return_dict["encoder"]["y"]

    loss_style = cross_entropy(style_logits, torch.argmax(style_label, dim=-1)) # style label is 1-hot?
    loss_entropy = entropy(style_logits, y_pred)

    z_pred = return_dict["encoder"]["z"]
    new_mu, new_var = return_dict["encoder"]["mu"], return_dict["encoder"]["var"]
    old_mu, old_var = return_dict["decoder"]["mu"], return_dict["decoder"]["var"]

    loss_kl = gaussian_kl(z_pred, new_mu, new_var, old_mu, old_var)

    loss_total = loss_pitch * self.weight_pitch + \
      loss_instrument * self.weight_instrument + loss_velocity * self.weight_velocity + \
      loss_style * self.weight_style + loss_entropy * self.weight_entropy + \
      loss_kl * self.weight_sampling

    stats_dict = {'kl': loss_kl, 'entropy': loss_entropy, 'ce_style': loss_style,
                  'ce_pitch': loss_pitch, 'ce_instrument': loss_instrument,
                  'mse_velocity': loss_velocity, 'total': loss_total}

  def _step(self):
    # get _elbo, optimize

  def train(self, data_loader):
    # iterate on data loader using step

    # decay lr and temp appropriately

  def test(self, data_loader):
    # iterate

  def run(self, train_loader, test_loader):
    # train, test then plot outputs?? save model somehow

  def transfer(self, data_loader):
    # use the style transfer function in GMVAE


In [None]:
def load_data():
  # returns data
  pass

In [None]:
# shared constants

PITCH_DIM = ???
INSTRUMENT_DIM = ???
VELOCITY_DIM = ???

NUM_STYLES = ???
LATENT_DIM = 128 # ??

RNN_CELL = "gru"
RNN_CELL_NUMBER = LATENT_DIM # this is a lot

# all hyperparameters in here lol
# sorry, hope this is easy to read
encoder_dict = {
    "q_y_rnn": {"input_dim": PITCH_DIM + INSTRUMENT_DIM + VELOCITY_DIM,
                "output_dim": 3 * LATENT_DIM,
                "rnn_cell": RNN_CELL},
    "q_y_linear": {"input_dim": 3 * LATENT_DIM,
                   "output_dim": LATENT_DIM,
                   "output_activation_function": "relu"},
    "q_y_gumbel": {"input_dim": LATENT_DIM,
                   "output_dim": NUM_STYLES},

    "q_z_rnn": {"input_dim": PITCH_DIM + INSTRUMENT_DIM + VELOCITY_DIM + NUM_STYLES,
                "output_dim": 3 * LATENT_DIM,
                "rnn_cell": RNN_CELL},
    "q_z_linear": {"input_dim": 3 * LATENT_DIM,
                   "output_dim": LATENT_DIM,
                   "output_activation_function": "relu"},
    "q_z_gaussian": {"input_dim": LATENT_DIM,
                     "output_dim": LATENT_DIM},
}

decoder_dict = {
    "input_dim": NUM_STYLES,
    "output_dim": LATENT_DIM,

    "p_x_linear": {"input_dim": LATENT_DIM,
                   "output_dim": 3 * LATENT_DIM,
                   "output_activation_function": "tanh"}, # tanh better for rnn?
    "p_x_rnn": {"input_dim": 3 * LATENT_DIM,
                "output_dim": PITCH_DIM + INSTRUMENT_DIM + VELOCITY_DIM,
                "rnn_cell": RNN_CELL},
}

argdict = {
    "learning_rate": 1e-4,
    "decay_epoch": 100,
    "lr_decay": 1e-1,

    # tune these
    "weight_style": 1,
    "weight_entropy": 0.5,
    "weight_sampling": 1,

    "weight_pitch": 1,
    "weight_velocity": 1,
    "weight_instrument": 1,

    "pitch_size": PITCH_DIM, #idk,
    "instrument_size": INSTRUMENT_DIM, # idk
    "velocity_size": VELOCITY_DIM,

    "init_temp": 1e-1,
    "decay_temp": 1e-1,
    "min_temp": 1e-5,
    "decay_temp_rate": 5, # every N epochs

    "cuda": True, # use GPU

    "encoder": encoder_dict,
    "decoder": decoder_dict,
}

In [None]:
# Algorithm pipeline

data_loader = load_data() # x and y's

model = Model(argdict)
model.run()