## Import Modules

In [None]:
import torch
import torch.nn as nn
import pandas as pd
import time

## Hyperparameters

In [None]:
LATENT_DIM = 16
BATCH_SIZE = 2048
EPOCHS = 1000
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5
GRADIENT_CLIPPING = False
CLIP_VALUE = 100
STD_DEV_MEASURE_MOD = 50
STD_DEV_MEASURE_SCALE = 1
LOSS_REDUCTION = "mean"    # sum or mean

TRAIN_RATIO = 0.8

## Definitions

In [None]:
from data_processing import PatchProcessor
import data_processing
params = data_processing.get_params()
processor = PatchProcessor(params)

df = pd.read_csv("E:\\Coding\\vae-main\\old\\dx7_cleaned.csv")
df = df.drop(columns=df.columns[0])

df = processor.normalize(df)
df_encoded, expanded_types = processor.one_hot_dataframe(df)

mse_mask, be_mask, ce_mask, alg_mask = processor.make_masks(expanded_types)
masks = (mse_mask, be_mask, ce_mask, alg_mask)
algorithms, alibi_distances = data_processing.get_algorithms()

encoded_data_array = df_encoded.to_numpy()
x_data = torch.tensor(encoded_data_array, dtype=torch.float32)
dataset = torch.utils.data.TensorDataset(x_data)

In [None]:
import os
from helpers import *

mse_loss = nn.MSELoss(reduction=LOSS_REDUCTION)
bce_loss = nn.BCELoss(reduction=LOSS_REDUCTION)
ce_loss = nn.CrossEntropyLoss(reduction=LOSS_REDUCTION)
algo_loss = nn.CrossEntropyLoss(reduction=LOSS_REDUCTION)

def latent_display(model, name):
    train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    visualize_latent_space(model, dataloader=train_dataloader, device="cuda", NAME=name)

def model_size(model):
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    return params

def warmup(epoch, max_epoch, rate, max_beta):
    y = min(rate * epoch, rate * max_epoch)
    y = min(y, max_beta)
    return y

def cyclic_beta_schedule(epoch, max_beta, period=100):
    # see: https://www.microsoft.com/en-us/research/blog/less-pain-more-gain-a-simple-method-for-vae-training-with-less-of-that-kl-vanishing-agony/
    half_period = period // 2
    i = epoch % period
     # hold at max for last half of period
    if i > half_period: return max_beta       
     # grow to max for first half of period
    return i / half_period * max_beta
    
