In [1]:
!pip install pretty_midi

Collecting pretty_midi
  Downloading pretty_midi-0.2.11.tar.gz (5.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m46.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting mido>=1.1.16 (from pretty_midi)
  Downloading mido-1.3.3-py3-none-any.whl.metadata (6.4 kB)
Downloading mido-1.3.3-py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.6/54.6 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: pretty_midi
  Building wheel for pretty_midi (setup.py) ... [?25l[?25hdone
  Created wheel for pretty_midi: filename=pretty_midi-0.2.11-py3-none-any.whl size=5595886 sha256=4fa66efdd6e1f32429ec3d8201f7ccf7fd5bf0d403fd68c1805b2dea0f87d11c
  Stored in directory: /root/.cache/pip/wheels/09/e6/e6/29223dbea25e71e517b8791bf35cc9a7b872cb2ad284e30181
Successfully built pretty_midi
Installing collected packages: mido, pretty_

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR

import os
import time
import pretty_midi

In [3]:
TIME_STEP = 0.02
MAX_TIME_SHIFT = 100
MAX_DURATION = 200

def find_midi_files(root_dir):
    midi_files = []
    for dirpath, _, filenames in os.walk(root_dir):
        for f in filenames:
            if f.lower().endswith(('.mid', '.midi')):
                midi_files.append(os.path.join(dirpath, f))
    return midi_files

def midi_to_notes(path):
    pm = pretty_midi.PrettyMIDI(path)

    events = []

    note_events = []
    for inst in pm.instruments:
        if inst.is_drum:
            continue

        for note in inst.notes:
            duration = note.end - note.start
            note_events.append((note.start, note.pitch, duration))
    
    note_events.sort(key=lambda x: x[0])
    
    prev_time = 0.0
    i = 0
    
    while i < len(note_events):
        current_start_time = note_events[i][0]
        
        dt = current_start_time - prev_time
        steps = int(dt / TIME_STEP)
        while steps > 0:
            shift = min(steps, MAX_TIME_SHIFT)
            events.append(f"time_shift_{shift}")
            steps -= shift
        
        chord_notes = []
        while i < len(note_events) and abs(note_events[i][0] - current_start_time) < TIME_STEP / 2:
            _, pitch, duration = note_events[i]
            chord_notes.append((pitch, duration))
            i += 1
        
        for pitch, duration in chord_notes:
            duration_steps = int(duration / TIME_STEP)
            duration_steps = min(duration_steps, MAX_DURATION)
            duration_steps = max(duration_steps, 1)
            
            events.append(f"note_on_{pitch}")
            events.append(f"duration_{duration_steps}")
        
        prev_time = current_start_time

    return events

def build_vocab(event_lists):
    vocab = {"<pad>": 0, "<start>": 1, "<end>": 2}
    idx = 3

    for events in event_lists:
        for e in events:
            if e not in vocab:
                vocab[e] = idx
                idx += 1

    return vocab

def notes_to_tokens(events, vocab):
    return [vocab["<start>"]] + [vocab[e] for e in events if e in vocab] + [vocab["<end>"]]

def tokens_to_midi(tokens, vocab, save_path="out.mid"):
    inv_vocab = {i: e for e, i in vocab.items()}
    pm = pretty_midi.PrettyMIDI()
    inst = pretty_midi.Instrument(program=0)
    pm.instruments.append(inst)

    current_time = 0.0
    pending_notes = []

    for tok in tokens:
        if tok not in inv_vocab:
            continue

        e = inv_vocab[tok]
        
        if e in ("<pad>", "<start>", "<end>"):
            continue

        if e.startswith("time_shift_"):
            dt = int(e.split("_")[2]) * TIME_STEP
            current_time += dt

        elif e.startswith("note_on_"):
            pitch = int(e.split("_")[2])
            pending_notes.append([pitch, current_time, None])

        elif e.startswith("duration_"):
            duration_steps = int(e.split("_")[1])
            for note in reversed(pending_notes):
                if note[2] is None:
                    note[2] = duration_steps
                    end_time = note[1] + duration_steps * TIME_STEP
                    inst.notes.append(pretty_midi.Note(
                        velocity=80,
                        pitch=note[0],
                        start=note[1],
                        end=end_time
                    ))
                    break

    pm.write(save_path)

class MIDITokenDataset(Dataset):
    def __init__(self, token_sequences, seq_len=1024, pad_id=0):
        self.data = token_sequences
        self.seq_len = seq_len
        self.pad_id = pad_id

        self.chunks = []
        for seq in token_sequences:
            if len(seq) <= seq_len + 1:
                self.chunks.append((seq, 0, len(seq)))
            else:
                num_chunks = (len(seq) - seq_len - 1) // (seq_len // 2) + 1
                for i in range(num_chunks):
                    start = i * (seq_len // 2)
                    if start + seq_len + 1 > len(seq):
                        start = len(seq) - seq_len - 1
                    self.chunks.append((seq, start, start + seq_len + 1))

        self.total_tokens = sum(len(x) for x in token_sequences)
        print(f"Loaded dataset with {len(self.data)} files, {self.total_tokens} total tokens")
        print(f"Created {len(self.chunks)} training chunks")

    def __len__(self):
        return len(self.chunks)

    def __getitem__(self, idx):
        seq, start, end = self.chunks[idx]
        chunk = seq[start:end]

        if len(chunk) < self.seq_len + 1:
            padded = chunk + [self.pad_id] * (self.seq_len + 1 - len(chunk))
            inp = padded[:-1]
            target = padded[1:]
        else:
            inp = chunk[:-1]
            target = chunk[1:]

        return torch.tensor(inp, dtype=torch.long), torch.tensor(target, dtype=torch.long)


def collate_batch(batch, pad_id=0):
    inps, targets = zip(*batch)

    max_len = max(x.size(0) for x in inps)

    padded_inps = []
    padded_targets = []

    for inp, tgt in zip(inps, targets):
        pad_len = max_len - inp.size(0)

        padded_inps.append(
            torch.cat([inp, torch.full((pad_len,), pad_id, dtype=torch.long)])
        )
        padded_targets.append(
            torch.cat([tgt, torch.full((pad_len,), pad_id, dtype=torch.long)])
        )

    return torch.stack(padded_inps), torch.stack(padded_targets)

def create_dataloader(token_sequences, seq_len=512, batch_size=128, pad_id=0, shuffle=True):
    dataset = MIDITokenDataset(
        token_sequences,
        seq_len=seq_len,
        pad_id=pad_id,
    )

    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        collate_fn=lambda b: collate_batch(b, pad_id=pad_id),
        drop_last=True,
        num_workers=0,
    )
    return loader

In [4]:
MIDI_DIR = '/kaggle/input/classical-music-midi'
MIN_EVENTS = 50
MAX_EVENTS = 10000

midi_files = find_midi_files(MIDI_DIR)
print(f"Найдено файлов: {len(midi_files)}")

all_notes = []
skipped = 0

for file in midi_files:
    try:
        notes = midi_to_notes(file)
        
        if MIN_EVENTS <= len(notes) <= MAX_EVENTS:
            all_notes.append(notes)
        else:
            skipped += 1
            
    except Exception as e:
        print(f"\nОшибка в {file}: {e}")
        skipped += 1

print(f"Обработано файлов: {len(all_notes)}")
print(f"Пропущено файлов: {skipped}")

vocab = build_vocab(all_notes)
print(f"Размер словаря: {len(vocab)}")

token_sequences = []
for notes in all_notes:
    tokens = notes_to_tokens(notes, vocab)
    token_sequences.append(tokens)

total_tokens = sum(len(seq) for seq in token_sequences)
avg_length = total_tokens / len(token_sequences) if token_sequences else 0
print(f"\nВсего токенов: {total_tokens}")
print(f"Средняя длина последовательности: {avg_length:.2f}")

SEQ_LEN = 512
BATCH_SIZE = 128
PAD_ID = vocab["<pad>"]

train_loader = create_dataloader(
    token_sequences[:int(0.8 * len(token_sequences))],
    seq_len=SEQ_LEN,
    batch_size=BATCH_SIZE,
    pad_id=PAD_ID,
    shuffle=True
)

val_loader = create_dataloader(
    token_sequences[int(0.8 * len(token_sequences)):],
    seq_len=SEQ_LEN,
    batch_size=BATCH_SIZE,
    pad_id=PAD_ID,
    shuffle=True
)

Найдено файлов: 295
Обработано файлов: 239
Пропущено файлов: 56
Размер словаря: 388

Всего токенов: 936936
Средняя длина последовательности: 3920.23
Loaded dataset with 191 files, 772804 total tokens
Created 2731 training chunks
Loaded dataset with 48 files, 164132 total tokens
Created 573 training chunks


In [5]:
class MusicGRU(nn.Module):
    def __init__(self, vocab_size, embed_dim=256, hidden_dim=512, num_layers=3, dropout=0.1):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim)

        self.gru = nn.GRU(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout
        )

        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden=None):
        emb = self.embedding(x)
        out, hidden = self.gru(emb, hidden)
        logits = self.fc(out)
        return logits, hidden

    def generate(self, start_token, max_len=1024, temperature=1.0, top_p=0.95):
        self.eval()

        device = next(self.parameters()).device
        
        tokens = [start_token]
        inp = torch.tensor([[start_token]], device=device)
        hidden = None

        for _ in range(max_len):
            logits, hidden = self.forward(inp, hidden)
            logits = logits[:, -1, :] / temperature

            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            cumulative_probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)

            cutoff = cumulative_probs > top_p
            cutoff[..., 1:] = cutoff[..., :-1]
            cutoff[..., 0] = False
            sorted_logits[cutoff] = -1e9

            probs = torch.softmax(sorted_logits, dim=-1)
            next_token = sorted_indices[0, torch.multinomial(probs[0], 1)].item()
            
            tokens.append(next_token)
            inp = torch.tensor([[next_token]], device=device)

        return tokens

