In [1]:
import os
import gc
import math

import pandas as pd
import numpy as np
import miditok
from miditok import REMI, TokenizerConfig
import librosa
import matplotlib.pyplot as plt
import seaborn as sns
from symusic import Score

from pathlib import Path

In [3]:
import torch
from torch import nn
from torch.utils.data import Dataset, random_split
import torchaudio

![dataset](images/Dataset.png)

In [4]:
class ArrangerEmbedding(nn.Module):
  def __init__(self, arranger_ids=256, hidden_size=128):
    super().__init__()
    self.embeddings = nn.Embedding(arranger_ids, hidden_size)

  def forward(self, arranger_id, mel_db):
    return torch.cat([self.embeddings(arranger_id), mel_db], dim=-2)

In [9]:
def preprocess_data(tokenizer, sample_rate=22050, music_folder="song2midi/dataset/song", midi_folder="song2midi/dataset/midi", output_folder="song2midi/dataset/preprocess"):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    if not os.path.exists(os.path.join(output_folder, "mel")):
        os.makedirs(os.path.join(output_folder, "mel"))

    if not os.path.exists(os.path.join(output_folder, "midi")):
        os.makedirs(os.path.join(output_folder, "midi"))
                    
    music_ids = [name.split(".")[0] for name in os.listdir(music_folder)]
    uncheck_midis = [name.split(".")[:2] for name in os.listdir(midi_folder)]
    midis = [midi for midi in uncheck_midis if midi[0] in music_ids] # [music_id, arranger_id]

    skipped_music = []

    for music_id in music_ids:
        music_path = os.path.join(music_folder, f"{music_id}.mp3")
        waveform, sr = torchaudio.load(music_path)
        waveform = torchaudio.transforms.Resample(sr, sample_rate)(waveform)
        waveform = torch.mean(waveform, dim=0, keepdim=True)

        mel_transform = torchaudio.transforms.MelSpectrogram(sample_rate=sr, n_fft=4096, hop_length=1024, n_mels=128)
        mel = mel_transform(waveform)
        mel_db = torchaudio.transforms.AmplitudeToDB()(mel)

        mel_shape = mel_db.shape
        if mel_shape[2] > 6144: # have to skip this music
            print(f"music id: {music_id} is too long skipping")
            skipped_music.append(music_id)
            continue

        mel_db = mel_db.reshape(mel_shape[0], mel_shape[2], mel_shape[1])

        np.save(os.path.join(output_folder, "mel", f"{music_id}.mel.npy"), mel_db.numpy())

    for music_id, arranger_id in midis:
        if music_id in skipped_music:
            continue

        midi_path = os.path.join(midi_folder, f"{music_id}.{arranger_id}.mid")
        score = Score(midi_path)
        miditok.utils.merge_same_program_tracks(score.tracks)
        midi_encoded = tokenizer.encode(score)
        np.save(os.path.join(output_folder, "midi", f"{music_id}.{arranger_id}.midi.npy"), midi_encoded)


In [16]:
def mel_interpolate(mel):
    mel_shape = mel.shape[1]
    return nn.functional.interpolate(mel.unsqueeze(1), size=(math.floor(mel_shape / 4), 128), mode='bilinear', align_corners=False).squeeze(1)