def train_model(vae, masks, train_dataloader, val_dataloader, device, debug=False):
    mse_mask, be_mask, ce_mask, alg_mask = masks

    # register hook for gradient clipping
    if GRADIENT_CLIPPING:
        for param in vae.parameters():
            param.register_hook(lambda grad: torch.clamp(grad, -CLIP_VALUE, CLIP_VALUE))
    
    # use AdamW instead as Adam is broken
    optimizer = torch.optim.AdamW(vae.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=LEARNING_RATE,
        pct_start=0.0,
        steps_per_epoch=len(train_dataloader),
        epochs=EPOCHS
    )
    
    train_losses, train_mse_recon_losses, train_ce_recon_losses, train_kl_losses, train_be_recon_losses, train_algo_recon_losses  = [], [], [], [], [], []
    val_losses, val_mse_recon_losses, val_ce_recon_losses, val_kl_losses, val_be_recon_losses, val_algo_recon_losses = [], [], [], [], [], []

    for epoch in range(EPOCHS):
        start_time = time.time()
        train_epoch_total_loss = 0
        train_epoch_mse_recon_loss = 0
        train_epoch_ce_recon_loss = 0
        train_epoch_be_recon_loss = 0
        train_epoch_kl_loss = 0
        train_sparsity_loss = 0
        train_epoch_algo_loss = 0
        train_total_mse_loss = 0

        val_epoch_total_loss = 0
        val_epoch_mse_recon_loss = 0
        val_epoch_ce_recon_loss = 0
        val_epoch_be_recon_loss = 0
        val_epoch_kl_loss = 0
        val_epoch_sparsity_loss = 0
        val_epoch_algo_loss = 0
        val_total_mse_loss = 0
        if epoch % STD_DEV_MEASURE_MOD == 0:
            plt.clf()
            
        beta = cyclic_beta_schedule(epoch, BETA, period=(EPOCHS // BETA_CYCLES))

        vae.train()
        for batch in train_dataloader:
            x_batch = batch[0].to(device)
            
            recon_x, mu, logvar, sparse_loss = vae(x_batch)

            train_total_loss, train_mse_recon_loss, train_ce_recon_loss, train_be_recon_loss, train_kl_loss, total_mse, alg_loss = vae_total_loss(
                recon_x,
                x_batch,
                be_mask,
                ce_mask,
                mse_mask,
                alg_mask,
                mu,
                logvar,
                beta
            )

            #train_total_loss += sparse_loss

            optimizer.zero_grad()
            train_total_loss.backward()
            optimizer.step()
            scheduler.step()

            train_epoch_total_loss += train_total_loss.item()
            train_epoch_mse_recon_loss += train_mse_recon_loss.item()
            train_epoch_ce_recon_loss += train_ce_recon_loss.item()
            train_epoch_be_recon_loss += train_be_recon_loss.item()
            train_epoch_kl_loss += train_kl_loss.item()
            train_epoch_algo_loss += alg_loss.item()
            #train_sparsity_loss += sparse_loss
            train_total_mse_loss += total_mse.item()

        vae.eval() # <- mode for just evaluating
        with torch.no_grad(): # <- don't track gradients
            for batch in val_dataloader:
                x_batch = batch[0].to(device)
                recon_x, mu, logvar, sparse_loss = vae(x_batch)
                val_total_loss, val_mse_recon_loss, val_ce_recon_loss, val_be_recon_loss, val_kl_loss, val_mse_total, alg_loss = vae_total_loss(
                    recon_x,
                    x_batch,
                    be_mask,
                    ce_mask,
                    mse_mask,
                    alg_mask,
                    mu,
                    logvar,
                    beta,
                )

                #val_total_loss += sparse_loss

                val_epoch_total_loss += val_total_loss.item()
                val_epoch_mse_recon_loss += val_mse_recon_loss.item()
                val_epoch_ce_recon_loss += val_ce_recon_loss.item()
                val_epoch_be_recon_loss += val_be_recon_loss.item()
                val_epoch_kl_loss += val_kl_loss.item()
                val_total_mse_loss += val_mse_total.item()
                val_epoch_algo_loss += alg_loss.item()
                #val_epoch_sparsity_loss += sparse_loss

        if epoch % STD_DEV_MEASURE_MOD == 0:
            plt.savefig(f"graph/outputs_{epoch//STD_DEV_MEASURE_MOD:08d}.png")

        train_avg_total_loss = train_epoch_total_loss           / len(train_dataloader)
        train_avg_mse_recon_loss = train_epoch_mse_recon_loss   / len(train_dataloader)
        train_avg_ce_recon_loss = train_epoch_ce_recon_loss     / len(train_dataloader)
        train_avg_be_recon_loss = train_epoch_be_recon_loss     / len(train_dataloader)
        train_avg_kl_loss = train_epoch_kl_loss                 / len(train_dataloader)
        train_avg_algo_loss = train_epoch_algo_loss             / len(train_dataloader)
        #train_avg_sparsity_loss = train_sparsity_loss           / len(train_dataloader)
        train_avg_epoch_mse_loss = train_total_mse_loss         / len(train_dataloader)

        train_losses.append(train_avg_total_loss)
        train_mse_recon_losses.append(train_avg_mse_recon_loss)
        train_ce_recon_losses.append(train_avg_ce_recon_loss)
        train_be_recon_losses.append(train_avg_be_recon_loss)
        train_algo_recon_losses.append(train_avg_algo_loss)

        train_kl_losses.append(train_avg_kl_loss)

        val_avg_total_loss = val_epoch_total_loss               / len(val_dataloader)
        val_avg_mse_recon_loss = val_epoch_mse_recon_loss       / len(val_dataloader)
        val_avg_ce_recon_loss = val_epoch_ce_recon_loss         / len(val_dataloader)
        val_avg_be_recon_loss = val_epoch_be_recon_loss         / len(val_dataloader)
        val_avg_kl_loss = val_epoch_kl_loss                     / len(val_dataloader)
        val_avg_algo_loss = val_epoch_algo_loss                 / len(val_dataloader)
        #val_avg_sparsity_loss = val_epoch_sparsity_loss         / len(val_dataloader)
        val_avg_epoch_mse_loss = val_total_mse_loss             / len(val_dataloader)

        val_losses.append(val_avg_total_loss)
        val_mse_recon_losses.append(val_avg_mse_recon_loss)
        val_ce_recon_losses.append(val_avg_ce_recon_loss)
        val_be_recon_losses.append(val_avg_be_recon_loss)
        val_algo_recon_losses.append(val_avg_algo_loss)
        val_kl_losses.append(val_avg_kl_loss)

        end_time = time.time()
        runtime = end_time - start_time
        
        if debug:
            print("-" * 50)
            print(f"Epoch {epoch+1}/{EPOCHS} done in {runtime:.4f} seconds")
            print(f"Beta: {beta:.4f}")
            print(f"Learning Rate: {scheduler.get_last_lr()[0]:.8f}")
            
            print(f"Training Loss: {train_avg_total_loss:.4f}")
            print(f"\tTraining MSE Reconstruction Loss: {train_avg_mse_recon_loss:.4f}")
            print(f"\tTraining CE Reconstruction Loss: {train_avg_ce_recon_loss:.4f}")
            print(f"\tTraining BE Reconstruction Loss: {train_avg_be_recon_loss:.4f}")
            print(f"\tTraining Algorithm Loss: {train_avg_algo_loss:.4f}")
            print(f"\tTraining KL Loss: {train_avg_kl_loss:.4f}")
            #print(f"\tTraining Sparsity Loss: {train_avg_sparsity_loss:.4f}")
            print(f"\tTraining Total MSE Reconstruction Loss: {train_avg_epoch_mse_loss:.4f}")

            print(f"Validation Loss: {val_avg_total_loss:.4f}")
            print(f"\tValidation MSE Reconstruction Loss: {val_avg_mse_recon_loss:.4f}")
            print(f"\tValidation CE Reconstruction Loss: {val_avg_ce_recon_loss:.4f}")
            print(f"\tValidation BE Reconstruction Loss: {val_avg_be_recon_loss:.4f}")
            print(f"\tValidation Algorithm Loss: {val_avg_algo_loss:.4f}")
            print(f"\tValidation KL Loss: {val_avg_kl_loss:.4f}")
            #print(f"\tValidation Sparsity Loss: {val_avg_sparsity_loss:.4f}")
            print(f"\tValidation Total MSE Reconstruction Loss: {val_avg_epoch_mse_loss:.4f}")

            print("-" * 50)

    return vae, train_losses, train_mse_recon_losses, train_ce_recon_losses, train_kl_losses, train_be_recon_losses, train_algo_recon_losses, val_losses, val_mse_recon_losses, val_ce_recon_losses, val_be_recon_losses, val_kl_losses, val_algo_recon_losses

def train_loop(model):
    # determine device
    has_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if has_cuda else "cpu")
    if has_cuda:
        print("Using GPU")
    else:
        print("Using CPU")

    train_data_count = round(len(dataset) * TRAIN_RATIO)
    val_data_count = len(dataset) - train_data_count

    print("Total samples:", len(dataset))
    print("Training samples:", train_data_count)
    print("Validation samples:", val_data_count)

    train_data, val_data = torch.utils.data.random_split(dataset, [train_data_count, val_data_count])
    train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=True)

    model.to(device=device)

    print(f"{model_size(model)} Parameters")

    # train vae
    print("Training VAE...")
    vae, train_losses, train_mse_recon_losses, train_ce_recon_losses, train_kl_losses, train_be_recon_losses, train_algo_recon_losses, val_losses, val_mse_recon_losses, val_ce_recon_losses, val_be_recon_losses, val_kl_losses, val_algo_recon_loss = train_model(model, masks, train_dataloader, val_dataloader, device, debug=True)
    print("VAE complete.")

    # visualize results
    print("Plotting loss and UMAP of latent space...")
    os.makedirs(NAME, exist_ok=True)
    plot_loss(train_losses, train_mse_recon_losses, train_ce_recon_losses, train_kl_losses, train_be_recon_losses, train_algo_recon_losses, val_losses, val_mse_recon_losses, val_ce_recon_losses, val_be_recon_losses, val_algo_recon_loss, val_kl_losses, NAME)

    train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    visualize_latent_space(model, dataloader=train_dataloader, device="cuda", NAME=NAME)(model, NAME)

    # save the vae parameters
    torch.save(vae.state_dict(), "./models/" + NAME + "/model.pth")
    torch.save(vae, "./models/" + NAME + "/model.model")

