# Modeling 
1. Context
- Our ML task is to create models that will take in MIDI files and generate an extension of those MIDI files, so that we can indefintely continue adding to the current piece.
- All of our models will take in the same input MIDI files and also output an extended version with an equal number of extended notes (although the time extension may differ due to how fast the notes are played)
- Appropriate models: We determined that our final model would be a learned sequence model, since  it can learn and take into account the entire piece for its extension. However, we also wanted to test it among simpler models such as our random baseline model and an extension of Markov Chain models that require far less power in exchange for not being able to account for the whole piece. 
- Optimizations: our first 2 models (Markov Chain) will focus on perplexity, while the third model (LSTM) optimizes its cross entropy loss while training. 

1. Discussion
- Baseline Model:
  - It is a random based trivial model that randomly picks notes to extend our base MIDI file.
  - Pros: 
    - No training, fast
  - Cons:
    - No musical structure at all
    - Terrible Perplexity
- Markov (No Seeding)
  - This markov model is trained over our whole dataset to learn the unigrams, bigrams, and trigrams of notes and beats. However it does not account for anything in the input MIDI file.
  - Pros
    - Simple implementation and fast production of output
    - The extended section will follow better musical structures learned in the dataset
  - Cons:
    - Can only learn short-context sequences to produce new notes
    - The musical structure may or may not differ from the original piece
- Markov (Seeding)
  - This markov model learns the unigrams, bigrams and trigrams similar to the previous. The difference is that it will learn the last few notes of the MIDI input file to match its musical structure in the extension. 
  - Pros:
    - Extension has similar musical structure to where the MIDI file ends
  - Cons: 
    - Still only learns short-context sequences
- LSTM 
  - This model is a stacked LSTM next token language model that learns the tokens from the REMI tokenizer.
    - Pros: 
      - Can learn very long-contexts for the extension
      - It can learn the rhythm and melody much better and account for when a change in rhythm is expected
    - Cons:
      - Complex, long training times
      - Can overfit and produce less unique music. 
- Complexity: LSTM is the most complex, followed by Markov then Random baselinee
- Efficiency: Random and Markov will train and produce extensions very quickly in a few seconds, while the LSTM took minutes to train
- Implementaiton challenges:
  - Random: Does not have a way to predict the rhythm and tempo at all
  - Markov: Can occassionally produce very different sounding music as it continues generating notes
  - LSTM: Hyper-parameters were difficult to tuen as it takes a long time to train the model.   

# Code Walk Through

# Imports


In [1]:
import os
import glob
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pretty_midi
from tqdm import tqdm
import mido
import random
from glob import glob
from collections import defaultdict
import numpy as np
from numpy.random import choice
from symusic import Score
from miditok import REMI, TokenizerConfig
from midiutil import MIDIFile
import torch.nn.functional as F

# Device setup
device = torch.device('cuda')

# Setup and default functions

In [2]:
# Initial tokenization happening in this block
midi_files = glob('C:/Users/sugia/Desktop/UCSD/CSE 153/A2/melody/*.mid')
len(midi_files)
config = TokenizerConfig(num_velocities=1, use_chords=False, use_programs=False)
tokenizer = REMI(config)
tokenizer.train(vocab_size=1000, files_paths=midi_files)

In [3]:
midi = Score(midi_files[0])
tokens = tokenizer(midi)[0].tokens
tokens[:10]

['Bar_None',
 'Position_16',
 'Pitch_76',
 'Velocity_127',
 'Duration_1.0.8',
 'Position_24',
 'Pitch_74',
 'Velocity_127',
 'Duration_2.0.8',
 'Bar_None']

In [4]:
def note_extraction(midi_file):
    # Q1a: Your code goes here
    score = Score(midi_file)
    tokens = tokenizer(score)[0].tokens
    pitches = [int(t.split('_')[1]) for t in tokens if t.startswith("Pitch_")]
    return pitches

def note_frequency(midi_files):
    # Q1b: Your code goes here
    freq = defaultdict(int)
    for file in midi_files:
        for pitch in note_extraction(file):
            freq[pitch] += 1
    return dict(freq)

