In [None]:

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import librosa
import soundfile as sf
import os
from google.colab import drive
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import gc

# configure the google drive (since thats were out data is)

DRIVE_MOUNT_PATH = "/content/drive"

WAV_FILES_DIR = "/content/drive/MyDrive/CMPT final project/data/ds005876/stimuli"
OUTPUT_DIR    = "/content/drive/MyDrive/CMPT final project/data/stft_autoencoder"

SAMPLE_RATE = 22050
N_FFT       = 1024
HOP_LENGTH  = 256
DURATION    = 3

EPOCHS      = 120
BATCH_SIZE  = 8

# SETUP

def setup():
    tf.keras.mixed_precision.set_global_policy("mixed_float16")
    tf.keras.backend.clear_session()
    gc.collect()

    try:
        drive.mount(DRIVE_MOUNT_PATH)
        print("✓ Google Drive mounted")
    except Exception as e:
        print("Drive already mounted or error:", e)

    os.makedirs(OUTPUT_DIR, exist_ok=True)


# code to convert the audio to stft

def load_and_preprocess_audio(path, sr=SAMPLE_RATE, duration=DURATION):
    try:
        audio, _ = librosa.load(path, sr=sr, duration=duration)
        target_len = int(sr * duration)
        if len(audio) < target_len:
            audio = np.pad(audio, (0, target_len - len(audio)))
        else:
            audio = audio[:target_len]
        return audio
    except Exception as e:
        print("Error loading", path, ":", e)
        return None

def audio_to_stft(audio, n_fft=N_FFT, hop_length=HOP_LENGTH):
    stft = librosa.stft(audio, n_fft=n_fft, hop_length=hop_length)
    mag   = np.abs(stft)
    phase = np.angle(stft)
    return mag, phase

def stft_to_audio(mag, phase, n_fft=N_FFT, hop_length=HOP_LENGTH):
    stft_complex = mag * np.exp(1j * phase)
    audio = librosa.istft(stft_complex, hop_length=hop_length)
    return audio

# Load the data

def load_dataset(directory):
    mags_log = []
    phases   = []
    filenames = []

    wav_files = [f for f in os.listdir(directory) if f.lower().endswith(".wav")]
    if not wav_files:
        raise ValueError(f"No .wav files in {directory}")

    print(f"Found {len(wav_files)} wav files")

    for fn in wav_files:
        path = os.path.join(directory, fn)
        print("Processing:", fn)
        audio = load_and_preprocess_audio(path)
        if audio is None:
            continue

        mag, phase = audio_to_stft(audio)
        mag_log = np.log1p(mag)

        mags_log.append(mag_log)
        phases.append(phase)
        filenames.append(fn)

    mags_log = np.array(mags_log, dtype=np.float32)
    phases   = np.array(phases, dtype=np.float32)

    print("Raw log-mag shape:", mags_log.shape)

    # global normalization
    global_mean = mags_log.mean()
    global_std  = mags_log.std() + 1e-6
    mags_norm   = (mags_log - global_mean) / global_std

    # add channel dim for Conv2D
    mags_norm = mags_norm[..., np.newaxis]

    return mags_norm, phases, filenames, (global_mean, global_std)

# MODEL (pure conv autoencoder)

def build_stft_autoencoder(input_shape):

    F, T, C = input_shape

    enc_in = layers.Input(shape=input_shape, name="encoder_input")
    x = enc_in

    # Encoder
    x = layers.Conv2D(32, 3, activation="relu", padding="same")(x)
    x = layers.MaxPooling2D(2, padding="same")(x)     # -> ~F/2, T/2

    x = layers.Conv2D(64, 3, activation="relu", padding="same")(x)
    x = layers.MaxPooling2D(2, padding="same")(x)     # -> ~F/4, T/4

    x = layers.Conv2D(128, 3, activation="relu", padding="same")(x)
    x = layers.MaxPooling2D(2, padding="same")(x)     # -> ~F/8, T/8

    x = layers.Conv2D(256, 3, activation="relu", padding="same")(x)
    encoded = layers.MaxPooling2D(2, padding="same", name="bottleneck")(x)  # -> ~F/16, T/16

    # Decoder (mirror)
    x = layers.Conv2DTranspose(256, 3, strides=2, activation="relu", padding="same")(encoded)
    x = layers.Conv2DTranspose(128, 3, strides=2, activation="relu", padding="same")(x)
    x = layers.Conv2DTranspose(64, 3, strides=2, activation="relu", padding="same")(x)
    x = layers.Conv2DTranspose(32, 3, strides=2, activation="relu", padding="same")(x)

    x = layers.Conv2D(1, 3, activation="linear", padding="same")(x)

    # final resize
    x = layers.Lambda(
        lambda t: tf.image.resize(t, (F, T)),
        name="resize_to_input"
    )(x)

    autoencoder = keras.Model(enc_in, x, name="stft_autoencoder")
    return autoencoder


