In [None]:
from collections import defaultdict
import random
from collections import Counter
from typing import List, Callable, Dict
import pandas as pd
import os
import shutil
import wave
import contextlib
import time

In [None]:
ja_path = r"path/to/your/data/ja/audio" ###
cv_path = r"path/to/your/data/cv/audio" ###

ja_ref = r"path/to/your/data/reference/JA_reference.stm" ###
cv_ref = r"path/to/your/data/reference/CV_reference.stm" ###

out_path = r"path/to/your/data/selection" ###

In [None]:
def parse_stm_file(file_path):
    segments = []
    dataset_name = os.path.basename(file_path)
    
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) < 6:
                continue
            file_name = parts[0]
            speaker = parts[2]
            if speaker == 'inter_segment_gap':
                continue
            
            start_time = float(parts[3])
            end_time = float(parts[4])
            duration = end_time - start_time
            
            segments.append({
                'dataset': dataset_name,
                'file': file_name,
                'start': start_time,
                'end': end_time,
                'duration': duration
            })
    
    return pd.DataFrame(segments)

In [None]:
age_bins = [
    ("under ten", 0, 9),
    ("teens", 10, 19),
    ("twenties", 20, 29),
    ("thirties", 30, 39),
    ("fourties", 40, 49),
    ("fifties", 50, 59),
    ("sixties", 60, 69),
    ("seventies", 70, 79),
    ("eighties", 80, 89)
]

gender_bins = {
    "male": ["male", "m"],
     "female": ["female", "f"]
}


def map_to_age_bin(age_str):
    if age_str is None or age_str == "unknown":
        return None

    #direct match for common voice labels like "twenties"
    for label, _, _ in age_bins:
        if label in age_str:
            return label

    #match for jasmin age float values
    try:
        age = float(age_str)
        for label, min_age, max_age in age_bins:
            if min_age <= age <= max_age:
                return label
    except ValueError:
        pass 

    return None  #couldn't determine


def map_to_gender_bin(gender_str):
    for key, aliases in gender_bins.items():
        if gender_str.lower() in aliases:
            return key
    return None  #couldn't determine

In [None]:
def infer_dataset_type(ref_path):
    filename = ref_path.lower()
    if "cv_" in filename:
        return "cv"
    elif "ja_" in filename:
        return "ja"
    else:
        raise ValueError(f"Cannot infer dataset type from path: {ref_path}")
    
    
def get_national_variant(filename):
    if filename.startswith(("JApnl", "JAqnl")):
        return "dutch"
    elif filename.startswith(("JApvl", "JAqvl")):
        return "flemish"
    return None

In [None]:
def parse_metadata(metadata_str, dataset_type, filename):
    def clean(val):
        val = val.strip().lower() if val else None
        return val

    if dataset_type == "cv":
        age, gender, accent = (metadata_str + [None]*3)[:3] #if even one is missing, we can skip it
        return {
            "age": clean(age),
            "gender": clean(gender),
            "accent": clean(accent)
        }

    elif dataset_type == "ja":
        national_variant = get_national_variant(filename)

        gender, age, first_language, proficiency, accent_region = metadata_str
        return {
            "gender": clean(gender),
            "age": clean(age),
            "national_variant": clean(national_variant),
            "first_language": clean(first_language),
            "proficiency": clean(proficiency),
            "accent_region": clean(accent_region)
        }

In [None]:
def parse_reference_file(ref_path):
    dataset_type = infer_dataset_type(ref_path)
    parsed_entries = []

    with open(ref_path, 'r', encoding='utf-8') as infile:

        for line in infile:
            parts = line.strip().split()
            if len(parts) < 6:
                continue

            filename = parts[0]
            speaker = parts[2]
            speaker_data = parts[5]
            start, end = float(parts[3]), float(parts[4])
            duration = end - start

            if duration < 1 or duration > 20:
                continue

            if speaker == "inter_segment_gap":
                continue

            speaker_data = speaker_data.lower().strip("<>").split(",")
            if len(speaker_data) < 3:
                continue
            
            parsed_data = parse_metadata(speaker_data, dataset_type, filename)
            
            age_bin = map_to_age_bin(parsed_data["age"])
            gender_bin = map_to_gender_bin(parsed_data["gender"])
            national_variant = parsed_data.get("national_variant")
            accent = parsed_data.get("accent") or parsed_data.get("accent_region")
            first_language = parsed_data.get("first_language")
            proficiency = parsed_data.get("proficiency")

            if age_bin is None or gender_bin is None or accent == 'unknown':
                continue

            parsed_entries.append({
                "filename": filename,
                "speaker": speaker,
                "duration": duration,
                "age_bin": age_bin,
                "gender_bin": gender_bin,
                "national_variant": national_variant,
                "first_language": first_language,
                "proficiency": proficiency,
                "accent": accent,
                "dataset_type": dataset_type
            })

    return parsed_entries

