# Full example with the Hugging Face Transformers package

This notebook shows how to train a model (Mistral) and generate music with it, using the Hugging Face Transformers package.

## Setup Environment

In [1]:
#@title Install all dependencies (run only once per session)

# !nvidia-smi

# !pip install miditok
# !pip install symusic
# !pip install torch
# !pip install torchtoolkit
# !pip install transformers
# !pip install accelerate
# !pip install evaluate
# !pip install tqdm
# !pip install scikit-learn
# !pip install tensorboard

# !wget https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip
# !unzip 'maestro-v3.0.0-midi.zip'
# !rm 'maestro-v3.0.0-midi.zip'
# !mv 'maestro-v3.0.0' 'Maestro'

from copy import deepcopy
from pathlib import Path
from random import shuffle

from torch import Tensor, argmax
from torch.utils.data import DataLoader
from torch.cuda import is_available as cuda_available, is_bf16_supported
from torch.backends.mps import is_available as mps_available
from transformers import AutoModelForCausalLM, MistralConfig, Trainer, TrainingArguments, GenerationConfig, MambaConfig, MambaForCausalLM
from transformers.trainer_utils import set_seed
from evaluate import load as load_metric
from miditok import REMI, TokenizerConfig
from miditok.pytorch_data import DatasetTok, DataCollator
from tqdm import tqdm

## Tokenizer initialization and dataset loading

In [2]:
# Seed
set_seed(777)

# Our tokenizer's configuration
PITCH_RANGE = (21, 109)
BEAT_RES = {(0, 1): 8, (1, 2): 4, (2, 4): 2, (4, 8): 1}
NUM_VELOCITIES = 24
SPECIAL_TOKENS = ["PAD", "MASK", "BOS", "EOS"]
USE_CHORDS = False
USE_RESTS = False
USE_TEMPOS = True
USE_TIME_SIGNATURE = False
USE_PROGRAMS = False
NUM_TEMPOS = 32
TEMPO_RANGE = (50, 200)  # (min_tempo, max_tempo)
TOKENIZER_PARAMS = {
    "pitch_range": PITCH_RANGE,
    "beat_res": BEAT_RES,
    "num_velocities": NUM_VELOCITIES,
    "special_tokens": SPECIAL_TOKENS,
    "use_chords": USE_CHORDS,
    "use_rests": USE_RESTS,
    "use_tempos": USE_TEMPOS,
    "use_time_signatures": USE_TIME_SIGNATURE,
    "use_programs": USE_PROGRAMS,
    "num_tempos": NUM_TEMPOS,
    "tempo_range": TEMPO_RANGE,
}
config = TokenizerConfig(**TOKENIZER_PARAMS)

# Creates the tokenizer
# tokenizer = REMI(config)
tokenizer = REMI(params='./tokenizer.json')

# Trains the tokenizer with Byte Pair Encoding (BPE) to build the vocabulary, here 10k tokens
midi_paths = list(Path('../data/maestro-v3.0.0').glob('**/*.mid')) + list(Path('../data/maestro-v3.0.0').glob('**/*.midi'))
# tokenizer.learn_bpe(
#     vocab_size=10000,
#     files_paths=midi_paths,
#     start_from_empty_voc=False,
# )
# tokenizer.save_params("tokenizer.json")

# Split MIDI paths in train/valid/test sets
total_num_files = len(midi_paths)
num_files_valid = round(total_num_files * 0.2)
num_files_test = round(total_num_files * 0.1)
shuffle(midi_paths)
midi_paths_valid = midi_paths[:num_files_valid]
midi_paths_test = midi_paths[num_files_valid:num_files_valid + num_files_test]
midi_paths_train = midi_paths[num_files_valid + num_files_test:]