## Train Models

In [None]:
NAME = "GraphTransformer-Alibi-Full-Beta-0-001"
LATENT_DIM = 16
BATCH_SIZE = 2048
EPOCHS = 1000
LEARNING_RATE = 1e-4
BETA = 0.001
WEIGHT_DECAY = 1e-5
GRADIENT_CLIPPING = False
CLIP_VALUE = 100
STD_DEV_MEASURE_MOD = 50

STD_DEV_MEASURE_SCALE = 1
LOSS_REDUCTION = "mean"

BETA_CYCLES = 10

from GraphTransformerAlibiGlobalToken import GraphTransformerAutoencoderAlibi
model = GraphTransformerAutoencoderAlibi(input_size=21, latent_space=LATENT_DIM, d_model=128, depth=6, heads=4, mlp_dim=256, num_algorithms=32, num_global_params=67, sparsity_weight=1, reparameterization=True, algorithm_distance_matricies=alibi_distances)

train_loop(model)

In [None]:
NAME = "GraphTransformer-Alibi-Full-Beta-0-01"
BETA = 0.01
BATCH_SIZE = 2048
LATENT_DIM = 16
from GraphTransformerAlibiGlobalToken import GraphTransformerAutoencoderAlibi
model = GraphTransformerAutoencoderAlibi(input_size=21, latent_space=LATENT_DIM, d_model=128, depth=6, heads=4, mlp_dim=256, num_algorithms=32, num_global_params=67, sparsity_weight=1, reparameterization=True, algorithm_distance_matricies=alibi_distances)
train_loop(model)

