In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import librosa
import json
import matplotlib.pyplot as plt
from pathlib import Path
import pandas as pd
from tqdm import tqdm

class GuitarSetDataset(Dataset):
    def __init__(self, data_dir, split='train', sample_rate=44100, hop_length=512, segment_size=5):
        """
        Dataset for GuitarSet (https://guitarset.weebly.com/)

        Args:
            data_dir: Root directory of GuitarSet
            split: 'train' or 'test'
            sample_rate: Audio sample rate
            hop_length: Hop length for STFT
            segment_size: Length of audio segments in seconds
        """
        self.data_dir = Path(data_dir)
        self.sample_rate = sample_rate
        self.hop_length = hop_length
        self.segment_size = segment_size

        # All player IDs in GuitarSet
        all_players = ['00', '01', '02', '03', '04', '05']

        # Split into train and test (adjust as needed)
        if split == 'train':
            self.players = all_players[:4]  # Players 00-03 for training
        else:
            self.players = all_players[4:]  # Players 04-05 for testing

        self.file_pairs = self._get_file_pairs()

    def _get_file_pairs(self):
        """Get all matching audio and annotation file pairs"""
        file_pairs = []

        # GuitarSet structure: {player}/{style}/{piece}
        for player in self.players:
            player_dir = self.data_dir / f"guitarist_{player}"

            # Find all audio files
            audio_files = list(player_dir.glob("**/*audio_hex_cln.wav"))

            for audio_file in audio_files:
                # Find corresponding annotation file
                jams_file = audio_file.with_name(audio_file.stem.replace('audio_hex_cln', 'jams'))
                jams_file = jams_file.with_suffix('.jams')

                if jams_file.exists():
                    file_pairs.append((audio_file, jams_file))

        return file_pairs

    def __len__(self):
        return len(self.file_pairs)

    def _parse_jams(self, jams_path):
        """Parse JAMS file to extract guitar note events"""
        import jams

        # Load JAMS file
        jam = jams.load(jams_path)

        # Get note annotations
        notes_data = []
        for annotation in jam.annotations:
            if annotation.namespace == 'note_midi':
                for note in annotation.data:
                    notes_data.append({
                        'time': float(note.time),
                        'duration': float(note.duration),
                        'value': int(note.value),
                        'string': int(note.value) % 6,  # Extract string information
                        'fret': (int(note.value) - 40) // 6  # Approximate fret position
                    })

        return notes_data

    def _notes_to_piano_roll(self, notes, duration):
        """Convert note events to a piano roll format"""
        # Create piano roll array
        n_frames = int(duration * self.sample_rate / self.hop_length)
        piano_roll = np.zeros((128, n_frames), dtype=np.float32)

        for note in notes:
            # Convert time to frame
            start_frame = int(note['time'] * self.sample_rate / self.hop_length)
            end_frame = int((note['time'] + note['duration']) * self.sample_rate / self.hop_length)

            # Limit to array bounds
            end_frame = min(end_frame, n_frames)

            if start_frame < n_frames and start_frame < end_frame:
                piano_roll[note['value'], start_frame:end_frame] = 1.0

        return piano_roll

    def _create_guitar_roll(self, notes, duration):
        """Create a guitar-specific representation (string+fret)"""
        n_frames = int(duration * self.sample_rate / self.hop_length)

        # 6 strings × max fret position (20)
        guitar_roll = np.zeros((6, 21, n_frames), dtype=np.float32)

        for note in notes:
            # Convert time to frame
            start_frame = int(note['time'] * self.sample_rate / self.hop_length)
            end_frame = int((note['time'] + note['duration']) * self.sample_rate / self.hop_length)

            # Limit to array bounds
            end_frame = min(end_frame, n_frames)

            # Get string and fret
            string = note.get('string', 0)
            fret = min(note.get('fret', 0), 20)  # Limit to 20 frets

            # String-fret combination
            if 0 <= string < 6 and 0 <= fret <= 20 and start_frame < n_frames and start_frame < end_frame:
                guitar_roll[string, fret, start_frame:end_frame] = 1.0

        return guitar_roll

    def __getitem__(self, idx):
        audio_path, jams_path = self.file_pairs[idx]

        # Load audio
        audio, sr = librosa.load(audio_path, sr=self.sample_rate, mono=True)
        duration = len(audio) / sr

        # Extract features
        mel_spec = librosa.feature.melspectrogram(
            y=audio,
            sr=sr,
            n_mels=128,
            hop_length=self.hop_length
        )
        mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)

        # Normalize
        mel_spec_db = (mel_spec_db - mel_spec_db.mean()) / (mel_spec_db.std() + 1e-8)

        # Parse annotations
        notes = self._parse_jams(jams_path)

        # Create target outputs
        piano_roll = self._notes_to_piano_roll(notes, duration)
        guitar_roll = self._create_guitar_roll(notes, duration)

        # Convert to torch tensors
        mel_spec_db = torch.FloatTensor(mel_spec_db)
        piano_roll = torch.FloatTensor(piano_roll)
        guitar_roll = torch.FloatTensor(guitar_roll)

        return {
            'input': mel_spec_db,
            'piano_roll': piano_roll,
            'guitar_roll': guitar_roll,
            'audio_path': str(audio_path)
        }

