In [None]:
import torch
import numpy as np

import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter

from src.constants import LATENT_DIM, BATCH_SIZE, NUMBER_OF_PITCHES, LOWEST_PITCH, MUSIC_LENGTH, MUSIC_LENGTH, LPD_PATH, MAESTRO_PATH, LPD_FILE_EXTENSION, \
    MAESTRO_FILE_EXTENSION, NUM_EPOCHS, INFO_INTERVAL_EPOCHS

from src.utils import array_to_midi, get_device, show_learning_process, write_model_params_to_tensorboard, write_models_architecture_to_tensorboard, \
    write_losses_to_tensorboard, write_samples, save_models, generate_random_midi_array, array_to_midi

from src.data_preparation import prepare_data

from src.train_eval import train_one_step, Metrics

from src.models import Generator, Discriminator, SequenceBarGenerator, TemporalVectors

## Get device

In [None]:
device = get_device()

## Prepare data

In [None]:
# pianoroll_idx -> maestro has index 0 and LPD has index 1 for piano data

data_stacked = prepare_data(
    file_path = MAESTRO_PATH,
    file_extension=MAESTRO_FILE_EXTENSION,
    length=MUSIC_LENGTH,
    music_info_threshold = 0.02,
    pianoroll_idx=0,
    do_filtration=True
)
data_stacked.shape

In [None]:
np.save(f'../data/data_maestro_192_0.02.npy', data_stacked)

In [None]:
# data_lpd = np.load(f'../data/data_192_0.04.npy')

In [None]:
# data_maestro = np.load(f'../data/data_maestro_192_0.04.npy')
# data_stacked = np.load(f'../data/data_maestro_192_0.04.npy')

In [None]:
# data_stacked = np.concatenate((data_lpd, data_maestro), axis=0)
# data_stacked.shape

In [None]:
training_data = torch.as_tensor(data_stacked, dtype=torch.float32)
dataset = torch.utils.data.TensorDataset(training_data)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

## Define models, optimizers etc.

In [None]:
# Create tensorboard writer instance
tensorboard_writer = SummaryWriter()

# Create models
vectors_model = TemporalVectors(latent_vector_size=LATENT_DIM, hidden_size=LATENT_DIM, num_layers=2, sequence_length=3, device=device)
bar_generator = Generator()
generator = SequenceBarGenerator(vectors_generator=vectors_model, bar_generator=bar_generator)
discriminator = Discriminator()

# generator = torch.load('generators/99_generator.pt')
# discriminator = torch.load('discriminators/99_discriminator.pt')

discriminator = discriminator.to(device)
generator = generator.to(device)

# Create optimizers
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.001,  betas=(0.5, 0.9))
generator_optimizer = torch.optim.Adam(generator.parameters(), lr=0.001, betas=(0.5, 0.9))

# FOR ONE BAR
# sample_latent = torch.randn(8, LATENT_DIM, 1, 1).to(device)

# FOR SEQUENCE
sample_latent = torch.randn(1, 1, LATENT_DIM).to(device)

In [None]:
write_models_architecture_to_tensorboard(
    generator=generator,
    discriminator=discriminator,
    real_images_sample=training_data[0].reshape((1, 1, MUSIC_LENGTH, NUMBER_OF_PITCHES)).to(device),
    noise=torch.randn(1, 1, LATENT_DIM).to(device)
    )

## Train

In [None]:
training_evaluation = pd.DataFrame()
metrics = Metrics(bar_generator=generator, resolution=6, threshold=0.9, probe=50, device=device)

step = 0

history_samples = {}
discriminator_losses = []
generator_losses = []

generator.train()

for epoch in range(NUM_EPOCHS):
    for real_samples in data_loader:
        d_loss, g_loss = train_one_step(
            discriminator_optimizer=discriminator_optimizer,
            generator_optimizer=generator_optimizer,
            discriminator=discriminator,
            generator=generator,
            real_samples=real_samples[0].reshape(-1, 1, MUSIC_LENGTH, NUMBER_OF_PITCHES), 
            device=device
         )
        step += 1

    if epoch % INFO_INTERVAL_EPOCHS == 0:
        # Get generated samples
        print(f'EPOCH [{epoch}] | STEP [{step}] ---> Critic loss: {d_loss:.4f}  Generator loss: {g_loss:.4f}')
        generator.eval()
        samples = generator(sample_latent).cpu().detach().numpy()
        history_samples[step] = samples
        generator.train()

        d_loss = d_loss.detach().cpu()
        g_loss = g_loss.detach().cpu()
        discriminator_losses.append(d_loss)
        generator_losses.append(g_loss)

        write_model_params_to_tensorboard(tb_writer=tensorboard_writer, model=generator, epoch=epoch, prefix='generator_')
        write_model_params_to_tensorboard(tb_writer=tensorboard_writer, model=discriminator, epoch=epoch, prefix='discriminator_')

        write_losses_to_tensorboard(writer=tensorboard_writer, critic_loss=d_loss, generator_loss=g_loss, step=epoch)
        
        save_models(discriminator=discriminator, generator=generator, prefix=f'{epoch}_epoch')
        write_samples(generator=generator, device=device, name=f'{epoch}_epoch', threshold=0.5)

        next_epoch_metrics = metrics.create_metrics_df()
        training_evaluation = pd.concat([next_epoch_metrics, training_evaluation], ignore_index=True)
        training_evaluation.to_csv('../reports/metrics_eval.csv', index_label='epoch')

## Quality evaluation

In [None]:
plt.figure()
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Generator losses')
plt.plot(range(len(generator_losses)), generator_losses)
plt.savefig(f'../reports/figures/generator_losses.png')

In [None]:
plt.figure()
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Discriminator losses')
plt.plot(range(len(discriminator_losses)), discriminator_losses)
plt.savefig(f'../reports/figures/discriminator_losses.png')

In [None]:
show_learning_process(list(history_samples.values()))

In [None]:
metrics.bar_generator.vectors_generator.sequence_length=3
array_example = metrics.generate_random_midi_array()
array_to_midi(music_array=array_example, midi_path='test.midi', plot=True, resolution=6)

In [None]:
for i in training_evaluation.columns:
    plt.figure()
    plt.xlabel('Epoch')
    plt.title(i)
    plt.plot(training_evaluation[i])
    plt.savefig(f'../reports/figures/{i}.png')

In [None]:
music = muspy.read('../samples/midi/')
muspy.outputs.write_audio(path=output_path, music=music)