In [None]:
NAME = "GraphTransformer-Alibi-Full-Beta-0-1"
BETA = 0.1
BATCH_SIZE = 2048
LATENT_DIM = 16
from GraphTransformerAlibiGlobalToken import GraphTransformerAutoencoderAlibi
model = GraphTransformerAutoencoderAlibi(input_size=21, latent_space=LATENT_DIM, d_model=128, depth=6, heads=4, mlp_dim=256, num_algorithms=32, num_global_params=67, sparsity_weight=1, reparameterization=True, algorithm_distance_matricies=alibi_distances)
train_loop(model)

In [None]:
NAME = "GraphTransformer-Alibi-Full-Beta-0-25"
BETA = 0.25
BATCH_SIZE = 2048
LATENT_DIM = 16
from GraphTransformerAlibiGlobalToken import GraphTransformerAutoencoderAlibi
model = GraphTransformerAutoencoderAlibi(input_size=21, latent_space=LATENT_DIM, d_model=128, depth=6, heads=4, mlp_dim=256, num_algorithms=32, num_global_params=67, sparsity_weight=1, reparameterization=True, algorithm_distance_matricies=alibi_distances)
train_loop(model)

In [None]:
train_loop(model)
NAME = "GraphTransformer-Alibi-Full-Beta-0-3725"
BETA = 0.3725
BATCH_SIZE = 2048
LATENT_DIM = 16
from GraphTransformerAlibiGlobalToken import GraphTransformerAutoencoderAlibi