class GuitarTranscriber(nn.Module):
    def __init__(self):
        super(GuitarTranscriber, self).__init__()

        # Feature extraction layers
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        # Bidirectional LSTM for temporal modeling
        self.lstm = nn.LSTM(
            input_size=256 * 8,  # Depends on input size and pooling
            hidden_size=512,
            num_layers=2,
            batch_first=True,
            bidirectional=True,
            dropout=0.3
        )

        # Output layers for note detection
        self.note_classifier = nn.Sequential(
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 128)  # 128 MIDI notes
        )

        # Output layers for guitar-specific representation
        self.guitar_classifier = nn.Sequential(
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 6 * 21)  # 6 strings × 21 frets (0-20)
        )

    def forward(self, x):
        batch_size, freq_bins, time_frames = x.shape

        # Add channel dimension
        x = x.unsqueeze(1)

        # Feature extraction
        x = self.conv_layers(x)

        # Reshape for LSTM: [batch, channels, freq, time] -> [batch, time, channels*freq]
        x = x.permute(0, 3, 1, 2)
        x = x.reshape(batch_size, x.shape[1], -1)

        # LSTM for temporal modeling
        x, _ = self.lstm(x)

        # Note prediction
        note_output = self.note_classifier(x)

        # Guitar-specific prediction
        guitar_output = self.guitar_classifier(x)
        guitar_output = guitar_output.reshape(batch_size, time_frames, 6, 21)

        return note_output, guitar_output

def train_guitar_transcriber(data_dir, num_epochs=50, batch_size=16, learning_rate=0.001):
    # Create datasets and dataloaders
    train_dataset = GuitarSetDataset(data_dir, split='train')
    test_dataset = GuitarSetDataset(data_dir, split='test')

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    # Model, loss function and optimizer
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = GuitarTranscriber().to(device)

    criterion_notes = nn.BCEWithLogitsLoss()
    criterion_guitar = nn.BCEWithLogitsLoss()

    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )

    # Training loop
    history = {'train_loss': [], 'val_loss': []}
    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0

        for batch in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} - Training'):
            inputs = batch['input'].to(device)
            piano_roll_targets = batch['piano_roll'].to(device)
            guitar_roll_targets = batch['guitar_roll'].to(device)

            # Forward pass
            note_outputs, guitar_outputs = model(inputs)

            # Calculate losses
            note_loss = criterion_notes(note_outputs, piano_roll_targets)
            guitar_loss = criterion_guitar(guitar_outputs, guitar_roll_targets)

            # Combined loss (weighted as needed)
            loss = note_loss + 0.5 * guitar_loss

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()

            train_loss += loss.item()

        avg_train_loss = train_loss / len(train_loader)

        # Validation phase
        model.eval()
        val_loss = 0.0

        with torch.no_grad():
            for batch in tqdm(test_loader, desc=f'Epoch {epoch+1}/{num_epochs} - Validation'):
                inputs = batch['input'].to(device)
                piano_roll_targets = batch['piano_roll'].to(device)
                guitar_roll_targets = batch['guitar_roll'].to(device)

                # Forward pass
                note_outputs, guitar_outputs = model(inputs)

                # Calculate losses
                note_loss = criterion_notes(note_outputs, piano_roll_targets)
                guitar_loss = criterion_guitar(guitar_outputs, guitar_roll_targets)

                # Combined loss
                loss = note_loss + 0.5 * guitar_loss

                val_loss += loss.item()

        avg_val_loss = val_loss / len(test_loader)

        # Update learning rate
        scheduler.step(avg_val_loss)

        # Print progress
        print(f'Epoch {epoch+1}/{num_epochs} - Train Loss: {avg_train_loss:.4f} - Val Loss: {avg_val_loss:.4f}')

        # Save history
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)

        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), 'guitar_transcriber_best.pth')

    # Save final model
    torch.save(model.state_dict(), 'guitar_transcriber_final.pth')

    # Plot training history
    plt.figure(figsize=(10, 6))
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training History')
    plt.savefig('training_history.png')

    return model, history


