In [381]:
import mido
import pretty_midi
import multiprocessing
import time
import json
import torch
import re
from copy import deepcopy
from transformers import AutoModelForCausalLM, AutoTokenizer
from dataclasses import dataclass

# Option 2: Suppress warnings by setting the logging level to error
from transformers import logging
logging.set_verbosity_error()

@dataclass
class Note():
    start: float
    end: float
    pitch: int
    velocity: int
    channel: int

def conv_channel(channel):
    channel_map = {
        0: "%",
        1: "^",
        2: "&",
        3: "*",
        4: ";",
        5: ":",
        6: "'",
        7: '"',
        9: ")",
        10: "{",
        11: "}",
        12: "[",
        13: "]",
        14: "(",
    }

    if isinstance(channel, int):
        if channel in channel_map.keys():
            return channel_map[channel]
        return "%"
    elif isinstance(channel, str):
        # Invert the mapping
        channel_map = {
            v: k for k, v in channel_map.items()
        }
        if channel in channel_map.keys():
            return channel_map[channel]
        return 0
    else:
        raise "Wrong type for channel"

def conv_velocity(velocity):
    velocity_map = {
        48: "!",
        60: "@",
        100: "#",
        127: "$",
    }
    if isinstance(velocity, int):
        for i in velocity_map.keys():
            if velocity <= i:
                return velocity_map[i]
        return "@"
    elif isinstance(velocity, str):
        # Invert the mapping
        velocity_map = {
            v: k for k, v in velocity_map.items()
        }
        return velocity_map[velocity]
    else:
        raise "Wrong type for velocity"

def str_to_mido(note_str):
    res_notes = []
    mido_stack = []
    last_note = 0
    for token in note_str.split('|'):
        pattern = re.compile(r"(\d+)(\D)(\d+)(\D)(\d+)")
        m = pattern.match(token)
        if m:
            # res_notes.append(Note(
            #     start = mido.tick2second(int(m.group(1)), 480, 500000),
            #     velocity = conv_velocity(m.group(2)),
            #     end = mido.tick2second(int(m.group(3)), 480, 500000),
            #     channel = conv_channel(m.group(4)),
            #     pitch = int(m.group(5))
            # ))

            mido_stack.append(mido.Message(
                'note_on', note = int(m.group(5)), velocity = conv_velocity(m.group(2)), time = mido.tick2second(int(m.group(1)), 480, 500000) + last_note
            ))

            mido_stack.append(mido.Message(
                'note_off', note = int(m.group(5)), velocity = conv_velocity(m.group(2)), time = mido.tick2second(int(m.group(1)), 480, 500000) + mido.tick2second(int(m.group(3)), 480, 500000) + last_note
            ))

            last_note += mido.tick2second(int(m.group(1)), 480, 500000)

    return mido_stack



    mido_stack = []
    last_note = 0
    for note in res_notes:
        mido_stack.append(mido.Message(
            'note_on', note = note.pitch, velocity = note.velocity, time = note.start + last_note
        ))

        mido_stack.append(mido.Message(
            'note_off', note = note.pitch, velocity = note.velocity, time = note.start + note.end
        ))

        last_note += note.start
    mido_stack.sort(key = lambda x: x.time)
    return mido_stack

In [2]:
model_name = "kobimusic/esecutore-4-0619"
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps'

model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.to(DEVICE)
None



In [None]:
notes = []
active_notes = {}

try:
    with mido.open_input('Launchkey Mini MK3 MIDI Port') as inport, mido.open_output('IAC Driver Bus 1') as outport:
        start_time = time.time()
        for msg in inport:
            if msg.type in ['note_on', 'note_off']:
                if msg.type == 'note_on' and msg.channel == 0 and msg.velocity > 0:
                    active_notes[msg.note] = (time.time() - start_time, msg.velocity)
                elif msg.type == 'note_off' or (msg.type == 'note_on' and msg.velocity == 0) and msg.channel == 0 and msg.note in active_notes:
                    start, velocity = active_notes.pop(msg.note)
                    notes.append(Note(
                        start = start,
                        end = time.time() - start_time,
                        pitch = msg.note,
                        velocity = velocity,
                        channel = 0
                    ))
            outport.send(msg)
