# Training

## Load libraries and datasets

In [None]:
# @title Install and load libraries { display-mode: "form" }
# @markdown The following libraries and programs will be installed:
# @markdown - MIDItok (tokeniser)
# @markdown - Fluidsynth (MIDI to wav)
# @markdown - Muspy (Symbolic music libraries handle)
# @markdown ---
# @markdown All necessary libraries are included in this cell

# Install MIDItok for tokenising MIDI files
!pip install miditok
# Library used to handle symbolic music datasets. Used in this case for
# Emopia and Maestro datasets
!pip install muspy
# Fluidsynth for producing wav files from midi (using a soundfont)
!pip install midi2audio
!pip install fluidsynth
# Install fluidsynth
!apt install fluidsynth
# Copy default sample of musical instruments to the current directory, for Colab.
# Should be manually added in a jupyter file
!cp /usr/share/sounds/sf2/FluidR3_GM.sf2 ./font.sf2

# Using REMI scheme tokenisation
from miditok import REMI, TokenizerConfig
# Manage paths
from pathlib import Path
import os
# Ipython.display to show audio files
import IPython.display as ipd
# Librosa for music and audio analysis
import librosa
# Symbolic music datasets library
import muspy
# To calculate information from the midi file
import music21
# To copy objects
import copy
# Pytorch
import torch

# Set datasets path
datasets_folder = 'data'
vgmidi_path  = datasets_folder + '/vgmidi'
emopia_path  = datasets_folder + '/emopia'
maestro_path = datasets_folder + '/maestro'

# Create folder for datasets
if not os.path.exists(datasets_folder):
    os.makedirs(datasets_folder)

In [None]:
# @title Get datasets { display-mode: "form" }
# @markdown Datasets to download:
# @markdown - VGMIDI (from github)
# @markdown - EMOPIA (muspy)
# @markdown - MAESTRO (muspy)
# Not need to run if files already exist, in case of Colab notebook clone and
# extract the files
#!git clone https://github.com/lucasnfe/vgmidi.git {vgmidi_path}
# Extract unlabelled midi from VGMIDI dataset
#!unzip "{vgmidi_path}/unlabelled/midi.zip" -d "{vgmidi_path}/"

# Use muspy library for emopia and maestro datasets
emopia = muspy.EMOPIADataset(emopia_path, download_and_extract=True)
maestro = muspy.datasets.MAESTRODatasetV3(maestro_path, download_and_extract=True)

# Get paths of unlabelled data in VGMIDI dataset
vgmidi_unlabelled_midi_paths = list(Path("./data/vgmidi/midi").glob("*.mid"))
print("VGMIDI unlabelled paths", len(vgmidi_unlabelled_midi_paths))

# Get paths of labelled data in VGMIDI dataset
vgmidi_labelled_midi_paths = list(Path("./data/vgmidi/labelled/midi").glob("*.mid"))
print("VGMIDI labelled paths", len(vgmidi_labelled_midi_paths))

# Get paths of labelled data in EMOPIA dataset
emopia_labelled_midi_paths = list(Path("./data/emopia/EMOPIA_2.2/midis").glob("*.mid"))
print("EMOPIA labelled paths", len(emopia_labelled_midi_paths))

# Create list to store MAESTRO dataset paths
maestro_unlabelled_midi_paths = list()
# Get paths of unlabelled data in MAESTRO dataset
for folder in os.listdir("./data/maestro/maestro-v3.0.0/"):
  maestro_unlabelled_midi_paths +=  list(Path("./data/maestro/maestro-v3.0.0/" + folder).glob("*.mid*"))
print("MAESTRO labelled paths", len(maestro_unlabelled_midi_paths))

# Create list of combined laballed paths
combined_labelled_dataset_paths = vgmidi_labelled_midi_paths + emopia_labelled_midi_paths
print("Combined labelled paths", len(combined_labelled_dataset_paths))

# Create list of combined unlabelled paths
combined_unlabelled_dataset_paths = vgmidi_unlabelled_midi_paths + maestro_unlabelled_midi_paths
print("Combined unlabelled paths", len(combined_unlabelled_dataset_paths))

