In [None]:
# -*- coding: utf-8 -*-
"""SymphonicAI.ipynb"""

In [None]:
!unzip /content/DATASET_AUGMENTED.zip -d /content/DATASET_AUGMENTED/

In [None]:
!pip install -qq torch midi_neural_processor pretty_midi tqdm utils

In [None]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
import tqdm
from pathlib import Path
import midi_neural_processor.processor as midi_tokenizer

In [None]:
"""Defining model and dataset configuration"""

event_range = midi_tokenizer.RANGE_NOTE_ON
event_range += midi_tokenizer.RANGE_NOTE_OFF
event_range += midi_tokenizer.RANGE_TIME_SHIFT
event_range += midi_tokenizer.RANGE_VEL

CONFIG = {
    "max_sequence_len": 1024,
    "block_size": 512,
    "embedding_dim": 256,
    "num_heads": 8,
    "num_layers": 6, #num_blocks
    "batch_size": 16,
    "token_pad": event_range,
    "token_sos": event_range + 1,
    "token_eos": event_range + 2,
    "vocab_size": event_range + 3,
    "seed": 42,
    "model_out": "midi_transformer_model.pth"
}

In [None]:
"""Splitting dataset into train and evaluation files"""

DATASET_DIR = "/content/DATASET_AUGMENTED/DATASET_AUGMENTED"

all_midis = [f for f in os.listdir(DATASET_DIR) if f.endswith(".mid")]
random.shuffle(all_midis)

n = len(all_midis) # 35256 files
n_train = int(0.9 * n) # 31730 files
n_validation = int(0.1 * n) # 3526 files

train_files = all_midis[:n_train]
validation_files = all_midis[n_train:]

print(type(train_files))

In [None]:
number_of_tokens = []

for f in train_files:
    tokens = midi_tokenizer.encode_midi(os.path.join(DATASET_DIR, f))
    number_of_tokens.append(len(tokens))
    if(len(number_of_tokens) == 3000):
        break

print(number_of_tokens)

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

df = pd.DataFrame({
    "file": train_files[:3000],
    "token_count": number_of_tokens
})
print(df["token_count"].describe())

# Plot distribution
plt.figure(figsize=(8,4))
plt.hist(number_of_tokens, bins=50)
plt.title("Distribution of MIDI Token Counts")
plt.xlabel("Number of Tokens")
plt.ylabel("Number of Files")
plt.tight_layout()
plt.show()

In [None]:
"""ENCODE EACH MIDI FILE AND DIVIDE IT INTO CHUNKS OF 512 TOKENS WITH SOS, EOS AND PAD"""

list_of_512_seqs = []

for f in tqdm.tqdm(all_midis, desc="Encoding MIDI files into 512 sequences"):
    tokens = midi_tokenizer.encode_midi(os.path.join(DATASET_DIR, f))

    for i in range(0, len(tokens), 510):
        chunk = tokens[i : i + 510]

        if len(chunk) < 510:
            pad_count = 510 - len(chunk)
            chunk = chunk + [CONFIG['token_pad']] * pad_count

        seq512 = [CONFIG['token_sos']] + chunk + [CONFIG['token_eos']]
        assert len(seq512) == 512

        list_of_512_seqs.append(seq512)

print(f"Total 512-token sequences: {len(list_of_512_seqs)}")

In [None]:
n = len(list_of_512_seqs)
print(n)

train_seqs = list_of_512_seqs[:int(0.9 * n)]
val_seqs = list_of_512_seqs[int(0.9 * n):]

print(len(train_seqs))
print(len(val_seqs))

print(len(list_of_512_seqs))
print(list_of_512_seqs[0])
print(list_of_512_seqs[1])
print(list_of_512_seqs[2])

print(type(train_seqs))
print(len(train_seqs[0]))

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
train_data = torch.tensor(train_seqs, dtype=torch.long, device=device)
print(train_data.shape, train_data.dtype)
print(train_data[:100])

test_data = torch.tensor(val_seqs, dtype=torch.long, device=device)
print(test_data.shape, test_data.dtype)
print(test_data[:100])

X_train = train_data[:, :-1]
Y_train = train_data[:, 1:]

X_val = test_data[:, :-1]
Y_val = test_data[:, 1:]

train_ds = TensorDataset(X_train, Y_train)
val_ds   = TensorDataset(X_val,   Y_val)

train_loader = DataLoader(
    train_ds,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    drop_last=True
)
val_loader = DataLoader(
    val_ds,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    drop_last=True
)