In [5]:
class MusicDataset(Dataset):
    def __init__(self, tokenizer, embbeding, max_output_length=4096, max_input_length=6144, mel_processing=None, preprocess_folder="song2midi/dataset/preprocess"):
        self.tokenizer = tokenizer
        self.embbeding = embbeding
        self.max_output_length = max_output_length
        self.max_input_length = max_input_length
        self.preprocess_folder = preprocess_folder
        self.mel_processing = mel_processing
        
        self.data_files = [name.split(".")[:2] for name in os.listdir(os.path.join(preprocess_folder, "midi"))]

    def __len__(self):
        return len(self.data_files)
    
    def __getitem__(self, idx):
        music_id, arranger_id = self.data_files[idx]

        # load precomputed mel'
        mel_db = np.load(os.path.join(self.preprocess_folder, "mel", f"{music_id}.mel.npy"))
        mel_db = torch.tensor(mel_db)

        if self.mel_processing:
            mel_db = self.mel_processing(mel_db)

        if mel_db.shape[1] < self.max_input_length - 1:
            num_pad = self.max_input_length - mel_db.shape[1] - 1
            mel_padded = torch.cat([mel_db, torch.zeros((1, num_pad, mel_db.shape[2]))], dim=1)
        else:
            num_pad = 0
            mel_padded = mel_db[:, :self.max_input_length - 1]

        input_embed = self.embbeding(torch.tensor([[int(arranger_id)]]), mel_padded)
        attention_mask = torch.cat([torch.ones(mel_db.shape[:2], dtype=torch.int32), torch.zeros((mel_db.shape[0], num_pad + 1))], dim=1)

        global_attention_mask = torch.zeros(input_embed.shape[:2], dtype=torch.int32)
        global_attention_mask[:, 0] = 1

        # load precomputed midi
        midi_encoded = np.load(os.path.join(self.preprocess_folder, "midi", f"{music_id}.{arranger_id}.midi.npy"))
        midi = torch.tensor(midi_encoded, dtype=torch.int32)
        midi = torch.cat([torch.tensor([[self.tokenizer["BOS_None"]]]), midi, torch.tensor([[self.tokenizer["EOS_None"]]])], dim=-1)

        # padding with -100
        if midi.shape[1] < self.max_output_length:
            midi = torch.cat([midi, torch.ones((midi.shape[0], self.max_output_length - midi.shape[1]), dtype=torch.int32) * -100], dim=-1)
        else:
            midi = midi[:, :self.max_output_length]

        return {"inputs_embeds": input_embed.squeeze(0), 
                "attention_mask": attention_mask.squeeze(0), 
                #"global_attention_mask": global_attention_mask.squeeze(0), 
                "labels": midi.squeeze(0)}


# tokenizer

In [20]:
TOKENIZER_PARAMS = {
    "pitch_range": (0, 127),
    "beat_res": {(0, 4): 8, (4, 12): 4},
    "num_velocities": 32,
    "special_tokens": ["PAD", "BOS", "EOS", "MASK"],
    "use_chords": True,
    "use_rests": False,
    "use_tempos": True,
    "use_time_signatures": False,
    "use_programs": False,
    "num_tempos": 32,  # number of tempo bins
    "tempo_range": (40, 250),  # (min, max)
}
remi_config = TokenizerConfig(**TOKENIZER_PARAMS)

  remi_config = TokenizerConfig(**TOKENIZER_PARAMS)


In [21]:
plain_tokenizer = REMI(remi_config)

In [12]:
tokenizer = REMI(remi_config)

In [6]:
tokenizer = REMI(params="song2midi/dataset/tokenizer.json")

  self.config = TokenizerConfig()
  return cls(**input_dict, **kwargs)


In [14]:
midi_paths = list(Path("song2midi", "dataset", "midi_blend").glob("*.mid"))

In [15]:
tokenizer.train(vocab_size=50265, files_paths=midi_paths)






In [16]:
tokenizer.save_params("song2midi/dataset/tokenizer.json")

# preprocessing

In [17]:
preprocess_data(tokenizer, music_folder="song2midi/dataset/song_blend", midi_folder="song2midi/dataset/midi_blend")

music id: -7AMMFFiGLU is too long skipping


music id: 7Y-XlFtao_A is too long skipping
music id: bRBn5EY505E is too long skipping
music id: C5tXGwUCD1c is too long skipping
music id: cBC6c99OrOM is too long skipping
music id: CPLK6L1fq7k is too long skipping
music id: f3YhK-oZmY8 is too long skipping
music id: gwOGlPU63Dw is too long skipping
music id: HKtxkk3JvbQ is too long skipping
music id: hTbeVXuWyaU is too long skipping
music id: k0J1z82kVYQ is too long skipping
music id: KpS-9OcYyqk is too long skipping
music id: LEJlWRpqy2U is too long skipping
music id: lzyl7abaPe0 is too long skipping
music id: nFG3l5zxLdM is too long skipping
music id: Pa4EGtUfDvI is too long skipping
music id: r1k3hpTdc60 is too long skipping
music id: RcQy61-7F7c is too long skipping
music id: S8HAUbBLCPM is too long skipping
music id: T4yW6TwjNLM is too long skipping
music id: VnWo9-Dioik is too long skipping
music id: XbZHt8b7YTY is too long skipping
music id: XKROtTZs8iE is too long skipping
music id: Y0ygWXF-5tI is too long skipping