# Create list of combined labelled and unlabelled paths
combined_dataset_paths = combined_unlabelled_dataset_paths + combined_labelled_dataset_paths
print("Combined paths", len(combined_dataset_paths))

Skip downloading as the `.muspy.success` file is found.
Skip extracting as the `.muspy.success` file is found.
Skip downloading as the `.muspy.success` file is found.
Skip extracting as the `.muspy.success` file is found.
VGMIDI unlabelled paths 3850
VGMIDI labelled paths 204
EMOPIA labelled paths 1071
MAESTRO labelled paths 1276
Combined labelled paths 1275
Combined unlabelled paths 5126
Combined paths 6401


In [None]:
# @title Clone Emotion Wave github repository { display-mode: "form" }
!git clone https://github.com/JorgePdlR/EmotionWave.git

# Set datasets path
datasets_folder = 'data'
output_folder = 'FolderData'

truncated_folder = 'EmotionWave/vgmidi_unlabelled_truncated.zip'

!unzip "{truncated_folder}" -d "{output_folder}/"
dataset_path = output_folder + "Vgmidi_unlabelled"

In [None]:
# @title Run to update the content from the github repository { display-mode: "form" }
# @markdown This is a no-operation if no changes has been done
%%capture
%cd EmotionWave
!git pull
%cd ..

In [None]:
# @title Load tokeniser { display-mode: "form" }
# Load tokeniser
from miditok import TokSequence
import miditok
import importlib
import EmotionWave.MIDIoperations
from EmotionWave.MIDIoperations import REMItokenizer, MidiWav
import pretty_midi
from midi2audio import FluidSynth

# Reload the module
importlib.reload(EmotionWave.MIDIoperations)

TOKENIZER_PARAMS = {
    "pitch_range": (21, 109),
    "beat_res": {(0, 4): 8, (4, 12): 4},
    "num_velocities": 32,
    "use_chords": True,
    "use_rests": False,
    "use_tempos": True,
    "use_time_signatures": False,
    "use_programs": False,
    "use_pitchdrum_tokens": False,
    "num_tempos": 32,  # number of tempo bins
    "tempo_range": (40, 250),  # (min, max)
}

remi = REMItokenizer(TOKENIZER_PARAMS, max_bar_embedding=None)

In [None]:
# @title Filter data from dataset { display-mode: "form" }
import re
import pickle

# Load dataset
#with open('data_8bar_ids_dict.pkl', 'rb') as pickle_file:
with open('data_8bar_ids_dict_clean.pkl', 'rb') as pickle_file:
    data_dict = pickle.load(pickle_file)

# Remove all midi fragments that have sequences of tokens of more than 128
# values
val = 0
to_remove_list = []
# Remove all fragments with more than 128 tokens per bar
for bar8_path, bar8_dict in data_dict.items():
    for bar_path, bar_dict in bar8_dict.items():
        if len(bar_dict['ids']) >= 128:
            to_remove_list.append(bar8_path)
            val += 1
            # Break to avoid adding the same bar8_path multiple times
            break

print("8-bar sequences with 128 to remove", val)

# Remove the collected keys from the outer dictionary
for path_to_remove in to_remove_list:
    data_dict.pop(path_to_remove)

print("Size of 8-bar sequences", len(data_dict))

# Remove all midi fragments that are not exactly 8 bars
val = 0
to_remove_list = []
for bar8_path, bar8_dict in data_dict.items():
    for bar_path, bar_dict in bar8_dict.items():
        if len(bar8_dict) != 8:
            to_remove_list.append(bar8_path)
            val += 1
            break

print("not 8-bar sequences to remove", val)

# Remove the collected keys from the outer dictionary
for path_to_remove in to_remove_list:
    data_dict.pop(path_to_remove)

print("Size of 8-bar sequences", len(data_dict))

# Remove midi fragments without valence
val = 0
to_remove_list = []
for bar8_path, bar8_dict in data_dict.items():
    for bar_path, bar_dict in bar8_dict.items():
        if 'valence' not in bar_dict:
            to_remove_list.append(bar8_path)
            val += 1
            break

print("no valence sequences to remove", val)