def note_unigram_probability(midi_files):
    note_counts = note_frequency(midi_files)
    totalNotes = sum(note_counts.values())
    unigramProbabilities = {note : count/ totalNotes for note, count in note_counts.items()}
    
    # Q2: Your code goes here
    # ...
    
    return unigramProbabilities

def note_bigram_probability(midi_files):

    bigramTransitions = defaultdict(list)
    bigramTransitionProbabilities = defaultdict(list)

    # Q3a: Your code goes here
    bigramCounts = defaultdict(lambda: defaultdict(int))
    for file in midi_files:
        pitches = note_extraction(file)
        for i, j in zip(pitches, pitches[1:]):
            bigramCounts[i][j] += 1
    # ...
    T = {}
    P = {}
    for prev, next in bigramCounts.items():
        nextNotes = list(next.keys())
        counts = list(next.values())
        total = sum(counts)
        prob = [c/total for c in counts]
        bigramTransitions[prev] = nextNotes
        bigramTransitionProbabilities[prev] =  prob 

    

    return bigramTransitions, bigramTransitionProbabilities

def sample_next_note(note):
    bigramTransitions, bigramTransitionProbabilities = note_bigram_probability(midi_files)
    next = bigramTransitions.get(note)
    probs = bigramTransitionProbabilities.get(note)
    return random.choices(next, weights=probs, k=1)[0]

def note_trigram_probability(midi_files):
    trigramTransitions = defaultdict(list)
    trigramTransitionProbabilities = defaultdict(list)
    
    trigram_counts = defaultdict(lambda: defaultdict(int))
    for file in midi_files:
        pitch = note_extraction(file)
        for i in range(2, len(pitch)):
            note = (pitch[i-2], pitch[i-1])
            trigram_counts[note][pitch[i]] += 1
    # Q5a: Your code goes here
    # ...
    for note, next_dict in trigram_counts.items():
        notes  = list(next_dict.keys())
        counts = list(next_dict.values())
        total  = sum(counts)
        probs  = [c/total for c in counts]

        trigramTransitions[note] = notes
        trigramTransitionProbabilities[note] = probs

    return trigramTransitions, trigramTransitionProbabilities

duration2length = {
    '0.2.8': 2,  # sixteenth note, 0.25 beat in 4/4 time signature
    '0.4.8': 4,  # eighth note, 0.5 beat in 4/4 time signature
    '1.0.8': 8,  # quarter note, 1 beat in 4/4 time signature
    '2.0.8': 16, # half note, 2 beats in 4/4 time signature
    '4.0.4': 32, # whole note, 4 beats in 4/4 time signature
}

def beat_extraction(midi_file):
    score  = Score(midi_file)
    tokens = tokenizer(score)[0].tokens
    output = []

    for i, tok in enumerate(tokens):
        if tok.startswith("Position_"):
            position = int(tok.split("_",1)[1])
            # look ahead for the duration token
            if i+3 < len(tokens) and tokens[i+3].startswith("Duration_"):
                dur_str = tokens[i+3].split("_",1)[1]
                length = duration2length.get(dur_str, 0)
                # *** skip any zero‐length ***
                if length > 0:
                    output.append((position, length))
    return output

def beat_bigram_probability(midi_files):
    bigramBeatTransitions = defaultdict(list)
    bigramBeatTransitionProbabilities = defaultdict(list)
    counts = defaultdict(lambda: defaultdict(int))
    for file in midi_files:
        beat = beat_extraction(file)
        lengths = [length for _,length in beat]
        for i,j in zip(lengths, lengths[1:]):
            counts[i][j] += 1
    for prev, next in counts.items():
        nextVal = list(next.keys())
        cnts = list(next.values())
        total = sum(cnts)
        probs = [c/total for c in cnts]
        bigramBeatTransitions[prev] = nextVal
        bigramBeatTransitionProbabilities[prev] = probs
    # Q7: Your code goes here
    # ...
    
    return bigramBeatTransitions, bigramBeatTransitionProbabilities