# dataset

In [7]:
dataset = MusicDataset(tokenizer, ArrangerEmbedding(), max_output_length=4096, max_input_length=6144, mel_processing=None)

In [20]:
for a in dataset:
    if a["inputs_embeds"].shape[0] != 6144 or a["attention_mask"].shape[0] != 6144 or a["labels"].shape[0] != 4096:
        print(a["inputs_embeds"].shape, a["attention_mask"].shape, a["labels"].shape)

In [8]:
train_size = int(0.8 * len(dataset))  # 80% for training
val_size = len(dataset) - train_size  # 20% for validation

In [9]:
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Model

In [10]:
from transformers import PretrainedConfig, PreTrainedModel

In [11]:
from reformer_pytorch import ReformerLM, Reformer
from axial_positional_embedding import AxialPositionalEmbedding

In [12]:
class ReformerEncoderDecoderConfig(PretrainedConfig):
    def __init__(self,
                  vocab_size=tokenizer.vocab_size, 
                  d_model=128,
                  num_heads=8, 
                  encoder_layers=6, 
                  decoder_layers=6, 
                  encoder_max_seq_len=6144,
                  decoder_max_seq_len=4096,
                  encoder_axial_position_shape=(96, 64),
                  decoder_axial_position_shape=(64, 64),
                  pad_token_id=0,
                  bos_token_id=1,
                  eos_token_id=2,
                  **kwargs):
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.num_heads = num_heads
        self.encoder_layers = encoder_layers
        self.decoder_layers = decoder_layers
        self.encoder_max_seq_len = encoder_max_seq_len
        self.decoder_max_seq_len = decoder_max_seq_len
        self.encoder_axial_position_shape = encoder_axial_position_shape
        self.decoder_axial_position_shape = decoder_axial_position_shape
        super().__init__(**kwargs)
        self.pad_token_id = pad_token_id
        self.bos_token_id = bos_token_id
        self.eos_token_id = eos_token_id