In [None]:
import torch.nn as nn
import torch.nn.functional as F


class Head(nn.Module):
    def __init__(self, head_size, n_embd, block_size, dropout=0.1):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)
        wei = q @ k.transpose(-2, -1) / (k.size(-1) ** 0.5)

        wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
        wei = F.softmax(wei, dim=-1)
        return self.dropout(wei @ v)

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, n_embd, block_size, dropout=0.1):
        super().__init__()
        head_size = n_embd // num_heads
        self.heads = nn.ModuleList([Head(head_size, n_embd, block_size, dropout) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.dropout(self.proj(torch.cat([h(x) for h in self.heads], dim=-1)))

class FeedForward(nn.Module):
    def __init__(self, n_embd, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.GELU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self, n_embd, n_head, block_size, dropout=0.1):
        super().__init__()
        self.sa = MultiHeadAttention(n_head, n_embd, block_size, dropout)
        self.ffwd = FeedForward(n_embd, dropout)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class TransformerMIDILanguageModel(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.block_size    = config['block_size']
        self.vocab_size    = config['vocab_size']
        self.sos_token_id  = config['token_sos']
        self.pad_token_id  = config['token_pad']
        self.eos_token_id  = config['token_eos']

        self.token_embedding_table = nn.Embedding(
            self.vocab_size,
            config['embedding_dim'],
            padding_idx=self.pad_token_id
        )

        self.position_embedding_table = nn.Embedding(self.block_size, config['embedding_dim'])

        self.blocks = nn.Sequential(*[
            Block(n_embd=config['embedding_dim'],
                  n_head=config['num_heads'],
                  block_size=config['block_size'],
                  dropout=0.1)
            for _ in range(config['num_layers'])
            ])

        self.ln_f = nn.LayerNorm(config['embedding_dim'])
        self.lm_head = nn.Linear(config['embedding_dim'], self.vocab_size)


    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device))

        x = self.blocks(tok_emb + pos_emb)
        x = self.ln_f(x)

        logits = self.lm_head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, self.vocab_size), targets.view(-1), ignore_index=self.pad_token_id)
        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, eos_token_id, temperature=1.0, top_k=None):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size:]
            logits, _ = self(idx_cond)

            logits = logits[:, -1, :] / temperature

            if top_k is not None:
                v, _ = torch.topk(logits, top_k, dim=-1)
                min_vals = v[:, -1].unsqueeze(1)
                logits[logits < min_vals] = -float('Inf')

            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, idx_next], dim=1)

            if eos_token_id is not None:
                # check every item in batch
                eos_mask = idx_next.eq(eos_token_id).view(-1)
                if eos_mask.all():
                    break

        return idx

In [None]:
model = TransformerMIDILanguageModel(CONFIG).to(device)
print(sum(p.numel() for p in model.parameters()) / 1e6, "M parameters")
print(model)

In [None]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()

    for split, loader in [("train", train_loader), ("val", val_loader)]:
        losses = []
        for i, (xb, yb) in enumerate(loader):
            if i>= eval_iters:
                break
            xb, yb = xb.to(device), yb.to(device)
            _, loss = model(xb, yb)
            losses.append(loss.item())
        out[split] = sum(losses) / len(losses)
    model.train()
    return out

In [None]:
del iter
learning_rate = 1e-4
max_iters = 50000
eval_interval = 100
eval_iters = 200

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

train_iterator = iter(train_loader)