# Remove the collected keys from the outer dictionary
for path_to_remove in to_remove_list:
    data_dict.pop(path_to_remove)

print("Size of 8-bar sequences with valence", len(data_dict))

8-bar sequences with 128 to remove 13595
Size of 8-bar sequences 55468
not 8-bar sequences to remove 4488
Size of 8-bar sequences 50980
no valence sequences to remove 0
Size of 8-bar sequences with valence 50980


In [None]:
# @title Format all the data for training EmotionWave { display-mode: "form" }
import torch

# Convert valence to int. 1 for negative, 2 for positive
def map_valence_to_int(valence_val, max_val):
    # 0 Value is used as padding
    if valence_val > .5:
        return int(2)
    return int(1)

vocab_dict = remi.get_vocab_dict()

pad_ids = vocab_dict['PAD_None']
bars_per_sample = 8
sequence_per_bar_length = 128
sequence_length = 1024
valence_range = 2

encoder_input_list = []
decoder_input_list = []
decoder_bar_position_list = []
decoder_target_list = []
padding_mask_list = []
valence_cls_list = []
valence_val_list = []

for bar8_path, bar8_dict in data_dict.items():
    # Initialise the tensor with the padding ids
    # Max size of elements in bar, number of bars
    encoder_input = torch.full((sequence_per_bar_length, bars_per_sample), pad_ids)
    # Max size of elements in a sequence
    decoder_input = torch.full((sequence_length,), pad_ids)
    # Max size of number of bars, fixed to 8 + 1 including start-end position
    decoder_bar_position = torch.zeros((bars_per_sample + 1,))
    # Number of bars, max size of elements per bar sequence
    padding_mask = torch.ones((bars_per_sample, sequence_per_bar_length), dtype=torch.bool)
    # Shape of Max size of elements in a sequence.
    valence_cls = torch.full((sequence_length,), pad_ids)
    # For statistics of the valence
    valence_val = torch.full((bars_per_sample,), pad_ids)

    offset_decoder_input = 0
    bar_seq_size = 0

    # Go through all the individual bars
    for z, (bar_path, bar_dict) in enumerate(bar8_dict.items()):
        ids_t = torch.tensor(bar_dict['ids'])
        ids_len = len(ids_t)
        encoder_input[0:ids_len, z] = ids_t

        padding_mask[z, 0:ids_len] = False

        decoder_input[offset_decoder_input: offset_decoder_input + ids_len] = ids_t
        valence_cls[offset_decoder_input: offset_decoder_input + ids_len] = map_valence_to_int(bar_dict['valence'], valence_range)
        valence_val[z] = map_valence_to_int(bar_dict['valence'], valence_range)
        offset_decoder_input += ids_len
        # Adding one because first starting position is 0, we have 8 ranges
        # that correspond to 9 positions
        decoder_bar_position[z + 1] = offset_decoder_input

        bar_seq_size += ids_len


    # Store in a list the encoder input tensor
    encoder_input_list.append(encoder_input)

    # Store in a list the decoder input tensor
    decoder_input_list.append(decoder_input)

    # Store the decoder target tensor in a list
    decoder_target = torch.cat((decoder_input[1:], torch.tensor([pad_ids])))
    decoder_target_list.append(decoder_target)

    # Store in a list the decoder bar start end position
    decoder_bar_position_list.append(decoder_bar_position)

    # Store the padding values per bar in a list
    padding_mask_list.append(padding_mask)

    # Store in a list the valence values per sequence
    valence_cls_list.append(valence_cls)

    # Store in a list the valence values without expansions for statistics
    valence_val_list.append(valence_val)

# Stack the tensors along a new dimension
encoder_input_tensor = torch.stack(encoder_input_list)
decoder_input_tensor = torch.stack(decoder_input_list)
decoder_target_tensor = torch.stack(decoder_target_list)
decoder_bar_position_tensor = torch.stack(decoder_bar_position_list)
padding_mask_tensor = torch.stack(padding_mask_list)
valence_cls_tensor = torch.stack(valence_cls_list)
valence_val_tensor = torch.stack(valence_val_list)