# TRAINING

def train_autoencoder(X_train, X_val, input_shape):
    tf.keras.backend.clear_session()
    gc.collect()

    autoencoder = build_stft_autoencoder(input_shape)

    opt = keras.optimizers.Adam(1e-3)
    autoencoder.compile(
        optimizer=opt,
        loss="mae",          # MAE on normalized log-mag
        metrics=["mse"]
    )

    print(autoencoder.summary())
    print("Total params:", autoencoder.count_params())

    early = keras.callbacks.EarlyStopping(
        monitor="val_loss",
        patience=20,
        restore_best_weights=True
    )
    reduce = keras.callbacks.ReduceLROnPlateau(
        monitor="val_loss",
        factor=0.5,
        patience=10,
        min_lr=1e-7
    )

    history = autoencoder.fit(
        X_train, X_train,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        validation_data=(X_val, X_val),
        callbacks=[early, reduce],
        verbose=1
    )
    return autoencoder, history

# reconstructing the audio from the stft and creating plots

def reconstruct_clip(autoencoder, spec_norm, phase, norm_params):
    mean, std = norm_params
    pred_norm = autoencoder.predict(spec_norm[np.newaxis, ...], verbose=0)[0]
    pred_norm = pred_norm[..., 0]

    pred_log_mag = pred_norm * std + mean
    pred_mag = np.expm1(pred_log_mag)                    # invert log1p

    audio = stft_to_audio(pred_mag, phase)
    return audio, pred_log_mag

def save_audio(audio, path, sr=SAMPLE_RATE):
    sf.write(path, audio.astype(np.float32), sr)
    print("Saved:", path)

def plot_spectrograms(orig_log, recon_log):
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))

    im0 = axes[0].imshow(orig_log, aspect="auto", origin="lower", cmap="magma")
    axes[0].set_title("Original log-mag")
    plt.colorbar(im0, ax=axes[0])

    im1 = axes[1].imshow(recon_log, aspect="auto", origin="lower", cmap="magma")
    axes[1].set_title("Reconstructed log-mag")
    plt.colorbar(im1, ax=axes[1])

    for ax in axes:
        ax.set_xlabel("Time")
        ax.set_ylabel("Freq bin")

    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, "stft_comparison.png"), dpi=120)
    plt.show()

# MAIN