def beat_pos_bigram_probability(midi_files):
    bigramBeatPosTransitions = defaultdict(list)
    bigramBeatPosTransitionProbabilities = defaultdict(list)
    
    # Q8a: Your code goes here
    # ...
    counts = defaultdict(lambda: defaultdict(int))
    for file in midi_files:
        for position, length in beat_extraction(file):
            counts[position][length] += 1

    
    for position, length in counts.items():
        vals  = list(length.keys())
        cnts  = list(length.values())
        total = sum(cnts)
        probs = [c/total for c in cnts]
        bigramBeatPosTransitions[position]= vals
        bigramBeatPosTransitionProbabilities[position] = probs

    return bigramBeatPosTransitions, bigramBeatPosTransitionProbabilities

def beat_trigram_probability(midi_files):
    trigramBeatTransitions = defaultdict(list)
    trigramBeatTransitionProbabilities = defaultdict(list)

    # Q9a: Your code goes here
    # ...
    trigram_counts = defaultdict(lambda: defaultdict(int))
    for file in midi_files:
        beats = beat_extraction(file)
        for (prev_pos, prev_len), (pos, length) in zip(beats, beats[1:]):
            pair = (prev_len, pos)
            trigram_counts[pair][length] += 1

    trigramBeatTransitions = {}
    trigramBeatTransitionProbabilities = {}
    for pair, next in trigram_counts.items():
        next_lengths = list(next.keys())
        counts = list(next.values())
        total = sum(counts)
        probs = [c/total for c in counts]
        trigramBeatTransitions[pair] = next_lengths
        trigramBeatTransitionProbabilities[pair] = probs
    return trigramBeatTransitions, trigramBeatTransitionProbabilities

# Baseline Random Model

In [None]:


def extend_random_midi(
    in_path: str,
    out_path: str,
    length: int = 100,
    track_idx: int = 1,
    channel: int = 0,
    velocity: int = 100
):
    # Load all the tracks
    mid   = mido.MidiFile(in_path)
    track = mid.tracks[track_idx]
    tpb   = mid.ticks_per_beat

    # remove only the final End-Of-Track
    if track and track[-1].is_meta and track[-1].type == "end_of_track":
        track.pop()

    # possible lengths in your 1/16-note grid
    choices_bl = [2, 4, 8, 16, 32]
    for _ in range(length):
        pitch = random.randint(21, 108)
        bl    = random.choice(choices_bl)
        ticks = int((bl / 8.0) * tpb)
        # based on the random choice add the note on and offs with those values
        track.append(mido.Message(
            "note_on",
            note=pitch,
            velocity=velocity,
            time=0,
            channel=channel
        ))
        track.append(mido.Message(
            "note_off",
            note=pitch,
            velocity=0,
            time=ticks,
            channel=channel
        ))

    # re-add End-Of-Track and save
    track.append(mido.MetaMessage("end_of_track", time=0))
    mid.save(out_path)

In [7]:
extend_random_midi(
    "melody/trimmed_20s/ashover4.mid",
    "ashover4_baseline_ext.mid",
    length=20,
    track_idx=1,
    channel=0,
    velocity=90
)

# Markov Chain Extension (No dependency on current MIDI) expansion 