print("Encoder input", encoder_input_tensor.shape)
print("Decoder input", decoder_input_tensor.shape)
print("Decoder target", decoder_target_tensor.shape)
print("Decoder bar position", decoder_bar_position_tensor.shape)
print("Padding mask", padding_mask_tensor.shape)
# In valence per sequence the same valence value is replicated over the complete
# bar. Each bar might have a different valence. Valence is mapped to integers
# in the provided range
print("Valence per sequence", valence_cls_tensor.shape)
print("Valence statistics", valence_val_tensor.shape)

Encoder input torch.Size([50980, 128, 8])
Decoder input torch.Size([50980, 1024])
Decoder target torch.Size([50980, 1024])
Decoder bar position torch.Size([50980, 9])
Padding mask torch.Size([50980, 8, 128])
Valence per sequence torch.Size([50980, 1024])
Valence statistics torch.Size([50980, 8])


In [None]:
# @title Create train, test and validation sets { display-mode: "form" }

def print_dataset_shapes(dataset_dict, name):
    print(f"\n{name} dataset shapes:")
    for key, tensor in dataset_dict.items():
        print(f"  {key}: {tensor.shape}")

def shuffle_dataset(dataset):
    # Get the number of samples
    num_samples = next(iter(dataset.values())).shape[0]

    # Generate a random permutation of indices
    indices = torch.randperm(num_samples)

    # Create a new dictionary with shuffled data
    shuffled_dataset = {key: tensor[indices] for key, tensor in dataset.items()}

    return shuffled_dataset

num_files = encoder_input_tensor.shape[0]

# Create a dictionary of the dataset
emotionwave_dataset = {
    'encoder_input': encoder_input_tensor,
    'decoder_input': decoder_input_tensor,
    'decoder_target': decoder_target_tensor,
    'decoder_bar_position': decoder_bar_position_tensor,
    'padding_mask': padding_mask_tensor,
    'valence_cls': valence_cls_tensor,
}

emotionwave_dataset = shuffle_dataset(emotionwave_dataset)

# Calculate from which file the validation set starts
valid_start = round(0.8 * num_files)
# Calculate from which file the test set starts
test_start = round(0.9 * num_files)

# Create train, validation, and test dictionaries
train_data = {
    'encoder_input': encoder_input_tensor[:valid_start],
    'decoder_input': decoder_input_tensor[:valid_start],
    'decoder_target': decoder_target_tensor[:valid_start],
    'decoder_bar_position': decoder_bar_position_tensor[:valid_start],
    'padding_mask': padding_mask_tensor[:valid_start],
    'valence_cls': valence_cls_tensor[:valid_start],
}

valid_data = {
    'encoder_input': encoder_input_tensor[valid_start:test_start],
    'decoder_input': decoder_input_tensor[valid_start:test_start],
    'decoder_target': decoder_target_tensor[valid_start:test_start],
    'decoder_bar_position': decoder_bar_position_tensor[valid_start:test_start],
    'padding_mask': padding_mask_tensor[valid_start:test_start],
    'valence_cls': valence_cls_tensor[valid_start:test_start],
}

test_data = {
    'encoder_input': encoder_input_tensor[test_start:],
    'decoder_input': decoder_input_tensor[test_start:],
    'decoder_target': decoder_target_tensor[test_start:],
    'decoder_bar_position': decoder_bar_position_tensor[test_start:],
    'padding_mask': padding_mask_tensor[test_start:],
    'valence_cls': valence_cls_tensor[test_start:],
}

# Print information about train, validation and test
print_dataset_shapes(train_data, "Train")
print_dataset_shapes(valid_data, "Validation")
print_dataset_shapes(test_data, "Test")


Train dataset shapes:
  encoder_input: torch.Size([40784, 128, 8])
  decoder_input: torch.Size([40784, 1024])
  decoder_target: torch.Size([40784, 1024])
  decoder_bar_position: torch.Size([40784, 9])
  padding_mask: torch.Size([40784, 8, 128])
  valence_cls: torch.Size([40784, 1024])

Validation dataset shapes:
  encoder_input: torch.Size([5098, 128, 8])
  decoder_input: torch.Size([5098, 1024])
  decoder_target: torch.Size([5098, 1024])
  decoder_bar_position: torch.Size([5098, 9])
  padding_mask: torch.Size([5098, 8, 128])
  valence_cls: torch.Size([5098, 1024])

