### Balance Database

This notebook is used to balance out the database by analysing the full dataset and creating a metadata csv file with balanced out samples from all 4 classes.

In [None]:
import os
import random
import pandas as pd

In [None]:
# set path to database
base_path = '/Users/jannisdaiber/Documents/Repos/github/ProjectMedicalWearables/Database'

In [18]:
source_folder = os.path.join(base_path, 'spectrograms')
metadata_file = os.path.join(base_path, 'metadata.csv')
destination_metadata_file = os.path.join(base_path, 'metadata_balanced.csv')

class_files = {
    'chewing': [],
    'swallowing': [],
    'resting': [],
    'others': []
}

# Collect all files from all participant folders
participants = [d for d in os.listdir(source_folder) if os.path.isdir(os.path.join(source_folder, d))]
for participant in participants:
    participant_folder = os.path.join(source_folder, participant)
    for file in os.listdir(participant_folder):
        if file.endswith('.npy'):
            for class_name in class_files.keys():
                if file.startswith(class_name):
                    class_files[class_name].append(os.path.join(participant_folder, file))

# Determine the minimum class count
min_class_count = min(len(files) for files in class_files.values())

# Select an equal number of files from each class
selected_files = []
for class_name, files in class_files.items():
    selected_files.extend(random.sample(files, min_class_count))

# Copy the selected files to the destination folder and update metadata
updated_metadata = []
base_metadata = pd.read_csv(metadata_file)

for file in selected_files:
    index = base_metadata[base_metadata['spectrogram_path'] == file].index[0]
    updated_metadata.append({
        'participant': base_metadata['participant'][index],
        'label': base_metadata['label'][index],
        'spectrogram_path': file,
        'food_type': base_metadata['food_type'][index],
        'augmented_flag': base_metadata['augmented_flag'][index],
        'wav_path': base_metadata['wav_path'][index]
    })

updated_metadata = pd.DataFrame(updated_metadata)
updated_metadata.to_csv(destination_metadata_file, index=False)