ja_data = parse_reference_file(ja_ref)
cv_data = parse_reference_file(cv_ref)

all_data = ja_data + cv_data

ja_read_data = [entry for entry in ja_data if entry['filename'].startswith('JAq')]
ja_spon_data = [entry for entry in ja_data if entry['filename'].startswith('JAp')]

In [None]:
#this tries to get a balanced distribution of age -> gender -> first language category -> read/spontaneous
#it first assembles a key for each combination and tries to select perfectly balanced to fill the total_duration
#if a shortfall remains it will try again, focusing on the underrepresented categories
def balanced_hierarchical_split(
	data,
	total_duration,
	condition_var=None,
	seed=42,
	used_ids=None,
	tolerance=0.2
):
	random.seed(seed)
	total_duration *= 3600 #convert to seconds

	if used_ids is None:
		used_ids = set()

	#group by age_bin
	age_bins = defaultdict(list)
	for utt in data:
		if 'age_bin' not in utt or 'gender_bin' not in utt:
			raise ValueError(f"Utterance missing required fields: {utt}")
		age_bins[utt['age_bin']].append(utt)

	num_age_bins = len(age_bins)
	target_per_age = total_duration / num_age_bins

	selected = []
	shortfall = 0

	for age_bin, utts_in_age in age_bins.items():
		def get_key(utt):
			key = utt['gender_bin']
			if 'national_variant' in utt:
				key += f"|{utt['national_variant']}"
			if condition_var and condition_var in utt:
				key += f"|{utt[condition_var]}"
			return key

		strata = defaultdict(list)
		for utt in utts_in_age:
			strata[get_key(utt)].append(utt)

		num_strata = len(strata)
		target_per_stratum = target_per_age / num_strata
		min_duration = target_per_stratum * (1 - tolerance)
		max_duration = target_per_stratum * (1 + tolerance)

		acc_duration_age_bin = 0

		for stratum_key, utts in strata.items():
			utts = utts.copy()
			random.shuffle(utts)
			acc_duration_stratum = 0
			for utt in utts:
				if utt['filename'] in used_ids:
					continue
				if acc_duration_stratum >= max_duration:
					break
				selected.append(utt)
				used_ids.add(utt['filename'])
				acc_duration_stratum += utt['duration']
			acc_duration_age_bin += acc_duration_stratum
			#if acc_duration_stratum < min_duration:
			#	print(f"Warning: Stratum {stratum_key} in age_bin {age_bin} underfilled ({acc_duration_stratum:.1f}s)")

		if acc_duration_age_bin < target_per_age:
			shortfall += (target_per_age - acc_duration_age_bin)

	#redistribute shortfall
	if shortfall > 0:
		print(f"Redistributing shortfall of {shortfall:.1f}s")

		remaining_bins = [age_bin for age_bin in age_bins if sum(u['duration'] for u in age_bins[age_bin]) > target_per_age]
		if not remaining_bins:
			print("No bins available for redistribution.")
		else:
			redistributed_target = shortfall / len(remaining_bins)

			for age_bin in remaining_bins:
				def get_key(utt):
					key = utt['gender_bin']
					if 'national_variant' in utt:
						key += f"|{utt['national_variant']}"
					if condition_var and condition_var in utt:
						key += f"|{utt[condition_var]}"
					return key

				strata = defaultdict(list)
				for utt in age_bins[age_bin]:
					strata[get_key(utt)].append(utt)

				num_strata = len(strata)
				target_per_stratum = redistributed_target / num_strata
				min_duration = target_per_stratum * (1 - tolerance)
				max_duration = target_per_stratum * (1 + tolerance)

				for stratum_key, utts in strata.items():
					utts = utts.copy()
					random.shuffle(utts)
					acc_duration_stratum = 0
					for utt in utts:
						if utt['filename'] in used_ids:
							continue
						if acc_duration_stratum >= max_duration:
							break
						selected.append(utt)
						used_ids.add(utt['filename'])
						acc_duration_stratum += utt['duration']
					#if acc_duration_stratum < min_duration:
					#	print(f"Warning: Stratum {stratum_key} in age_bin {age_bin} underfilled ({acc_duration_stratum:.1f}s)")

	total_selected = sum(u['duration'] for u in selected)
	print(f"Selected total: {total_selected:.1f}s ({total_selected/3600:.2f}h)")

	return selected, used_ids