In [6]:
class MusicGRU(nn.Module):
    def __init__(self, vocab_size, embedding_dim=256, hidden_dim=512, num_layers=3, dropout=0.3, pad_id=0):
        super().__init__()
        
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.pad_id = pad_id
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_id)
        
        self.gru = nn.GRU(
            embedding_dim,
            hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0
        )
        
        self.dropout = nn.Dropout(dropout)
        self.fc_out = nn.Linear(hidden_dim, vocab_size)
        
        self._init_weights()
    
    def _init_weights(self):
        nn.init.xavier_uniform_(self.embedding.weight)
        
        for name, param in self.gru.named_parameters():
            if 'weight_ih' in name:
                nn.init.xavier_uniform_(param)
            elif 'weight_hh' in name:
                nn.init.orthogonal_(param)
            elif 'bias' in name:
                nn.init.zeros_(param)
        
        nn.init.xavier_uniform_(self.fc_out.weight)
        nn.init.zeros_(self.fc_out.bias)
    
    def forward(self, x, hidden=None):
        embedded = self.embedding(x)
        embedded = self.dropout(embedded)
        
        output, hidden = self.gru(embedded, hidden)
        output = self.dropout(output)
        
        logits = self.fc_out(output)
        
        return logits, hidden
    
    def init_hidden(self, batch_size, device):
        return torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=device)
    
    @torch.no_grad()
    def generate(self, start_tokens, vocab, max_length=1000, temperature=1.0, top_k=None, top_p=None, device='cuda'):
        self.eval()
        
        if isinstance(start_tokens, list):
            tokens = torch.tensor([start_tokens], dtype=torch.long, device=device)
        else:
            tokens = start_tokens.unsqueeze(0) if start_tokens.dim() == 1 else start_tokens
        
        hidden = self.init_hidden(1, device)
        
        end_id = vocab.get("<end>", None)
        pad_id = vocab.get("<pad>", self.pad_id)
        
        generated = tokens[0].tolist()
        
        for _ in range(max_length):
            logits, hidden = self.forward(tokens, hidden)
            next_token_logits = logits[0, -1, :] / temperature
            next_token_logits[pad_id] = -float('inf')
            
            if top_k is not None:
                indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
                next_token_logits[indices_to_remove] = -float('inf')
            
            if top_p is not None:
                sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
                cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
                
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = 0
                
                indices_to_remove = sorted_indices[sorted_indices_to_remove]
                next_token_logits[indices_to_remove] = -float('inf')
            
            probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            generated.append(next_token.item())
            
            if end_id is not None and next_token.item() == end_id:
                break
            
            tokens = next_token.unsqueeze(0)
        
        return generated