except:
    pass

notes.sort(key = lambda x: x.start)

offset = notes[0].start
for note in notes:
    note.start -= offset
    note.end -= offset

In [25]:
ticking_notes = []

for note in notes:
    ticking_notes.append(Note(
        start = mido.second2tick(note.start, 24, 500000),
        end = mido.second2tick(note.end, 24, 500000),
        pitch = note.pitch,
        velocity = note.velocity,
        channel = note.channel,
    ))

In [336]:
str_note = ''
prev_start = 0
for note in ticking_notes:
    str_note += str(note.start - prev_start) + str(conv_velocity(note.velocity)) + str(note.end - note.start) + str(conv_channel(note.channel)) + str(note.pitch) + "|"
    prev_start = note.start

In [337]:
tokens = f'. classical {str_note}'

In [341]:
import re
def convert_ppqn(data, factor=20):
    # This regex matches a delimiter ($, @, #, or !) followed by one or more digits
    def replace_tick(match):
        symbol = match.group(1)
        tick = int(match.group(2))
        return f"{symbol}{int(tick * factor)}"
    return re.sub(r'([$@#!])(\d+)', replace_tick, data)
tokens = convert_ppqn(tokens, factor = 480 // 24)

In [382]:
decoded_tokens = tokenizer.encode(tokens)
ins = torch.tensor([decoded_tokens], device=DEVICE)
init_shape = ins.size(1) - 1
attention_mask = torch.ones_like(ins)
generated_notes = 0
played_ = 0
i = 0
coll = ''
while generated_notes <= (20):
    res = model.generate(
        ins[:, i:],
        attention_mask = attention_mask,
        use_cache=False,
        max_new_tokens=1,
        do_sample=True,
        temperature=0.89,
        top_p=1.0,
        num_return_sequences=1,
    )
    ins = torch.cat((ins, res[:, -1][:, None]), dim=1)
    i += 1

    coll = tokenizer.batch_decode(ins[:, len(decoded_tokens) + played_:].cpu().detach())[0]

    if '|' in coll:
        generated_notes += 1
        print(ins[:, len(decoded_tokens) + played_:].cpu().detach().shape[1])
        print(coll, generated_notes)

        mido_stack = str_to_mido(tokenizer.batch_decode(ins[:, len(decoded_tokens) + played_:].cpu().detach())[0])
        
        if len(mido_stack) > 0:
            with mido.open_output('IAC Driver Bus 1') as outport:
                st_time = time.time()
                while len(mido_stack) > 0:
                    if mido_stack[0].time <= time.time() - st_time:
                        msg = mido_stack.pop(0)
                        outport.send(msg)

        played_ += ins[:, len(decoded_tokens) + played_:].cpu().detach().shape[1]

6
34$120%62| 1
7
6$3880%31| 2
7
0!3820&41| 3
6
12$200%65| 4
7
1$3880%31| 5
7
2!3880&79| 6
6
0$240)36| 7
6
0!240)42| 8
6
0#240)55| 9
6
0$440)42| 10


KeyboardInterrupt: 

In [368]:
tokenizer.batch_decode(torch.tensor([[1433,    3,  940, 1238,    4, 1899,   91]]))[0]

'16$1020%60|'

In [317]:
result_tokens = tokenizer.batch_decode(ins[:, torch.tensor([decoded_tokens], device=DEVICE).shape[1]:].cpu().detach())[0]
result_tokens, generated_notes

mido_stack = str_to_mido(result_tokens)

with mido.open_output('IAC Driver Bus 1') as outport:
    st_time = time.time()
    while len(mido_stack) > 0:
        if mido_stack[0].time <= time.time() - st_time:
            msg = mido_stack.pop(0)
            print(msg)
            outport.send(msg)

note_on channel=0 note=55 velocity=60 time=0.022916666666666665
note_off channel=0 note=55 velocity=60 time=0.46041666666666664


In [320]:
result_tokens = tokenizer.batch_decode(ins[:, torch.tensor([decoded_tokens], device=DEVICE).shape[1]:].cpu().detach())[0]
result_tokens

'2'