In [1]:
import torch
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader

In [4]:
import pickle
from src.prepare_data import *

harmonys_raw = load_harmonys("train")
dataset_size = 10000
block_size = 8

harmonys_raw = load_harmonys("train")

train_dataset = []
for i in range(dataset_size):
    harmony = random.choice(list(harmonys_raw.values()))
    while harmony == None or len(harmony) - block_size - 1 <= 0:
        harmony = random.choice(list(harmonys_raw.values()))
    start_idx = random.randint(0, len(harmony) - block_size - 1)
    harmony_x = harmony[start_idx : start_idx + block_size]
    harmony_y = harmony[start_idx + 1 : start_idx + block_size + 1]
    encoded_x = torch.tensor([encode_chord_all(chord_x) for chord_x in harmony_x])
    encoded_y = torch.tensor([encode_chord_all(chord_y) for chord_y in harmony_y])
    train_dataset.append((encoded_x,encoded_y))

In [5]:
x, y = train_dataset[0]
for a, b in zip(x,y):
    print(int(a),int(b))

497 22
22 497
497 402
402 877
877 22
22 687
687 212
212 499


In [6]:
# create a GPT instance
from mingpt.model import GPT

model_config = GPT.get_default_config()
model_config.model_type = 'gpt-mini'
model_config.vocab_size = chord_vocab_size
model_config.block_size = block_size
model = GPT(model_config)

参数数量: 2.89M


In [7]:
# create a Trainer object
from mingpt.trainer import Trainer

train_config = Trainer.get_default_config()
train_config.learning_rate = 5e-5 # the model we're using is so small that we can go a bit faster
train_config.max_iters = 5000
train_config.num_workers = 0
trainer = Trainer(train_config, model, train_dataset)

运行在设备： mps


In [15]:
model.train(True);

In [8]:
def batch_end_callback(trainer):
    if trainer.iter_num % 100 == 0:
        print(f"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}")
trainer.set_callback('on_batch_end', batch_end_callback)

trainer.run()

iter_dt 0.00ms; iter 0: train loss 7.06429
iter_dt 39.60ms; iter 100: train loss 5.16013
iter_dt 39.81ms; iter 200: train loss 4.37626
iter_dt 40.12ms; iter 300: train loss 3.76619
iter_dt 38.94ms; iter 400: train loss 3.55210
iter_dt 39.41ms; iter 500: train loss 3.34458
iter_dt 39.10ms; iter 600: train loss 3.36191
iter_dt 39.35ms; iter 700: train loss 3.27103
iter_dt 38.36ms; iter 800: train loss 3.16141
iter_dt 38.70ms; iter 900: train loss 3.23616


KeyboardInterrupt: 

In [9]:
# now let's perform some evaluation
model.eval();

In [15]:
input_chords = [
    {
        "onset": 0,
        "offset": 2,
        "root_pitch_class": 0,
        "root_position_intervals": [4, 3],
        "inversion": 0,
    },
    {
        "onset": 2,
        "offset": 4,
        "root_pitch_class": 5,
        "root_position_intervals": [4, 3],
        "inversion": 0,
    },
    {
        "onset": 2,
        "offset": 4,
        "root_pitch_class": 9,
        "root_position_intervals": [3, 4],
        "inversion": 0,
    }
]
input_chords_encoded = [encode_chord_all(input_chord) for input_chord in input_chords]
input_chords_encoded = torch.tensor([input_chords_encoded],dtype=torch.long).to("mps")
print(input_chords_encoded)

tensor([[ 22, 497, 878]], device='mps:0')


In [16]:
with torch.no_grad():
    result =  model.generate(input_chords_encoded,16,do_sample=False)
print(result)

tensor([[ 22, 497, 878, 497,  22, 497,  22, 497,  22, 497,  22, 497,  22, 497,
          22, 497,  22, 497,  22]], device='mps:0')


In [17]:
result_list = result.flatten().tolist()
print(result_list)
result_chord = [decode_chord_from_all_encoded(chord) for chord in result_list]
print(result_chord)

[22, 497, 878, 497, 22, 497, 22, 497, 22, 497, 22, 497, 22, 497, 22, 497, 22, 497, 22]
[chord(notes=[C4, E4, G4], interval=[0, 0, 0], start_time=0), chord(notes=[F4, A4, C5], interval=[0, 0, 0], start_time=0), chord(notes=[A4, C5, E5], interval=[0, 0, 0], start_time=0), chord(notes=[F4, A4, C5], interval=[0, 0, 0], start_time=0), chord(notes=[C4, E4, G4], interval=[0, 0, 0], start_time=0), chord(notes=[F4, A4, C5], interval=[0, 0, 0], start_time=0), chord(notes=[C4, E4, G4], interval=[0, 0, 0], start_time=0), chord(notes=[F4, A4, C5], interval=[0, 0, 0], start_time=0), chord(notes=[C4, E4, G4], interval=[0, 0, 0], start_time=0), chord(notes=[F4, A4, C5], interval=[0, 0, 0], start_time=0), chord(notes=[C4, E4, G4], interval=[0, 0, 0], start_time=0), chord(notes=[F4, A4, C5], interval=[0, 0, 0], start_time=0), chord(notes=[C4, E4, G4], interval=[0, 0, 0], start_time=0), chord(notes=[F4, A4, C5], interval=[0, 0, 0], start_time=0), chord(notes=[C4, E4, G4], interval=[0, 0, 0], start_time=0

In [18]:
# decode_chord_from_all_encoded(687)
# print(result_chord)
musicpy.play(result_chord,wait=True)