In [1]:
!pip install torch pretty_midi gdown wget

Collecting pretty_midi
  Downloading pretty_midi-0.2.11.tar.gz (5.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m95.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting wget
  Downloading wget-3.2.zip (10 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting mido>=1.1.16 (from pretty_midi)
  Downloading mido-1.3.3-py3-none-any.whl.metadata (6.4 kB)
Downloading mido-1.3.3-py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.6/54.6 kB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: pretty_midi, wget
  Building wheel for pretty_midi (setup.py) ... [?25l[?25hdone
  Created wheel for pretty_midi: filename=pretty_midi-0.2.11-py3-none-any.whl size=5595886 sha256=0848d6044feb49ad6d30943c4dba40d25324a15c5d7928d5d307994063e4c089
  Stored in directory: /root/.cache/pip/wheels/f4/ad/93/a7042fe12668827574927ade9deec

In [2]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from hw2 import Composer
from midi2seq import process_midi_seq, seq2piano

# Configuration
epoch = 200
bsz = 32

# Load data
print("Loading MIDI data...")
piano_seq = torch.from_numpy(process_midi_seq(datadir='/content/drive/MyDrive/', n=50000, maxlen=100))
loader = DataLoader(TensorDataset(piano_seq), shuffle=True, batch_size=bsz)
print(f"Loaded {len(piano_seq)} sequences\n")

# Initialize and train
print("Training model...")
cps = Composer()
for i in range(epoch):
    total_loss = 0
    count = 0
    for x in loader:
        loss = cps.train(x[0].cuda().long())
        total_loss += loss
        count += 1
    print(f"Epoch {i+1}/{epoch}, Avg Loss: {total_loss/count:.4f}")

# Save model
cps.save_model('composer_model.pth')
print("\n✓ Model saved to 'composer_model.pth'")

Loading MIDI data...
Loaded 50519 sequences

Training model...
Epoch 1/200, Avg Loss: 2.8000
Epoch 2/200, Avg Loss: 2.4742
Epoch 3/200, Avg Loss: 2.3783
Epoch 4/200, Avg Loss: 2.3127
Epoch 5/200, Avg Loss: 2.2616
Epoch 6/200, Avg Loss: 2.2190
Epoch 7/200, Avg Loss: 2.1808
Epoch 8/200, Avg Loss: 2.1458
Epoch 9/200, Avg Loss: 2.1130
Epoch 10/200, Avg Loss: 2.0822
Epoch 11/200, Avg Loss: 2.0518
Epoch 12/200, Avg Loss: 2.0221
Epoch 13/200, Avg Loss: 1.9939
Epoch 14/200, Avg Loss: 1.9672
Epoch 15/200, Avg Loss: 1.9402
Epoch 16/200, Avg Loss: 1.9147
Epoch 17/200, Avg Loss: 1.8898
Epoch 18/200, Avg Loss: 1.8657
Epoch 19/200, Avg Loss: 1.8425
Epoch 20/200, Avg Loss: 1.8191
Epoch 21/200, Avg Loss: 1.7967
Epoch 22/200, Avg Loss: 1.7748
Epoch 23/200, Avg Loss: 1.7534
Epoch 24/200, Avg Loss: 1.7332
Epoch 25/200, Avg Loss: 1.7129
Epoch 26/200, Avg Loss: 1.6927
Epoch 27/200, Avg Loss: 1.6744
Epoch 28/200, Avg Loss: 1.6554
Epoch 29/200, Avg Loss: 1.6379
Epoch 30/200, Avg Loss: 1.6206
Epoch 31/200, Av

In [4]:
# Test generation
cps2 = Composer(load_trained=True)
midi = cps2.compose()
midi = seq2piano(midi)
midi.write('piano1.midi')
print(f"✓ Sample saved to 'piano1.midi' ({midi.get_end_time():.2f}s)")

Downloading...
From (original): https://drive.google.com/uc?id=1Y720QxITeoCWmb1LPwa9pY4RC4DZTpYJ
From (redirected): https://drive.google.com/uc?id=1Y720QxITeoCWmb1LPwa9pY4RC4DZTpYJ&confirm=t&uuid=0a2a4669-b23b-45ee-9827-dc2242a4406e
To: /content/composer_model.pth
100%|██████████| 87.5M/87.5M [00:01<00:00, 74.2MB/s]


✓ Sample saved to 'piano1.midi' (28.55s)