model = GraphTransformerAutoencoderAlibi(input_size=21, latent_space=LATENT_DIM, d_model=128, depth=6, heads=4,
                                         mlp_dim=256, num_algorithms=32, num_global_params=67, sparsity_weight=1,
                                         reparameterization=True, algorithm_distance_matricies=alibi_distances)
train_loop(model)

In [None]:
NAME = "GraphTransformer-Alibi-Full-Beta-0-5"
BETA = 0.5
BATCH_SIZE = 2048
LATENT_DIM = 16
from GraphTransformerAlibiGlobalToken import GraphTransformerAutoencoderAlibi
model = GraphTransformerAutoencoderAlibi(input_size=21, latent_space=LATENT_DIM, d_model=128, depth=6, heads=4, mlp_dim=256, num_algorithms=32, num_global_params=67, sparsity_weight=1, reparameterization=True, algorithm_distance_matricies=alibi_distances)
train_loop(model)

In [None]:
NAME = "GraphTransformer-Alibi-Full-Beta-1"
BETA = 1
BATCH_SIZE = 2048
LATENT_DIM = 16
from GraphTransformerAlibiGlobalToken import GraphTransformerAutoencoderAlibi
model = GraphTransformerAutoencoderAlibi(input_size=21, latent_space=LATENT_DIM, d_model=128, depth=6, heads=4, mlp_dim=256, num_algorithms=32, num_global_params=67, sparsity_weight=1, reparameterization=True, algorithm_distance_matricies=alibi_distances)
train_loop(model)

In [None]:
from ResidualNet import VariationalAutoencoder

LATENT_DIM = 16
BATCH_SIZE = 2048
EPOCHS = 1000
LEARNING_RATE = 1e-4

WEIGHT_DECAY = 1e-5
GRADIENT_CLIPPING = False
CLIP_VALUE = 100
STD_DEV_MEASURE_MOD = 50

STD_DEV_MEASURE_SCALE = 1
LOSS_REDUCTION = "mean"

BETA_CYCLES = 10
BETA = 0.001

NAME = f"VAE-{LATENT_DIM}-Beta-0-001"

model = VariationalAutoencoder(input_dim=225, output_dim=225, latent_dim=LATENT_DIM)
train_loop(model)

In [None]:
dx7 midi manual
f0
id
stuff
f7

Euclidean Drum Patterns
Conan Networks

In [None]:
from ResidualNet import VariationalAutoencoder

BETA = 0.001
NAME = f"VAE-{LATENT_DIM}-Beta-0-001"
model = VariationalAutoencoder(input_dim=225, output_dim=225, latent_dim=LATENT_DIM)
train_loop(model)

In [None]:
from ResidualNet import VariationalAutoencoder
BETA = 0.01
NAME = f"VAE-{LATENT_DIM}-Beta-0-01"
model = VariationalAutoencoder(input_dim=225, output_dim=225, latent_dim=LATENT_DIM)
train_loop(model)

In [None]:
BETA = 0.1
NAME = f"VAE-{LATENT_DIM}-Beta-0-1"
model = VariationalAutoencoder(input_dim=225, output_dim=225, latent_dim=LATENT_DIM)
train_loop(model)

In [None]:
BETA = 0.5
NAME = f"VAE-{LATENT_DIM}-Beta-0-5"
model = VariationalAutoencoder(input_dim=225, output_dim=225, latent_dim=LATENT_DIM)
train_loop(model)

In [None]:
BETA = 1
NAME = f"VAE-{LATENT_DIM}-Beta-1"
model = VariationalAutoencoder(input_dim=225, output_dim=225, latent_dim=LATENT_DIM)
train_loop(model)