device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
model = MusicGRU(
    vocab_size=len(vocab),
    embedding_dim=256,
    hidden_dim=512,
    num_layers=3,
    dropout=0.3
).to(device)
    
print(f"{sum(p.numel() for p in model.parameters()):,} параметров")

4,632,964 параметров


In [7]:
def train_epoch(model, dataloader, optimizer, criterion, device, grad_clip=1.0, print_every=50):
    model.train()
    total_loss = 0
    total_tokens = 0
    
    for batch_idx, (inputs, targets) in enumerate(dataloader):
        inputs = inputs.to(device)
        targets = targets.to(device)

        optimizer.zero_grad()
        logits, _ = model(inputs)
        
        logits_flat = logits.view(-1, logits.size(-1))
        targets_flat = targets.view(-1)
        
        loss = criterion(logits_flat, targets_flat)
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        optimizer.step()
        
        total_loss += loss.item()
        mask = targets_flat != model.pad_id
        total_tokens += mask.sum().item()
        
        if (batch_idx + 1) % print_every == 0:
            avg_loss = total_loss / (batch_idx + 1)
            print(f"  Batch {batch_idx + 1}/{len(dataloader)} | Loss: {avg_loss:.4f} | PPL: {np.exp(avg_loss):.2f}")
    
    avg_loss = total_loss / len(dataloader)
    
    return avg_loss