In [None]:
used_ids = set()

def ja_non_native(data):
    return [utt for utt in data if utt.get("first_language") not in ["dut", "tus", "dia"]]

def ja_native(data):
    return [utt for utt in data if utt.get("first_language") in ["dut", "tus", "dia"]]

def cv_native(data):
    return [utt for utt in data if utt.get("accent").lower() == "nederlands-nederlands"]

Set0, used_ids = balanced_hierarchical_split(
        cv_native(cv_data),
        total_duration=10.0,
        used_ids=used_ids
        )

set1, used_ids = balanced_hierarchical_split(
        ja_non_native(ja_read_data),
        total_duration=9.0, 
        used_ids=used_ids
        )

set2, used_ids = balanced_hierarchical_split(
        ja_native(ja_read_data), 
        total_duration=9.0, 
        condition_var="accent", 
        used_ids=used_ids
        )

set3, used_ids = balanced_hierarchical_split(
        ja_non_native(ja_spon_data), 
        total_duration=7.0, 
        used_ids=used_ids
        )

set4, used_ids = balanced_hierarchical_split(
        ja_native(ja_spon_data), 
        total_duration=5.0,
        condition_var="accent", 
        used_ids=used_ids
        )

set6, used_ids = balanced_hierarchical_split(
       cv_data,
       total_duration=10.0,
       used_ids=used_ids
       )

Redistributing shortfall of 3442.2s
Selected total: 33128.3s (9.20h)
Redistributing shortfall of 7938.6s
Selected total: 32719.0s (9.09h)
Redistributing shortfall of 8573.6s
Selected total: 22192.0s (6.16h)
Redistributing shortfall of 4463.7s
Selected total: 16505.4s (4.58h)
Redistributing shortfall of 10449.2s
Selected total: 39633.5s (11.01h)
Redistributing shortfall of 16010.5s
Selected total: 32692.2s (9.08h)


In [None]:
def generate_stm_and_copy_audio(selected_set, ref_file, out_dir, ref_name, audio_dir):
    #generates a stm reference file for the new set, plus copies the selected audio files to a folder

    start_time = time.time()
    
    out_dir = os.path.join(out_path, out_dir)
    os.makedirs(out_dir, exist_ok=True)

    #load reference lines
    with open(ref_file, 'r') as f:
     ref_transcripts = f.readlines()

    print(f"Loaded {len(ref_transcripts)} reference lines.")

    #create stm file
    stm_file_path = os.path.join(out_dir, f"{ref_name}.stm")
    matched_transcripts = 0
    missing_transcripts = 0
    missing_audio_files = 0

    with open(stm_file_path, 'w') as stm_file:
     for utt in selected_set:     
       #find matching transcript
       transcript = next((line for line in ref_transcripts if line.startswith(utt['filename'])), None)
       if transcript:
              stm_file.write(transcript)
              matched_transcripts += 1
       else:
              print(f"Warning: No transcript found for {utt['filename']}")
              missing_transcripts += 1

       #copy audio file
       audio_file_path = os.path.join(audio_dir, utt['filename']+ ".wav")	
       if os.path.exists(audio_file_path):
              shutil.copy(audio_file_path, out_dir)
       else:
              print(f"Warning: No audio file found for {utt['filename']} in {audio_dir}")
              missing_audio_files += 1

    end_time = time.time()
    elapsed_time = end_time - start_time

    print(f".stm file and audio files for {out_dir} generated successfully.")
    print(f"Processed {len(selected_set)} utterances.")
    print(f"Matched {matched_transcripts} transcripts.")
    print(f"Missing {missing_transcripts} transcripts.")
    print(f"Missing {missing_audio_files} audio files.")
    print(f"Elapsed time: {elapsed_time:.2f} seconds.")

In [None]:
generate_stm_and_copy_audio(set1, ja_ref, '1. nonnative-read', '1_reference', ja_path)
generate_stm_and_copy_audio(set2, ja_ref, '2. native-read', '2_reference', ja_path)
generate_stm_and_copy_audio(set3, ja_ref, '3. nonnative-spon', '3_reference', ja_path)
generate_stm_and_copy_audio(set4, ja_ref, '4. native-spon', '4_reference', ja_path)
generate_stm_and_copy_audio(Set0, cv_ref, '0. baseline', '0_reference', cv_path)
generate_stm_and_copy_audio(set6, cv_ref, '6. hallucination mix/speech', '6_reference', cv_path)

