In [None]:
import os
import json
import music21 as m21
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import StepLR
from torch.cuda.amp import autocast, GradScaler
import gc

# ============================================================================
# CHECKPOINT MANAGEMENT (Your original class)
# ============================================================================

class CheckpointManager:
    """Manages model checkpoints."""
    def __init__(self, checkpoint_dir='checkpoints'):
        self.checkpoint_dir = checkpoint_dir
        os.makedirs(checkpoint_dir, exist_ok=True)

    def save_checkpoint(self, model, optimizer, scheduler, scaler, epoch, loss, model_params, is_best=False):
        """Save checkpoint with all training state."""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'scaler_state_dict': scaler.state_dict(),
            'loss': loss,
            'model_params': model_params,
        }

        # Save latest checkpoint
        latest_path = os.path.join(self.checkpoint_dir, 'checkpoint_latest.pth')
        torch.save(checkpoint, latest_path)
        print(f"✓ Checkpoint saved: {latest_path}")

        if is_best:
            best_path = os.path.join(self.checkpoint_dir, 'model_best.pth')
            torch.save(checkpoint, best_path)
            print(f"✓ Best model saved: {best_path}")

    def load_checkpoint(self):
        """Load the latest checkpoint."""
        checkpoint_path = os.path.join(self.checkpoint_dir, 'checkpoint_latest.pth')
        if os.path.exists(checkpoint_path):
            print(f"Loading checkpoint from: {checkpoint_path}")
            return torch.load(checkpoint_path)
        print("No checkpoint found, starting from scratch.")
        return None

# ============================================================================
# CONFIGURATION
# ============================================================================
# Unzip data if not already done
if not os.path.exists('/content/essen'):
    !unzip -q /content/deutschl.zip -d /content/

KERN_DATASET_PATH = '/content/essen/europa/deutschl/test'
PREPROCESSED_DATA_PATH = "dataset"
SINGLE_FILE_DATASET_PATH = "file_dataset"
MAPPING_PATH = "mapping.json"
CHECKPOINT_DIR = "checkpoints"

# Preprocessing & Model Hyperparameters
ACCEPTABLE_DURATIONS = [0.25, 0.5, 0.75, 1.0, 1.5, 2.0, 3.0, 4.0]
SEQUENCE_LENGTH = 64
BATCH_SIZE = 128 # Increased for better GPU utilization
EPOCHS = 5
LEARNING_RATE = 0.001
EMBEDDING_DIM = 256
HIDDEN_SIZES = [256, 256]
DROPOUT_RATE = 0.2

