In [None]:
import torchaudio
fpath = "/graft2/datasets/znovack/LJSpeech-1.1/wavs/"
# get length of each file in fpath
import os
import tqdm


lengths = {}
for f in tqdm.tqdm(os.listdir(fpath)):
    if f.endswith(".wav"):
        info = torchaudio.info(os.path.join(fpath, f))
        print(f, info.num_frames)
        lengths[f] = info.num_frames

# write to json
import json
with open("ljspeech_info.json", "w") as f:
    json.dump(lengths, f)

In [None]:
# for given path, get length, sampling rate, and bit depth
import torchaudio
import os
import tqdm
def get_wav_info(path):
    # get length, sr, bit depth, and # of chanels
    info = torchaudio.info(path)
    return info.num_frames, info.sample_rate, info.bits_per_sample, info.num_channels

pth = '/mnt/arrakis_data/pnlong/lnac/birdvox/unit06/split_data/'
wav_info = {}
for root, _, files in tqdm.tqdm(os.walk(pth)):
    for f in files:
        if f.lower().endswith(('.flac', '.wav')):
            full_path = os.path.join(root, f)
            try:
                length, sr, bps, ch = get_wav_info(full_path)
            except Exception as e:
                print(f"Error reading {full_path}: {e}")
                continue
            key = os.path.basename(full_path)
            wav_info[key] = {'length': length, 'sample_rate': sr, 'bits_per_sample': bps, 'n_channels': ch}
            
import json
with open('birdvox_info.json', 'w') as f:
    json.dump(wav_info, f)

In [None]:
# make 90/10 train/val split of birdvox files
# it can't be fully random, as the data is underlying from 11 macro 12-hour recordings
# so the set of original recordings between train and val should be disjoint
import os
import random
import json

# file names are like: 'unit06_[ogfilename]_000.flac'

with open('birdvox_info.json', 'r') as f:
    wav_info = json.load(f)
all_files = list(wav_info.keys())

# group by original recording
recording_dict = {}
for f in all_files:
    rec_id = f.split('_')[0] + '_' + f.split('_')[1]  # e.g. unit06_[ogfilename]
    if rec_id not in recording_dict:
        recording_dict[rec_id] = []
    recording_dict[rec_id].append(f)

# try to split it as 90-10 as possible while keeping recordings disjoint
all_recordings = list(recording_dict.keys())
n_files_per_rec = {rec: len(files) for rec, files in recording_dict.items()}

random.shuffle(all_recordings)
train_files = []
val_files = []
train_recs = set()
val_recs = set()
total_files = len(all_files)
train_target = int(0.9 * total_files)
current_train_count = 0
for rec in all_recordings:
    rec_files = recording_dict[rec]
    if current_train_count + n_files_per_rec[rec] <= train_target:
        train_files.extend(rec_files)
        current_train_count += n_files_per_rec[rec]
        train_recs.add(rec)
    else:
        val_files.extend(rec_files)
        val_recs.add(rec)
print(f"Train files: {len(train_files)}, Val files: {len(val_files)}")
print(f"Train recordings: {train_recs}, Val recordings: {val_recs}")

# assert disjointness
assert train_recs.isdisjoint(val_recs), "Train and Val recordings are not disjoint!"
assert len(train_files) + len(val_files) == total_files, "File count mismatch!"
assert set(train_files).isdisjoint(set(val_files)), "Train and Val files are not disjoint!"

# make train info and val info jsons
train_info = {f: wav_info[f] for f in train_files}
val_info = {f: wav_info[f] for f in val_files}

# save
with open('birdvox_train_info.json', 'w') as f:
    json.dump(train_info, f)
with open('birdvox_val_info.json', 'w') as f:
    json.dump(val_info, f)

In [None]:
len(wav_info)

In [None]:
import json

# load
with open('ljspeech_info.json', 'r') as f:
    wav_info = json.load(f)

# get statistics
lengths = [v['length'] for v in wav_info.values()]
sample_rates = [v['sample_rate'] for v in wav_info.values()]
bit_depths = [v['bits_per_sample'] for v in wav_info.values()]
channels = [v['n_channels'] for v in wav_info.values()]
import numpy as np
print("Length: min {}, max {}, mean {}, std {}".format(np.min(lengths), np.max(lengths), np.mean(lengths), np.std(lengths)))
print("Sample Rate: min {}, max {}, mean {}, std {}".format(np.min(sample_rates), np.max(sample_rates), np.mean(sample_rates), np.std(sample_rates)))
print("Bit Depth: min {}, max {}, mean {}, std {}".format(np.min(bit_depths), np.max(bit_depths), np.mean(bit_depths), np.std(bit_depths)))
print("Channels: min {}, max {}, mean {}, std {}".format(np.min(channels), np.max(channels), np.mean(channels), np.std(channels)))

# get count of unique sample rates, bit depths, and channels
from collections import Counter
sr_counts = Counter(sample_rates)
bps_counts = Counter(bit_depths)
ch_counts = Counter(channels)
print("Sample Rate Counts:", sr_counts)
print("Bit Depth Counts:", bps_counts)
print("Channel Counts:", ch_counts)


In [None]:
len(wav_info)