# Installation and imports

In [None]:
!pip install pretty-midi

Collecting pretty-midi
  Downloading pretty_midi-0.2.10.tar.gz (5.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m21.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting mido>=1.1.16 (from pretty-midi)
  Downloading mido-1.3.3-py3-none-any.whl.metadata (6.4 kB)
Downloading mido-1.3.3-py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.6/54.6 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: pretty-midi
  Building wheel for pretty-midi (setup.py) ... [?25l[?25hdone
  Created wheel for pretty-midi: filename=pretty_midi-0.2.10-py3-none-any.whl size=5592286 sha256=f356d9b6ac0c2dfb54be216587a653691f15d2ba76318f2640791ac4c9485fb2
  Stored in directory: /root/.cache/pip/wheels/e6/95/ac/15ceaeb2823b04d8e638fd1495357adb8d26c00ccac9d7782e
Successfully built pretty-midi
Installing collected packages: mido, pretty-midi
Successf

In [None]:
import zipfile
import os
import librosa
import pretty_midi
from torch.utils.data import Dataset
import pandas as pd
from collections import defaultdict
from IPython.display import display, Audio
import matplotlib.pyplot as plt
import librosa.display
import random
import zipfile
import numpy as np
import urllib.request

# Dataset class

In [None]:
class GuitarTECHSDataset(Dataset):
    """
    Filters and Attributes:
      - root_dir (str): Base directory where the dataset is stored.
      - sr (int): Target audio sample rate (Hz) used for loading audio (default: 48000).
      - players (list): Which player folders to use (e.g., ['P1', 'P2'] or ['all'] for all players).
      - content_types (list): Which content types to include (e.g., ['chords', 'scales', 'singlenotes', 'techniques'] for P1/P2, or ['music'] for P3, or 'all').
      - modalities (list): Subset of ["directinput", "micamp", "exo", "ego"] indicating which data streams to load.
          - "directinput" and "micamp" are loaded from the audio folder.
          - "exo" and "ego" are loaded from the video folder.
      - slice_size (float): If specified (in seconds), samples are segmented into contiguous slices of this length.
                           The last slice, if shorter than slice_size, will be padded with zeros.
      - slice_range (tuple): Alternatively, a fixed (start, end) time window for all samples (in seconds).
                             Only one of slice_size and slice_range may be set.

    Main Methods:
      - __init__(): Initializes the dataset by downloading & extracting the data (if not already present), scanning
                    the appropriate subfolders for directinput files, building an index of samples, and constructing an
                    expanded index with slice boundaries using the MIDI file durations.
      - __len__(): Returns the number of available slices (or full samples if no slicing is applied).
      - _get_base_dir(player, content): Constructs the expected base directory for a given player and content type.
                                          Follows the naming convention "<player>_<content.lower()>", using a nested
                                          folder if available.
      - _get_midi_path(item): Returns the full path to the MIDI file corresponding to a given sample.
      - load_audio(path): Loads an audio file from a given path at the defined sample rate.
      - slice_audio(audio, start, end): Extracts a segment of the audio between the given start and end times.
                                        If the slice is shorter than the expected duration, it pads the segment with zeros.
      - parse_midi(midi_obj, start, end): Extracts note information from a PrettyMIDI object for notes within the [start, end) window.
                                         Each note label includes 'note', 'onset', 'offset', 'string', and 'fret'.
      - pitch_to_fret(midi_note): Maps a MIDI note number to a fret number based on a default tuning (returns None if not valid).
      - __getitem__(idx): Retrieves a single data slice, including:
            • Metadata: player, content_type, sample identifier, chord_type (if applicable)
            • Data: Sliced (and padded) audio/video modalities
            • Labels: MIDI note information for the slice
            • Timing: The actual slice start and end timestamps
    """
    def __init__(self,
                 root_dir='Guitar-TECHS',
                 sr=48000,
                 players=['all'],
                 content_types='all',
                 modalities='all',
                 slice_size=None,
                 slice_range=None):

        if slice_size and slice_range:
            raise ValueError("Cannot specify both slice_size and slice_range.")

        self.root_dir = root_dir
        self.sr = sr
        self.slice_size = slice_size
        self.slice_range = slice_range

        # Download and extract the dataset if it's not already present.
        if not os.path.exists(self.root_dir):
            print(f"{self.root_dir} not found. Downloading dataset...")
            zip_path = "dataset.zip"
            url = "https://zenodo.org/api/records/14963133/files-archive"
            urllib.request.urlretrieve(url, zip_path)
            print("Download complete. Extracting dataset...")
            with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                zip_ref.extractall(self.root_dir)
            os.remove(zip_path)
            self._extract_nested_zip(self.root_dir)
            print("Dataset downloaded and extracted successfully.")

        # Define available players and content types.
        self.available_players = ['P1', 'P2', 'P3']
        self.available_content = {
            'P1': ['chords', 'scales', 'singlenotes', 'techniques'],
            'P2': ['chords', 'scales', 'singlenotes', 'techniques'],
            'P3': ['music']
        }
        self.valid_modalities = ['directinput', 'micamp', 'exo', 'ego']

        self.players = self.available_players if players in ['all', ['all']] else players
        self.modalities = self.valid_modalities if modalities in ['all', ['all']] else modalities
        assert all(m in self.valid_modalities for m in self.modalities), \
            f"Modalities must be a subset of {self.valid_modalities}"

        self.index = []

        # Build sample index by scanning directinput files.
        for player in self.players:
            valid_contents = self.available_content[player]
            selected_contents = valid_contents if content_types in ['all', ['all']] else content_types
            for content in selected_contents:
                if content not in valid_contents:
                    continue
                # Construct the base directory. Note: folder naming uses lower-case for content.
                base_dir = self._get_base_dir(player, content)
                di_dir = os.path.join(base_dir, 'audio', 'directinput')
                if os.path.exists(di_dir):
                    for fname in os.listdir(di_dir):
                        if fname.startswith('directinput_') and fname.endswith('.wav'):
                            # The sample identifier is based on the file name.
                            sample_value = fname.replace('directinput_', '').replace('.wav', '')
                            chord_type = None
                            if content.lower() == 'chords':
                                prefix = sample_value.split('_')[0]
                                if prefix in ['Set1', 'Set2', 'Set3', 'Set4']:
                                    chord_type = '3-note chord'
                                elif prefix == 'Drop3':
                                    chord_type = '4-note chord'
                            self.index.append({
                                'player': player,
                                'content_type': content,
                                'sample': sample_value,
                                'chord_type': chord_type
                            })

        # Build expanded_index for slicing. Use ceiling division so the last slice may be incomplete.
        if self.slice_size:
            self.expanded_index = []
            for i, sample_meta in enumerate(self.index):
                midi_path = self._get_midi_path(sample_meta)
                if os.path.exists(midi_path):
                    midi_obj = pretty_midi.PrettyMIDI(midi_path)
                    total_duration = midi_obj.get_end_time()
                    num_slices = int(np.ceil(total_duration / self.slice_size))
                    for s in range(num_slices):
                        start = s * self.slice_size
                        end = (s + 1) * self.slice_size  # May exceed total_duration; will be padded.
                        self.expanded_index.append((i, start, end))
        elif self.slice_range:
            self.expanded_index = [(i, self.slice_range[0], self.slice_range[1])
                                   for i in range(len(self.index))]
        else:
            # No slicing: entire sample is one slice.
            self.expanded_index = [(i, None, None) for i in range(len(self.index))]

    def _extract_nested_zip(self, root_dir):
        """
        Recursively extracts all zip files found within the directory tree starting at root_dir.
        After extraction, the original zip files are removed.
        """
        for foldername, subfolders, filenames in os.walk(root_dir):
            for filename in filenames:
                if filename.endswith('.zip'):
                    zip_path = os.path.join(foldername, filename)
                    extract_path = os.path.splitext(zip_path)[0]  # Folder name without .zip
                    print("Extracting:", zip_path, "to", extract_path)
                    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                        zip_ref.extractall(extract_path)
                    os.remove(zip_path)

    def __len__(self):
        return len(self.expanded_index)

    def _get_base_dir(self, player, content):
        """
        Constructs the base directory for a given player and content type.
        Expected naming: "<player>_<content.lower()>", and inside that folder may be a nested folder of the same name.
        """
        dir_name = f"{player}_{content.lower()}"
        candidate = os.path.join(self.root_dir, dir_name)
        nested = os.path.join(candidate, dir_name)
        return nested if os.path.exists(nested) else candidate

    def _get_midi_path(self, item):
        base_dir = self._get_base_dir(item['player'], item['content_type'])
        return os.path.join(base_dir, 'midi', f"midi_{item['sample']}.mid")

    def load_audio(self, path):
        audio, _ = librosa.load(path, sr=self.sr)
        return audio

    def slice_audio(self, audio, start, end):
        """
        Returns a slice of the audio corresponding to [start, end) seconds. When using slice_size,
        if the extracted segment is shorter than the desired slice length (i.e. (end - start) * sr),
        it is padded with zeros at the end.
        """
        start_sample = int(start * self.sr)
        # Determine desired slice length in samples.
        desired_length = int(self.slice_size * self.sr) if self.slice_size else int((end - start) * self.sr)
        end_sample = start_sample + desired_length
        segment = audio[start_sample: min(len(audio), end_sample)]
        if len(segment) < desired_length:
            segment = np.pad(segment, (0, desired_length - len(segment)), mode='constant')
        return segment

    def parse_midi(self, midi_obj, start=None, end=None):
        """
        Extracts MIDI note information from a PrettyMIDI object.
        Only considers notes that fall (at least partially) within the [start, end) window.
        Onset and offset times in the label are relative to the start time.
        """
        labels = []
        for string_index, instrument in enumerate(midi_obj.instruments):
            for note in instrument.notes:
                # Skip notes that lie entirely outside the window.
                if start is not None and (note.end <= start or note.start >= end):
                    continue
                onset = max(note.start, start) if start else note.start
                offset = min(note.end, end) if end else note.end
                labels.append({
                    'note': note.pitch,
                    'onset': onset - start if start else onset,
                    'offset': offset - start if start else offset,
                    'string': string_index + 1,
                    'fret': self.pitch_to_fret(note.pitch)
                })
        return labels

    def pitch_to_fret(self, midi_note, tuning=[40, 45, 50, 55, 59, 64]):
        for string_midi in tuning[::-1]:
            fret = midi_note - string_midi
            if 0 <= fret <= 24:
                return fret
        return None

    def __getitem__(self, idx):
        real_idx, start, end = self.expanded_index[idx]
        item = self.index[real_idx]
        base_dir = self._get_base_dir(item['player'], item['content_type'])

        data = {}
        # Load each modality.
        for dtype in self.modalities:
            if dtype in ['directinput', 'micamp']:
                folder = os.path.join('audio', dtype)
                ext = '.wav'
            elif dtype in ['exo', 'ego']:
                folder = os.path.join('video', dtype)
                ext = '.mp3'
            else:
                continue

            path = os.path.join(base_dir, folder, f"{dtype}_{item['sample']}{ext}")
            if os.path.exists(path):
                modality_data = self.load_audio(path)
                # Slice (and pad if needed) the audio for the desired time window.
                data[dtype] = self.slice_audio(modality_data, start, end) if start is not None else modality_data
            else:
                data[dtype] = None

        # Process MIDI labels for the corresponding time window.
        midi_path = self._get_midi_path(item)
        if os.path.exists(midi_path):
            midi_obj = pretty_midi.PrettyMIDI(midi_path)
            labels = self.parse_midi(midi_obj, start, end)
        else:
            labels = []

        # If no MIDI labels are found, you can choose to return None or an empty dict.
        if not labels:
            return None

        # Return the sample dictionary including sample name and slice timestamps.
        return {
            'player': item['player'],
            'content_type': item['content_type'],
            'sample': item['sample'],
            'chord_type': item.get('chord_type'),
            'data': data,
            'labels': labels,
            'midi_path': midi_path,
            'slice_start': start,
            'slice_end': end
        }


#Initialising the dataset

In [None]:
dataset = GuitarTECHSDataset(
    root_dir='Guitar-TECHS',
    players=['all'],
    content_types=['all'],
    modalities=['all'],
    slice_size=5  # Each sample is split into 5-second segments; last slice padded.
)

Guitar-TECHS not found. Downloading dataset...
Download complete. Extracting dataset...
Extracting: Guitar-TECHS/P2_scales.zip to Guitar-TECHS/P2_scales
Extracting: Guitar-TECHS/P1_singlenotes.zip to Guitar-TECHS/P1_singlenotes
Extracting: Guitar-TECHS/P2_techniques.zip to Guitar-TECHS/P2_techniques
Extracting: Guitar-TECHS/P1_scales.zip to Guitar-TECHS/P1_scales
Extracting: Guitar-TECHS/P2_singlenotes.zip to Guitar-TECHS/P2_singlenotes
Extracting: Guitar-TECHS/P1_chords.zip to Guitar-TECHS/P1_chords
Extracting: Guitar-TECHS/P2_chords.zip to Guitar-TECHS/P2_chords
Extracting: Guitar-TECHS/P1_techniques.zip to Guitar-TECHS/P1_techniques
Extracting: Guitar-TECHS/P3_music.zip to Guitar-TECHS/P3_music
Dataset downloaded and extracted successfully.


In [None]:
print("Number of samples/slices:", len(dataset))

Number of samples/slices: 3639


In [None]:
# Get a sample from the dataset
sample = dataset[0]

# Print basic sample metadata
print("Sample name:", sample['sample'])
print("Player:", sample['player'])
print("Content type:", sample['content_type'])
if 'chord_type' in sample:
    print("Chord type:", sample['chord_type'])
print("Slice time:", sample['slice_start'], "to", sample['slice_end'])

# Display label information as a table
labels_df = pd.DataFrame(sample['labels'])
labels_df_sorted = labels_df.sort_values(by='onset').reset_index(drop=True)
display(labels_df_sorted)

# Play each modality if available (directinput, micamp, exo, ego)
for modality in ['directinput', 'micamp', 'exo', 'ego']:
    modality_data = sample['data'].get(modality)
    if modality_data is not None:
        print(f"\nPlaying {modality} modality:")
        display(Audio(modality_data, rate=dataset.sr))
    else:
        print(f"\nModality '{modality}' is not available for this sample.")


Sample name: Set2_dim
Player: P1
Content type: chords
Chord type: 3-note chord
Slice time: 0 to 5


Unnamed: 0,note,onset,offset,string,fret
0,54,0.002083,3.601042,3,4
1,57,0.061458,3.614583,2,2
2,60,0.114583,3.630208,1,1
3,61,4.052083,5.0,1,2
4,58,4.140625,5.0,2,3
5,55,4.219792,5.0,3,0



Playing directinput modality:



Playing micamp modality:



Playing exo modality:



Playing ego modality:


# Another instance of dataset

In [None]:
dataset = GuitarTECHSDataset(
    root_dir='Guitar-TECHS',
    players=['P1'], #just player 1
    content_types=['scales'],
    modalities=['all'],
    slice_range=[0,3])  # Each sample we get the slice from 0-3 seconds

#dataset doesnt get downloaded again

In [None]:
print("Number of samples/slices:", len(dataset))

Number of samples/slices: 12


In [None]:
# Get a sample from the dataset
sample = dataset[10]

# Print basic sample metadata
print("Sample name:", sample['sample'])
print("Player:", sample['player'])
print("Content type:", sample['content_type'])
if 'chord_type' in sample:
    print("Chord type:", sample['chord_type'])
print("Slice time:", sample['slice_start'], "to", sample['slice_end'])

# Display label information as a table
labels_df = pd.DataFrame(sample['labels'])
labels_df_sorted = labels_df.sort_values(by='onset').reset_index(drop=True)
display(labels_df_sorted)

# Play each modality if available (directinput, micamp, exo, ego)
for modality in ['directinput', 'micamp', 'exo', 'ego']:
    modality_data = sample['data'].get(modality)
    if modality_data is not None:
        print(f"\nPlaying {modality} modality:")
        display(Audio(modality_data, rate=dataset.sr))
    else:
        print(f"\nModality '{modality}' is not available for this sample.")


Sample name: F
Player: P1
Content type: scales
Chord type: None
Slice time: 0 to 3


Unnamed: 0,note,onset,offset,string,fret
0,41,0.045833,0.51875,6,1
1,43,0.55625,1.013542,6,3
2,45,1.051042,1.477083,6,0
3,46,1.532292,1.985417,5,1
4,48,2.029167,2.448958,5,3
5,50,2.5125,2.9875,5,0



Playing directinput modality:



Playing micamp modality:



Playing exo modality:



Playing ego modality:
