# Installation and imports

In [None]:
!pip install pretty-midi

In [None]:
import zipfile
import os
import librosa
import pretty_midi
from torch.utils.data import Dataset
import pandas as pd
from collections import defaultdict
from IPython.display import display, Audio
import matplotlib.pyplot as plt
import librosa.display
import random
import zipfile
import numpy as np
import urllib.request
import requests
from tqdm import tqdm

# Dataset class

In [None]:
class GuitarTECHSDataset(Dataset):
    """
    Filters and Attributes:
      - root_dir (str): Base directory where the dataset is stored.
      - sr (int): Target audio sample rate (Hz) used for loading audio (default: 48000).
      - players (list): Which player folders to use (e.g., ['P1', 'P2'] or ['all'] for all players).
      - content_types (list): Which content types to include (e.g., ['chords', 'scales', 'singlenotes', 'techniques'] for P1/P2, or ['music'] for P3, or 'all').
      - modalities (list): Subset of ["directinput", "micamp", "exo", "ego"] indicating which data streams to load.
      - slice_dur (float): If specified (in seconds), samples are segmented into contiguous slices of this length.
                           The last slice, if shorter than slice_dur, will be padded with zeros.
      - slice_range (tuple): Alternatively, a fixed (start, end) time window for all samples (in seconds).
                             Only one of slice_dur and slice_range may be set.
      - slice_overlap(float): overlap time between consecutive slices. Has to be < slice_dur

    Main Methods:
      - __init__(): Initializes the dataset by downloading & extracting the data (if not already present), scanning
                    the appropriate subfolders for directinput files, building an index of samples, and constructing an
                    expanded index with slice boundaries using the MIDI file durations.
      - __len__(): Returns the number of available slices (or full samples if no slicing is applied).
      - _get_base_dir(player, content): Constructs the expected base directory for a given player and content type.
                                          Follows the naming convention "<player>_<content.lower()>", using a nested
                                          folder if available.
      - _get_midi_path(item): Returns the full path to the MIDI file corresponding to a given sample.
      - load_audio(path): Loads an audio file from a given path at the defined sample rate.
      - slice_audio(audio, start, end): Extracts a segment of the audio between the given start and end times.
                                        If the slice is shorter than the expected duration, it pads the segment with zeros.
      - parse_midi(midi_obj, start, end): Extracts note information from a PrettyMIDI object for notes within the [start, end) window.
                                         Each note label includes 'note', 'onset', 'offset', 'string', and 'fret'.
      - pitch_to_fret(midi_note): Maps a MIDI note number to a fret number based on a default tuning (returns None if not valid).
      - __getitem__(idx): Retrieves a single data slice, including:
            • Metadata: player, content_type, sample identifier, chord_type (if applicable)
            • Data: Sliced (and padded) audio/video modalities
            • Labels: MIDI note information for the slice
            • Timing: The actual slice start and end timestamps
    """
    def __init__(self,
                 root_dir='Guitar-TECHS',
                 sr=48000,
                 players=['all'],
                 content_types='all',
                 modalities='all',
                 slice_dur=None,
                 slice_range=None):
      
        if slice_dur and slice_range:
            raise ValueError("Cannot specify both slice_dur and slice_range.")
        if slice_overlap >= slice_dur:
            raise ValueError("slice_overlap must be less than slice_dur.")
        self.root_dir = root_dir
        if not os.path.exists(self.root_dir):
           self._download_and_extract_dataset()
        self.sr = sr
        self.slice_dur = slice_dur
        self.slice_range = slice_range
        self.slice_overlap = slice_overlap

        # Define available players and content types.
        AVAILABLE_PLAYERS = ['P1', 'P2', 'P3']
        AVAILABLE_CONTENT = {
            'P1': ['chords', 'scales', 'singlenotes', 'techniques'],
            'P2': ['chords', 'scales', 'singlenotes', 'techniques'],
            'P3': ['music']
        }
        VALID_MODALITIES = ['directinput', 'micamp', 'exo', 'ego']

        self.players = AVAILABLE_PLAYERS if players in ['all', ['all']] else players
        assert all(p in AVAILABLE_PLAYERS for p in self.players), \
            f"Players must be a subset of {AVAILABLE_PLAYERS}"
        self.modalities = VALID_MODALITIES if modalities in ['all', ['all']] else modalities
        assert all(m in VALID_MODALITIES for m in self.modalities), \
            f"Modalities must be a subset of {VALID_MODALITIES}"

        self.index = []

        # Build sample index by scanning directinput files.
        for player in self.players:
            valid_contents = AVAILABLE_CONTENT[player]
            selected_contents = valid_contents if content_types in ['all', ['all']] else content_types
            for content in selected_contents:
                if content not in valid_contents:
                    print(f"Skipping content '{content}' for player '{player}' — not available in this player's dataset.")
                    continue
                # Construct the base directory. Note: folder naming uses lower-case for content.
                base_dir = self._get_base_dir(player, content)
                di_dir = os.path.join(base_dir, 'audio', 'directinput')
                if os.path.exists(di_dir):
                    for fname in os.listdir(di_dir):
                        if fname.startswith('directinput_') and fname.endswith('.wav'):
                            # The sample identifier is based on the file name.
                            sample_value = fname.replace('directinput_', '').replace('.wav', '')
                            chord_type = None
                            if content.lower() == 'chords':
                                prefix = sample_value.split('_')[0]
                                if prefix in ['Set1', 'Set2', 'Set3', 'Set4']:
                                    chord_type = '3-note chord'
                                elif prefix == 'Drop3':
                                    chord_type = '4-note chord'
                            self.index.append({
                                'player': player,
                                'content_type': content,
                                'sample': sample_value,
                                'chord_type': chord_type
                            })

        if self.slice_dur:
          self.expanded_index = []
          for i, sample_meta in enumerate(self.index):
              base_dir = self._get_base_dir(sample_meta['player'], sample_meta['content_type'])
              # use micamp for total length 
              audio_path = os.path.join(base_dir, 'audio', 'micamp', f"micamp_{sample_meta['sample']}.wav")

              if not os.path.exists(audio_path):
                  continue

              duration = librosa.get_duration(path=audio_path)
              total_samples = int(duration * self.sr)

              # Load full audio
              y, _ = librosa.load(audio_path, sr=self.sr)

              slice_samples = int(self.slice_dur * self.sr)
              overlap_samples = int(self.slice_overlap * self.sr)
              hop_length = slice_samples - overlap_samples

              # Pad the signal 
              pad_width = (slice_samples - len(y) % hop_length) % hop_length
              y_padded = np.pad(y, (0, pad_width), mode='constant')

              # Use librosa utils for slicing 
              frames = librosa.util.frame(y_padded, frame_length=slice_samples, hop_length=hop_length)

              # For each frame, compute start and end time (in seconds)
              for s in range(frames.shape[1]):
                  start_sample = s * hop_length
                  start_sec = start_sample / self.sr
                  end_sec = start_sec + self.slice_dur
                  self.expanded_index.append((i, start_sec, end_sec))


    def _download_and_extract_dataset(self):
        """
        Downloads and extracts the Guitar-TECHS dataset if it's not already present.
        Uses a progress bar to show download progress.
        """
        print(f"{self.root_dir} not found. Downloading dataset...")

        zip_path = "dataset.zip"
        url = "https://zenodo.org/api/records/14963133/files-archive"

        # Define known total size in bytes (3942.06 MB)
        total_size = int(3942.06 * 1024 * 1024)
        block_size = 1024  # 1 Kilobyte

        response = requests.get(url, stream=True)

        with open(zip_path, 'wb') as f, tqdm(
            desc="Downloading",
            total=total_size,
            unit='iB',
            unit_scale=True,
            unit_divisor=1024,
        ) as bar:
            for data in response.iter_content(block_size):
                f.write(data)
                bar.update(len(data))

        print("Download complete. Extracting dataset...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(self.root_dir)
        os.remove(zip_path)

        self._extract_nested_zip(self.root_dir)
        print("Dataset downloaded and extracted successfully.")


    def _extract_nested_zip(self, root_dir):
        """
        Recursively extracts all zip files found within the directory tree starting at root_dir.
        After extraction, the original zip files are removed.
        """
        for foldername, subfolders, filenames in os.walk(root_dir):
            for filename in filenames:
                if filename.endswith('.zip'):
                    zip_path = os.path.join(foldername, filename)
                    extract_path = os.path.splitext(zip_path)[0]  # Folder name without .zip
                    print("Extracting:", zip_path, "to", extract_path)
                    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                        zip_ref.extractall(extract_path)
                    os.remove(zip_path)

    def __len__(self):
        return len(self.expanded_index)

    def _get_base_dir(self, player, content):
        """
        Constructs the base directory for a given player and content type.
        Expected naming: "<player>_<content.lower()>", and inside that folder may be a nested folder of the same name.
        """
        dir_name = f"{player}_{content.lower()}"
        candidate = os.path.join(self.root_dir, dir_name)
        nested = os.path.join(candidate, dir_name)
        return nested if os.path.exists(nested) else candidate

    def _get_midi_path(self, item):
        base_dir = self._get_base_dir(item['player'], item['content_type'])
        return os.path.join(base_dir, 'midi', f"midi_{item['sample']}.mid")

    def load_audio(self, path):
        audio, _ = librosa.load(path, sr=self.sr)
        return audio

    def slice_audio(self, audio, start, end):
        """
        Returns a slice of the audio corresponding to [start, end) seconds. When using slice_dur,
        if the extracted segment is shorter than the desired slice length (i.e. (end - start) * sr),
        it is padded with zeros at the end.
        """
        start_sample = int(start * self.sr)
        # Determine desired slice length in samples.
        desired_length = int(self.slice_dur * self.sr) if self.slice_dur else int((end - start) * self.sr)
        end_sample = start_sample + desired_length
        segment = audio[start_sample: min(len(audio), end_sample)]
        if len(segment) < desired_length:
            segment = np.pad(segment, (0, desired_length - len(segment)), mode='constant')
        return segment

    def parse_midi(self, midi_obj, start=None, end=None):
        """
        Extracts MIDI note information from a PrettyMIDI object.
        Only considers notes that fall (at least partially) within the [start, end) window.
        Onset and offset times in the label are relative to the start time.
        """
        labels = []
        for string_index, instrument in enumerate(midi_obj.instruments):
            for note in instrument.notes:
                # Skip notes that lie entirely outside the window.
                if start is not None and (note.end <= start or note.start >= end):
                    continue
                onset = max(note.start, start) if start else note.start
                offset = min(note.end, end) if end else note.end
                labels.append({
                    'note': note.pitch,
                    'onset': onset - start if start else onset,
                    'offset': offset - start if start else offset,
                    'string': string_index + 1,
                    'fret': self.pitch_to_fret(note.pitch)
                })
        return labels

    def pitch_to_fret(self, midi_note, tuning=[40, 45, 50, 55, 59, 64]):
        for string_midi in tuning[::-1]:
            fret = midi_note - string_midi
            if 0 <= fret <= 24:
                return fret
        return None

    def __getitem__(self, idx):
        real_idx, start, end = self.expanded_index[idx]
        item = self.index[real_idx]
        base_dir = self._get_base_dir(item['player'], item['content_type'])

        data = {}
        # Load each modality.
        for dtype in self.modalities:
            if dtype in ['directinput', 'micamp']:
                folder = os.path.join('audio', dtype)
                ext = '.wav'
            elif dtype in ['exo', 'ego']:
                folder = os.path.join('video', dtype)
                ext = '.mp3'
            else:
                continue

            path = os.path.join(base_dir, folder, f"{dtype}_{item['sample']}{ext}")
            if os.path.exists(path):
                modality_data = self.load_audio(path)
                # Slice (and pad if needed) the audio for the desired time window.
                data[dtype] = self.slice_audio(modality_data, start, end) if start is not None else modality_data
            else:
                data[dtype] = None

        # Process MIDI labels for the corresponding time window.
        midi_path = self._get_midi_path(item)
        if os.path.exists(midi_path):
            midi_obj = pretty_midi.PrettyMIDI(midi_path)
            labels = self.parse_midi(midi_obj, start, end)
        else:
            labels = []

        # If no MIDI labels are found, you can choose to return None or an empty dict.
        if not labels:
            return None

        # Return the sample dictionary including sample name and slice timestamps.
        return {
            'player': item['player'],
            'content_type': item['content_type'],
            'sample': item['sample'],
            'chord_type': item.get('chord_type'),
            'data': data,
            'labels': labels,
            'midi_path': midi_path,
            'slice_start': start,
            'slice_end': end
        }


#Initialising the dataset

In [None]:
dataset = GuitarTECHSDataset(
    root_dir='Guitar-TECHS',
    players=['all'],
    content_types=['all'],
    modalities=['all'],
    slice_dur=5, slice_overlap= 1  
)

In [None]:
print("Number of samples/slices:", len(dataset))

#Looking at samples

In [None]:
# Get a sample from the dataset
sample = dataset[1000]

# Print basic sample metadata
print("Sample name:", sample['sample'])
print("Player:", sample['player'])
print("Content type:", sample['content_type'])
if 'chord_type' in sample:
    print("Chord type:", sample['chord_type'])
print("Slice time:", sample['slice_start'], "to", sample['slice_end'])

# Display label information as a table
labels_df = pd.DataFrame(sample['labels'])
labels_df_sorted = labels_df.sort_values(by='onset').reset_index(drop=True)
display(labels_df_sorted)

# Play each modality if available (directinput, micamp, exo, ego)
for modality in ['directinput', 'micamp', 'exo', 'ego']:
    modality_data = sample['data'].get(modality)
    if modality_data is not None:
        print(f"\nPlaying {modality} modality:")
        display(Audio(modality_data, rate=dataset.sr))
    else:
        print(f"\nModality '{modality}' is not available for this sample.")


In [None]:
# Get the consecutive/next sample from the dataset (to observe slice times overlap)
sample = dataset[1001]

# Print basic sample metadata
print("Sample name:", sample['sample'])
print("Player:", sample['player'])
print("Content type:", sample['content_type'])
if 'chord_type' in sample:
    print("Chord type:", sample['chord_type'])
print("Slice time:", sample['slice_start'], "to", sample['slice_end'])

# Display label information as a table
labels_df = pd.DataFrame(sample['labels'])
labels_df_sorted = labels_df.sort_values(by='onset').reset_index(drop=True)
display(labels_df_sorted)

# Play each modality if available (directinput, micamp, exo, ego)
for modality in ['directinput', 'micamp', 'exo', 'ego']:
    modality_data = sample['data'].get(modality)
    if modality_data is not None:
        print(f"\nPlaying {modality} modality:")
        display(Audio(modality_data, rate=dataset.sr))
    else:
        print(f"\nModality '{modality}' is not available for this sample.")
