## First, run this cell to set up paths and import dependencies

In [None]:
import os

import matplotlib.pyplot as plt

from tqdm import tqdm

if not os.path.exists("./notebooks"):
    %cd ..

from src.data_processing import load_audio, split_into_clips, create_spectrogram, prepare_datasets, list_audio_files, SOAAudioClips, save_mean_std, compute_mean_std_from_images
from src.dataset_analysis import plot_spectrogram, duration_statistics
from src.config import VALID_ACCESS_LABELS, TRAIN_DIR, TEST_DIR, VAL_DIR, DATA_DIR, DATASET_DIR

# Ensure the output directory structure exists
os.makedirs(DATASET_DIR, exist_ok=True)
os.makedirs(TRAIN_DIR, exist_ok=True)
os.makedirs(VAL_DIR, exist_ok=True)
os.makedirs(TEST_DIR, exist_ok=True)


## 1. Load all .wav files in the provided directory and preview some

In [None]:
wav_files = list_audio_files(DATA_DIR)
print(f"Found {len(wav_files)} .wav files in directory '{DATA_DIR}'")

## 2. Statistics about authorized/unauthorized speakers

In [None]:
authorized_speakers_files = []
unauthorized_speakers_files = []

for file in wav_files:
    speaker_id = os.path.split(file)[-1].split('_')[0]
    if speaker_id in VALID_ACCESS_LABELS:
        authorized_speakers_files.append(file)
    else:
        unauthorized_speakers_files.append(file)

print("Authorized speakers recordings:")
soa_authorized = SOAAudioClips(authorized_speakers_files)
print(duration_statistics(soa_authorized.clips))

print("\nUnauthorized speakers recordings:")
soa_unauthorized = SOAAudioClips(unauthorized_speakers_files)
print(duration_statistics(soa_unauthorized.clips))

## 3. Split files into train, validation, and test sets

In [None]:
train_files, val_files, test_files = prepare_datasets(DATA_DIR)
print(f"Training files: {len(train_files)} | Validation files: {len(val_files)} | Test files: {len(test_files)}")

def save_spectrogram(spectrogram, output_path):
    plt.imsave(output_path, spectrogram, cmap='gray')


## 4. Calculate and display statistics about raw dataset

In [None]:
soa_train_full_clips = SOAAudioClips(train_files)
soa_test_full_clips = SOAAudioClips(test_files)
soa_val_full_clips = SOAAudioClips(val_files)

print("\nDataset Statistics:")
print("Training set:")
print(duration_statistics(soa_train_full_clips.clips))

print("Validation set:")
print(duration_statistics(soa_test_full_clips.clips))

print("Test set:")
print(duration_statistics(soa_val_full_clips.clips))

## 5. Process each dataset split by converting 3-second clips into spectrograms

In [None]:
import torch
from torchvision.transforms import transforms
from src.config import HOP_LENGTH, SAMPLE_RATE
from src.dataset_analysis import duration_statistics_spectrograms


def process_split(soa_full_clips, output_subdir):
    spectrograms = []
    for file_path, full_clip in tqdm(soa_full_clips):
        full_spectrogram = torch.tensor(create_spectrogram(full_clip)).unsqueeze(0)
        clip_time_seconds = 3
        clip_count = 100
        transform_width = clip_time_seconds * SAMPLE_RATE / HOP_LENGTH
        transform_height = full_spectrogram.shape[1]
        transform = transforms.RandomCrop((int(transform_height), int(transform_width)))

        for i in range(clip_count):
            spectrogram = (transform(full_spectrogram)).squeeze(0).numpy()
            spectrograms.append(spectrogram)
            output_path = os.path.join(output_subdir, f"{os.path.basename(file_path).split('.')[0]}_{i}_clip.png")
            save_spectrogram(spectrogram, output_path)
    print(duration_statistics_spectrograms(spectrograms))

print("Preprocessed Train Dataset:")
process_split(soa_train_full_clips, TRAIN_DIR)

print("\nPreprocessed Validation Dataset:")
process_split(soa_val_full_clips, VAL_DIR)

print("\nPreprocessed Test Dataset:")
process_split(soa_test_full_clips, TEST_DIR)


## 6. Mean and Standard deviation of training dataset

In [None]:
mean, std = compute_mean_std_from_images(TRAIN_DIR)
print(f"Mean: {mean}, Standard deviation: {std}")
save_mean_std(mean, std, f"{DATASET_DIR}/scaling_params.json")

## 7. Visualize some spectrogram examples

In [None]:
sample_spectrogram_paths = [os.path.join(TRAIN_DIR, f) for f in os.listdir(TRAIN_DIR)[:3]]
for path in sample_spectrogram_paths:
    spectrogram = plt.imread(path)
    plot_spectrogram(spectrogram, title=f"Spectrogram from {path}")