## First, run this cell to set up paths and import dependencies

In [None]:
import os

import matplotlib.pyplot as plt
import random
from tqdm import tqdm
from collections import defaultdict
import re
import random

if not os.path.exists(r"./notebooks"):
    %cd ..


from src.data_processing import split_into_clips, create_spectrogram, SOAAudioClips, save_mean_std, compute_mean_std_from_images, list_audio_files_recursively, save_spectrogram
from src.dataset_analysis import duration_statistics
from src.config import VALID_ACCESS_LABELS, TRAIN_DIR, TEST_DIR, VAL_DIR, DATA_DIR, DATASET_DIR, DATA_DIR_SPECIFIC
from collections import defaultdict

# Ensure the output directory structure exists
os.makedirs(DATASET_DIR, exist_ok=True)
os.makedirs(TRAIN_DIR, exist_ok=True)
os.makedirs(VAL_DIR, exist_ok=True)
os.makedirs(TEST_DIR, exist_ok=True)
random.seed(42)  # For reproducibility

## 1  Load all .wav files from the dataset

In [None]:

allowed_dictionaries=['ipadflat_confroom1', 'ipadflat_office1', 'ipad_balcony1', 'ipad_bedroom1', 'ipad_confroom1', 'ipad_confroom2', 'ipad_livingroom1', 'ipad_office1', 'ipad_office2', 'iphone_balcony1', 'iphone_bedroom1', 'iphone_livingroom1']
print(len(allowed_dictionaries))
wav_files_all = list_audio_files_recursively(DATA_DIR,allowed_dictionaries)
print(f"Found {len(wav_files_all)} .wav files in directory '{DATA_DIR}' in the following allowed directories: {allowed_dictionaries}")

In [None]:
wav_files_all

## 2 Balance all .wav files, split it and display statistics

In [None]:
# Data structures to hold the parsed information
speaker_script_to_files = defaultdict(list)

# Regular expression to extract speaker tag and script number
pattern = re.compile(r'([fm]\d+)_script(\d+)_')

for filepath in wav_files_all:
    filename = os.path.basename(filepath)  # Extract the file name cross-platform
    match = pattern.match(filename)
    if match:
        speaker_tag = match.group(1)
        script_number = int(match.group(2))
        speaker_script_to_files[(speaker_tag, script_number)].append(filepath)
    else:
        print(f"Filename {filename} does not match the expected pattern.")


In [None]:
# Initialize sets
train_set = []
validate_set = []
test_set = []

# Lists to keep track of counts
authorized_train_samples = []
unauthorized_train_samples = []

# Collect all speakers
all_speakers = set(speaker for speaker, script in speaker_script_to_files.keys())
authorized_speakers = all_speakers.intersection(VALID_ACCESS_LABELS)
unauthorized_speakers = all_speakers - authorized_speakers

# Shuffle scripts for randomness
random.seed(42)  # For reproducibility

for speaker in all_speakers:
    speaker_scripts = [script for (spk, script) in speaker_script_to_files.keys() if spk == speaker]
    random.shuffle(speaker_scripts)
    
    num_scripts = len(speaker_scripts)
    num_train_scripts = int(0.7 * num_scripts)
    num_validate_scripts = int(0.15 * num_scripts)
    
    # Ensure at least one script in each set if possible
    num_train_scripts = max(1, num_train_scripts)
    num_validate_scripts = max(1, num_validate_scripts)
    num_test_scripts = num_scripts - num_train_scripts - num_validate_scripts
    
    if num_test_scripts == 0:
        num_test_scripts = 1
        num_train_scripts -= 1
    
    # Assign scripts to sets
    train_scripts = speaker_scripts[:num_train_scripts]
    validate_scripts = speaker_scripts[num_train_scripts:num_train_scripts + num_validate_scripts]
    test_scripts = speaker_scripts[num_train_scripts + num_validate_scripts:]
    
    for script in train_scripts:
        files = speaker_script_to_files[(speaker, script)]
        train_set.extend(files)
        if speaker in VALID_ACCESS_LABELS:
            authorized_train_samples.extend(files)
        else:
            unauthorized_train_samples.extend(files)
    
    for script in validate_scripts:
        files = speaker_script_to_files[(speaker, script)]
        validate_set.extend(files)
    
    for script in test_scripts:
        files = speaker_script_to_files[(speaker, script)]
        test_set.extend(files)