In [None]:
def generate_notes_and_beats(length):
    # Use the uni , bi, and trigrams of notes and beats here
    unigramProbabilities = note_unigram_probability(midi_files)
    bigramTransitions, bigramTransitionProbabilities = note_bigram_probability(midi_files)
    trigramTransitions, trigramTransitionProbabilities = note_trigram_probability(midi_files)
    
    notes = []
    for i in range(length):
        # Choose first note from unigram
        if i == 0:
            next = list(unigramProbabilities.keys())
            weights= list(unigramProbabilities.values())
        elif i == 1:
            # Choose second note from bigram
            prev = notes[-1]
            next = bigramTransitions.get(prev, list(unigramProbabilities.keys()))
            weights= bigramTransitionProbabilities.get(prev, list(unigramProbabilities.values()))
        else:
            # choose every other note using trigrams
            prev2, prev1 = notes[-2], notes[-1]
            pair = (prev2, prev1)
            if pair in trigramTransitions:
                next, weights = trigramTransitions[pair], trigramTransitionProbabilities[pair]
            elif prev1 in bigramTransitions:
                next, weights = bigramTransitions[prev1], bigramTransitionProbabilities[prev1]
            else:
                next, weights = list(unigramProbabilities.keys()), list(unigramProbabilities.values())

        notes.append(random.choices(next, weights, k=1)[0])
        
    bigramBeatPosTransitions, bigramBeatPosTransitionProbabilities = beat_pos_bigram_probability(midi_files)
    pos = 0
    beats = []
    # The way we choose the beats is basically the same as above for notes except we are using the bigrams for beats and positions
    for _ in range(length):
        next = bigramBeatPosTransitions.get(pos, [8]) 
        weights= bigramBeatPosTransitionProbabilities.get(pos, None)
        if weights is None:
            weights = [1]*len(next)
        bl = random.choices(next, weights, k=1)[0]
        beats.append(bl)
        pos = (pos + bl) % 32
    return notes, beats



def extend_midi(
    input_midi_path: str,
    output_midi_path: str,
    length: int,
    track_idx: int = 1,
    channel: int   = 0,
    velocity: int  = 100
):
    # load your base file
    mid = mido.MidiFile(input_midi_path)

    # pick the track you want to append to and remove its EndOfTrack
    track = mid.tracks[track_idx]
    track[:] = [msg for msg in track
                if not (msg.is_meta and msg.type == 'end_of_track')]

    # generate your notes + beats
    notes, beats = generate_notes_and_beats(length)

    # 4) append each new note
    for pitch, bl in zip(notes, beats):
        # convert your “beats” to delta-ticks
        dur_beats = bl / 8.0              # same scale you used in addNote()
        dur_ticks = int(dur_beats * mid.ticks_per_beat)

        # note_on at delta=0 (immediately after previous event)
        track.append(mido.Message('note_on',
                                  note=pitch,
                                  velocity=velocity,
                                  time=0,
                                  channel=channel))
        # note_off after dur_ticks
        track.append(mido.Message('note_off',
                                  note=pitch,
                                  velocity=0,
                                  time=dur_ticks,
                                  channel=channel))

    # finally close the track again
    track.append(mido.MetaMessage('end_of_track', time=0))

    # write out a brand-new file
    mid.save(output_midi_path)

In [64]:
extend_midi("melody/trimmed_20s/ashover4.mid", "ashover4_random_extended.mid", length=50)

# Markov Chain Extension + Seeding (Ensures reading context of the current MIDI) Expansion 

In [None]:


def generate_notes_and_beats(length,unigramP, bigramT, bigramP,trigramT, trigramP,beat_pos_T, beat_pos_P,seed_pitches=None,seed_beats=None):
    # start with the seed from the midi file input
    notes = list(seed_pitches or [])
    beats = list(seed_beats   or [])

    # Figures out which notes to add based on the the usual uni, bi and trigrams from before but under the context where the MIDI ended
    for i in range(length):
       
        if len(notes) == 0:
            choices = list(unigramP.keys())
            weights = list(unigramP.values())
        elif len(notes) == 1:
            prev    = notes[-1]
            choices = bigramT.get(prev, list(unigramP.keys()))
            weights = bigramP.get(prev, list(unigramP.values()))
        else:
            prev2, prev1 = notes[-2], notes[-1]
            pair = (prev2, prev1)
            if pair in trigramT:
                choices = trigramT[pair]
                weights = trigramP[pair]
            else:
                choices = bigramT.get(prev1, list(unigramP.keys()))
                weights = bigramP.get(prev1, list(unigramP.values()))

        new_note = random.choices(choices, weights, k=1)[0]
        notes.append(new_note)

        # Beat generation just like in the previous model
        pos = sum(beats) % 32
        next_beats = beat_pos_T.get(pos, [8])
        next_wghts = beat_pos_P.get(pos, [1])
        new_bl = random.choices(next_beats, next_wghts, k=1)[0]
        beats.append(new_bl)

    return notes, beats