In [14]:
class ReformerEncoderDecoder(PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.pad_token_id = config.pad_token_id
        self.bos_token_id = config.bos_token_id
        self.eos_token_id = config.eos_token_id

        self.encoder = Reformer(
            dim=config.d_model,
            depth=config.encoder_layers,
            heads=config.num_heads,
        )

        self.decoder = ReformerLM(
            dim=config.d_model,
            depth=config.decoder_layers,
            heads=config.num_heads,
            max_seq_len=config.decoder_max_seq_len,
            num_tokens=config.vocab_size,
            axial_position_emb=True,
            axial_position_shape=config.decoder_axial_position_shape,
            causal=True
        )

        self.position_embedding = AxialPositionalEmbedding(
            config.d_model,
            axial_shape=config.encoder_axial_position_shape
        )

    # https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/autopadder.py
    def pad_to_multiple(self, tensor, seq_len, multiple, dim=-1):
        m = seq_len / multiple
        if m.is_integer():
            return tensor
        
        remainder = math.ceil(m) * multiple - seq_len
        pad_offset = (0,) * (-1 - dim) * 2
        return nn.functional.pad(tensor, (*pad_offset, 0, remainder), value=self.pad_token_id)

    # https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/autopadder.py
    # pad_dim = -1 if its LM model else -2
    def auto_paddding(self, input_ids, pad_dim, bucket_size, num_mem_kv, full_attn_thres, keys=None, input_mask=None, input_attn_mask=None):
        device = input_ids.device

        batch_size, t = input_ids.shape[:2]

        keys_len = 0 if keys is None else keys.shape[1]
        seq_len = t + num_mem_kv + keys_len
        

        if seq_len > full_attn_thres:
            if input_mask is None:
                input_mask = torch.full((batch_size, t), True, dtype=torch.bool, device=device)

            input_ids = self.pad_to_multiple(input_ids, seq_len, bucket_size * 2, dim=pad_dim)

            if input_mask is not None:
                input_mask = nn.functional.pad(input_mask, (0, input_ids.shape[1] - input_mask.shape[1]), value=False)

            if input_attn_mask is not None:
                offset = input_ids.shape[1] - input_attn_mask.shape[1]
                input_attn_mask = nn.functional.pad(input_attn_mask, (0, offset, 0, offset), value=False)

        return input_ids, input_mask, input_attn_mask


    def shift_tokens_right(self, input_ids):
        shifted_input_ids = input_ids.new_zeros(input_ids.shape)
        shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
        shifted_input_ids[:, 0] = self.eos_token_id

        if self.pad_token_id is None:
            raise ValueError("config.pad_token_id has to be defined.")
        # replace possible -100 values in labels by `pad_token_id`
        shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)

        return shifted_input_ids


    def forward(self, inputs_embeds, attention_mask=None, decoder_input=None, labels=None):
        if decoder_input is None:
            decoder_input = self.shift_tokens_right(labels)

        # encoder
        encoder_input = inputs_embeds + self.position_embedding(inputs_embeds)

        encoder_output = self.encoder(encoder_input, input_mask=attention_mask.bool())

        # decoder
        decoder_output = self.decoder(decoder_input, keys=encoder_output)

        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            masked_lm_loss = loss_fct(decoder_output.view(-1, self.config.vocab_size), labels.view(-1))
            return {"loss": masked_lm_loss, "logits": decoder_output}
        
        return {"logits": decoder_output}


    @torch.no_grad()
    def generate(self, inputs_embeds, attention_mask=None, max_length=4096, temperature=1.0, top_k=50, top_p=1):
        is_training = self.training
        device = inputs_embeds.device

        # padding settings
        pad_dim = -1
        bucket_size = self.decoder.reformer.bucket_size
        num_mem_kv = self.decoder.reformer.num_mem_kv
        full_attn_thres = self.decoder.reformer.full_attn_thres

        self.eval()

        # encoder
        encoder_input = inputs_embeds + self.position_embedding(inputs_embeds)

        encoder_keys = self.encoder(encoder_input, input_mask=attention_mask.bool())

        # decoder
        generated = torch.tensor([self.bos_token_id], device=device).unsqueeze(0)

        decoder_mask = torch.full_like(generated, True, dtype=torch.bool, device=device)

        for _ in range(max_length):
            generated = generated[:, -self.config.decoder_max_seq_len:]
            decoder_mask = decoder_mask[:, -self.config.decoder_max_seq_len:]

            generated, decoder_mask, _ = self.auto_paddding(generated, 
                                                             pad_dim, 
                                                             bucket_size, 
                                                             num_mem_kv, 
                                                             full_attn_thres, 
                                                             keys=encoder_keys, 
                                                             input_mask=decoder_mask)
            
            logits = self.decoder(generated, input_mask=decoder_mask, keys=encoder_keys)[:, -1, :]  / temperature

            if top_k > 0:
                top_k_values, top_k_indices = torch.topk(logits, top_k)
                filtered_logits = torch.full_like(logits, -float('Inf'))
                logits = filtered_logits.scatter(1, top_k_indices, top_k_values)

            if top_p < 1.0:
                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                cumulative_probs = torch.cumsum(nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
                
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
                sorted_indices_to_remove[:, 0] = 0

                sorted_logits[sorted_indices_to_remove] = -float('Inf')
                logits = sorted_logits.scatter(1, sorted_indices, sorted_logits)

            probs = nn.functional.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            generated = torch.cat([generated, next_token], dim=-1)

            if next_token == self.eos_token_id:
                break

        self.train(is_training)
        return generated





## Absolute pos No use

In [65]:
class AbsolutePositionalEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len):
        super().__init__()
        self.emb = nn.Embedding(max_seq_len, dim)

    def forward(self, x):
        t = torch.arange(x.shape[1], device=x.device)
        return self.emb(t)

