In [None]:
from google.colab import drive
drive.mount('/content/drive')
!pip install jams

Mounted at /content/drive
Collecting jams
  Downloading jams-0.3.4.tar.gz (51 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.3/51.3 kB[0m [31m800.7 kB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting sortedcontainers>=2.0.0 (from jams)
  Downloading sortedcontainers-2.4.0-py2.py3-none-any.whl.metadata (10 kB)
Collecting mir_eval>=0.5 (from jams)
  Downloading mir_eval-0.7.tar.gz (90 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m90.7/90.7 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting future (from mir_eval>=0.5->jams)
  Downloading future-1.0.0-py3-none-any.whl.metadata (4.0 kB)
Downloading sortedcontainers-2.4.0-py2.py3-none-any.whl (29 kB)
Downloading future-1.0.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.3/491.3 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding 

In [None]:
#@title Detecting Frets from JAMS
import os
import jams
import numpy as np

def generate_tabs_from_jams(dataset_dir):
    # Define the directory where JAMS files are stored
    annotation_dir = os.path.join(dataset_dir, "annotation")

    # Create the Tabs directory if it doesn't exist
    tabs_dir = os.path.join(dataset_dir, "Tabs")
    os.makedirs(tabs_dir, exist_ok=True)

    # Iterate through all JAMS files in the annotation directory
    for filename in os.listdir(annotation_dir):
        if filename.endswith(".jams"):
            jams_file_path = os.path.join(annotation_dir, filename)
            generate_tab_for_file(jams_file_path, tabs_dir)

def generate_tab_for_file(jams_file_path, tabs_dir):
    # Load the JAMS file
    jam = jams.load(jams_file_path)

    # Extract tablature information from JAMS annotations
    tab_data = process_jam_to_tab(jam)

    # Save or print tab data
    print_tab_data(tab_data, jams_file_path, tabs_dir)

def process_jam_to_tab(jam):
    string_midi_pitches = [40, 45, 50, 55, 59, 64]  # E2, A2, D3, G3, B3, E4
    tab_data = {1: [], 2: [], 3: [], 4: [], 5: [], 6: []}  # Initialize strings

    for anno in jam.annotations:
        if anno.namespace == 'note_midi':
            for note in anno.data:
                midi_pitch = note.value
                string_number, fret_number = find_string_and_fret(midi_pitch, string_midi_pitches)
                if string_number is not None and fret_number is not None:
                    rounded_fret = round(fret_number)
                    print(f"Found fret number: {rounded_fret} on string {string_number} at time {note.time:.2f}")
                    tab_data[string_number].append((note.time, rounded_fret))

    # Sort each string's notes by time
    for string_number in tab_data:
        tab_data[string_number].sort(key=lambda x: x[0])

    return tab_data

def find_string_and_fret(midi_pitch, string_midi_pitches):
    for string_number, open_string_pitch in enumerate(string_midi_pitches):
        if midi_pitch >= open_string_pitch and midi_pitch <= open_string_pitch + 19:
            fret_number = midi_pitch - open_string_pitch
            return string_number + 1, fret_number  # Adding 1 to match string numbering (1-6)
    return None, None

def format_tab_line(tab_line):
    formatted_line = []
    last_char = ""
    i = 0

    while i < len(tab_line):
        char = tab_line[i]

        # Determine the length of the current "fret" (could be 1 or 2 digits)
        if char.isdigit():
            fret_len = 1
            if i + 1 < len(tab_line) and tab_line[i + 1].isdigit():
                fret_len = 2

            # If fret_len is 2 or greater, add the dash after the next character
            if fret_len >= 2:
                formatted_line.append(char)
                i += 1  # Skip the next character to account for the second digit
                if i < len(tab_line):
                    formatted_line.append(tab_line[i])
                    if i + 1 < len(tab_line) and tab_line[i + 1] != '|':
                        formatted_line.append('-')
            else:
                formatted_line.append(char)
                if i + 1 < len(tab_line) and tab_line[i + 1] != '|':
                    formatted_line.append('-')
        else:
            formatted_line.append(char)

        last_char = char
        i += 1

    return ''.join(formatted_line)

#Currently unused due to complexity of changing time signatures to determine bar lengths
def add_bar_lines(tab_representation, bar_interval=25):
    max_len = max(len(line) for line in tab_representation.values())

    # Convert each string's representation from string to list for modification
    padded_tab_representation = {k: list(v) for k, v in tab_representation.items()}

    for string_number, line in padded_tab_representation.items():
        current_len = len(line)
        if current_len < max_len:
            line.extend(['-'] * (max_len - current_len))

    # Add bar lines at regular intervals, ensuring dashes don't touch the fret numbers
    bar_lines_added = {1: [], 2: [], 3: [], 4: [], 5: [], 6: []}
    for string_number, line in padded_tab_representation.items():
        bar_start = 0
        while bar_start < len(line):
            bar_end = min(bar_start + bar_interval, len(line))
            bar_segment = line[bar_start:bar_end]

            # Ensure there's always a dash between a fret number and a bar "|"
            if bar_end < len(line):
                if bar_segment[-1].isdigit():
                    bar_segment.append('-')
                bar_segment.append('|')
            bar_lines_added[string_number].extend(bar_segment)
            bar_start += bar_interval

    # Convert back to string format
    for string_number in bar_lines_added:
        tab_representation[string_number] = ''.join(bar_lines_added[string_number])

def print_tab_data(tab_data, jams_file_path, tabs_dir):
    max_time = max(max(notes, key=lambda x: x[0])[0] if notes else 0 for notes in tab_data.values())
    tab_length = int(max_time * 10)  # Assume each 0.1s step corresponds to one position in the tab

    # Initialize strings with dashes
    tab_representation = {1: ['-'] * tab_length, 2: ['-'] * tab_length, 3: ['-'] * tab_length,
                          4: ['-'] * tab_length, 5: ['-'] * tab_length, 6: ['-'] * tab_length}

    # Place frets on the appropriate positions
    for string_number, notes in tab_data.items():
        for time, fret in notes:
            position = int(time * 10)
            fret_str = str(fret)
            fret_len = len(fret_str)

            if position + fret_len <= len(tab_representation[string_number]):
                tab_representation[string_number][position:position + fret_len] = list(fret_str)

                # Ensure dash before and after fret number
                if position + fret_len < len(tab_representation[string_number]):
                    tab_representation[string_number][position + fret_len] = '-'
                if position > 0 and not tab_representation[string_number][position - 1].isdigit():
                    tab_representation[string_number][position - 1] = '-'

    # Format each tab line to ensure spacing between numbers
    for string_number in tab_representation:
        tab_representation[string_number] = format_tab_line(tab_representation[string_number])

    # Add bar lines at regular intervals
    add_bar_lines(tab_representation, bar_interval=25)

    # Prepare the string labels
    string_labels = {1: "e|", 2: "B|", 3: "G|", 4: "D|", 5: "A|", 6: "E|"}

    # Save the tab data to a file in the Tabs directory
    base_filename = os.path.basename(jams_file_path).replace(".jams", ".txt")
    tab_filename = os.path.join(tabs_dir, base_filename)

    with open(tab_filename, 'w') as f:
        for string_number in range(6, 0, -1):  # Print from high E (string 1) to low E (string 6)
            f.write(f"{string_labels[string_number]}{tab_representation[string_number]}\n")

# Path to your dataset directory
dataset_dir = '/content/drive/MyDrive/Summer2024Research/GuitarSet'

# Call the function to process all JAMS files
generate_tabs_from_jams(dataset_dir)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Found fret number: 16 on string 3 at time 0.80
Found fret number: 19 on string 2 at time 1.61
Found fret number: 17 on string 2 at time 3.23
Found fret number: 16 on string 2 at time 4.03
Found fret number: 16 on string 4 at time 6.92
Found fret number: 14 on string 4 at time 7.23
Found fret number: 18 on string 3 at time 7.54
Found fret number: 16 on string 4 at time 8.12
Found fret number: 19 on string 3 at time 8.23
Found fret number: 18 on string 3 at time 8.50
Found fret number: 18 on string 3 at time 17.70
Found fret number: 15 on string 5 at time 0.48
Found fret number: 19 on string 4 at time 1.02
Found fret number: 18 on string 4 at time 1.33
Found fret number: 18 on string 4 at time 1.80
Found fret number: 16 on string 4 at time 3.04
Found fret number: 16 on string 4 at time 3.42
Found fret number: 14 on string 4 at time 3.72
Found fret number: 14 on string 4 at time 4.31
Found fret number: 18 on string 3 at time

In [None]:
#@title Working & Current Attempt At Processing Tab (text) Files into Correct Format With Some Spacing Issues
import os
import jams
import numpy as np

def generate_tabs_from_jams(dataset_dir):
    annotation_dir = os.path.join(dataset_dir, "annotation")
    tabs_dir = os.path.join(dataset_dir, "TabsNoBarsLongerDuration")
    os.makedirs(tabs_dir, exist_ok=True)

    for filename in os.listdir(annotation_dir):
        if filename.endswith(".jams"):
            jams_file_path = os.path.join(annotation_dir, filename)
            generate_tab_for_file(jams_file_path, tabs_dir)

def generate_tab_for_file(jams_file_path, tabs_dir):
    jam = jams.load(jams_file_path)
    tab_data = process_jam_to_tab(jam)
    #print(f"Contents of tab_data for {jams_file_path}:")
    #for string, notes in tab_data.items():
        #print(f"String: {string}, Notes: {notes}")
    print_tab_data(tab_data, jams_file_path, tabs_dir)

def process_jam_to_tab(jam):
    string_midi_pitches = [40, 45, 50, 55, 59, 64]
    tab_data = {1: [], 2: [], 3: [], 4: [], 5: [], 6: []}

    for anno in jam.annotations:
        if anno.namespace == 'note_midi':
            for note in anno.data:
                midi_pitch = note.value
                string_number, fret_number = find_string_and_fret(midi_pitch, string_midi_pitches)
                if string_number is not None and fret_number is not None:
                    rounded_fret = round(fret_number)
                    tab_data[string_number].append((note.time, rounded_fret))

    for string_number in tab_data:
        tab_data[string_number].sort(key=lambda x: x[0])

    return tab_data

def find_string_and_fret(midi_pitch, string_midi_pitches):
    for string_number, open_string_pitch in enumerate(string_midi_pitches):
        if midi_pitch >= open_string_pitch and midi_pitch <= open_string_pitch + 19:
            fret_number = midi_pitch - open_string_pitch
            return string_number + 1, fret_number
    return None, None

def format_tab_line(tab_data, tab_length):
    tab_line = []
    dashes_per_second = 10
    last_time = 0
    last_fret_isNumb = False

    for time, fret in tab_data:
        #position = int(time * 100) # 1/100th second steps to ensure enough space (overcrowding can cause notes to "squish" together). Downside is this makes the tabs a bit long
        fret_str = str(fret)
        #print(fret_str)
        #print(fret_str)
        fret_len = len(fret_str)
        time_difference = time - last_time
        num_dashes = int(time_difference * dashes_per_second)

        if last_fret_isNumb:
          num_dashes = num_dashes = max(0, num_dashes - 1)
          last_fret_isNumb = False

        tab_line.extend(['-'] * num_dashes)


        #required_length = position + fret_len + 1
        #if len(tab_line) < required_length:
            # Extend tab_line with dashes to reach the required length
            #tab_line.extend(['-'] * (required_length - len(tab_line)))


        """
        # Ensure valid position before inserting a dash
        if position > 0 and tab_line[position - 1] != '-':
            #print("Replacement Triggered at position", position)
            #print("Start:", tab_line[position - 1], "End: ", tab_line[position + 1])
            tab_line = tab_line[:position] + ['-'] + tab_line[position:]
            #tab_line.insert(position - 1, '-')
            position = position + 1
        """


        # Insert a dash then fret num into the tab line


        tab_line.append(fret_str)

        if fret_str.isdigit():
          last_fret_isNumb = True
          tab_line.append("-")

        last_time = time

    return ''.join(tab_line)


def print_tab_data(tab_data, jams_file_path, tabs_dir):
    max_time = max(max(notes, key=lambda x: x[0])[0] if notes else 0 for notes in tab_data.values())
    tab_length = int(max_time * 10)

    tab_representation = {}
    for string_number, notes in tab_data.items():
        tab_representation[string_number] = format_tab_line(notes, tab_length)

    max_len = max(len(line) for line in tab_representation.values())
    for string_number in tab_representation:
        current_len = len(tab_representation[string_number])
        if current_len < max_len:
            tab_representation[string_number] += '-' * (max_len - current_len)

    string_labels = {1: "e|", 2: "B|", 3: "G|", 4: "D|", 5: "A|", 6: "E|"}
    base_filename = os.path.basename(jams_file_path).replace(".jams", ".txt")
    tab_filename = os.path.join(tabs_dir, base_filename)

    with open(tab_filename, 'w') as f:
        for string_number in range(6, 0, -1):
            f.write(f"{string_labels[string_number]}{tab_representation[string_number]}\n")

dataset_dir = '/content/drive/MyDrive/Summer2024Research/GuitarSet'
generate_tabs_from_jams(dataset_dir)

#FINALLY WORKS!

In [None]:
#@title Spectrogram Processing


import os
import librosa
import numpy as np
import matplotlib.pyplot as plt

def generate_mel_spectrograms(dataset_dir, save_dir, target_length=None, n_fft=4096, hop_length=256, n_mels=512):
    os.makedirs(save_dir, exist_ok=True)

    max_length = 0  # Variable to track the maximum length

    for filename in os.listdir(dataset_dir):
        if filename.endswith(".wav"):
            wav_file_path = os.path.join(dataset_dir, filename)

            # Load the audio file
            y, sr = librosa.load(wav_file_path, sr=None)

            # If stereo audio, convert to mono
            if len(y.shape) == 2:
                y = librosa.to_mono(y)

            # Compute the mel-spectrogram
            mel_spectrogram = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=n_mels)

            # Convert to log scale (dB)
            mel_spectrogram = librosa.power_to_db(mel_spectrogram, ref=np.max)

            # Update the maximum length
            max_length = max(max_length, mel_spectrogram.shape[1])

            # Pad or truncate the spectrogram to the target length (if provided)
            if target_length:
                if mel_spectrogram.shape[1] < target_length:
                    padding = target_length - mel_spectrogram.shape[1]
                    mel_spectrogram = np.pad(mel_spectrogram, ((0, 0), (0, padding)), mode='constant')
                else:
                    mel_spectrogram = mel_spectrogram[:, :target_length]

            # Save the mel-spectrogram as a .npy file
            npy_filename = filename.replace('.wav', '.npy')
            np.save(os.path.join(save_dir, npy_filename), mel_spectrogram)

            # Plot and save the spectrogram image (for reference)
            plt.figure(figsize=(10, 4))
            librosa.display.specshow(mel_spectrogram, sr=sr, x_axis='time', y_axis='mel')
            plt.colorbar(format='%+2.0f dB')
            plt.title(f'Mel-Spectrogram: {filename}')
            plt.show()
            plt.close()

            # Display file name and spectrogram
            print(f"Processed {filename}, Spectrogram shape: {mel_spectrogram.shape}")

    # Print the maximum length found
    print(f"Largest number of time steps found: {max_length}")

# Example usage
dataset_dir = '/content/drive/MyDrive/Summer2024Research/GuitarSet/audio/audio_mic'
output_dir = '/content/drive/MyDrive/Summer2024Research/GuitarSet/melSpectrograms'
target_length = 4000  # Example target length (number of time steps), can be None to disable, change to 3937 if 4000 causes problems
generate_mel_spectrograms(dataset_dir, output_dir, target_length)

In [None]:
# Version where I manually one-hot encode

import os
import numpy as np
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import to_categorical

tab_dir = "/content/drive/MyDrive/Summer2024Research/GuitarSet/TabsNoBarsLongerDuration"
spectrogram_dir = "/content/drive/MyDrive/Summer2024Research/GuitarSet/melSpectrograms"

def encode_tab_file(tab_file_path):
    encoding_map = {
        'e': 0, 'B': 1, 'G': 2, 'D': 3, 'A': 4, 'E': 5,
        '|': 6, '-': 7,
        **{str(i): 8 + i for i in range(20)}  # Map fret numbers 0-19 starting from index 8
    }

    encoded_tabs = []

    try:
        with open(tab_file_path, 'r') as file:
            #print()
            #print("New file:", tab_file_path)
            current_string_index = None
            skipCharacter = False
            #print(tab_file_path)
            for line in file:
                stripped_line = line.strip()
                #print(f"Processing line: {stripped_line}")  # Print each line being processed

                if stripped_line and stripped_line[0] in encoding_map:
                    current_string_index = encoding_map[stripped_line[0]]
                    #print(f"Identified string: {stripped_line[0]}, index: {current_string_index}")  # Print identified string

                if current_string_index is not None:
                    for i, char in enumerate(stripped_line):
                        if skipCharacter:
                            skipCharacter = False
                            continue

                        charToAppend = ""
                        if i + 1 < len(stripped_line) and (char + stripped_line[i + 1]).isdigit():
                            charToAppend = char + stripped_line[i + 1]
                            skipCharacter = True
                            #print(f"Detected two-digit fret: {charToAppend}")  # Debug two-digit fret

                        else:
                            charToAppend = char

                        encoded_value = encoding_map.get(charToAppend, 7)
                        encoded_tabs.append(encoded_value)
                        #print(f"Encoded character: {charToAppend}, value: {encoded_value}")  # Print each encoded character and its value


    except Exception as e:
        print(f"Error processing file {tab_file_path}: {e}")

    print(f"Final encoded tabs: {encoded_tabs}")
    return encoded_tabs



def load_data(spectrogram_dir, tab_dir):
    spectrograms = []
    encoded_tabs = []
    max_tab_length = 0

    for filename in os.listdir(spectrogram_dir):
        if filename.endswith('.npy'):
            # Load the spectrogram
            spectrogram = np.load(os.path.join(spectrogram_dir, filename))
            spectrograms.append(spectrogram)

            # Load and encode the corresponding tab file
            tab_filename = filename.replace('_mic.npy', '.txt')
            tab_file_path = os.path.join(tab_dir, tab_filename)

            if os.path.exists(tab_file_path):
                encoded_tab = encode_tab_file(tab_file_path)
                encoded_tabs.append(encoded_tab)

                # Update max_tab_length if necessary
                if len(encoded_tab) > max_tab_length:
                    max_tab_length = len(encoded_tab)

    # Pad all encoded tabs to the maximum length
    print(max_tab_length)
    for i in range(len(encoded_tabs)):
        if len(encoded_tabs[i]) < max_tab_length:
          encoded_tabs[i] += [7] * (max_tab_length - len(encoded_tabs[i]))
          #print("New length", len(encoded_tabs[i]))

    print(f"Shape of spectrograms array: {np.array(spectrograms).shape}")
    print(f"Shape of encoded tabs array: {np.array(encoded_tabs).shape}")

    return np.array(spectrograms), np.array(encoded_tabs)

# Load the data
X, y = load_data(spectrogram_dir, tab_dir)

# Normalize X
X = np.array(X) / np.max(X)

# X should have shape (batch_size, height, width, 1)
X = np.expand_dims(X, axis=-1)

num_classes = len(np.unique(np.concatenate(y)))
print("Num Classes:", num_classes)

y_one_hot = to_categorical(y, num_classes=num_classes)


# Split data into training and test sets
X_train, X_test, y_train_one_hot, y_test_one_hot = train_test_split(X, y_one_hot, test_size=0.2, random_state=42)

#num_classes = len(set(np.concatenate(y_train)))  # Adjust based on the actual unique values

#y_train_one_hot = to_categorical(y_train, num_classes=num_classes)
#y_test_one_hot = to_categorical(y_test, num_classes=num_classes)

# np.set_printoptions(threshold=np.inf)
# print()
# print("Contents of first from train set encoded categorically:")
# print(y_train_one_hot[0])

Final encoded tabs: [5, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7

In [None]:
#@title Custom Categorical Cross Entropy Loss Function With 10x Penalty For False Positives of "-"

import tensorflow as tf
import tensorflow.keras.backend as K

def focal_categorical_crossentropy(gamma=2., alpha=.25):
    def focal_loss(y_true, y_pred):
        # Clip the prediction value to prevent NaN's and Inf's
        epsilon = K.epsilon()
        y_pred = K.clip(y_pred, epsilon, 1. - epsilon)

        # Convert y_true to the same dtype as y_pred
        y_true = tf.cast(y_true, y_pred.dtype)

        # Calculate Cross Entropy
        cross_entropy = -y_true * K.log(y_pred)

        # Calculate Focal Loss
        loss = alpha * K.pow(1 - y_pred, gamma) * cross_entropy

        # Sum over classes
        loss = K.sum(loss, axis=-1)

        # Return mean loss
        return K.mean(loss)

    return focal_loss


In [None]:
#@title Displays Our Classes by the Inverse of their Frequency

from sklearn.utils import class_weight
import numpy as np

# Assuming y_train_one_hot is your one-hot encoded labels

# Convert the one-hot encoded labels back to their original class labels
y_train_flat = np.argmax(y_train_one_hot, axis=-1).flatten()

# Calculate class weights
class_weights = class_weight.compute_class_weight(class_weight='balanced', classes=np.unique(y_train_flat), y=y_train_flat)

# Convert to a dictionary format to easily map class indices to their weights
class_weights_dict = dict(enumerate(class_weights))

# Display the class weights
print("Class Weights:")
for class_index, weight in class_weights_dict.items():
    print(f"Class {class_index}: {weight}")

Class Weights:
Class 0: 247.64285714285714
Class 1: 247.64285714285714
Class 2: 247.64285714285714
Class 3: 247.64285714285714
Class 4: 247.64285714285714
Class 5: 247.64285714285714
Class 6: 41.273809523809526
Class 7: 0.03669368204659326
Class 8: 186.70456245325354
Class 9: 161.36005171299288
Class 10: 177.8582116138226
Class 11: 135.59152634437805
Class 12: 114.8488612836439
Class 13: 84.20441895766571
Class 14: 97.56654289622826
Class 15: 78.11735252699108
Class 16: 68.11952517396644
Class 17: 49.254932912391475
Class 18: 56.38035008469791
Class 19: 52.02125664270084
Class 20: 46.493574222387785
Class 21: 34.57156706599266
Class 22: 12.444798963033127
Class 23: 8.768274263233693
Class 24: 9.711484593837534
Class 25: 10.343893090231017
Class 26: 10.241404775580536
Class 27: 60.08520880972439


In [None]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Reshape, Dense, TimeDistributed, Add, MultiHeadAttention, LayerNormalization, Dropout, Layer
import numpy as np
from tensorflow.keras.losses import CategoricalFocalCrossentropy

loss_fn = CategoricalFocalCrossentropy(gamma=3.0, alpha=[1.0] * 28)

# Define Positional Encoding Layer
class PositionalEncoding(Layer):
    def __init__(self, sequence_length, d_model):
        super(PositionalEncoding, self).__init__()
        self.positional_encoding = self.get_positional_encoding(sequence_length, d_model)

    def get_positional_encoding(self, sequence_length, d_model):
        position = np.arange(sequence_length)[:, np.newaxis]
        div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
        positional_encoding = np.zeros((sequence_length, d_model))
        positional_encoding[:, 0::2] = np.sin(position * div_term)
        positional_encoding[:, 1::2] = np.cos(position * div_term)
        return tf.cast(positional_encoding[np.newaxis, ...], dtype=tf.float32)

    def call(self, inputs):
        return inputs + self.positional_encoding[:, :tf.shape(inputs)[1], :]

# Define Transformer Block
def transformer_block(inputs, num_heads, d_model, ff_dim, dropout=0.1):
    attention_output = MultiHeadAttention(num_heads=num_heads, key_dim=d_model)(inputs, inputs)
    attention_output = Dropout(dropout)(attention_output)
    attention_output = LayerNormalization(epsilon=1e-6)(attention_output + inputs)

    ff_output = Dense(ff_dim, activation='relu')(attention_output)
    ff_output = Dense(d_model)(ff_output)
    ff_output = Dropout(dropout)(ff_output)
    ff_output = LayerNormalization(epsilon=1e-6)(ff_output + attention_output)

    return ff_output

# Model Architecture
input_layer = Input(shape=(512, 4000))  # Adjusted to 512 Mel bins

# Add a channel dimension for Conv2D
x = Reshape((512, 4000, 1))(input_layer)

# First Convolution + Pooling layer
x = Conv2D(filters=32, kernel_size=(3, 3), activation='relu', padding='same')(x)
x = MaxPooling2D(pool_size=(2, 2), strides=(1, 2))(x)  # Downsample

# Second Convolution + Pooling layer
x = Conv2D(filters=64, kernel_size=(3, 3), activation='relu', padding='same')(x)
x = MaxPooling2D(pool_size=(2, 2), strides=(1, 2))(x)  # Downsample

# Third Convolution + Pooling layer
x = Conv2D(filters=128, kernel_size=(3, 3), activation='relu', padding='same')(x)
x = MaxPooling2D(pool_size=(2, 2), strides=(1, 2))(x)  # Downsample

# Flatten and reshape for transformer
x = Flatten()(x)
flattened_size = x.shape[-1]
target_size = 6934 * 128

# Calculate the correct shape if it's close
crop_amount = flattened_size - target_size
if crop_amount > 0:
    x = Reshape((6934, 128))(x[:, :target_size])  # Crop to fit target
else:
    raise ValueError(f"Flattened size ({flattened_size}) is smaller than the target ({target_size}). Adjust architecture.")

# Add positional encoding
sequence_length, d_model = 6934, 128
x = PositionalEncoding(sequence_length, d_model)(x)

# Transformer blocks
num_heads = 8
ff_dim = 512
num_transformer_blocks = 4

for _ in range(num_transformer_blocks):
    x = transformer_block(x, num_heads=num_heads, d_model=d_model, ff_dim=ff_dim)

# Fully connected layers
x = TimeDistributed(Dense(512, activation='relu'))(x)

# Output layer with 28 classes
output_layer = TimeDistributed(Dense(28, activation='softmax'))(x)

# Compile the model
model = Model(inputs=input_layer, outputs=output_layer)
model.compile(optimizer='adam', loss=loss_fn)

# Model summary
model.summary()


Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_2 (InputLayer)        [(None, 512, 4000)]          0         []                            
                                                                                                  
 reshape_2 (Reshape)         (None, 512, 4000, 1)         0         ['input_2[0][0]']             
                                                                                                  
 conv2d_3 (Conv2D)           (None, 512, 4000, 32)        320       ['reshape_2[0][0]']           
                                                                                                  
 max_pooling2d_3 (MaxPoolin  (None, 511, 2000, 32)        0         ['conv2d_3[0][0]']            
 g2D)                                                                                       

In [None]:
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from sklearn.utils import class_weight
import numpy as np


""" In case want to calculate class weights here instead of above
y_train_flat = np.argmax(y_train_one_hot, axis=-1).flatten()

# Calculate class weights
class_weights = class_weight.compute_class_weight(class_weight='balanced', classes=np.unique(y_train_flat), y=y_train_flat)

# Convert to a dictionary format to easily map class indices to their weights
class_weights_dict = dict(enumerate(class_weights))
"""


# Set up ModelCheckpoint with min_delta
checkpoint = ModelCheckpoint(
    'focalLossAttempt1.h5',
    monitor='val_loss',
    save_best_only=True,
    mode='min',
    min_delta=0.001  # Only save the model if the validation loss decreases by at least 0.001
)

# Set up EarlyStopping with min_delta
early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=5,
    mode='min',
    min_delta=0.001,  # Stop training if the validation loss does not decrease by at least 0.001
    restore_best_weights=True
)


model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy'])

# Train the model, let's experiment w/higher epochs & batche sizes soon
history = model.fit(
    X_train,
    y_train_one_hot,
    validation_data=(X_test, y_test_one_hot),  # Use validation data to monitor training
    epochs=50,  #Make sure to note accuracy changes w/epoch increases
    batch_size=8,  # Adjust batch size based on your GPU/CPU memory
    callbacks=[checkpoint, early_stopping],  # Use callbacks for better training
    verbose=1,  # Print training progress
    class_weight=class_weights_dict
)

# Evaluate the model on the test data
test_loss, test_accuracy = model.evaluate(X_test, y_test_one_hot)
print(f"Test Loss: {test_loss}")
print(f"Test Accuracy: {test_accuracy}")

# Save the trained model
model.save('focalLossAttempt1.h5.h5')


Epoch 1/50

  saving_api.save_model(


Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Test Loss: 2.687126398086548
Test Accuracy: 0.016234416514635086