Test dataset shapes:
  encoder_input: torch.Size([5098, 128, 8])
  decoder_input: torch.Size([5098, 1024])
  decoder_target: torch.Size([5098, 1024])
  decoder_bar_position: torch.Size([5098, 9])
  padding_mask: torch.Size([5098, 8, 128])
  valence_cls: torch.Size([5098, 1024])


In [None]:
# @title Create loaders { display-mode: "form" }
from torch.utils.data import Dataset, DataLoader
import pickle

class DictDataset(Dataset):
    def __init__(self, tensor_dict):
        self.tensor_dict = tensor_dict
        self.keys = list(tensor_dict.keys())

    def __getitem__(self, index):
        return {key: tensor[index] for key, tensor in self.tensor_dict.items()}

    def __len__(self):
        return len(next(iter(self.tensor_dict.values())))

def create_dataloader(data_dict, batch_size, shuffle=False):
    # Convert all tensors to torch tensors if they aren't already
    tensor_dict = {k: torch.tensor(v) if not isinstance(v, torch.Tensor) else v for k, v in data_dict.items()}

    # Create a DictDataset
    dataset = DictDataset(tensor_dict)

    # Create and return a DataLoader
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

def print_dataset_info(data_dict, loader, name):
    print(f"\n{name} Dataset Information:")

    # Print full dataset shapes
    print("Full dataset shapes:")
    for key, tensor in data_dict.items():
        print(f"  {key}: {tensor.shape}")

    # Print shapes of the first batch
    for batch in loader:
        print("\nShapes of tensors in the first batch:")
        for key, tensor in batch.items():
            print(f"  {key}: {tensor.shape}")
        break  # We only want to see the first batch

    print(f"\nNumber of batches: {len(loader)}")
    print(f"Batch size: {loader.batch_size}")
    total_samples = len(loader.dataset)
    print(f"Total number of samples: {total_samples}")

# Size of each batch
batch_size = 4

# Create DataLoaders for each set
train_loader = create_dataloader(train_data, batch_size, shuffle=True)
valid_loader = create_dataloader(valid_data, batch_size, shuffle=False)
test_loader = create_dataloader(test_data, batch_size, shuffle=False)

# Print information for each dataset
print_dataset_info(train_data, train_loader, "Training")
print_dataset_info(valid_data, valid_loader, "Validation")
print_dataset_info(test_data, test_loader, "Test")


# Store datasets
#with open('./pickle_data/train_data.pkl', 'wb') as pickle_file:
#    pickle.dump(train_data, pickle_file)

#with open('./pickle_data/valid_data.pkl', 'wb') as pickle_file:
#    pickle.dump(valid_data, pickle_file)

#with open('./pickle_data/test_data.pkl', 'wb') as pickle_file:
#    pickle.dump(test_data, pickle_file)

#with open('./pickle_data/train_loader.pkl', 'wb') as pickle_file:
#    pickle.dump(train_loader, pickle_file)

#with open('./pickle_data/valid_loader.pkl', 'wb') as pickle_file:
#    pickle.dump(valid_loader, pickle_file)

#with open('./pickle_data/test_loader.pkl', 'wb') as pickle_file:
#    pickle.dump(test_loader, pickle_file)


Training Dataset Information:
Full dataset shapes:
  encoder_input: torch.Size([40784, 128, 8])
  decoder_input: torch.Size([40784, 1024])
  decoder_target: torch.Size([40784, 1024])
  decoder_bar_position: torch.Size([40784, 9])
  padding_mask: torch.Size([40784, 8, 128])
  valence_cls: torch.Size([40784, 1024])

Shapes of tensors in the first batch:
  encoder_input: torch.Size([4, 128, 8])
  decoder_input: torch.Size([4, 1024])
  decoder_target: torch.Size([4, 1024])
  decoder_bar_position: torch.Size([4, 9])
  padding_mask: torch.Size([4, 8, 128])
  valence_cls: torch.Size([4, 1024])

Number of batches: 10196
Batch size: 4
Total number of samples: 40784