@torch.no_grad()
def validate(model, dataloader, criterion, device, print_every=50):
    model.eval()
    total_loss = 0
    
    for batch_idx, (inputs, targets) in enumerate(dataloader):
        inputs = inputs.to(device)
        targets = targets.to(device)
        
        logits, _ = model(inputs)
        
        logits_flat = logits.view(-1, logits.size(-1))
        targets_flat = targets.view(-1)
        
        loss = criterion(logits_flat, targets_flat)
        total_loss += loss.item()
        
        if (batch_idx + 1) % print_every == 0:
            print(f"  Val Batch {batch_idx + 1}/{len(dataloader)} | Loss: {loss.item():.4f}")
    
    avg_loss = total_loss / len(dataloader)
    
    return avg_loss


def train_model(
    model,
    train_loader,
    val_loader=None,
    num_epochs=50,
    learning_rate=0.001,
    weight_decay=1e-5,
    grad_clip=1.0,
    device='cuda',
    save_dir='checkpoints',
    save_every=5,
    patience=10,
    print_every=50
):
    os.makedirs(save_dir, exist_ok=True)
    
    optimizer = optim.AdamW(
        model.parameters(),
        lr=learning_rate,
        weight_decay=weight_decay,
        betas=(0.9, 0.98),
        eps=1e-9
    )
    
    if val_loader is not None:
        scheduler = ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=0.5,
            patience=3,
            verbose=True
        )
    else:
        scheduler = CosineAnnealingLR(
            optimizer,
            T_max=num_epochs,
            eta_min=learning_rate * 0.01
        )
    
    criterion = nn.CrossEntropyLoss(ignore_index=model.pad_id)
    
    best_val_loss = float('inf')
    epochs_without_improvement = 0
    
    start_time = time.time()
    
    for epoch in range(1, num_epochs + 1):
        epoch_start = time.time()
        
        print(f"\nЭпоха {epoch}/{num_epochs}")
        
        train_loss = train_epoch(model, train_loader, optimizer, criterion, device, grad_clip, print_every)
        val_loss = validate(model, val_loader, criterion, device, print_every)
        scheduler.step(val_loss)
            
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_without_improvement = 0
                
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, os.path.join(save_dir, 'best_model.pt'))
            print(f"Сохранена лучшая модель (val_loss: {val_loss:.4f})")
        else:
            epochs_without_improvement += 1
        
        epoch_time = time.time() - epoch_start
        
        print(f"Время: {epoch_time:.1f}s, LR: {optimizer.param_groups[0]['lr']:.6f}")
        print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
        
        if epoch % save_every == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_loss,
            }, os.path.join(save_dir, f'checkpoint_epoch_{epoch}.pt'))
            print(f"Сохранён чекпоинт эпохи {epoch}")
        
        if val_loader is not None and epochs_without_improvement >= patience:
            print(f"\nEarly stopping после {epoch} эпох (patience={patience})")
            break
    
    total_time = time.time() - start_time
    
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, os.path.join(save_dir, 'final_model.pt'))