In [None]:
for step in range(max_iters):
    if step % eval_interval == 0 or step == max_iters -1:
        losses = estimate_loss()
        print(f"Iteration {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    try:
        xb, yb = next(train_iterator)
    except StopIteration:
        train_iterator = iter(train_loader)
        xb, yb = next(train_iterator)

    xb, yb = xb.to(device), yb.to(device)

    optimizer.zero_grad(set_to_none=True)
    _, loss = model(xb, yb)
    loss.backward()
    optimizer.step()

torch.save(model.state_dict(), CONFIG['model_out'])

In [None]:
# Load the trained model
model = TransformerMIDILanguageModel(CONFIG).to(device)
model.load_state_dict(torch.load(CONFIG['model_out']))
model.eval()

# Generate a new MIDI sequence
context = torch.zeros((1, 1), dtype=torch.long, device=device)  # Start with an empty context
context[0, 0] = CONFIG['token_sos']  # Start with SOS token
generated_sequence = model.generate(context, max_new_tokens=512, eos_token_id=CONFIG['token_eos'], temperature=1.0)

# Convert the generated sequence back to MIDI data
generated_midi = midi_tokenizer.decode_midi(generated_sequence[0].tolist())

# Save the generated MIDI file
output_midi_path = "generated_midi.mid"
generated_midi.write(output_midi_path)
print(f"Generated MIDI file saved to: {output_midi_path}")

# prompt: give the model MIDI context (list of tokens) and make it generate its continuation

# Assuming 'generated_sequence' is already defined from the previous code

# Example context (replace with your actual MIDI context)
# context_tokens = [CONFIG['token_sos'], 10, 25, 50, 75] # Example tokens, replace with your actual MIDI data

context_tokens_fur_elise = [CONFIG['token_sos'], 342, 365, 76, 277, 204, 364, 75, 277, 203, 365, 76, 276, 204, 365, 75, 276, 203, 367, 76, 276, 204, 366, 71, 277, 199, 367, 74, 277, 202, 365, 72, 279, 200, 365, 45, 368, 69, 277, 173, 364, 52, 277, 197, 364, 57, 277, 364, 60, 276, 365, 64, 276, 366, 69, 278, 367, 71, 366, 40, 258, 180, 185, 188, 192, 197, 273, 168, 364, 52, 277, 199, 364, 56, 276, 365, 64, 276, 365, 68, 276, 366, 71, 279, 367, 72, 365, 45, 256, 180, 184, 192, 196, 199, 276, 173, 364, 52, 277, 200, 364, 57, 277, 364, 64, 261, 180, 185, 271, 192, 365, 76, 277, 204, 364, 75, 276, 203, 366, 76, 276, 204, 365, 75, 276, 203, 367, 76, 276, 204, 365, 71, 277, 199, 368, 74, 277, 202, 365, 72, 278, 200, 365, 45, 368, 69, 277, 173, 364, 52, 277, 197, 364, 57, 277, 364, 60, 276, 365, 64, 276, 366, 69, 278, 367, 71, 365, 40, 257, 180, 185, 188, 192, 197, 275, 168, 364, 52, 277, 199, 365, 56, 276, 365, 64, 277, 366, 72, 277, 365, 71, 278, 366, 69, 365, 45, 258, 180, 184, 192, 200, 199, 275, 173, 364, 52, 278, 197, 364, 57, 283, 180, 185, 277, 365, 76, 278, 204, 364, 75, 277, 203, 365, 76, 276, 204, 365, 75, 276, 203, 367, 76, 277, 204, 366, 71, 276, 199, 367, 74, 276, 202, 365, 72, 279, 200, 365, 45, 368, 69, 277, 173, 364, 52, 277, 197, 364, 57, 277, 364, 60, 276, 365, 64, 276, 366, 69, 278, 367, 71, 366, 40, 258, 180, 185, 188, 192, 197, 273, 168, 364, 52, 277, 199, 364, 56, 276, 365, 64, 276, 365, 68, 276, 366, 71, 279, 367, 72, 365, 45, 256, 180, 184, 192, 196, 199, 276, 173, 364, 52, 277, 200, 364, 57, 277, 364, 64, 261, 180, 185, 271, 192, 365, 76, 277, 204, 364, 75, 276, 203, 366, 76, 276, 204, 365, 75, 276, 203, 367, 76, 276, 204, 365, 71, 277, 199, 368, 74, 277, 202, 365, 72, 278, 200, 365, 45, 368, 69, 277, 173, 364, 52, 277, 197, 364, 57, 277, 364, 60, 276, 365, 64, 276, 366, 69, 278, 367, 71, 365, 40, 257, 180, 185, 188, 192, 197, 275, 168, 365, 52, 277, 199, 365, 56, 277, 365, 64, 278, 366, 72, 278, 365, 71, 278, 366, 69, 365, 45, 259, 180, 184, 192, 200, 199, 275, 173, 364, 52, 278, 197, 364, 57, 277, 366, 71, 277, 367, 72, 277, 368, 74, 278, 368, 76, 367, 48, 257, 180, 185, 199, 200, 202, 274, 176, 365, 55, 276, 365, 60, 277, 204, 366, 67, 277, 367, 77, 277, 366, 76, 277, 368, 74, 367, 43, 257, 183, 188, 195, 205, 204, 274, 171, 365, 55, 276, 365, 59, 277, 202, 366, 65, 277, 367, 76, 277, 366, 74, 277, 368, 72, 367, 45, 257, 183, 187, 193, 204, 202, 274, 173, 365, 52, 276, 365, 57, 277, 200, 366, 64, 278, 367, CONFIG['token_eos']]

context = torch.tensor([context_tokens_fur_elise], dtype=torch.long, device=device)

# Generate continuation
generated_sequence = model.generate(context, max_new_tokens=512, eos_token_id=CONFIG['token_eos'], temperature=1.0)

# Convert the generated sequence back to MIDI data
generated_midi = midi_tokenizer.decode_midi(generated_sequence[0].tolist())

# Save the generated MIDI file
output_midi_path = "generated_midi_from_furelise.mid"
generated_midi.write(output_midi_path)
print(f"Generated MIDI file saved to: {output_midi_path}")

first_tokens = [280, 376, 63, 286, 191, 376, 63, 286, 191, 376, 63, 286, 191, 376, 63, 302, 191, 376, 65, 302, 193, 376, 62, 317, 190, 376, 62, 286, 190, 376, 62, 286, 190, 376, 62, 286, 190, 376, 62, 302, 190, 376, 63, 302, 191, 376, 60, 317, 188, 376, 60, 286, 188, 376, 60, 286, 188, 376, 60, 286, 188, 376, 60, 302, 188, 376, 55, 302, 183, 376, 56, 317, 184, 376, 60, 286, 188, 376, 60, 286, 188, 376, 60, 286, 188, 376, 60, 302, 188, 376, 62, 302, 190, 376, 63, 355, 311, 191, 286, 376, 51, 286, 179, 376, 51, 286, 179, 376, 58, 349, 186, 376, 55, 286, 183, 376, 53, 286, 181, 376, 53, 286, 181, 376, 51, 286, 179, 376, 51, 286, 179, 376, 58, 349, 186, 376, 55, 286, 183, 376, 53, 286, 181, 376, 53, 286, 181, 376, 51, 286, 179, 376, 51, 286, 179, 376, 60, 349, 188, 376, 55, 286, 183, 376, 53, 286, 181, 376, 53, 286, 181, 376, 51, 286, 179, 376, 51, 286, 179, 376, 56, 317, 184, 376, 56, 317, 184, 376, 56, 317, 184, 376, 51, 286, 179, 376, 51, 286, 179, 376, 58, 349, 186, 376, 55, 286, 183, 376, 53, 286, 181, 376, 53, 286, 181, 376, 51, 286, 179, 376, 51, 286, 179, 376, 58, 349, 186, 376, 55, 286, 183, 376, 53, 286, 181, 376, 53, 286, 181, 376, 51, 286, 179, 376, 51, 286, 179, 376, 60, 349, 188, 376, 55, 286, 183, 376, 53, 286, 181, 376, 53, 286, 181, 376, 51, 286, 179, 376, 51, 286, 179, 376, 56, 317, 184, 376, 56, 317, 184, 376, 56, 286, 184, 376, 55, 355, 280, 183, 286, 376, 58, 286, 186, 376, 58, 286, 186, 376, 58, 286, 186, 376, 58, 302, 186, 376, 55, 302, 183, 376, 58, 317, 186, 376, 58, 286, 186, 376, 58, 286, 186, 376, 58, 286, 186, 376, 63, 302, 191, 376, 62, 302, 190, 376, 60, 317, 188, 376, 55, 286, 183, 376, 55, 286, 183, 376, 55, 286, 183, 376, 55, 302, 183, 376, 55, 302, 183, 376, 56, 317, 184, 376, 56, 286, 184, 376, 56, 286, 184, 376, 56, 286, 184, 376, 56, 302, 184, 376, 55, 302, 183, 376, 55, 317, 183, 376, 63, 286, 191, 376, 63, 286, 191, 376, 63, 286, 191, 376, 63, 302, 191, 376, 63, 302, 191, 376, 62, 317, 190, 376, 62, 286, 190, 376, 62, 286, 190, 376, 62, 286, 190, 376, 62, 302, 190, 376, 63, 302, 191, 376, 60, 317, 188, 376, 55, 286, 183, 376, 55, 286, 183, 376, 55, 286, 183, 376, 55, 302, 183, 376, 55, 302, 183, 376, 56, 317, 184, 376, 56, 286, 184, 376, 56, 286, 184, 376, 56, 286, 184, 376, 56, 302, 184, 376, 55, 302, 183, 376, 55, 349, 183, 384, 67, 286, 195, 384, 67, 286, 195, 384, 70, 302, 198, 384, 63, 302, 191, 384, 62, 349, 190, 384, 67, 286, 195, 384, 67, 286, 195, 384]
context_demons_imagine_dragons_tokens = [CONFIG['token_sos'],
                                         *first_tokens,
                                         CONFIG['token_eos']]

context = torch.tensor([context_demons_imagine_dragons_tokens], dtype=torch.long, device=device)

# Generate continuation
generated_sequence = model.generate(context, max_new_tokens=512, eos_token_id=CONFIG['token_eos'], temperature=1.0)

# Convert the generated sequence back to MIDI data
generated_midi = midi_tokenizer.decode_midi(generated_sequence[0].tolist())

# Save the generated MIDI file
output_midi_path = "generated_midi_from_demons.mid"
generated_midi.write(output_midi_path)
print(f"Generated MIDI file saved to: {output_midi_path}")

first_tokens = [275, 376, 70, 275, 198, 376, 70, 275, 198, 376, 70, 275, 198, 376, 70, 285, 198, 376, 72, 285, 200, 376, 69, 295, 197, 376, 69, 275, 197, 376, 69, 275, 197, 376, 69, 275, 197, 376, 69, 285, 197, 376, 70, 285, 198, 376, 67, 295, 195, 376, 67, 275, 195, 376, 67, 275, 195, 376, 67, 275, 195, 376, 67, 285, 195, 376, 62, 285, 190, 376, 63, 295, 191, 376, 67, 275, 195, 376, 67, 275, 195, 376, 67, 275, 195, 376, 67, 285, 195, 376, 69, 285, 197, 376, 70, 355, 198, 275, 376, 58, 275, 186, 376, 58, 275, 186, 376, 65, 315, 193, 376, 62, 275, 190, 376, 60, 275, 188, 376, 60, 275, 188, 376, 58, 275, 186, 376, 58, 275, 186, 376, 65, 315, 193, 376, 62, 275, 190, 376, 60, 275, 188, 376, 60, 275, 188, 376, 58, 275, 186, 376, 58, 275, 186, 376, 67, 315, 195, 376, 62, 275, 190, 376, 60, 275, 188, 376, 60, 275, 188, 376, 58, 275, 186, 376, 58, 275, 186, 376, 63, 295, 191, 376, 63, 295, 191, 376, 63, 295, 191, 376, 58, 275, 186, 376, 58, 275, 186, 376, 65, 315, 193, 376, 62, 275, 190, 376, 60, 275, 188, 376, 60, 275, 188, 376, 58, 275, 186, 376, 58, 275, 186, 376, 65, 315, 193, 376, 62, 275, 190, 376, 60, 275, 188, 376, 60, 275, 188, 376, 58, 275, 186, 376, 58, 275, 186, 376, 67, 315, 195, 376, 62, 275, 190, 376, 60, 275, 188, 376, 60, 275, 188, 376, 58, 275, 186, 376, 58, 275, 186, 376, 63, 295, 191, 376, 63, 295, 191, 376, 63, 275, 191, 376, 62, 335, 190, 275, 376, 65, 275, 193, 376, 65, 275, 193, 376, 65, 275, 193, 376, 65, 285, 193, 376, 62, 285, 190, 376, 65, 295, 193, 376, 65, 275, 193, 376, 65, 275, 193, 376, 65, 275, 193, 376, 70, 285, 198, 376, 69, 285, 197, 376, 67, 295, 195, 376, 62, 275, 190, 376, 62, 275, 190, 376, 62, 275, 190, 376, 62, 285, 190, 376, 62, 285, 190, 376, 63, 295, 191, 376, 63, 275, 191, 376, 63, 275, 191, 376, 63, 275, 191, 376, 63, 285, 191, 376, 62, 285, 190, 376, 62, 295, 190, 376, 70, 275, 198, 376, 70, 275, 198, 376, 70, 275, 198, 376, 70, 285, 198, 376, 70, 285, 198, 376, 69, 295, 197, 376, 69, 275, 197, 376, 69, 275, 197, 376, 69, 275, 197, 376, 69, 285, 197, 376, 70, 285, 198, 376, 67, 295, 195, 376, 62, 275, 190, 376, 62, 275, 190, 376, 62, 275, 190, 376, 62, 285, 190, 376, 62, 285, 190, 376, 63, 295, 191, 376, 63, 275, 191, 376, 63, 275, 191, 376, 63, 275, 191, 376, 63, 285, 191, 376, 62, 285, 190, 376, 62, 315, 190, 384, 74, 275, 202, 384, 74, 275, 202, 384, 77, 285, 205, 384, 70, 285, 198, 384, 69, 315, 197, 384, 74, 275, 202, 384, 74, 275, 202, 384, 77, 285]
context_demons_imagine_dragons_fast_tokens = [CONFIG['token_sos'],
                                         *first_tokens,
                                         CONFIG['token_eos']]

context = torch.tensor([context_demons_imagine_dragons_fast_tokens], dtype=torch.long, device=device)

# Generate continuation
generated_sequence = model.generate(context, max_new_tokens=512, eos_token_id=CONFIG['token_eos'], temperature=1.0)

# Convert the generated sequence back to MIDI data
generated_midi = midi_tokenizer.decode_midi(generated_sequence[0].tolist())

# Save the generated MIDI file
output_midi_path = "generated_midi_from_fast_demons.mid"
generated_midi.write(output_midi_path)
print(f"Generated MIDI file saved to: {output_midi_path}")

first_tokens = [376, 58, 315, 186, 376, 60, 275, 188, 376, 58, 295, 186, 376, 57, 295, 185, 376, 55, 295, 183, 376, 57, 295, 185, 376, 58, 295, 186, 376, 65, 295, 193, 376, 65, 335, 193, 335, 376, 62, 295, 190, 376, 63, 295, 191, 376, 62, 335, 190, 376, 58, 315, 186, 376, 60, 275, 188, 376, 58, 295, 186, 376, 57, 295, 185, 376, 55, 295, 183, 376, 57, 295, 185, 376, 58, 295, 186, 376, 65, 295, 193, 376, 65, 335, 193, 335, 376, 62, 295, 190, 376, 63, 295, 191, 376, 62, 335, 190, 376, 57, 315, 185, 376, 58, 275, 186, 376, 60, 295, 188, 376, 57, 295, 185, 376, 58, 335, 186, 355, 355, 355, 355, 376, 58, 315, 186, 376, 60, 275, 188, 376, 58, 295, 186, 376, 57, 295, 185, 376, 55, 295, 183, 376, 57, 295, 185, 376, 58, 295, 186, 376, 65, 295, 193, 376, 65, 335, 193, 335, 376, 65, 295, 193, 376, 65, 295, 193, 376, 65, 335, 193, 376, 58, 315, 186, 376, 60, 275, 188, 376, 58, 295, 186, 376, 57, 295, 185, 376, 55, 295, 183, 376, 57, 295, 185, 376, 58, 295, 186, 376, 65, 295, 193, 376, 65, 335, 193, 335, 376, 62, 295, 190, 376, 63, 295, 191, 376, 62, 335, 190, 376, 58, 315, 186, 376, 60, 275, 188, 376, 58, 295, 186, 376, 57, 295, 185, 376, 55, 295, 183, 376, 57, 295, 185, 376, 58, 295, 186, 376, 65, 295, 193, 376, 65, 335, 193, 335, 376, 62, 295, 190, 376, 63, 295, 191, 376, 62, 335, 190, 376, 57, 315, 185, 376, 58, 275, 186, 376, 60, 295, 188, 376, 57, 295, 185, 376, 58, 335, 186, 355, 355, 355, 355, 376, 58, 315, 186, 376, 60, 275, 188, 376, 58, 295, 186, 376, 57, 295, 185, 376, 55, 295, 183, 376, 57, 295, 185, 376, 58, 295, 186, 376, 65, 295, 193, 376, 65, 335, 193, 335, 376, 65, 295, 193, 376, 65, 295, 193, 376, 65, 335, 193]
context_demons_low = [CONFIG['token_sos'],
                                         *first_tokens,
                                         CONFIG['token_eos']]

context = torch.tensor([context_demons_low], dtype=torch.long, device=device)

# Generate continuation
generated_sequence = model.generate(context, max_new_tokens=512, eos_token_id=CONFIG['token_eos'], temperature=1.0)

# Convert the generated sequence back to MIDI data
generated_midi = midi_tokenizer.decode_midi(generated_sequence[0].tolist())

# Save the generated MIDI file
output_midi_path = "generated_midi_from_low.mid"
generated_midi.write(output_midi_path)
print(f"Generated MIDI file saved to: {output_midi_path}")

print(generated_sequence)