Validation Dataset Information:
Full dataset shapes:
  encoder_input: torch.Size([5098, 128, 8])
  decoder_input: torch.Size([5098, 1024])
  decoder_target: torch.Size([5098, 1024])
  decoder_bar_position: torch.Size([5098, 9])
  padding_mask: torch.Size([5098, 8, 128])
  valence_cls: torch.Size([5098, 1024])

Shapes of tensors in 

## EmotionWave training
Time to train EmotionWave, log files will be created with the training information

In [None]:
# @title Training functions { display-mode: "form" }
import logging
import os
from logging.handlers import RotatingFileHandler
import numpy as np

# Get information about the tensor, debugging purposes
def print_tensor_info(ptensor):
    # Type of the tensor object
    print(type(ptensor))

    # Data type of the elements in the tensor
    print(ptensor.dtype)

    # Shape of the tensor
    print(ptensor.shape)

    # Device of the tensor (CPU or GPU)
    print(ptensor.device)

# Evaluate the model provided the loader and the number of rounds for the evaluation
def evaluate(model, data_loader, n_rounds=8, valence_cnd=True):
    model.eval()
    loss_recons = []
    kl_loss_recons = []

    with torch.no_grad():
        for i in range(n_rounds):

            # Go through all the data
            for batch_id, batch_inputs in enumerate(data_loader):
                model.zero_grad()

                # Put data in the GPU and accommodate the columns
                encoder_input = batch_inputs['encoder_input'].permute(1,0,2).to(device, dtype=torch.long)
                decoder_input = batch_inputs['decoder_input'].permute(1,0).to(device, dtype=torch.long)
                decoder_target = batch_inputs['decoder_target'].permute(1,0).to(device, dtype=torch.long)
                decoder_bar_position = batch_inputs['decoder_bar_position'].to(device, dtype=torch.int)
                padding_mask = batch_inputs['padding_mask'].to(device, dtype=torch.bool)

                if valence_cnd:
                    valence_cls = batch_inputs['valence_cls'].permute(1,0).to(device, dtype=torch.long)
                else:
                    valence_cls = None

                # Get the model outputs: mu (mean), logvar (log variance), and decoder logits
                mu, logvar, decoder_logits = model(encoder_input, decoder_input,
                                                decoder_bar_position, valence_cls,
                                                padding_mask=padding_mask)

                loss = model.compute_loss(mu, logvar, 0.0, 0.0, decoder_logits, decoder_target)

                loss_recons.append(loss['recons_loss'].item())
                kl_loss_recons.append(loss['kldiv_raw'].item())

    return loss_recons, kl_loss_recons

# Compute the exponential moving average loss
def compute_loss_ema(ema, batch_loss, decay=0.95):
    # If ema is zero (initial condition), return the batch loss
    if ema == 0.:
        return batch_loss
    else:
        # Compute the exponentially moving average (EMA) of the loss
        # EMA is updated as a weighted average of the current batch loss and the
        # previous EMA
        return batch_loss * (1 - decay) + ema * decay

# Compute beta cyclical scheduler
def beta_cyclical_schedulder(step, kl_cycle_steps=5000, no_kl_steps=10000,
                             kl_max_beta=1.0):
    # Calculate the current step within the KL cycle
    step_in_cycle = (step - 1) % kl_cycle_steps

    # Determine the progress within the current cycle as a fraction
    cycle_progress = step_in_cycle / kl_cycle_steps

    # If the current step is within the initial no-KL steps, return 0
    if step < no_kl_steps:
        return 0.

    # If the cycle progress is less than 0.5, calculate beta proportionally
    if cycle_progress < 0.5:
        return kl_max_beta * cycle_progress * 2.
    else:
        # If the cycle progress is greater than or equal to 0.5, return the maximum beta
        return kl_max_beta