In [6]:
def evaluate_model(model, test_loader, device):
    """Evaluate the model performance"""
    model.eval()

    all_preds = []
    all_targets = []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc='Evaluating'):
            inputs = batch['input'].to(device)
            piano_roll_targets = batch['piano_roll'].to(device)

            # Forward pass
            note_outputs, _ = model(inputs)

            # Apply sigmoid to get probabilities
            note_probs = torch.sigmoid(note_outputs)

            # Convert to binary predictions using threshold
            note_preds = (note_probs > 0.5).float()

            # Collect predictions and targets for metrics
            all_preds.append(note_preds.cpu().numpy())
            all_targets.append(piano_roll_targets.cpu().numpy())

    # Concatenate all batches
    all_preds = np.concatenate(all_preds, axis=0)
    all_targets = np.concatenate(all_targets, axis=0)

    # Calculate metrics
    precision, recall, f1 = calculate_frame_metrics(all_targets, all_preds)

    print(f"Frame-level Metrics:")
    print(f"  Precision: {precision:.4f}")
    print(f"  Recall: {recall:.4f}")
    print(f"  F1 Score: {f1:.4f}")

    return {'precision': precision, 'recall': recall, 'f1': f1}

def calculate_frame_metrics(targets, predictions):
    """Calculate frame-level precision, recall, and F1 score"""
    # Reshape to [frames, notes]
    targets_flat = targets.reshape(-1, targets.shape[-1])
    preds_flat = predictions.reshape(-1, predictions.shape[-1])

    # Calculate true positives, false positives, false negatives
    true_positives = np.sum(np.logical_and(targets_flat == 1, preds_flat == 1))
    false_positives = np.sum(np.logical_and(targets_flat == 0, preds_flat == 1))
    false_negatives = np.sum(np.logical_and(targets_flat == 1, preds_flat == 0))

    # Calculate metrics
    precision = true_positives / (true_positives + false_positives + 1e-8)
    recall = true_positives / (true_positives + false_negatives + 1e-8)
    f1 = 2 * precision * recall / (precision + recall + 1e-8)

    return precision, recall, f1

def transcribe_guitar_audio(model, audio_path, output_dir="transcriptions"):
    """Transcribe a guitar audio file using the trained model"""
    device = next(model.parameters()).device
    model.eval()

    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    file_name = os.path.splitext(os.path.basename(audio_path))[0]

    # Load and preprocess audio
    y, sr = librosa.load(audio_path, sr=44100, mono=True)

    # Extract features
    mel_spec = librosa.feature.melspectrogram(
        y=y,
        sr=sr,
        n_mels=128,
        hop_length=512
    )
    mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)

    # Normalize
    mel_spec_db = (mel_spec_db - mel_spec_db.mean()) / (mel_spec_db.std() + 1e-8)

    # Convert to torch tensor
    mel_spec_db = torch.FloatTensor(mel_spec_db).unsqueeze(0).to(device)

    # Forward pass
    with torch.no_grad():
        note_outputs, guitar_outputs = model(mel_spec_db)

        # Apply sigmoid to get probabilities
        note_probs = torch.sigmoid(note_outputs)
        guitar_probs = torch.sigmoid(guitar_outputs)

        # Convert to binary predictions using threshold
        note_preds = (note_probs > 0.5).float()
        guitar_preds = (guitar_probs > 0.5).float()

    # Convert predictions to numpy
    note_preds = note_preds.squeeze().cpu().numpy()
    guitar_preds = guitar_preds.squeeze().cpu().numpy()

    # Convert frame-level predictions to note events
    midi_data = frames_to_midi(note_preds, sr=sr, hop_length=512)

    # Add guitar-specific information
    add_guitar_information(midi_data, guitar_preds, sr=sr, hop_length=512)

    # Save MIDI file
    midi_output_path = os.path.join(output_dir, f"{file_name}.mid")
    midi_data.write(midi_output_path)

    # Create visualization
    visualization_path = os.path.join(output_dir, f"{file_name}_transcription.png")
    create_transcription_visualization(
        mel_spec_db.squeeze().cpu().numpy(),
        note_preds,
        guitar_preds,
        visualization_path,
        sr=sr,
        hop_length=512
    )

    return midi_output_path