# 1) extract the last N seed notes & beats from an existing track
def extract_seed(input_midi_path, track_idx=1, n_seed=2):
    mid = mido.MidiFile(input_midi_path)
    ticks_per_beat = mid.ticks_per_beat
    track = mid.tracks[track_idx]

    time_cursor = 0
    active = {}       # note_on time for each pitch
    all_pitches = []
    all_beats   = []

    # This part extracts the last few notes and beats in the midi file
    for msg in track:
        time_cursor += msg.time
        # note-on
        if msg.type == 'note_on' and msg.velocity > 0:
            active[msg.note] = time_cursor
        # note-off
        elif (msg.type == 'note_off' or (msg.type=='note_on' and msg.velocity==0)) \
             and msg.note in active:
            start = active.pop(msg.note)
            dt = time_cursor - start
            # convert dt→your “beat units” (you used bl/8.0 earlier)
            bl = int((dt / ticks_per_beat) * 8)
            if bl <= 0:
                bl = 1
            all_pitches.append(msg.note)
            all_beats.append(bl)

    # take the last n_seed values
    seed_pitches = all_pitches[-n_seed:]
    seed_beats   = all_beats[-n_seed:]
    return seed_pitches, seed_beats

In [None]:
def extend_with_continuation(in_path, out_path, length):
    # pull off the last two notes/beats as seed
    seed_n, seed_b = extract_seed(in_path, track_idx=1, n_seed=2)

    # precompute your uni, bi, trigram and beat-pos bigrams exactly as before
    U   = note_unigram_probability(midi_files)
    BT, BP = note_bigram_probability(midi_files)
    TT, TP = note_trigram_probability(midi_files)
    bPT, bPP = beat_pos_bigram_probability(midi_files)

    # 3) generate new events continuing from the seed
    notes, beats = generate_notes_and_beats(length,
                                            U, BT, BP,
                                            TT, TP,
                                            bPT, bPP,
                                            seed_pitches=seed_n,
                                            seed_beats=seed_b)

    # now use mido to read, strip EndOfTrack, append your new notes, re-add EndOfTrack
    mid = mido.MidiFile(in_path)
    track = mid.tracks[1]
    track[:] = [m for m in track if not (m.is_meta and m.type=='end_of_track')]

    for pitch, bl in zip(notes[len(seed_n):], beats[len(seed_b):]):
        dur_ticks = int((bl/8.0) * mid.ticks_per_beat)
        track.append(mido.Message('note_on',  note=pitch, velocity=100, time=0,           channel=0))
        track.append(mido.Message('note_off', note=pitch, velocity=0,   time=dur_ticks, channel=0))

    track.append(mido.MetaMessage('end_of_track', time=0))
    mid.save(out_path)

In [67]:
extend_with_continuation("melody/trimmed_20s/ashover4.mid", "ashover4_seeding_extended.mid", length=50)

# LSTM Model for Extension

In [11]:


# helper function to convert midi to token ids
def midi_to_token_ids(path, tokenizer):
    score  = Score(path)
    tokens = tokenizer(score)[0].tokens
    # look up each token in the tokenizer’s vocab
    return [tokenizer.vocab[t] for t in tokens]


# Our model
class REMILanguageModel(nn.Module):
    def __init__(self, vocab_size, emb_size=256, hidden=512, nlayers=2):
        super().__init__()
        # This is the token embedding layer
        # map REMI token to a vector size 256
        # used to learn pitch, duration and bar
        self.embed = nn.Embedding(vocab_size, emb_size)
        # 2 Stacked LSTM layers
        self.lstm  = nn.LSTM(emb_size, hidden, nlayers, batch_first=True)
        # Takes each LSTM hidden-vector and linearly maps
        # it back to `vocab_size` logits for next-token classification.
        self.fc    = nn.Linear(hidden, vocab_size)

    def forward(self, x, hidden=None):
        # x: [B, T] token-ids
        e, hidden = self.embed(x), hidden # embedding here
        out, hidden = self.lstm(e, hidden) # Run LSTM here
        logits = self.fc(out)        # Projecting outputs here
        return logits, hidden