# Train function, time to train the model!! Using log file to avoid issues with console output
def train(model, train_loader, valid_loader, optimizer, num_epochs, saved_model,
          log_file, evaluate_every_n_epochs=1, scheduler=None, constant_kl=False,
          lr_warmup_steps=1000, kl_max_beta=1.0, free_bit_lambda=0.25, max_lr=1.0e-4,
          verbose=False, valence_cnd=True):


    # Setting training logger
    handler = RotatingFileHandler(log_file, maxBytes=10**6, backupCount=5)
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    handler.setFormatter(formatter)

    # Remove any existing handlers
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    # Remove any existing handlers to avoid printing to console
    logger.handlers = []
    # Adding file handler
    logger.addHandler(handler)

    # Log the start of training
    logger.info('Training started')

    # Indicate that we are training the model, gradients will be calculated
    model.train()
    num_batches = len(train_loader)

    # Keep track of the best accuracy on the validation data
    best_valid_acc = 0.0

    # Keep count of trained steps
    trained_steps = 0

    # Initialize the exponentially moving averages (EMAs) for different loss
    # components. Initially set to 0, indicating no prior information
    # EMA for reconstruction loss
    recons_loss_ema = 0
    # EMA for KL divergence loss (weighted by beta)
    kl_loss_ema = 0
    # EMA for raw KL divergence loss (unweighted)
    kl_raw_ema = 0

    # Train for the provided number of epochs
    for epoch in range(num_epochs):
        # Go through all the data
        for batch_id, batch_inputs in enumerate(train_loader):
            # Put data in the GPU and accommodate the columns
            encoder_input = batch_inputs['encoder_input'].permute(1,0,2).to(device, dtype=torch.long)
            decoder_input = batch_inputs['decoder_input'].permute(1,0).to(device, dtype=torch.long)
            decoder_target = batch_inputs['decoder_target'].permute(1,0).to(device, dtype=torch.long)
            decoder_bar_position = batch_inputs['decoder_bar_position'].to(device, dtype=torch.int)
            padding_mask = batch_inputs['padding_mask'].to(device, dtype=torch.bool)

            if valence_cnd:
                valence_cls = batch_inputs['valence_cls'].permute(1,0).to(device, dtype=torch.long)
            else:
                valence_cls = None

            # Forward + backward + optimize
            # Get the model outputs: mu (mean), logvar (log variance), and decoder logits
            mu, logvar, decoder_logits = model(encoder_input, decoder_input,
                                               decoder_bar_position, valence_cls,
                                               padding_mask=padding_mask, verbose=verbose)

            # Increment the number of training steps
            trained_steps += 1

            # Determine the KL beta value
            if constant_kl:
                # If constant KL is used, set kl_beta to the maximum value
                kl_beta = kl_max_beta
            else:
                # Otherwise, use a cyclical scheduler to get the current kl_beta
                kl_beta = beta_cyclical_schedulder(trained_steps)

            # Compute the loss using the model's loss function
            loss = model.compute_loss(mu, logvar, kl_beta, free_bit_lambda,
                                      decoder_logits, decoder_target)

            # Adjust the learning rate based on the number of training steps
            if trained_steps < lr_warmup_steps:
                # During the warmup phase, linearly increase the learning rate
                curr_lr = max_lr * trained_steps / lr_warmup_steps
                # Set the current learning rate for the optimizer
                optimizer.param_groups[0]['lr'] = curr_lr
            else:
                # After the warmup phase, use the scheduler to adjust the learning rate
                scheduler.step()

            # Set gradients of model parameters to zero before propagating
            optimizer.zero_grad()
            # Propagate the loss
            loss['total_loss'].backward()

            # Clip the gradients to prevent exploding gradients
            # This limits the norm of the gradients to a maximum value of 0.5
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)

            # Update model parameters
            optimizer.step()

            recons_loss_ema = compute_loss_ema(recons_loss_ema, loss['recons_loss'].item())
            kl_loss_ema = compute_loss_ema(kl_loss_ema, loss['kldiv_loss'].item())
            kl_raw_ema = compute_loss_ema(kl_raw_ema, loss['kldiv_raw'].item())

            logger.info(f'kl_beta {kl_beta}, batch_id {batch_id}, trained_steps {trained_steps}')
            logger.info(f'loss_recons {recons_loss_ema}, kl_loss_ema {kl_loss_ema}, kl_raw_ema {kl_raw_ema}')


        logger.info(f'[{epoch+1}]')

        # Evaluate the network on the validation data
        if((epoch+1) % evaluate_every_n_epochs == 0):
            valloss = evaluate(model, valid_loader, valence_cnd=valence_cnd)
            logger.info(f'Validation loss: train_steps {trained_steps}, loss_recons {np.mean(valloss[0])}, kl_raw_ema {np.mean(valloss[1])}')
            model.train()

        # Save the model and optimizer to a file
        model_name = saved_model + '_' + str(epoch+1) + '_' + str(recons_loss_ema) + str(kl_raw_ema) + '.pt'
        optimizer_name = saved_model + '_' + str(epoch+1) + '_' + str(recons_loss_ema) + str(kl_raw_ema) + '_optim.pt'
        torch.save(model.state_dict(), model_name)
        torch.save(optimizer.state_dict(), optimizer_name)

    testloss = evaluate(model, test_loader, valence_cnd=valence_cnd)
    logger.info(f'Test loss: loss_recons {np.mean(testloss[0])}, kl_raw_ema {np.mean(testloss[1])}')

    logger.info('Training completed')