def frames_to_midi(note_frames, sr=44100, hop_length=512, velocity=100):
    """Convert frame-wise predictions to MIDI notes"""
    midi_data = pretty_midi.PrettyMIDI()

    # Add acoustic guitar instrument
    guitar = pretty_midi.Instrument(program=24)  # 24 = Acoustic Guitar (nylon)

    # Helper function to convert frame index to time
    def frame_to_time(frame_idx):
        return frame_idx * hop_length / sr

    # Process each pitch
    for pitch in range(note_frames.shape[1]):
        # Skip pitches outside guitar range (E2 to E6)
        if pitch < 40 or pitch > 88:
            continue

        # Find continuous segments
        frames = note_frames[:, pitch]
        begins = np.where(np.diff(np.concatenate(([0], frames))) > 0)[0]
        ends = np.where(np.diff(np.concatenate((frames, [0]))) < 0)[0]

        # Create notes for each segment
        for start, end in zip(begins, ends):
            # Only create notes with minimum duration
            if end - start >= 2:  # At least 2 frames
                note = pretty_midi.Note(
                    velocity=velocity,
                    pitch=pitch,
                    start=frame_to_time(start),
                    end=frame_to_time(end)
                )
                guitar.notes.append(note)

    midi_data.instruments.append(guitar)
    return midi_data

def add_guitar_information(midi_data, guitar_preds, sr=44100, hop_length=512):
    """Add guitar-specific information (string/fret) to MIDI file"""
    if len(midi_data.instruments) == 0:
        return

    guitar = midi_data.instruments[0]

    # Helper function to convert frame index to time
    def frame_to_time(frame_idx):
        return frame_idx * hop_length / sr

    # Add string/fret information as text annotations
    for note in guitar.notes:
        start_frame = int(note.start * sr / hop_length)
        end_frame = int(note.end * sr / hop_length)

        # Limit to valid frame indices
        start_frame = max(0, min(start_frame, guitar_preds.shape[2] - 1))
        end_frame = max(0, min(end_frame, guitar_preds.shape[2] - 1))

        # Skip if no valid frames
        if start_frame >= end_frame:
            continue

        # Find most likely string/fret combination for this note
        max_prob = 0
        best_string = 0
        best_fret = 0

        for string in range(6):
            for fret in range(21):
                # Average probability over the note duration
                avg_prob = np.mean(guitar_preds[string, fret, start_frame:end_frame])

                if avg_prob > max_prob:
                    max_prob = avg_prob
                    best_string = string
                    best_fret = fret

        # Store information in the note (using pretty_midi's internal storage)
        # This would be accessible in the MIDI file's metadata
        # For visualization, you might want to store this separately
        note.string = best_string
        note.fret = best_fret

