## First, run this cell to set up paths and import dependencies

In [None]:
%cd ..

import os
from src.data_processing import load_audio, split_into_clips, create_spectrogram, prepare_datasets
from src.dataset_analysis import plot_spectrogram, dataset_summary
import matplotlib.pyplot as plt
from tqdm import tqdm

# Configure directories
data_dir = "./data"          # Directory where .wav files are stored
output_dir = "./datasets"    # Directory to store processed spectrograms

# Ensure the output directory structure exists
os.makedirs(output_dir, exist_ok=True)
os.makedirs(f"{output_dir}/train", exist_ok=True)
os.makedirs(f"{output_dir}/val", exist_ok=True)
os.makedirs(f"{output_dir}/test", exist_ok=True)


## 1. Load all .wav files in the provided directory and preview some

In [None]:
wav_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.wav')]
print(f"Found {len(wav_files)} .wav files in directory '{data_dir}'")

## 2. Split files into train, validation, and test sets

In [None]:
train_files, val_files, test_files = prepare_datasets(data_dir)
print(f"Training files: {len(train_files)} | Validation files: {len(val_files)} | Test files: {len(test_files)}")

def save_spectrogram(spectrogram, output_path):
    plt.imsave(output_path, spectrogram, cmap='gray')


## 3. Process each dataset split by converting 2-second clips into spectrograms

In [None]:
def process_split(file_list, output_subdir):
    for file_path in tqdm(file_list, desc=f"Processing {output_subdir}"):
        audio, sr = load_audio(file_path)
        clips = split_into_clips(audio)

        for i, clip in enumerate(clips):
            spectrogram = create_spectrogram(clip, sr)
            output_path = os.path.join(output_subdir, f"{os.path.basename(file_path).split('.')[0]}_{i}_clip.png")
            save_spectrogram(spectrogram, output_path)

process_split(train_files, f"{output_dir}/train")
process_split(val_files, f"{output_dir}/val")
process_split(test_files, f"{output_dir}/test")


## 4. Calculate and display statistics about the dataset

In [None]:

print("\nDataset Statistics:")
print("Training set:")
dataset_summary(train_files)

print("\nValidation set:")
dataset_summary(val_files)

print("\nTest set:")
dataset_summary(test_files)


## 5. Visualize some spectrogram examples

In [None]:
sample_spectrogram_paths = [os.path.join(f"{output_dir}/train", f) for f in os.listdir(f"{output_dir}/train")[:3]]
for path in sample_spectrogram_paths:
    spectrogram = plt.imread(path)
    plot_spectrogram(spectrogram, title=f"Spectrogram from {path}")