Loaded 201394 reference transcripts.
.stm file and audio files for C:\Users\Topicus\Documents\Datasets\selection\1. nonnative-read generated successfully.
Processed 16271 utterances.
Matched 16271 transcripts.
Missing 0 transcripts.
Missing 0 audio files.
Elapsed time: 569.12 seconds.
Loaded 201394 reference transcripts.
.stm file and audio files for C:\Users\Topicus\Documents\Datasets\selection\2. native-read generated successfully.
Processed 16012 utterances.
Matched 16012 transcripts.
Missing 0 transcripts.
Missing 0 audio files.
Elapsed time: 730.92 seconds.
Loaded 201394 reference transcripts.
.stm file and audio files for C:\Users\Topicus\Documents\Datasets\selection\3. nonnative-spon generated successfully.
Processed 12199 utterances.
Matched 12199 transcripts.
Missing 0 transcripts.
Missing 0 audio files.
Elapsed time: 227.86 seconds.
Loaded 201394 reference transcripts.
.stm file and audio files for C:\Users\Topicus\Documents\Datasets\selection\4. native-spon generated success

'C:\\Users\\Topicus\\Documents\\Datasets\\selection\\6. hallucination mix/speech\\6_reference.stm'

In [None]:
def validate_audio_files(main_dir):
    subfolders = [f.path for f in os.scandir(main_dir) if f.is_dir()]
    summary = {}

    for subfolder in subfolders:
        subfolder_name = os.path.basename(subfolder)
        if os.path.isdir(subfolder):
            x = subfolder_name.split('.')[0]
            ref_file = os.path.join(subfolder, f"{x}_reference.stm")

        audio_files = [f for f in os.listdir(subfolder) if f.endswith('.wav')]

        unique_files = set()
        failed_checks = {
            "unique": 0,
            "format": 0,
            "channels": 0,
            "rate": 0,
            "transcript": 0
        }
        total_files_checked = 0

        for audio_file in audio_files:
            total_files_checked += 1
            file_path = os.path.join(subfolder, audio_file)

            #check for unique filenames
            if audio_file in unique_files:
                failed_checks["unique"] += 1
            else:
                unique_files.add(audio_file)

            #check audio format
            try:
                with wave.open(file_path, 'rb') as wav_file:
                    #check number of channels
                    if wav_file.getnchannels() != 1:
                        failed_checks["channels"] += 1
                    #check sample rate
                    if wav_file.getframerate() != 16000:
                        failed_checks["rate"] += 1
            except wave.Error:
                failed_checks["format"] += 1

            #check for transcript in reference file
            if x != '5':
                utt_id = os.path.splitext(audio_file)[0]
                with open(ref_file, 'r') as f:
                    ref_transcripts = f.readlines()
                if not any(utt_id in line for line in ref_transcripts):
                    failed_checks["transcript"] += 1

        summary[subfolder_name] = failed_checks
        summary[subfolder_name]["total_files_checked"] = total_files_checked

    #print summary
    for subfolder_name, checks in summary.items():
        print(f"Subfolder: {subfolder_name}")
        total = checks.pop("total_files_checked", 0)
        for check, count in checks.items():
            print(f"  {check}: {count} files failed")
        print(f"  Total files checked: {total}")
        print()


validate_audio_files(out_path)

Subfolder: 0. baseline
  unique: 0 files failed
  format: 0 files failed
  channels: 0 files failed
  rate: 0 files failed
  transcript: 0 files failed
  Total files checked: 9109

Subfolder: 1. nonnative-read
  unique: 0 files failed
  format: 0 files failed
  channels: 0 files failed
  rate: 0 files failed
  transcript: 0 files failed
  Total files checked: 16271

Subfolder: 2. native-read
  unique: 0 files failed
  format: 0 files failed
  channels: 0 files failed
  rate: 0 files failed
  transcript: 0 files failed
  Total files checked: 16012

Subfolder: 3. nonnative-spon
  unique: 0 files failed
  format: 0 files failed
  channels: 0 files failed
  rate: 0 files failed
  transcript: 0 files failed
  Total files checked: 12199

Subfolder: 4. native-spon
  unique: 0 files failed
  format: 0 files failed
  channels: 0 files failed
  rate: 0 files failed
  transcript: 0 files failed
  Total files checked: 9076

Subfolder: 5. hallucination noise
  unique: 0 files failed
  format: 0 fil