def create_transcription_visualization(mel_spec, note_preds, guitar_preds, output_path, sr=44100, hop_length=512):
    """Create visualization of the transcription results"""
    plt.figure(figsize=(15, 10))

    # Plot mel spectrogram
    plt.subplot(3, 1, 1)
    plt.title("Mel Spectrogram")
    librosa.display.specshow(
        mel_spec,
        sr=sr,
        hop_length=hop_length,
        x_axis='time',
        y_axis='mel'
    )
    plt.colorbar(format='%+2.0f dB')

    # Plot transcribed notes
    plt.subplot(3, 1, 2)
    plt.title("Transcribed Notes")
    librosa.display.specshow(
        note_preds.T,
        sr=sr,
        hop_length=hop_length,
        x_axis='time',
        y_axis='midi'
    )
    plt.colorbar()

    # Plot guitar-specific information (combine strings and frets)
    plt.subplot(3, 1, 3)
    plt.title("Guitar String/Fret Activations")
    # Sum over frets to get string activations for visualization
    string_activations = np.sum(guitar_preds, axis=1)
    librosa.display.specshow(
        string_activations,
        sr=sr,
        hop_length=hop_length,
        x_axis='time'
    )
    plt.yticks(np.arange(6), ['E', 'A', 'D', 'G', 'B', 'e'])
    plt.ylabel('String')
    plt.colorbar()

    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()

def download_guitarset():
    """Helper function to download GuitarSet dataset from the correct location"""
    import urllib.request
    import zipfile
    import os

    # Current URLs for GuitarSet (as of 2024)
    urls = [
        'https://zenodo.org/records/3371780/files/audio_hex-pickup_original.zip?download=1',
        'https://zenodo.org/records/3371780/files/annotation.zip?download=1'
    ]

    # Create directory
    os.makedirs('guitarset', exist_ok=True)

    # Download and extract
    for url in urls:
        file_name = os.path.join('guitarset', os.path.basename(url))
        print(f"Downloading {url}...")

        try:
            # Download
            urllib.request.urlretrieve(url, file_name)

            # Extract
            print(f"Extracting {file_name}...")
            with zipfile.ZipFile(file_name, 'r') as zip_ref:
                zip_ref.extractall('guitarset')

        except urllib.error.HTTPError as e:
            print(f"Error downloading {url}: {e}")
            print("Attempting alternative download method...")

            # Try alternative method or sources
            alternative_source = url.replace('zenodo.org/records/', 'zenodo.org/record/')
            try:
                print(f"Trying alternative URL: {alternative_source}")
                urllib.request.urlretrieve(alternative_source, file_name)

                print(f"Extracting {file_name}...")
                with zipfile.ZipFile(file_name, 'r') as zip_ref:
                    zip_ref.extractall('guitarset')

            except Exception as e2:
                print(f"Alternative download also failed: {e2}")
                print("\nSuggested manual download steps:")
                print("1. Visit https://guitarset.weebly.com/")
                print("2. Download the dataset files manually")
                print("3. Extract them to a folder named 'guitarset'")

    # Check if files were downloaded
    if os.path.exists('guitarset/annotation') and os.path.exists('guitarset/audio'):
        print("Download complete.")
    else:
        print("\nDownload may have been incomplete. Please try the manual download steps:")
        print("1. Visit https://guitarset.weebly.com/")
        print("2. Download the dataset files manually")
        print("3. Extract them to a folder named 'guitarset'")



In [7]:
def main():
    # Check if CUDA is available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Option to download dataset
    download_dataset = input("Download GuitarSet dataset? (y/n): ").lower() == 'y'
    if download_dataset:
        download_guitarset()

    # Set dataset path
    data_dir = 'guitarset'

    # Training options
    train_model = input("Train a new model? (y/n): ").lower() == 'y'

    if train_model:
        model, history = train_guitar_transcriber(
            data_dir=data_dir,
            num_epochs=30,
            batch_size=16,
            learning_rate=0.001
        )
    else:
        # Load pre-trained model
        model = GuitarTranscriber().to(device)
        model_path = input("Enter path to pre-trained model (or 'guitar_transcriber_best.pth' for default): ")
        if not model_path:
            model_path = 'guitar_transcriber_best.pth'

        model.load_state_dict(torch.load(model_path, map_location=device))

    # Transcribe a sample file
    test_audio = input("Enter path to guitar audio file to transcribe: ")
    if test_audio:
        output_midi = transcribe_guitar_audio(model, test_audio)
        print(f"Transcription saved to {output_midi}")

if __name__ == "__main__":
    main()

Using device: cpu
Download GuitarSet dataset? (y/n): y
Downloading https://zenodo.org/records/3371780/files/audio_hex-pickup_original.zip?download=1...


KeyboardInterrupt: 