def load_checkpoint(model, checkpoint_path, optimizer=None):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    if optimizer is not None and 'optimizer_state_dict' in checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    return checkpoint


train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=200,
    learning_rate=3e-4,
    device=device,
    save_dir='checkpoints',
    save_every=5,
    patience=15
)




Эпоха 1/200
Сохранена лучшая модель (val_loss: 4.6218)
Время: 9.6s, LR: 0.000300
Train Loss: 5.4184, Val Loss: 4.6218

Эпоха 2/200
Сохранена лучшая модель (val_loss: 4.3132)
Время: 8.9s, LR: 0.000300
Train Loss: 4.5676, Val Loss: 4.3132

Эпоха 3/200
Сохранена лучшая модель (val_loss: 4.1181)
Время: 8.9s, LR: 0.000300
Train Loss: 4.2925, Val Loss: 4.1181

Эпоха 4/200
Сохранена лучшая модель (val_loss: 3.9481)
Время: 8.9s, LR: 0.000300
Train Loss: 4.1028, Val Loss: 3.9481

Эпоха 5/200
Сохранена лучшая модель (val_loss: 3.8054)
Время: 8.9s, LR: 0.000300
Train Loss: 3.9457, Val Loss: 3.8054
Сохранён чекпоинт эпохи 5

Эпоха 6/200
Сохранена лучшая модель (val_loss: 3.2743)
Время: 8.9s, LR: 0.000300
Train Loss: 3.6430, Val Loss: 3.2743

Эпоха 7/200
Сохранена лучшая модель (val_loss: 3.0013)
Время: 8.9s, LR: 0.000300
Train Loss: 3.1652, Val Loss: 3.0013

Эпоха 8/200
Сохранена лучшая модель (val_loss: 2.9051)
Время: 8.9s, LR: 0.000300
Train Loss: 2.9777, Val Loss: 2.9051

Эпоха 9/200
Сохранена

In [8]:
for i in range(10):
    generated = model.generate(
        start_tokens=[vocab["<start>"]],
        vocab=vocab,
        max_length=1000,
        temperature=0.9,
        top_k=40,
        top_p=0.95,
        device=device
    )

    tokens_to_midi(generated, vocab, "generated" + str(i) + ".mid")