In [12]:

# This section is responsible for the training of our model
class SeqDataset(Dataset):
    def __init__(self, data, seq_len=128):
        # data: a single long list of token-IDs
        self.v, self.L = torch.tensor(data), len(data)
        self.seq_len = seq_len

    def __len__(self):
        return (self.L - 1) // self.seq_len

    def __getitem__(self, i):
        i0 = i * self.seq_len
        x  = self.v[i0 : i0 + self.seq_len]
        y  = self.v[i0 + 1 : i0 + 1 + self.seq_len]
        return x, y

device = torch.device("cuda")
print("Training on:", device)

# assemble your corpus
all_ids = []
bar_id   = tokenizer.vocab["Bar_None"]

for f in midi_files:
    all_ids.extend(midi_to_token_ids(f, tokenizer))
    all_ids.append(bar_id)

ds     = SeqDataset(all_ids, seq_len=256)
loader = DataLoader(ds, batch_size=16, shuffle=True, drop_last=True)


model = REMILanguageModel(tokenizer.vocab_size).to(device)
opt   = torch.optim.Adam(model.parameters(), 1e-3)
ce    = nn.CrossEntropyLoss()
# Training for 5 epochs currently and optimizes for cross entropy loss
# for epoch in range(5):
#     model.train()
#     total_loss = 0
#     for x, y in loader:
#         logits, _ = model(x)            # [B, T, V]
#         loss = ce(logits.view(-1, logits.size(-1)),
#                   y.view(-1))
#         opt.zero_grad(); loss.backward()
#         opt.step()
#         total_loss += loss.item()
#     print(f"Epoch {epoch}: {total_loss/len(loader):.4f}")

for epoch in range(20):
    model.train()
    total_loss = 0.0
    for x, y in loader:
        # 🟢 move each batch to GPU (non_blocking since pin_memory=True)
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        opt.zero_grad()
        logits, _ = model(x)                          # now runs on GPU
        B, T, V    = logits.shape
        loss       = ce(logits.view(-1, V), y.view(-1))
        loss.backward()
        opt.step()

        total_loss += loss.item()

    print(f"Epoch {epoch}: {total_loss/len(loader):.4f}")

Training on: cuda
Epoch 0: 1.7900
Epoch 1: 1.0184
Epoch 2: 0.8503
Epoch 3: 0.7478
Epoch 4: 0.6852
Epoch 5: 0.6362
Epoch 6: 0.6040
Epoch 7: 0.5787
Epoch 8: 0.5564
Epoch 9: 0.5377
Epoch 10: 0.5198
Epoch 11: 0.5040
Epoch 12: 0.4897
Epoch 13: 0.4717
Epoch 14: 0.4560
Epoch 15: 0.4388
Epoch 16: 0.4223
Epoch 17: 0.4025
Epoch 18: 0.3853
Epoch 19: 0.3746


In [13]:


def sample_from_model(model, seed_ids, length, top_k=5, temperature=1.0, device="cuda"):
    model.eval()
    seq = seed_ids.copy()
    model = model.to(device)

    # Start generating the new tokens
    with torch.no_grad():
        hidden = None
        # prime with the seed
        for tok in seed_ids:
            inp = torch.tensor([[tok]], device=device)
            _, hidden = model(inp, hidden)

        for _ in range(length):
            # look at the last token, then get the next one
            inp    = torch.tensor([[seq[-1]]], device=device)
            logits, hidden = model(inp, hidden)      # [1,1,vocab]
            logits = (logits[0, -1] / temperature)    # [vocab]
            # Get the top k highest logits and sample the highest probability  for us to append
            vals, idxs = torch.topk(logits, top_k)    # top logits
            probs      = F.softmax(vals, dim=0)       # top probabilities
            choice     = torch.multinomial(probs, 1)  # Sample the new choice
            new_id     = idxs[choice].item()
            seq.append(new_id) # Add the new token

    return seq


