In [1]:
import torch as t
import torch.distributions as dist

import mido
from mido import Message, MidiFile, MidiTrack, MetaMessage

from os.path import exists
import time

In [2]:
tf = 10
alpha = 0.0001
n_epochs = 1000

s_st = 30.
s_ts = 12.
s_ex = 15.

r_pc = 50.
r_ve = 50.
r_int= 50.

# p_st = t.tensor([1.]*tf)
# p_ts = t.tensor([0.]*tf)
# p_ex = t.tensor([0.]*tf)
# P_st = dist.Bernoulli(p_st)
# P_ts = dist.Bernoulli(p_ts)
# P_ex = dist.Bernoulli(p_ex)
# p_st = P_st.sample()
# p_ts = P_ts.sample()
# p_ex = P_ex.sample()

p_st = t.tensor([1., 1., 0., 1., 0., 1., 1., 0., 1., 1.])
p_ts = t.tensor([0., 0., 0., 0., 1., 0., 0., 0., 0., 0.])
p_ex = t.tensor([0., 0., 1., 0., 0., 0., 0., 1., 0., 0.])


u_pc = t.rand(tf, requires_grad=True)*23
u_ve = t.rand(tf, requires_grad=True)

x_st_0 = t.tensor(1.)
x_ts_0 = t.tensor(1.)
x_ex_0 = t.tensor(1.)

# 状態方程式からｘを求める

In [3]:
def calc_x(u_pc, u_ve, x_st_0, x_ts_0, x_ex_0, tf):
    x_st = []
    x_ts = []
    x_ex = []

    x_st.append(x_st_0)
    # x_st.append(-t.abs(u_pc[1]-u_pc[0]) % 2 - t.abs(u_ve[1]-u_ve[0]))
    for i in range(1,tf):
        x_st.append(t.abs((t.abs(u_pc[i]-u_pc[i-1]) % 2 - 1)) - t.abs(u_ve[i]-u_ve[i-1]))
    
    x_ts.append(x_ts_0)
    # x_ts.append(x_st[0] * t.abs(u_pc[1]-u_pc[0]) % 2)
    for i in range(1,tf):
        x_ts.append(-x_st[i-1] * t.abs((t.abs(u_pc[i]-u_pc[i-1]) % 2 -1)))
    
    x_ex.append(x_ex_0)
    # x_ex.append( (u_pc[1]-u_pc[0]) + (u_ve[1]-u_ve[0]))
    for i in range(1,tf):
        x_ex.append( (u_pc[i]-u_pc[i-1]) + (u_ve[i]-u_ve[i-1]))

    x_st = t.stack(x_st)
    x_ts = t.stack(x_ts)
    x_ex = t.stack(x_ex)

    return x_st, x_ts, x_ex

# ステージコストを求める

In [4]:
def calc_L(u_pc, u_ve, x_st, x_ts, x_ex, p_st, p_ts, p_ex, s_st, s_ts, s_ex, r_pc, r_ve, tf):
    L = []
    for i in range(tf):
        bad_feeling = -p_st[i] * s_st * (x_st[i])\
                      -p_ts[i] * s_ts * (x_ts[i])\
                      -p_ex[i] * s_ex * (x_ex[i])
        penalty = r_pc * (-t.min(t.tensor(0), u_pc[i]) - t.min(t.tensor(0), 23-u_pc[i]))\
                + r_int* (t.abs(u_pc[i]-t.round(u_pc[i])))\
                + r_ve * (-t.min(t.tensor(0), u_ve[i]) - t.min(t.tensor(0), 1-u_ve[i]))

        L.append(bad_feeling + penalty)
    L = t.stack(L)
    return L
        

# λを求めて、Hを計算する

In [5]:
def calc_H(L, x_st, x_ts, x_ex, tf):
    H_reversed = []
    lamda_st = 0.
    lamda_ts = 0.
    lamda_ex = 0.
    for i in range(tf-2,-1,-1):
        H = L[i] + lamda_st * x_st[i+1] + lamda_ts * x_ts[i+1] + lamda_ex * x_ex[i+1]
        H_reversed.append(H)
        x_st.retain_grad()
        x_ts.retain_grad()
        x_ex.retain_grad()
        H.backward(retain_graph=True)
        # H.backward()
        lamda_st = x_st.grad[i]
        lamda_ts = x_ts.grad[i]
        lamda_ex = x_ex.grad[i]
    H_list = t.stack(H_reversed)
    H = t.flip(H_list,[0])
    return H

# 学習反復

In [6]:
for i in range(n_epochs):
    new_u_pc = t.zeros(tf)
    new_u_ve = t.zeros(tf)
    x_st, x_ts, x_ex = calc_x(u_pc, u_ve, x_st_0, x_ts_0, x_ex_0, tf)
    L = calc_L(u_pc, u_ve, x_st, x_ts, x_ex, p_st, p_ts, p_ex, s_st, s_ts, s_ex, r_pc, r_ve, tf)
    H = calc_H(L, x_st, x_ts, x_ex, tf)
    for j in range(tf-1):
        u_pc.retain_grad()
        u_ve.retain_grad()
        H[j].backward(retain_graph=True)
        new_u_pc[j] = u_pc[j] - alpha * u_pc.grad[j]
        new_u_ve[j] = u_ve[j] - alpha * u_ve.grad[j]
    u_pc = new_u_pc.detach()
    u_ve = new_u_ve.detach()
    u_pc = t.tensor(u_pc,requires_grad=True)
    u_ve = t.tensor(u_ve,requires_grad=True)

    if i % 100 == 0 or i == n_epochs-1:
        print(i)

0
100
200
300
400
500
600
700
800
900
999


In [7]:
u_pc = t.round(u_pc)
print(u_pc)
print(u_ve)

tensor([10., 10.,  5.,  8., 13., 19.,  5.,  9., 23.,  0.],
       grad_fn=<RoundBackward>)
tensor([ 0.0151, -0.0031,  0.9890,  1.0011,  0.0148,  0.0262, -0.0052,  1.0084,
         0.9851,  0.0000], requires_grad=True)


In [8]:
def save_as_midi(song, velocitys, path="", name="default.mid", BPM = 120, interval = 240):
    mid = MidiFile()
    track = MidiTrack()
    mid.tracks.append(track)
    track.append(MetaMessage('set_tempo', tempo=mido.bpm2tempo(BPM)))
    for i, tones in enumerate(song):
        which_tone = (tones == 1).nonzero().reshape(-1)
        if len(which_tone) == 0:
            track.append(Message('note_on', note=0, velocity=0, time=0))
            track.append(Message('note_off', note=0, time=interval))
        else:
            for which in which_tone:
                track.append(Message('note_on', note=int(which), velocity= 47+int(80*velocitys[i]), time=0))
            for which in which_tone:
                track.append(Message('note_off', note=int(which), time=interval))
    mid.save(os.path.join(path, name))


def listen_midi(name, path=""):
    ports = mido.get_output_names()
    print("START")
    with mido.open_output(ports[0]) as outport:
        for msg in mido.MidiFile(os.path.join(path,name)):
            time.sleep(msg.time)
            if not msg.is_meta:
                # print(outport, msg)
                outport.send(msg)
    print("END")

In [9]:
shift = 50
song = t.zeros(tf,88)
for i in range(tf):
    num = int(u_pc[i].detach().numpy()) + shift
    song[i][num] = 1.

In [10]:
save_as_midi(song, u_ve, path="", name="Balance.mid", BPM = 120, interval = 360)

In [11]:
listen_midi(name="Balance.mid", path="")

START
END