In [None]:
# @title Train the model and print general information about it. Check log file for progress { display-mode: "form" }
import EmotionWave.Model.EmotionWave as emw
import torch.optim.lr_scheduler as lr_scheduler
import gc
import numpy as np


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

gc.collect()

# Empty the CUDA cache to free up GPU memory
torch.cuda.empty_cache()

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

model_dir = 'Models'
log_dir = 'Logs'
log_file = 'log'

vocab_size = len(remi.get_vocab_dict())
print("Vocabulary size:", vocab_size)

valence_cls=True

# Model
emw_model = emw.EmotionWave(8,          # Encoder number of layers
                            512,        # Encoder dimensions
                            8,          # Encoder number of heads
                            2048,       # Encoder dimensions feedforward
                            128,        # Dimensions latent VAE
                            512,        # Dimensions embedding
                            vocab_size, # Number of embeddings
                            4,          # Decoder number of layers
                            512,        # Decoder dimensions
                            8,          # Decoder number of heads
                            2048,       # Valence number CLS
                            valence_dim_embeddings=64, # Valence embedding dimensions
                            valence_num_cls=3, # Valence vocabulary
                            valence_cls=valence_cls
                            )

print(emw_model)
# Get total number of parameters
total_params = count_parameters(emw_model)
print(f"Total trainable parameters: {total_params}")
emw_model.to(device)

# Optimizer
optimizer = torch.optim.Adam(emw_model.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-08)
# Scheduler
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, 70000, eta_min=5.0e-6)

# Training epochs
num_epochs = 20

if not os.path.exists(model_dir):
    os.makedirs(model_dir)

if not os.path.exists(log_dir):
    os.makedirs(log_dir)

#pretrained_params_path = './Models_20_0.72904897893097350.199615425882635.pt'
#emw_model.load_state_dict(torch.load(pretrained_params_path))

#optimizer_path = './Models_20_0.72904897893097350.199615425882635_optim.pt'
#optimizer.load_state_dict(torch.load(optimizer_path))

# How often the network will be evaluated during training
evaluate_every_n_epochs = 5

train(emw_model, train_loader, valid_loader, optimizer, num_epochs, model_dir, log_file,
      evaluate_every_n_epochs=evaluate_every_n_epochs, scheduler=scheduler, constant_kl=False,
      lr_warmup_steps=200, kl_max_beta=1.0, free_bit_lambda=0.25, max_lr=1.0e-4, verbose=False, valence_cnd=valence_cls)

cuda:0
Vocabulary size: 267
EmotionWave(
  (input_embedding): EmbeddingWithProjection(
    (embedding_lookup): Embedding(267, 512)
  )
  (positional_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): VAETransformerEncoder(
    (transformer_encoder_layer): TransformerEncoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
      )
      (linear1): Linear(in_features=512, out_features=2048, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (linear2): Linear(in_features=2048, out_features=512, bias=True)
      (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (dropout1): Dropout(p=0.1, inplace=False)
      (dropout2): Dropout(p=0.1, inplace=False)
    )
    (transformer_encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-7): 8 x Transforme