In [66]:
class ReformerEncoderDecoder(PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.pad_token_id = config.pad_token_id
        self.bos_token_id = config.bos_token_id
        self.eos_token_id = config.eos_token_id

        self.projection = nn.Linear(128, config.d_model, bias=False)

        self.encoder = Reformer(
            dim=config.d_model,
            depth=config.encoder_layers,
            heads=config.num_heads,
        )

        self.decoder = ReformerLM(
            dim=config.d_model,
            depth=config.decoder_layers,
            heads=config.num_heads,
            max_seq_len=config.decoder_max_seq_len,
            num_tokens=config.vocab_size,
            absolute_position_emb=True,
            causal=True
        )

        self.position_embedding = AbsolutePositionalEmbedding(
            config.d_model,
            config.encoder_max_seq_len
        )

    # https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/autopadder.py
    def pad_to_multiple(self, tensor, seq_len, multiple, dim=-1):
        m = seq_len / multiple
        if m.is_integer():
            return tensor
        
        remainder = math.ceil(m) * multiple - seq_len
        pad_offset = (0,) * (-1 - dim) * 2
        return nn.functional.pad(tensor, (*pad_offset, 0, remainder), value=self.pad_token_id)

    # https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/autopadder.py
    # pad_dim = -1 if its LM model else -2
    def auto_paddding(self, input_ids, pad_dim, bucket_size, num_mem_kv, full_attn_thres, keys=None, input_mask=None, input_attn_mask=None):
        device = input_ids.device

        batch_size, t = input_ids.shape[:2]

        keys_len = 0 if keys is None else keys.shape[1]
        seq_len = t + num_mem_kv + keys_len
        

        if seq_len > full_attn_thres:
            if input_mask is None:
                input_mask = torch.full((batch_size, t), True, dtype=torch.bool, device=device)

            input_ids = self.pad_to_multiple(input_ids, seq_len, bucket_size * 2, dim=pad_dim)

            if input_mask is not None:
                input_mask = nn.functional.pad(input_mask, (0, input_ids.shape[1] - input_mask.shape[1]), value=False)

            if input_attn_mask is not None:
                offset = input_ids.shape[1] - input_attn_mask.shape[1]
                input_attn_mask = nn.functional.pad(input_attn_mask, (0, offset, 0, offset), value=False)

        return input_ids, input_mask, input_attn_mask


    def shift_tokens_right(self, input_ids):
        shifted_input_ids = input_ids.new_zeros(input_ids.shape)
        shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
        shifted_input_ids[:, 0] = self.eos_token_id

        if self.pad_token_id is None:
            raise ValueError("config.pad_token_id has to be defined.")
        # replace possible -100 values in labels by `pad_token_id`
        shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)

        return shifted_input_ids


    def forward(self, inputs_embeds, attention_mask=None, decoder_input=None, labels=None):
        if decoder_input is None:
            decoder_input = self.shift_tokens_right(labels)

        # encoder
        projected_input = self.projection(inputs_embeds)

        encoder_input = projected_input + self.position_embedding(inputs_embeds)

        encoder_output = self.encoder(encoder_input, input_mask=attention_mask.bool())

        # decoder
        decoder_output = self.decoder(decoder_input, keys=encoder_output)

        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            masked_lm_loss = loss_fct(decoder_output.view(-1, self.config.vocab_size), labels.view(-1))
            return {"loss": masked_lm_loss, "logits": decoder_output}
        
        return {"logits": decoder_output}


    @torch.no_grad()
    def generate(self, inputs_embeds, attention_mask=None, max_length=4096, temperature=1.0, top_k=50, top_p=1):
        is_training = self.training
        device = inputs_embeds.device

        # padding settings
        pad_dim = -1
        bucket_size = self.decoder.reformer.bucket_size
        num_mem_kv = self.decoder.reformer.num_mem_kv
        full_attn_thres = self.decoder.reformer.full_attn_thres

        self.eval()

        # encoder
        projected_input = self.projection(inputs_embeds)

        encoder_input = projected_input + self.position_embedding(inputs_embeds)

        encoder_keys = self.encoder(encoder_input, input_mask=attention_mask.bool())

        # decoder
        generated = torch.tensor([self.bos_token_id], device=device).unsqueeze(0)

        decoder_mask = torch.full_like(generated, True, dtype=torch.bool, device=device)

        for _ in range(max_length):
            generated = generated[:, -self.config.decoder_max_seq_len:]
            decoder_mask = decoder_mask[:, -self.config.decoder_max_seq_len:]

            generated, decoder_mask, _ = self.auto_paddding(generated, 
                                                             pad_dim, 
                                                             bucket_size, 
                                                             num_mem_kv, 
                                                             full_attn_thres, 
                                                             keys=encoder_keys, 
                                                             input_mask=decoder_mask)
            
            logits = self.decoder(generated, input_mask=decoder_mask, keys=encoder_keys)[:, -1, :]  / temperature

            if top_k > 0:
                top_k_values, top_k_indices = torch.topk(logits, top_k)
                filtered_logits = torch.full_like(logits, -float('Inf'))
                logits = filtered_logits.scatter(1, top_k_indices, top_k_values)

            if top_p < 1.0:
                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                cumulative_probs = torch.cumsum(nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
                
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
                sorted_indices_to_remove[:, 0] = 0

                sorted_logits[sorted_indices_to_remove] = -float('Inf')
                logits = sorted_logits.scatter(1, sorted_indices, sorted_logits)

            probs = nn.functional.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            generated = torch.cat([generated, next_token], dim=-1)

            if next_token == self.eos_token_id:
                break

        self.train(is_training)
        return generated





## training

In [15]:
RedConfig = ReformerEncoderDecoderConfig()

model = ReformerEncoderDecoder(RedConfig).cuda()

In [19]:
a = model(inputs_embeds=torch.rand((1, 6144, 128)).cuda(), attention_mask=torch.zeros(1, 6144).bool().cuda(), labels=torch.randint(0, 50265, (1, 4096)).cuda())

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [16]:
model.load_state_dict(torch.load("song2midi/model.pth"))

<All keys matched successfully>

In [16]:
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm

In [19]:
model

ReformerEncoderDecoder(
  (encoder): Reformer(
    (layers): ReversibleSequence(
      (blocks): ModuleList(
        (0-5): 6 x ReversibleBlock(
          (f): Deterministic(
            (net): PreNorm(
              (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
              (fn): LSHSelfAttention(
                (toqk): Linear(in_features=128, out_features=128, bias=False)
                (tov): Linear(in_features=128, out_features=128, bias=False)
                (to_out): Linear(in_features=128, out_features=128, bias=True)
                (lsh_attn): LSHAttention(
                  (dropout): Dropout(p=0.0, inplace=False)
                  (dropout_for_hash): Dropout(p=0.0, inplace=False)
                )
                (full_attn): FullQKAttention(
                  (dropout): Dropout(p=0.0, inplace=False)
                )
                (post_attn_dropout): Dropout(p=0.0, inplace=False)
                (local_attn): LocalAttention(
                  (dropout

In [22]:
dataloader_train = DataLoader(dataset, batch_size=4, shuffle=True)
dataloader_val = DataLoader(val_dataset, batch_size=4, shuffle=False)

num_epochs = 100
 
optimizer = AdamW(model.parameters(), lr=0.0003125)
scheduler = ReduceLROnPlateau(optimizer, "min", patience=5, factor=0.3125, threshold=0.01)

old_loss = 1000000
time = 3

for epoch in range(num_epochs):

    # training
    train_loss = 0
    model.train()
    progess_bar = tqdm(dataloader_train)
    for batch in progess_bar:
        optimizer.zero_grad()
        outp = model(
            inputs_embeds=batch["inputs_embeds"].cuda(),
            attention_mask=batch['attention_mask'].to('cuda'),
            labels=batch["labels"].cuda()
            )
        
        loss = outp["loss"]
        loss.backward()

        progess_bar.set_description(f"lr: {scheduler.get_last_lr()[0]} training epoch: {epoch+1} loss: {loss.item()}")

        train_loss += loss.item()
        optimizer.step()

    avg_train_loss = train_loss / len(dataloader_train)
    print(f"training epoch: {epoch+1} avg_loss: {avg_train_loss}")

    # evaluation
    val_loss = 0
    tok_error = 0
    model.eval()
    progess_bar = tqdm(dataloader_val)
    for batch in progess_bar:
        with torch.no_grad():
            outp = model(
                inputs_embeds=batch["inputs_embeds"].cuda(),
                attention_mask=batch['attention_mask'].to('cuda'),
                labels=batch["labels"].cuda()
                )
            
            loss = outp["loss"]
            val_loss += loss.item()

            tok_err = torch.mean(torch.tensor(tokenizer.tokens_errors(torch.argmax(outp["logits"], dim=-1).cpu()))).item()
            tok_error += tok_err

            progess_bar.set_description(f"validation epoch: {epoch+1} loss: {loss.item()} tok_error: {tok_err}")

    avg_val_loss = val_loss / len(dataloader_val)
    avg_tok_error = tok_error / len(dataloader_val)
    print(f"validation epoch: {epoch+1} avg_loss: {avg_val_loss} avg_tok_error: {avg_tok_error}")

    scheduler.step(avg_val_loss)

    if avg_val_loss < old_loss:
        old_loss = avg_val_loss
        torch.save(model.state_dict(), f"song2midi/check/reformer_encoder_decoder_{time}_{epoch+1}.pth")


lr: 0.0003125 training epoch: 1 loss: 3.001537322998047: 100%|████████████| 200/200 [13:25<00:00,  4.03s/it]


training epoch: 1 avg_loss: 3.292593777179718


validation epoch: 1 loss: 2.7438926696777344 tok_error: 0.1187744140625: 100%|█| 40/40 [00:21<00:00,  1.87it


validation epoch: 1 avg_loss: 3.214534121751785 avg_tok_error: 0.136553955078125


lr: 0.0003125 training epoch: 2 loss: 3.1599314212799072: 100%|███████████| 200/200 [13:14<00:00,  3.97s/it]


training epoch: 2 avg_loss: 3.263686016201973


validation epoch: 2 loss: 2.758176326751709 tok_error: 0.09735107421875: 100%|█| 40/40 [00:21<00:00,  1.90it


validation epoch: 2 avg_loss: 3.206604093313217 avg_tok_error: 0.116510009765625


lr: 0.0003125 training epoch: 3 loss: 3.634474992752075: 100%|████████████| 200/200 [13:13<00:00,  3.97s/it]


training epoch: 3 avg_loss: 3.2527723687887193


validation epoch: 3 loss: 2.7613515853881836 tok_error: 0.114013671875: 100%|█| 40/40 [00:21<00:00,  1.89it/


validation epoch: 3 avg_loss: 3.1994731426239014 avg_tok_error: 0.13108978271484376


lr: 0.0003125 training epoch: 4 loss: 3.4382355213165283: 100%|███████████| 200/200 [12:39<00:00,  3.80s/it]


training epoch: 4 avg_loss: 3.251042654514313


validation epoch: 4 loss: 2.770627975463867 tok_error: 0.1146240234375: 100%|█| 40/40 [00:20<00:00,  1.98it/


validation epoch: 4 avg_loss: 3.1884994506835938 avg_tok_error: 0.13695068359375


lr: 0.0003125 training epoch: 5 loss: 2.9310972690582275: 100%|███████████| 200/200 [12:38<00:00,  3.79s/it]


training epoch: 5 avg_loss: 3.2381991171836852


validation epoch: 5 loss: 2.7556653022766113 tok_error: 0.1142578125: 100%|█| 40/40 [00:20<00:00,  1.97it/s]


validation epoch: 5 avg_loss: 3.179256671667099 avg_tok_error: 0.1314971923828125


lr: 0.0003125 training epoch: 6 loss: 3.547208786010742:  94%|███████████▎| 189/200 [11:57<00:41,  3.79s/it]

# Compute Score

In [61]:
import torcheval

In [78]:
def bleu_no_padding(labels, logits, pad_index=-100):
    avg_score = 0
    num = 0
    for label, logit in zip(labels, logits):
        #output = logit.argmax(dim=-1)
        output = logit

        valid = label != pad_index
        bpe_label = label[valid].unsqueeze(0).numpy()
        bpe_output = output[valid].unsqueeze(0).numpy()

        #print(np.array(plain_tokenizer.encode(tokenizer.decode(bpe_label))).tolist()[0])
        no_bpe_label = " ".join([str(num) for num in np.array(plain_tokenizer.encode(tokenizer.decode(bpe_label))).tolist()[0]])
        no_bpe_output = " ".join([str(num) for num in np.array(plain_tokenizer.encode(tokenizer.decode(bpe_output))).tolist()[0]])

        avg_score += torcheval.metrics.functional.bleu_score(no_bpe_output, [no_bpe_label]).item()
        num += 1

    return avg_score / num

In [83]:
# compute cosine similarity between two labels

avg_acc = 0
for batch in DataLoader(val_dataset, batch_size=4, shuffle=False):
    with torch.no_grad():
        outp = model(
            inputs_embeds=batch["inputs_embeds"].cuda(),
            attention_mask=batch['attention_mask'].to('cuda'),
            labels=batch["labels"].cuda()
            )

        avg_acc += bleu_no_padding(batch["labels"], torch.argmax(outp["logits"], dim=-1).cpu())


print("avg_bleu_score", avg_acc / len(val_dataset))

avg_bleu_score 0.10628176755271852