# Loads tokens and create data collator
kwargs_dataset = {"min_seq_len": 256, "max_seq_len": 1024, "tokenizer": tokenizer}
dataset_train = DatasetTok(midi_paths_train, **kwargs_dataset)
dataset_valid = DatasetTok(midi_paths_valid, **kwargs_dataset)
dataset_test = DatasetTok(midi_paths_test, **kwargs_dataset)
collator = DataCollator(
    tokenizer["PAD_None"], tokenizer["BOS_None"], tokenizer["EOS_None"]
)


Loading data: ../data/maestro-v3.0.0/2004: 100%|██████████| 893/893 [00:45<00:00, 19.69it/s]
Loading data: ../data/maestro-v3.0.0/2015: 100%|██████████| 255/255 [00:13<00:00, 19.06it/s]
Loading data: ../data/maestro-v3.0.0/2008: 100%|██████████| 128/128 [00:07<00:00, 17.91it/s]


## Model initialization

We will use the [Mistral implementation of Hugging Face](https://huggingface.co/docs/transformers/model_doc/mistral).
Feel free to explore the documentation and source code to dig deeper.

**You may need to adjust the model's configuration, the training configuration and the maximum input sequence length (cell above) depending on your hardware.**

In [3]:
# Creates model
model_config = MistralConfig(
    vocab_size=len(tokenizer),
    hidden_size=512,
    intermediate_size=2048,
    num_hidden_layers=8,
    num_attention_heads=8,
    num_key_value_heads=4,
    sliding_window=256,
    max_position_embeddings=8192,
    pad_token_id=tokenizer['PAD_None'],
    bos_token_id=tokenizer['BOS_None'],
    eos_token_id=tokenizer['EOS_None'],
)
model = AutoModelForCausalLM.from_config(model_config)

In [3]:
model_config = MambaConfig(
    vocab_size=len(tokenizer),
    hidden_size=192,
    state_size=8,
    max_position_embeddings=8192,
    num_hidden_layers=16,
    pad_token_id=tokenizer['PAD_None'],
    bos_token_id=tokenizer['BOS_None'],
    eos_token_id=tokenizer['EOS_None'],
)

model = MambaForCausalLM(model_config)
# model = AutoModelForCausalLM.from_pretrained(model_config)

In [4]:
model.num_parameters()

5800128

## Model training

In [5]:
def preprocess_logits(logits: Tensor, _: Tensor) -> Tensor:
    """
    Preprocess the logits before accumulating them during evaluation.

    This allows to significantly reduce the memory usage and make the training tractable.
    """
    pred_ids = argmax(logits, dim=-1)  # long dtype
    return pred_ids

# Create config for the Trainer
USE_CUDA = cuda_available()
if not cuda_available():
    FP16 = FP16_EVAL = BF16 = BF16_EVAL = False
elif is_bf16_supported():
    BF16 = BF16_EVAL = True
    FP16 = FP16_EVAL = False
else:
    BF16 = BF16_EVAL = False
    FP16 = FP16_EVAL = True
USE_MPS = not USE_CUDA and mps_available()
training_config = TrainingArguments(
    "runs", False, True, False, False, "steps",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=48,
    gradient_accumulation_steps=3,
    eval_accumulation_steps=None,
    eval_steps=250,
    learning_rate=1e-4,
    weight_decay=0.01,
    max_grad_norm=3.0,
    max_steps=2000,
    lr_scheduler_type="cosine_with_restarts",
    warmup_ratio=0.3,
    log_level="debug",
    logging_strategy="steps",
    logging_steps=20,
    save_strategy="steps",
    save_steps=1000,
    save_total_limit=5,
    no_cuda=not USE_CUDA,
    seed=444,
    fp16=FP16,
    fp16_full_eval=FP16_EVAL,
    bf16=BF16,
    bf16_full_eval=BF16_EVAL,
    load_best_model_at_end=True,
    label_smoothing_factor=0.,
    optim="adamw_torch",
    report_to=["tensorboard"],
    gradient_checkpointing=True,
)

collator = DataCollator(tokenizer["PAD_None"], tokenizer["BOS_None"], tokenizer["EOS_None"], copy_inputs_as_labels=True)
trainer = Trainer(
    model=model,
    args=training_config,
    data_collator=collator,
    train_dataset=dataset_train,
    eval_dataset=dataset_valid,
    callbacks=None,
    preprocess_logits_for_metrics=preprocess_logits,
)

# Training
train_result = trainer.train()
trainer.save_model()  # Saves the tokenizer too
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
max_steps is given, it will override any value given in num_train_epochs
Using auto half precision backend
Currently training with a batch size of: 16
***** Running training *****
  Num examples = 7,658
  Num Epochs = 13
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 48
  Gradient Accumulation steps = 3
  Total optimization steps = 2,000
  Number of trainable paramet

  0%|          | 0/2000 [00:00<?, ?it/s]



{'loss': 10.8543, 'grad_norm': 1.1173369884490967, 'learning_rate': 3.3333333333333333e-06, 'epoch': 0.13}
{'loss': 10.7975, 'grad_norm': 1.087349772453308, 'learning_rate': 6.666666666666667e-06, 'epoch': 0.25}
{'loss': 10.6804, 'grad_norm': 1.0429495573043823, 'learning_rate': 1e-05, 'epoch': 0.38}
{'loss': 10.5176, 'grad_norm': 1.0284433364868164, 'learning_rate': 1.3333333333333333e-05, 'epoch': 0.5}
{'loss': 10.3007, 'grad_norm': 1.1212576627731323, 'learning_rate': 1.6666666666666667e-05, 'epoch': 0.63}
{'loss': 9.9798, 'grad_norm': 1.0334638357162476, 'learning_rate': 2e-05, 'epoch': 0.75}
{'loss': 9.6176, 'grad_norm': 0.8386845588684082, 'learning_rate': 2.3333333333333336e-05, 'epoch': 0.88}
{'loss': 9.339, 'grad_norm': 0.6770079731941223, 'learning_rate': 2.6666666666666667e-05, 'epoch': 1.0}
{'loss': 9.1099, 'grad_norm': 0.5555304884910583, 'learning_rate': 3e-05, 'epoch': 1.13}
{'loss': 8.9373, 'grad_norm': 0.4903627634048462, 'learning_rate': 3.3333333333333335e-05, 'epoch

***** Running Evaluation *****
  Num examples = 2244
  Batch size = 48


  0%|          | 0/47 [00:00<?, ?it/s]

{'eval_loss': 8.493900299072266, 'eval_runtime': 9.1385, 'eval_samples_per_second': 245.553, 'eval_steps_per_second': 5.143, 'epoch': 1.57}
{'loss': 8.4852, 'grad_norm': 0.3856232762336731, 'learning_rate': 4.3333333333333334e-05, 'epoch': 1.63}
{'loss': 8.396, 'grad_norm': 0.36467087268829346, 'learning_rate': 4.666666666666667e-05, 'epoch': 1.75}
{'loss': 8.2906, 'grad_norm': 0.31746095418930054, 'learning_rate': 5e-05, 'epoch': 1.88}
{'loss': 8.1911, 'grad_norm': 0.3036254048347473, 'learning_rate': 5.333333333333333e-05, 'epoch': 2.0}
{'loss': 8.1044, 'grad_norm': 0.266647070646286, 'learning_rate': 5.666666666666667e-05, 'epoch': 2.13}
{'loss': 8.0477, 'grad_norm': 0.237630695104599, 'learning_rate': 6e-05, 'epoch': 2.25}
{'loss': 8.0059, 'grad_norm': 0.2264871895313263, 'learning_rate': 6.333333333333333e-05, 'epoch': 2.38}
{'loss': 7.9548, 'grad_norm': 0.19846822321414948, 'learning_rate': 6.666666666666667e-05, 'epoch': 2.51}
{'loss': 7.9146, 'grad_norm': 0.18411993980407715, '

***** Running Evaluation *****
  Num examples = 2244
  Batch size = 48


{'loss': 7.7709, 'grad_norm': 0.281116783618927, 'learning_rate': 8.333333333333334e-05, 'epoch': 3.13}


  0%|          | 0/47 [00:00<?, ?it/s]

{'eval_loss': 7.728194236755371, 'eval_runtime': 9.062, 'eval_samples_per_second': 247.628, 'eval_steps_per_second': 5.187, 'epoch': 3.13}
{'loss': 7.7053, 'grad_norm': 0.41264280676841736, 'learning_rate': 8.666666666666667e-05, 'epoch': 3.26}
{'loss': 7.6575, 'grad_norm': 0.4622027277946472, 'learning_rate': 9e-05, 'epoch': 3.38}
{'loss': 7.5817, 'grad_norm': 0.4035986661911011, 'learning_rate': 9.333333333333334e-05, 'epoch': 3.51}
{'loss': 7.5178, 'grad_norm': 0.34891533851623535, 'learning_rate': 9.666666666666667e-05, 'epoch': 3.63}
{'loss': 7.4341, 'grad_norm': 0.6543633937835693, 'learning_rate': 0.0001, 'epoch': 3.76}
{'loss': 7.4323, 'grad_norm': 0.49501627683639526, 'learning_rate': 9.994965332706573e-05, 'epoch': 3.88}
{'loss': 7.3983, 'grad_norm': 0.4102728068828583, 'learning_rate': 9.979871469976196e-05, 'epoch': 4.01}
{'loss': 7.3499, 'grad_norm': 0.7887152433395386, 'learning_rate': 9.954748808839674e-05, 'epoch': 4.13}
{'loss': 7.2932, 'grad_norm': 0.5180826783180237,

***** Running Evaluation *****
  Num examples = 2244
  Batch size = 48


  0%|          | 0/47 [00:00<?, ?it/s]

{'eval_loss': 7.122819423675537, 'eval_runtime': 9.2801, 'eval_samples_per_second': 241.808, 'eval_steps_per_second': 5.065, 'epoch': 4.7}
{'loss': 7.132, 'grad_norm': 0.8340454697608948, 'learning_rate': 9.681174353198687e-05, 'epoch': 4.76}
{'loss': 7.091, 'grad_norm': 0.5852659344673157, 'learning_rate': 9.597638862757255e-05, 'epoch': 4.89}
{'loss': 7.0676, 'grad_norm': 0.361488401889801, 'learning_rate': 9.504844339512095e-05, 'epoch': 5.01}
{'loss': 7.0335, 'grad_norm': 0.46568670868873596, 'learning_rate': 9.40297765928369e-05, 'epoch': 5.14}
{'loss': 6.9799, 'grad_norm': 0.5836212635040283, 'learning_rate': 9.292243968009331e-05, 'epoch': 5.26}
{'loss': 6.9266, 'grad_norm': 0.6494702100753784, 'learning_rate': 9.172866268606513e-05, 'epoch': 5.39}
{'loss': 6.9011, 'grad_norm': 0.7492420077323914, 'learning_rate': 9.045084971874738e-05, 'epoch': 5.51}
{'loss': 6.8344, 'grad_norm': 0.4563840329647064, 'learning_rate': 8.90915741234015e-05, 'epoch': 5.64}
{'loss': 6.8082, 'grad_no

***** Running Evaluation *****
  Num examples = 2244
  Batch size = 48


{'loss': 6.6563, 'grad_norm': 0.7212328910827637, 'learning_rate': 8.117449009293668e-05, 'epoch': 6.26}


  0%|          | 0/47 [00:00<?, ?it/s]

Saving model checkpoint to runs/checkpoint-1000
Configuration saved in runs/checkpoint-1000/config.json
Configuration saved in runs/checkpoint-1000/generation_config.json
Model weights saved in runs/checkpoint-1000/model.safetensors


{'eval_loss': 6.6417436599731445, 'eval_runtime': 9.2805, 'eval_samples_per_second': 241.797, 'eval_steps_per_second': 5.064, 'epoch': 6.26}




{'loss': 6.6352, 'grad_norm': 0.7302793264389038, 'learning_rate': 7.938926261462366e-05, 'epoch': 6.39}
{'loss': 6.6108, 'grad_norm': 0.49132946133613586, 'learning_rate': 7.754484907260513e-05, 'epoch': 6.51}
{'loss': 6.5829, 'grad_norm': 0.7991334795951843, 'learning_rate': 7.564496387029532e-05, 'epoch': 6.64}
{'loss': 6.5513, 'grad_norm': 0.6123929619789124, 'learning_rate': 7.369343312364993e-05, 'epoch': 6.76}
{'loss': 6.5168, 'grad_norm': 0.8323113918304443, 'learning_rate': 7.169418695587791e-05, 'epoch': 6.89}
{'loss': 6.4985, 'grad_norm': 0.6493431329727173, 'learning_rate': 6.965125158269619e-05, 'epoch': 7.01}
{'loss': 6.475, 'grad_norm': 0.947213888168335, 'learning_rate': 6.756874120406714e-05, 'epoch': 7.14}
{'loss': 6.4686, 'grad_norm': 0.7997612953186035, 'learning_rate': 6.545084971874738e-05, 'epoch': 7.27}
{'loss': 6.4473, 'grad_norm': 1.0834264755249023, 'learning_rate': 6.330184227833376e-05, 'epoch': 7.39}
{'loss': 6.434, 'grad_norm': 0.9170778393745422, 'learni

***** Running Evaluation *****
  Num examples = 2244
  Batch size = 48


  0%|          | 0/47 [00:00<?, ?it/s]

{'eval_loss': 6.379364490509033, 'eval_runtime': 9.1896, 'eval_samples_per_second': 244.19, 'eval_steps_per_second': 5.114, 'epoch': 7.83}
{'loss': 6.3945, 'grad_norm': 0.5491718649864197, 'learning_rate': 5.448196544517168e-05, 'epoch': 7.89}
{'loss': 6.368, 'grad_norm': 0.4666157066822052, 'learning_rate': 5.2243241517525754e-05, 'epoch': 8.02}
{'loss': 6.3646, 'grad_norm': 0.5591413378715515, 'learning_rate': 5e-05, 'epoch': 8.14}
{'loss': 6.3454, 'grad_norm': 0.5666479468345642, 'learning_rate': 4.775675848247427e-05, 'epoch': 8.27}
{'loss': 6.3303, 'grad_norm': 0.5538848638534546, 'learning_rate': 4.551803455482833e-05, 'epoch': 8.39}
{'loss': 6.3255, 'grad_norm': 0.5233960151672363, 'learning_rate': 4.328833670911724e-05, 'epoch': 8.52}
{'loss': 6.3087, 'grad_norm': 0.6081569194793701, 'learning_rate': 4.107215526006817e-05, 'epoch': 8.64}
{'loss': 6.3241, 'grad_norm': 0.37131088972091675, 'learning_rate': 3.887395330218429e-05, 'epoch': 8.77}
{'loss': 6.3025, 'grad_norm': 0.5819

***** Running Evaluation *****
  Num examples = 2244
  Batch size = 48


{'loss': 6.2747, 'grad_norm': 0.47911709547042847, 'learning_rate': 2.8305813044122097e-05, 'epoch': 9.39}


  0%|          | 0/47 [00:00<?, ?it/s]

{'eval_loss': 6.2642974853515625, 'eval_runtime': 9.2038, 'eval_samples_per_second': 243.812, 'eval_steps_per_second': 5.107, 'epoch': 9.39}
{'loss': 6.2627, 'grad_norm': 0.6638137698173523, 'learning_rate': 2.630656687635007e-05, 'epoch': 9.52}
{'loss': 6.2618, 'grad_norm': 0.6825015544891357, 'learning_rate': 2.43550361297047e-05, 'epoch': 9.65}
{'loss': 6.2671, 'grad_norm': 0.4511934816837311, 'learning_rate': 2.245515092739488e-05, 'epoch': 9.77}
{'loss': 6.2483, 'grad_norm': 0.48987722396850586, 'learning_rate': 2.061073738537635e-05, 'epoch': 9.9}
{'loss': 6.2506, 'grad_norm': 0.48589959740638733, 'learning_rate': 1.8825509907063327e-05, 'epoch': 10.02}
{'loss': 6.2253, 'grad_norm': 0.3956669569015503, 'learning_rate': 1.7103063703014372e-05, 'epoch': 10.15}
{'loss': 6.2215, 'grad_norm': 0.49434563517570496, 'learning_rate': 1.544686755065677e-05, 'epoch': 10.27}
{'loss': 6.2498, 'grad_norm': 0.3422397971153259, 'learning_rate': 1.3860256808630428e-05, 'epoch': 10.4}
{'loss': 6.2

***** Running Evaluation *****
  Num examples = 2244
  Batch size = 48


  0%|          | 0/47 [00:00<?, ?it/s]

{'eval_loss': 6.223368167877197, 'eval_runtime': 9.1226, 'eval_samples_per_second': 245.981, 'eval_steps_per_second': 5.152, 'epoch': 10.96}
{'loss': 6.2283, 'grad_norm': 0.420021116733551, 'learning_rate': 7.077560319906695e-06, 'epoch': 11.02}
{'loss': 6.2214, 'grad_norm': 0.369074285030365, 'learning_rate': 5.9702234071631e-06, 'epoch': 11.15}
{'loss': 6.2199, 'grad_norm': 0.3708091378211975, 'learning_rate': 4.951556604879048e-06, 'epoch': 11.27}
{'loss': 6.2219, 'grad_norm': 0.39810821413993835, 'learning_rate': 4.023611372427471e-06, 'epoch': 11.4}
{'loss': 6.221, 'grad_norm': 0.3980230391025543, 'learning_rate': 3.18825646801314e-06, 'epoch': 11.52}
{'loss': 6.2061, 'grad_norm': 0.3656119108200073, 'learning_rate': 2.4471741852423237e-06, 'epoch': 11.65}
{'loss': 6.2133, 'grad_norm': 0.31287604570388794, 'learning_rate': 1.8018569652073381e-06, 'epoch': 11.77}
{'loss': 6.223, 'grad_norm': 0.4212947189807892, 'learning_rate': 1.2536043909088191e-06, 'epoch': 11.9}
{'loss': 6.2281

***** Running Evaluation *****
  Num examples = 2244
  Batch size = 48


{'loss': 6.2218, 'grad_norm': 0.41447681188583374, 'learning_rate': 0.0, 'epoch': 12.53}


  0%|          | 0/47 [00:00<?, ?it/s]

Saving model checkpoint to runs/checkpoint-2000
Configuration saved in runs/checkpoint-2000/config.json
Configuration saved in runs/checkpoint-2000/generation_config.json
Model weights saved in runs/checkpoint-2000/model.safetensors


Training completed. Do not forget to share your model on huggingface.co/models =)


Loading best model from runs/checkpoint-2000 (score: 6.217372417449951).
There were missing keys in the checkpoint model loaded: ['lm_head.weight'].
Saving model checkpoint to runs
Configuration saved in runs/config.json
Configuration saved in runs/generation_config.json
Model weights saved in runs/model.safetensors


{'eval_loss': 6.217372417449951, 'eval_runtime': 9.069, 'eval_samples_per_second': 247.436, 'eval_steps_per_second': 5.182, 'epoch': 12.53}
{'train_runtime': 1813.704, 'train_samples_per_second': 52.93, 'train_steps_per_second': 1.103, 'train_loss': 7.168868602752686, 'epoch': 12.53}
***** train metrics *****
  epoch                    =      12.53
  train_loss               =     7.1689
  train_runtime            = 0:30:13.70
  train_samples_per_second =      52.93
  train_steps_per_second   =      1.103


## Generate music

In [8]:
(gen_results_path := Path('gen_res')).mkdir(parents=True, exist_ok=True)
generation_config = GenerationConfig(
    max_new_tokens=512,  # extends samples by 512 tokens
    num_beams=1,        # no beam search
    do_sample=True,     # but sample instead
    temperature=0.9,
    top_k=15,
    top_p=0.95,
    epsilon_cutoff=3e-4,
    eta_cutoff=1e-3,
    # pad_token_id=config.padding_token_id,
)

# Here the sequences are padded to the left, so that the last token along the time dimension
# is always the last token of each seq, allowing to efficiently generate by batch
collator.pad_on_left = True
collator.eos_token = None
dataloader_test = DataLoader(dataset_test, batch_size=16, collate_fn=collator)

model = MambaForCausalLM.from_pretrained('./runs/checkpoint-2000')

model.eval()
count = 0
for batch in tqdm(dataloader_test, desc='Testing model / Generating results'):  # (N,T)
    res = model.generate(
        inputs=batch["input_ids"].to(model.device),
        attention_mask=batch["attention_mask"].to(model.device),
        generation_config=generation_config)  # (N,T)

    # Saves the generated music, as MIDI files and tokens (json)
    for prompt, continuation in zip(batch["input_ids"], res):
        generated = continuation[len(prompt):]
        midi = tokenizer.tokens_to_midi([deepcopy(generated.tolist())])
        tokens = [generated, prompt, continuation]  # list compr. as seqs of dif. lengths
        tokens = [seq.tolist() for seq in tokens]
        # for tok_seq in tokens[1:]:
        #     _midi = tokenizer.tokens_to_midi([deepcopy(tok_seq)])
        #     midi.instruments.append(_midi.instruments[0])
        # midi.instruments[0].name = f'Continuation of original sample ({len(generated)} tokens)'
        # midi.instruments[1].name = f'Original sample ({len(prompt)} tokens)'
        # midi.instruments[2].name = f'Original sample and continuation'
        midi.dump_midi(gen_results_path / f'{count}.mid')
        tokenizer.save_tokens(tokens, gen_results_path / f'{count}.json') 

        count += 1
    break

loading configuration file ./runs/checkpoint-2000/config.json
Model config MambaConfig {
  "architectures": [
    "MambaForCausalLM"
  ],
  "bos_token_id": 2,
  "conv_kernel": 4,
  "eos_token_id": 3,
  "expand": 2,
  "hidden_act": "silu",
  "hidden_size": 192,
  "initializer_range": 0.1,
  "intermediate_size": 384,
  "layer_norm_epsilon": 1e-05,
  "max_position_embeddings": 8192,
  "model_type": "mamba",
  "num_hidden_layers": 16,
  "pad_token_id": 0,
  "rescale_prenorm_residual": false,
  "residual_in_fp32": true,
  "state_size": 8,
  "time_step_floor": 0.0001,
  "time_step_init_scheme": "random",
  "time_step_max": 0.1,
  "time_step_min": 0.001,
  "time_step_rank": 12,
  "time_step_scale": 1.0,
  "torch_dtype": "float32",
  "transformers_version": "4.40.0.dev0",
  "use_bias": false,
  "use_cache": true,
  "use_conv_bias": true,
  "vocab_size": 10000
}

loading weights file ./runs/checkpoint-2000/model.safetensors
Generate config GenerationConfig {
  "bos_token_id": 2,
  "eos_token_id

In [8]:
dataloader_test = DataLoader(dataset_test, batch_size=16, collate_fn=collator)
model.eval()

for batch in tqdm(dataloader_test, desc='Testing model / Generating results'):
    res = model.generate(batch['input_ids'].to(model.device), max_new_tokens=100)
    break

res

A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
Testing model / Generating results:   0%|          | 0/74 [00:00<?, ?it/s]


RuntimeError: Expected conv_state.scalar_type() == input_type to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)