In [None]:
# Calculate the number of samples from authorized and unauthorized speakers
num_authorized_samples = len(authorized_train_samples)
num_unauthorized_samples = len(unauthorized_train_samples)

# Adjust the unauthorized samples to match the authorized samples
if num_authorized_samples < num_unauthorized_samples:
    # Reduce unauthorized samples
    difference = num_unauthorized_samples - num_authorized_samples
    random.shuffle(unauthorized_train_samples)
    unauthorized_train_samples = unauthorized_train_samples[:num_authorized_samples]
    # Update the train set
    train_set = authorized_train_samples + unauthorized_train_samples
else:
    # Reduce authorized samples (unlikely given the dataset)
    difference = num_authorized_samples - num_unauthorized_samples
    random.shuffle(authorized_train_samples)
    authorized_train_samples = authorized_train_samples[:num_unauthorized_samples]
    # Update the train set
    train_set = authorized_train_samples + unauthorized_train_samples


In [None]:
def compute_statistics(dataset, name):
    total_samples = len(dataset)
    speakers = set()
    authorized_count = 0
    unauthorized_count = 0
    speaker_sample_counts = defaultdict(int)
    
    for filepath in dataset:
        filename = os.path.basename(filepath)  # Extract the file name cross-platform
        match = pattern.match(filename)
        if match:
            speaker_tag = match.group(1)
            speakers.add(speaker_tag)
            speaker_sample_counts[speaker_tag] += 1  # Increment the count for this speaker
            if speaker_tag in VALID_ACCESS_LABELS:
                authorized_count += 1
            else:
                unauthorized_count += 1
                    
    print(f"--- {name} Set Statistics ---")
    print(f"Total Samples: {total_samples}")
    print(f"Total Speakers: {len(speakers)}")
    print(f"Authorized Samples: {authorized_count}")
    print(f"Unauthorized Samples: {unauthorized_count}")
    print(f"Authorized to Unauthorized Ratio: {authorized_count}:{unauthorized_count}")
    print("\nSamples per Speaker:")
    for speaker in sorted(speaker_sample_counts.keys()):
        print(f"  {speaker}: {speaker_sample_counts[speaker]}")
    print()


In [None]:
compute_statistics(train_set, "Training")
compute_statistics(validate_set, "Validation")
compute_statistics(test_set, "Test")


## 3 Display files info

In [None]:
# train_files_paths = [entry['path'] for entry in train_set]
# test_files_paths = [entry['path'] for entry in validate_set]
# val_files_paths = [entry['path'] for entry in test_set]

soa_train_full_clips = SOAAudioClips(train_set)
soa_test_full_clips = SOAAudioClips(validate_set)
soa_val_full_clips = SOAAudioClips(test_set)

print("\nDataset Statistics:")
print("Training set:")
print(duration_statistics(soa_train_full_clips.clips))

print("Validation set:")
print(duration_statistics(soa_test_full_clips.clips))

print("Test set:")
print(duration_statistics(soa_val_full_clips.clips))





## 4 Process each dataset split by converting 3-second clips into spectrograms

In [None]:
def process_split(soa_full_clips, output_subdir):
    all_splitted_clips = []
    for file_path, full_clip in tqdm(soa_full_clips):
        clips = split_into_clips(full_clip)
        all_splitted_clips.extend(clips)
        for i, clip in enumerate(clips):
            spectrogram = create_spectrogram(clip)
            output_path = os.path.join(output_subdir, f"{os.path.basename(file_path).split('.')[0]}_{i}_clip.png")
            save_spectrogram(spectrogram, output_path)
    print(duration_statistics(all_splitted_clips))

print("Preprocessed Train Dataset:")
process_split(soa_train_full_clips, TRAIN_DIR)

print("\nPreprocessed Validation Dataset:")
process_split(soa_val_full_clips, VAL_DIR)

print("\nPreprocessed Test Dataset:")
process_split(soa_test_full_clips, TEST_DIR)

## 5 Mean and Standard deviation of training dataset

In [None]:
mean, std = compute_mean_std_from_images(TRAIN_DIR)
print(f"Mean: {mean}, Standard deviation: {std}")
save_mean_std(mean, std, f"{DATASET_DIR}/scaling_params.json")