In [None]:
pip uninstall torch


In [5]:
import torch
import onnxruntime
import pretty_midi

# Helper function to convert MIDI to text (as provided)
def midi_to_text(midi_file):
    midi_data = pretty_midi.PrettyMIDI(midi_file)
    text_data = []
    
    for instrument in midi_data.instruments:
        for note in instrument.notes:
            start_time = note.start
            end_time = note.end
            instrument_name = instrument.name if instrument.name else 'Unnamed'
            note_name = pretty_midi.note_number_to_name(note.pitch)
            start_beat = midi_data.get_beats(start_time)[0] if midi_data.get_beats(start_time) else 0
            end_beat = midi_data.get_beats(end_time)[0] if midi_data.get_beats(end_time) else 0
            note_value = note.pitch
            duration = end_time - start_time  
            velocity = note.velocity  
            
            text_line = f"{start_time} {end_time} {instrument_name} {note_name} {start_beat} {end_beat} {note_value} {duration} {velocity}"
            text_data.append(text_line)
    
    return '\n'.join(text_data)

# Helper function to convert text back to MIDI (as provided)
def text_to_midi(text_data):
    midi_data = pretty_midi.PrettyMIDI()
    instrument = pretty_midi.Instrument(program=0)
    
    for line in text_data.split('\n'):
        start_time, end_time, instrument_name, note_name, start_beat, end_beat, note_value, duration, velocity = line.split()
        note = pretty_midi.Note(
            velocity=int(velocity),
            pitch=int(note_value),
            start=float(start_time),
            end=float(end_time)
        )
        instrument.notes.append(note)
    
    midi_data.instruments.append(instrument)
    return midi_data

# Function to predict the next notes based on a text input
def predict_next_notes(onnx_model_path, text_data, max_length=100):
    # Load the ONNX model
    ort_session = onnxruntime.InferenceSession(onnx_model_path)
    
    # Tokenize the input text data
    input_ids = tokenizer.encode(text_data, return_tensors="pt")  # Ensure tokenizer is loaded in your environment

    # Initialize the generated sequence
    generated_ids = input_ids

    for _ in range(max_length):
        # Run the model on the current sequence
        ort_inputs = {ort_session.get_inputs()[0].name: generated_ids.numpy()}
        ort_outs = ort_session.run(None, ort_inputs)
        
        # Get the predicted next token and append it to the sequence
        next_token_id = torch.argmax(torch.tensor(ort_outs[0]), dim=-1)[:, -1].unsqueeze(-1)
        generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)
        
        # Stop if the model generates an end-of-sequence token (check for EOS token, if defined)
        if next_token_id.item() == tokenizer.eos_token_id:
            break

    # Decode the generated sequence back to text
    predicted_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    return predicted_text




AttributeError: partially initialized module 'torch' has no attribute 'version' (most likely due to a circular import)

In [None]:
# Example usage: Load a MIDI file, predict next notes, and convert back to MIDI
midi_file_path = "path_to_your_midi_file.mid"
onnx_model_path = "fine_tuned_gpt2.onnx"

# Convert the MIDI file to text format
text_data = midi_to_text(midi_file_path)

# Predict the next notes in text format
predicted_text = predict_next_notes(onnx_model_path, text_data)

# Convert the predicted text back to MIDI format
predicted_midi = text_to_midi(predicted_text)

# Save the generated MIDI
predicted_midi.save("predicted_output.mid")