# Codebase - Custom Transformer Training

In [17]:
import torch
from codebase.model import Model
from codebase.data import Dataset
from codebase.train import train, train_exhaustively
from codebase.inference import generate, tokens_to_segs
from codebase.utils import load_dataset, visualize_model, visualize_teacher_forcing, plot_loss
from codebase.preprocessing import create_dataset

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

Using device: cuda


In [None]:
dataset = create_dataset(
    split="train",
    dataset_path="maestro-v3.0.0", 
    seg_fit_tightness=0.12,
    nocturnes=False,
    track_idx=None,
    num_workers=-1,
    max_chunk_size_mb=400
    )

Loading paths...
Found 962 tracks
Using 8 parallel workers
[1/962] MIDI-Unprocessed_15_R1_2011_MID--AUDIO_R1-D6_09_Track09_wav.midi | 10 augmented versions | Avg: 8.8s/track | ETA: 140.2min
[2/962] MIDI-Unprocessed_19_R1_2006_01-07_ORIG_MID--AUDIO_19_R1_2006_04_Track04_wav.midi | 10 augmented versions | Avg: 9.7s/track | ETA: 155.6min
[3/962] MIDI-Unprocessed_09_R1_2011_MID--AUDIO_R1-D3_15_Track15_wav.midi | 10 augmented versions | Avg: 7.2s/track | ETA: 115.7min
[4/962] ORIG-MIDI_02_7_10_13_Group_MID--AUDIO_11_R3_2013_wav--3.midi | 10 augmented versions | Avg: 5.7s/track | ETA: 90.6min
[5/962] ORIG-MIDI_03_7_8_13_Group__MID--AUDIO_19_R2_2013_wav--3.midi | 10 augmented versions | Avg: 4.8s/track | ETA: 76.2min
[6/962] MIDI-Unprocessed_066_PIANO066_MID--AUDIO-split_07-07-17_Piano-e_3-02_wav--3.midi | 10 augmented versions | Avg: 4.2s/track | ETA: 66.7min
[7/962] MIDI-Unprocessed_03_R2_2008_01-03_ORIG_MID--AUDIO_03_R2_2008_wav--2.midi | 10 augmented versions | Avg: 3.7s/track | ETA: 58.4

In [18]:
dataset = load_dataset("complete_dataset/chunk_16.pkl")
print(f"Dataset loaded: {len(dataset)} tracks")

Dataset loaded: 140 tracks


In [37]:
model = Model.load('complete_30k.pt')

In [2]:
model = Model(528, 6, 6, 8, 2112, 0.1).to(device)

In [None]:
train(
    batch_size=64,
    lr=1e-4,
    num_steps=500,
    device=device,
    model=model,
    print_every=1,
    dataset=dataset,
    model_path="nocturnes_unnormalized.pt",
    alpha=0
)

Using provided model
Creating dataloader...
Starting training loop...
Step 1/500
  Total: 0.6375 | Segment: 0.0000 | Param: 0.6375
  Height: 0.0653 | Amount: 0.0124 | Time: 0.9889
Step 2/500
  Total: 0.9985 | Segment: 0.0000 | Param: 0.9985
  Height: 0.0972 | Amount: 0.0155 | Time: 1.5771
Step 3/500
  Total: 0.7791 | Segment: 0.0000 | Param: 0.7791
  Height: 0.0873 | Amount: 0.0161 | Time: 1.1768
Step 4/500
  Total: 0.7765 | Segment: 0.0000 | Param: 0.7765
  Height: 0.0724 | Amount: 0.0150 | Time: 1.2335
Step 5/500
  Total: 0.7534 | Segment: 0.0000 | Param: 0.7534
  Height: 0.0764 | Amount: 0.0137 | Time: 1.1738
Step 6/500
  Total: 0.7340 | Segment: 0.0000 | Param: 0.7340
  Height: 0.0707 | Amount: 0.0134 | Time: 1.1584
Step 7/500
  Total: 0.7313 | Segment: 0.0000 | Param: 0.7313
  Height: 0.0770 | Amount: 0.0118 | Time: 1.1312
Step 8/500
  Total: 0.6545 | Segment: 0.0000 | Param: 0.6545
  Height: 0.0719 | Amount: 0.0144 | Time: 0.9928
Step 9/500
  Total: 0.6410 | Segment: 0.0000 | Par

KeyboardInterrupt: 

In [None]:
train_exhaustively(
    batch_size=50,
    lr=3e-4,
    num_steps=4000,
    device="cuda",
    model=Model(528, 6, 6, 8, 2112, 0.1).to(device),
    print_every=2000,
    model_path="complete_392000.pt",
    dataset_path="complete_dataset",
    accumulation_steps=2,
    num_rotations=2,
    num_workers=4,
    alpha=0.7,
    add_checkpoints=10000,
    record_loss=100
    )

# 14700 nach 190min

Found 49 chunk files in complete_dataset
Training: warmup=19600, effective_batch=50x2=100, workers=3, rotations=2
Total steps: 4000 steps/chunk × 49 chunks × 2 rotations = 392000 steps
Optimizations: mixed_precision=ON, cudnn_benchmark=ON
Loading chunk 1/49...
Creating dataloader for chunk 1...
Starting training on chunk 1...


KeyboardInterrupt: 

In [38]:
fig = visualize_teacher_forcing(model, dataset)
fig.show()

In [40]:
fig = visualize_model(
    model=model,
    dataset=dataset,
    num_plots=12,
    exclude_context=False,
    show_notes=False,
    generate=True,
    seed=None
    )

In [16]:
fig.show()

In [43]:
plot_loss("loss.pkl", False)