def main():
    print("="*70)
    print("STFT AUDIO AUTOENCODER WITH ORIGINAL PHASE (FIXED)")
    print("="*70)

    print("[1] Setup...")
    setup()

    print("\n[2] Loading dataset (STFT)...")
    specs_norm, phases, filenames, norm_params = load_dataset(WAV_FILES_DIR)
    print("Specs_norm shape:", specs_norm.shape)
    print("Phases shape:", phases.shape)

    print("\n[3] Train/val split...")
    X_train, X_val, P_train, P_val, F_train, F_val = train_test_split(
        specs_norm, phases, filenames, test_size=0.2, random_state=42
    )
    print("Train:", len(X_train), "Val:", len(X_val))

    del specs_norm, phases
    gc.collect()

    print("\n[4] Training autoencoder...")
    input_shape = X_train.shape[1:]
    autoencoder, history = train_autoencoder(X_train, X_val, input_shape)

    model_path = os.path.join(OUTPUT_DIR, "stft_autoencoder.keras")
    autoencoder.save(model_path)
    print("Model saved to:", model_path)

    print("\n[5] Plot training curves...")
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.plot(history.history["loss"], label="Train")
    plt.plot(history.history["val_loss"], label="Val")
    plt.xlabel("Epoch"); plt.ylabel("MAE"); plt.title("Loss")
    plt.legend(); plt.grid(True)
    plt.subplot(1, 2, 2)
    plt.plot(history.history["mse"], label="Train MSE")
    plt.plot(history.history["val_mse"], label="Val MSE")
    plt.xlabel("Epoch"); plt.ylabel("MSE"); plt.title("MSE")
    plt.legend(); plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, "training.png"), dpi=120)
    plt.show()

    print("\n[6] Reconstruct one validation clip...")
    idx = 0
    spec_norm = X_val[idx]
    phase     = P_val[idx]
    fname     = F_val[idx]

    audio_recon, logmag_recon = reconstruct_clip(autoencoder, spec_norm, phase, norm_params)

    mean, std = norm_params
    logmag_orig = spec_norm[..., 0] * std + mean

    plot_spectrograms(logmag_orig, logmag_recon)

    out_path = os.path.join(OUTPUT_DIR, "reconstructed_" + fname)
    save_audio(audio_recon, out_path)

    print("\n▶ Playing original and reconstructed...")
    from IPython.display import Audio, display
    orig_audio = load_and_preprocess_audio(os.path.join(WAV_FILES_DIR, fname))
    display(Audio(orig_audio, rate=SAMPLE_RATE))
    display(Audio(audio_recon, rate=SAMPLE_RATE))

    # SAVE VALIDATION DATA FOR PCA ANALYSIS
    print("\n[7] Saving validation data for PCA analysis...")
    np.save(os.path.join(OUTPUT_DIR, "X_val.npy"), X_val)
    np.save(os.path.join(OUTPUT_DIR, "F_val.npy"), F_val)
    print("✓ Validation data saved")

    print("\nDone.")
    print("Outputs:", OUTPUT_DIR)
    return autoencoder, history, X_val, F_val

# RUN

if __name__ == "__main__":
    autoencoder, history, X_val, F_val = main()

The following is the code for PCA analysis

In [None]:
# PCA Analysis
from sklearn.decomposition import PCA

# Create encoder
print("Creating encoder...")
encoder = keras.Model(
    inputs=autoencoder.input,
    outputs=autoencoder.get_layer('bottleneck').output
)

# Extract latent vectors
print("Extracting latent vectors...")
latent_vectors = encoder.predict(X_val, verbose=1)
latent_flat = latent_vectors.reshape(len(latent_vectors), -1)
print(f"Latent shape: {latent_flat.shape}")

# Do PCA
print("\nRunning PCA...")
pca = PCA(n_components=min(50, len(latent_flat), latent_flat.shape[1]))
latent_pca = pca.fit_transform(latent_flat)

# Print results
print("\nVariance Explained:")
cumsum = np.cumsum(pca.explained_variance_ratio_)
for i in range(min(10, len(pca.explained_variance_ratio_))):
    print(f"PC{i+1}: {pca.explained_variance_ratio_[i]*100:.2f}% | Cumulative: {cumsum[i]*100:.2f}%")

for thresh in [0.90, 0.95, 0.99]:
    n_comp = np.argmax(cumsum >= thresh) + 1
    print(f"{thresh*100:.0f}% variance needs {n_comp} components")

# Plot
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

axes[0].bar(range(1, len(pca.explained_variance_ratio_)+1),
            pca.explained_variance_ratio_)
axes[0].set_xlabel('Component')
axes[0].set_ylabel('Variance Explained')
axes[0].set_title('Scree Plot')
axes[0].grid(True)

axes[1].plot(range(1, len(cumsum)+1), cumsum, 'b-', linewidth=2)
axes[1].axhline(y=0.90, color='r', linestyle='--', label='90%')
axes[1].axhline(y=0.95, color='g', linestyle='--', label='95%')
axes[1].set_xlabel('Number of Components')
axes[1].set_ylabel('Cumulative Variance')
axes[1].set_title('Cumulative Variance')
axes[1].legend()
axes[1].grid(True)

axes[2].scatter(latent_pca[:, 0], latent_pca[:, 1], alpha=0.6, s=100)
for i, fname in enumerate(F_val):
    axes[2].annotate(fname[:10], (latent_pca[i, 0], latent_pca[i, 1]),
                    fontsize=7, alpha=0.6)
axes[2].set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)')
axes[2].set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)')
axes[2].set_title('First 2 Principal Components')
axes[2].grid(True)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'pca_analysis.png'), dpi=150)
plt.show()

print("\n✅ Done!")