In [17]:
# Pares the last 64 notes for context for how to start the extension
def extract_token_seed(path, tokenizer, n_seed=64):
    full_ids = midi_to_token_ids(path, tokenizer)
    # grab the final n_seed IDs
    return full_ids[-n_seed:]

def note_extraction(midi_or_tokens):
    # if someone passed in a bare list of tokens, just use it
    if isinstance(midi_or_tokens, list):
        tokens = midi_or_tokens
    else:
        # otherwise assume it's a file‐path or MidiFile and run Score()
        score  = Score(midi_or_tokens)
        tokens = tokenizer(score)[0].tokens

    pitches = [int(t.split("_",1)[1])
               for t in tokens
               if t.startswith("Pitch_")]
    return pitches

def beat_extraction(midi_or_tokens):
    if isinstance(midi_or_tokens, list):
        tokens = midi_or_tokens
    else:
        score  = Score(midi_or_tokens)
        tokens = tokenizer(score)[0].tokens

    output = []
    for i, tok in enumerate(tokens):
        if tok.startswith("Position_"):
            pos = int(tok.split("_",1)[1])
            # find the Duration_
            if i+3 < len(tokens) and tokens[i+3].startswith("Duration_"):
                dur_str = tokens[i+3].split("_",1)[1]
                bl      = duration2length.get(dur_str, 0)
                if bl > 0:
                    output.append((pos, bl))
    return output

id2token = {v:k for k,v in tokenizer.vocab.items()}

# Convert from tokens back into pitches and beats
def notes_beats_from_ids(ids: list[int]):
    
    # 1) decode into REMI tokens
    events = [id2token[i] for i in ids]

    # 2) fake‐up a minimal Score-like container for your old extractors
    class FakeScore:
        def __init__(self, tokens): self.tokens = tokens
    fake = FakeScore(events)

    # 3) reuse your existing functions
    pitches = note_extraction(fake.tokens)
    beats   = beat_extraction(events)
    return pitches, beats
# add the new notes to the original MIDI
def extend_midi(
    in_path: str,
    out_path: str,
    new_pitches: list[int],
    new_beats:   list[int],
    track_idx:  int = 1,
    channel:    int = 0,
    velocity:   int = 100
):
    mid   = mido.MidiFile(in_path)
    track = mid.tracks[track_idx]
    tpb   = mid.ticks_per_beat

    # drop only the final End of track
    if track and track[-1].is_meta and track[-1].type=="end_of_track":
        track.pop()
    # Append the new beats and notes again
    for p, bl in zip(new_pitches, new_beats):
        ticks = int((bl/8.0)*tpb)
        track.append(mido.Message("note_on",  note=p, velocity=velocity, time=0,      channel=channel))
        track.append(mido.Message("note_off", note=p, velocity=0,       time=ticks, channel=channel))

    track.append(mido.MetaMessage("end_of_track", time=0))
    mid.save(out_path)


# Running everything to make the new file
if __name__ == "__main__":
    IN  = "melody/trimmed_20s/ashover4.mid"
    OUT = "song_neural_ext.mid"

    # a) grab the last 64 REMI-IDs as seed
    seed_ids = extract_token_seed(IN, tokenizer, n_seed=64)

    # b) sample 256 new IDs
    sampled = sample_from_model(model, seed_ids, length=256, top_k=5, temperature=1.0)

    # c) decode & extract pitch/beat
    new_p, new_b = notes_beats_from_ids(sampled)
 
    # d) append them onto your original MIDI
    # lengths_only = [ length for (_pos, length) in new_b ]

    min_bl = 4   # quarter-note at 120 BPM

    # 1) unpack each (pos,bl) → apply clamp only to bl
    clamped_beats = [ max(bl, min_bl) for (_, bl) in new_b ]

    # 2) still drop the positional seed if you want:
    lengths_only = clamped_beats[len(new_b) - len(clamped_beats):]

    extend_midi(
        IN,
        OUT,
        new_p,          # [pitch1, pitch2, …]
        lengths_only,   # [len1,   len2,   …]
        track_idx=1,
        channel=0
    )
    print("Wrote →", OUT)

Wrote → song_neural_ext.mid
