In [None]:
# 1. INSTALL DEPENDENCIES (run in Colab or Jupyter)
# !pip install pretty_midi miditok torch numpy

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import miditok
from miditoolkit import MidiFile

# 2. TOKENIZE MIDI FILES
class MIDIDataset(Dataset):
    def __init__(self, midi_folder, tokenizer, max_len=512):
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.midi_paths = []
        self.valid_midi_count = 0

        all_files = os.listdir(midi_folder)
        print(f"🟡 Found {len(all_files)} MIDI files. Checking validity...")

        for file in all_files:
            if file.endswith(".mid") or file.endswith(".midi"):
                full_path = os.path.join(midi_folder, file)
                try:
                    midi = MidiFile(full_path)
                    _ = self.tokenizer(midi)
                    self.midi_paths.append(full_path)
                    self.valid_midi_count += 1
                except Exception as e:
                    print(f"⚠️ Skipping file {file}: {e}")

        print(f"✅ {self.valid_midi_count} valid MIDI files loaded.")

    def __len__(self):
        return len(self.midi_paths)

    def __getitem__(self, idx):
        path = self.midi_paths[idx]
        try:
            midi = MidiFile(path)
            tokens = self.tokenizer(midi)
            # Flatten tokens from all tracks
            all_ids = []
            for t in tokens:
                if hasattr(t, "ids"):
                    all_ids.extend(t.ids)
            
            token_ids = torch.tensor(all_ids, dtype=torch.long)

            # Pad or truncate
            if len(token_ids) < self.max_len:
                token_ids = F.pad(token_ids, (0, self.max_len - len(token_ids)))
            else:
                token_ids = token_ids[:self.max_len]

            # Model usually needs input and target (shifted version)
            input_ids = token_ids[:-1]
            target_ids = token_ids[1:]

            return input_ids, target_ids

        except Exception as e:
            print(f"⚠️ Error reading {path}: {e}")
            zero = torch.zeros(self.max_len - 1, dtype=torch.long)
            return zero, zero

# 3. DEFINE TRANSFORMER MODEL
class MusicTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = nn.Parameter(torch.randn(1, 10000, d_model))
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.decoder = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        x = self.embedding(x) + self.pos_encoder[:, :x.size(1)]
        x = self.transformer(x)
        return self.decoder(x)

# 4. TRAINING LOOP
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Training on device:", device)
def train(model, dataloader, vocab_size, device):
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    loss_fn = nn.CrossEntropyLoss()

    for epoch in range(5):  # change to more epochs later
        model.train()
        total_loss = 0
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            output = model(x)
            loss = loss_fn(output.view(-1, vocab_size), y.view(-1))

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            total_loss += loss.item()
        print(f"Epoch {epoch + 1}, Loss: {total_loss / len(dataloader):.4f}")

# 5. GENERATE MUSIC

def generate(model, tokenizer, seed_ids, max_len=10000, device='cuda'):
    model.eval()
    generated = seed_ids[:]
    input_ids = torch.tensor(generated).unsqueeze(0).to(device)

    for _ in range(max_len):
        with torch.no_grad():
            logits = model(input_ids)
            next_token = torch.argmax(logits[0, -1]).item()
            generated.append(next_token)
            input_ids = torch.tensor(generated[-512:]).unsqueeze(0).to(device)

    return tokenizer.decode([generated])

# 6. EXECUTE EVERYTHING

if __name__ == '__main__':
    midi_dir = "lofi_midis/" # <- your 1200 midi files here
    tokenizer = miditok.REMI()
    dataset = MIDIDataset(midi_dir, tokenizer)
    print(f"Total valid files for training: {len(dataset)}")

    dataloader = DataLoader(dataset, batch_size=4, shuffle=True, drop_last=True)


    vocab_size = tokenizer.vocab_size
    model = MusicTransformer(vocab_size)

    train(model, dataloader, vocab_size,device)

    seed = dataset[0][0][:100].tolist()  # take first 100 tokens of one file
    new_midi = generate(model, tokenizer, seed)
    new_midi.dump("generated_lofi.mid")