# Ensure directories exist
os.makedirs(PREPROCESSED_DATA_PATH, exist_ok=True)
os.makedirs(SINGLE_FILE_DATASET_PATH, exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# ============================================================================
# DATA PREPROCESSING (Your original functions)
# ============================================================================
def load_songs_in_kern(dataset_path):
    songs = []
    for path, _, files in os.walk(dataset_path):
        for file in files:
            if file.endswith(".krn"):
                try:
                    song = m21.converter.parse(os.path.join(path, file))
                    songs.append(song)
                except m21.converter.ConverterException:
                    print(f"Warning: Could not parse {file}. Skipping.")
    return songs

def has_acceptable_durations(song, acceptable_durations):
    for note in song.flat.notesAndRests:
        if note.duration.quarterLength not in acceptable_durations:
            return False
    return True

def transpose(song):
    try:
        parts = song.getElementsByClass(m21.stream.Part)
        measures = parts[0].getElementsByClass(m21.stream.Measure)
        key = measures[0][4]
    except (IndexError, AttributeError):
        key = None
    if not isinstance(key, m21.key.Key):
        key = song.analyze("key")
    print(f"Transposing from key: {key}")
    mode = key.mode if key.mode in ["major", "minor"] else "major"
    if mode == "major":
        interval = m21.interval.Interval(key.tonic, m21.pitch.Pitch("C"))
    else:
        interval = m21.interval.Interval(key.tonic, m21.pitch.Pitch("A"))
    return song.transpose(interval)

def encode_song(song):
    encoded_song = []
    for event in song.flat.notesAndRests:
        symbol = "r" if isinstance(event, m21.note.Rest) else event.pitch.midi
        steps = int(event.duration.quarterLength / 0.25)
        for step in range(steps):
            encoded_song.append(str(symbol) if step == 0 else "_")
    return " ".join(encoded_song)

def preprocess(dataset_path):
    print('Loading songs....')
    songs = load_songs_in_kern(dataset_path)
    print(f'Loaded {len(songs)} songs.')
    saved_count = 0
    for i, song in enumerate(songs):
        if not has_acceptable_durations(song, ACCEPTABLE_DURATIONS):
            continue
        song = transpose(song)
        encoded_song = encode_song(song)
        save_path = os.path.join(PREPROCESSED_DATA_PATH, f"{i}.txt")
        with open(save_path, "w") as fp:
            fp.write(encoded_song)
        saved_count += 1
    print(f"Preprocessed and saved {saved_count} songs.")

def create_single_file_dataset(dataset_path, file_dataset_path, sequence_length):
    new_song_delimiter = "/ " * sequence_length
    songs = ""
    for path, _, files in os.walk(dataset_path):
        for file in files:
            file_path = os.path.join(path, file)
            with open(file_path, "r") as fp:
                song = fp.read()
            songs += song + " " + new_song_delimiter
    songs = songs.strip()
    output_file_path = os.path.join(file_dataset_path, "encoded_songs.txt")
    with open(output_file_path, "w") as fp:
        fp.write(songs)
    return songs

def create_mapping(songs, mapping_path):
    songs = songs.split()
    vocabulary = sorted(list(set(songs)))
    mappings = {symbol: i for i, symbol in enumerate(vocabulary)}
    with open(mapping_path, "w") as fp:
        json.dump(mappings, fp, indent=4)
    print("Vocabulary mapping created.")
    return mappings

# ============================================================================
# PYTORCH DATASET AND MODEL
# ============================================================================

class MusicDataset(Dataset):
    """PyTorch Dataset for loading music sequences."""
    def __init__(self, int_songs, sequence_length):
        self.int_songs = int_songs
        self.sequence_length = sequence_length

    def __len__(self):
        return len(self.int_songs) - self.sequence_length

    def __getitem__(self, index):
        inputs = torch.tensor(self.int_songs[index:index+self.sequence_length], dtype=torch.long)
        target = torch.tensor(self.int_songs[index+self.sequence_length], dtype=torch.long)
        return inputs, target

class MusicLSTM(nn.Module):
    """LSTM model for music generation."""
    def __init__(self, vocabulary_size, embedding_dim, hidden_sizes, dropout_rate):
        super(MusicLSTM, self).__init__()
        self.embedding = nn.Embedding(vocabulary_size, embedding_dim)
        # Using a multi-layer LSTM is more conventional in PyTorch
        self.lstm = nn.LSTM(embedding_dim, hidden_sizes[0], num_layers=2,
                            batch_first=True, dropout=dropout_rate)
        self.dropout = nn.Dropout(dropout_rate)
        self.fc = nn.Linear(hidden_sizes[0], vocabulary_size)

    def forward(self, x):
        x = self.embedding(x)
        lstm_out, _ = self.lstm(x)
        # We only need the output of the last time step
        out = lstm_out[:, -1, :]
        out = self.dropout(out)
        logits = self.fc(out)
        return logits

# ============================================================================
# TRAINING
# ============================================================================

def train():
    """Main training loop."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # --- 1. Load Data and Mappings ---
    with open(MAPPING_PATH, "r") as fp:
        mappings = json.load(fp)
    vocabulary_size = len(mappings)

    encoded_songs_path = os.path.join(SINGLE_FILE_DATASET_PATH, "encoded_songs.txt")
    with open(encoded_songs_path, "r") as fp:
        songs = fp.read()

    int_songs = [mappings[symbol] for symbol in songs.split()]

    # --- 2. Create Dataset and DataLoader ---
    dataset = MusicDataset(int_songs, SEQUENCE_LENGTH)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)

    # --- 3. Initialize Model and Training Components ---
    model_params = {
        'vocabulary_size': vocabulary_size,
        'embedding_dim': EMBEDDING_DIM,
        'hidden_sizes': HIDDEN_SIZES,
        'dropout_rate': DROPOUT_RATE
    }
    model = MusicLSTM(**model_params).to(device)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = StepLR(optimizer, step_size=10, gamma=0.5) # Learning rate decay
    criterion = nn.CrossEntropyLoss()
    scaler = GradScaler()
    checkpoint_manager = CheckpointManager(CHECKPOINT_DIR)

    # --- 4. Load from Checkpoint if available ---
    start_epoch = 0
    best_loss = float('inf')
    checkpoint = checkpoint_manager.load_checkpoint()
    if checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        scaler.load_state_dict(checkpoint['scaler_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_loss = checkpoint.get('loss', float('inf'))
        print(f"Resuming training from epoch {start_epoch}")

    # --- 5. Training Loop ---
    model.train()
    for epoch in range(start_epoch, EPOCHS):
        running_loss = 0.0
        correct_predictions = 0
        total_predictions = 0

        for i, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()

            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, targets)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item()

            # Calculate accuracy
            _, predicted = torch.max(outputs.data, 1)
            total_predictions += targets.size(0)
            correct_predictions += (predicted == targets).sum().item()

            if (i + 1) % 100 == 0:
                print(f"Epoch [{epoch+1}/{EPOCHS}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item():.4f}")

        epoch_loss = running_loss / len(dataloader)
        epoch_accuracy = correct_predictions / total_predictions
        print(f"--- End of Epoch [{epoch+1}/{EPOCHS}], Average Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.4f} ---")

        # --- 6. Save Checkpoint ---
        is_best = epoch_loss < best_loss
        if is_best:
            best_loss = epoch_loss

        checkpoint_manager.save_checkpoint(
            model=model, optimizer=optimizer, scheduler=scheduler, scaler=scaler,
            epoch=epoch, loss=epoch_loss, model_params=model_params, is_best=is_best
        )
        scheduler.step()
        gc.collect() # Garbage collection

    print("Training finished.")

if __name__ == "__main__":
    # Run preprocessing steps
    preprocess(KERN_DATASET_PATH)
    songs_str = create_single_file_dataset(PREPROCESSED_DATA_PATH, SINGLE_FILE_DATASET_PATH, SEQUENCE_LENGTH)
    create_mapping(songs_str, MAPPING_PATH)

    # Start training
    train()

Loading songs....
Loaded 12 songs.
Transposing from key: F major
Transposing from key: e minor
Transposing from key: F major
Transposing from key: C major
Transposing from key: b minor
Transposing from key: e minor
Transposing from key: e minor
Transposing from key: C major
Transposing from key: e minor
Transposing from key: g minor
Transposing from key: F major
Transposing from key: C major
Preprocessed and saved 12 songs.
Vocabulary mapping created.
Using device: cuda
Loading checkpoint from: checkpoints/checkpoint_latest.pth
Resuming training from epoch 5
Training finished.


  return self.iter().getElementsByClass(classFilterList)
  scaler = GradScaler()


In [None]:
import torch
import os
import json
from torch.utils.data import Dataset, DataLoader

# Assuming the necessary classes and functions (MusicDataset, MusicLSTM, etc.) are defined in the previous cell

# Load the mapping
MAPPING_PATH = "mapping.json"
with open(MAPPING_PATH, "r") as fp:
    mappings = json.load(fp)
vocabulary_size = len(mappings)

# Load the preprocessed integer songs
SINGLE_FILE_DATASET_PATH = "file_dataset"
encoded_songs_path = os.path.join(SINGLE_FILE_DATASET_PATH, "encoded_songs.txt")
with open(encoded_songs_path, "r") as fp:
    songs = fp.read()
int_songs = [mappings[symbol] for symbol in songs.split()]

# Create the dataset and dataloader for evaluation
SEQUENCE_LENGTH = 64
BATCH_SIZE = 128 # Use the same batch size or adjust as needed
dataset = MusicDataset(int_songs, SEQUENCE_LENGTH)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True) # Shuffle should be False for evaluation

# Load the best model checkpoint
CHECKPOINT_DIR = "checkpoints"
best_model_path = os.path.join(CHECKPOINT_DIR, 'model_best.pth')

if not os.path.exists(best_model_path):
    print(f"Error: Best model checkpoint not found at {best_model_path}")
else:
    print(f"Loading best model from: {best_model_path}")
    checkpoint = torch.load(best_model_path)

    # Re-initialize the model with the saved parameters
    model_params = checkpoint['model_params']
    model = MusicLSTM(**model_params)
    model.load_state_dict(checkpoint['model_state_dict'])

    # Set model to evaluation mode and move to device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    print("Calculating accuracy...")
    correct_predictions = 0
    total_predictions = 0

    with torch.no_grad(): # No gradient calculation needed for evaluation
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)

            outputs = model(inputs)

            _, predicted = torch.max(outputs.data, 1)
            total_predictions += targets.size(0)
            correct_predictions += (predicted == targets).sum().item()

    accuracy = correct_predictions / total_predictions
    print(f"Accuracy of the best model on the training dataset: {accuracy:.4f}")

Loading best model from: checkpoints/model_best.pth
Calculating accuracy...
Accuracy of the best model on the training dataset: 0.7715
