In [None]:
import os, glob, json, librosa, numpy as np, pandas as pd, tensorflow as tf
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report
import tf2onnx          # ⇐ pip install tf2onnx
from tensorflow.keras import layers, models

import matplotlib.pyplot as plt
import random
from scipy.stats import ttest_ind
import shutil

# ---------------- PATHS ----------------
WAV_DIR   = '/Users/shadowhusky/Desktop/CGVI/Term2/Thesis/SnapHandsDemo/SnapHandsDemo/output/files'
OUT_DIR   = '/Users/shadowhusky/Desktop/CGVI/Term2/Thesis/SnapHandsDemo/SnapHandsDemo/Assets/Models'
CLICK_DIR = os.path.join(WAV_DIR, 'clicks')
NOISE_DIR = os.path.join(WAV_DIR, 'noise')

# ---------------- AUDIO PARAMS ----------------
SR          = 16_000           # 16 kHz is enough & faster
WIN_LEN     = 1024             # 64 ms @16 k
HOP_LEN     = 256              # 16 ms
N_MELS      = 64
EXCERPT_SEC = 0.64             # pad/trim to 0.64 s  ⇒  40 frames

def file_to_mels(file):
    y, _ = librosa.load(file, sr=SR, mono=True)
    # Calculate the number of samples needed for exactly 40 frames
    n_frames = 40
    needed_len = WIN_LEN + (n_frames - 1) * HOP_LEN
    if len(y) < needed_len:
        y = np.pad(y, (0, needed_len - len(y)))
    else:
        y = y[:needed_len]
    S = librosa.feature.melspectrogram(
        y=y, sr=SR, n_fft=WIN_LEN, hop_length=HOP_LEN,
        n_mels=N_MELS, fmin=80, fmax=8_000)
    S_dB = librosa.power_to_db(S, ref=np.max)
    # Ensure the output has exactly 40 frames (second dimension)
    if S_dB.shape[1] > n_frames:
        S_dB = S_dB[:, :n_frames]
    elif S_dB.shape[1] < n_frames:
        pad_width = n_frames - S_dB.shape[1]
        S_dB = np.pad(S_dB, ((0,0), (0, pad_width)), mode='constant')
    return S_dB

def load_dataset():
    X, y, names = [], [], []
    for lbl, folder in [(1, CLICK_DIR), (0, NOISE_DIR)]:
        for f in sorted(glob.glob(os.path.join(folder, '*.wav'))):
            X.append(file_to_mels(f))
            y.append(lbl)
            names.append(os.path.basename(f))
    X = np.stack(X)[..., np.newaxis]               # (N, 64, 40, 1)
    y = np.asarray(y).astype('float32')
    return X, y, names

X, y, names = load_dataset()
print("Dataset:", X.shape, y.shape)

# ---------------- DESCRIPTIVE STATS (keeps your old printout) ----------------
# Use quick feature proxies for gating insight
flatness = X.squeeze().reshape(len(X), N_MELS, -1).var(axis=(-2,-1))
centroid = (X.squeeze()*np.arange(N_MELS)[:,None]).sum((-2,-1))/X.squeeze().sum((-2,-1))
for lbl in [1,0]:
    sub = (y==lbl)
    print("\nStats for", "CLICK" if lbl else "NOISE")
    for feat, name in [(flatness, 'flatness'), (centroid, 'centroid')]:
        mu, sig = feat[sub].mean(), feat[sub].std()
        print(f"{name:9s}: mean={mu:7.3f}  std={sig:6.3f}  gate≈{mu-sig:.3f}")

# ---------------- MODEL ----------------
tf.keras.backend.clear_session()

inputs = layers.Input(shape=(N_MELS, int(EXCERPT_SEC*SR/HOP_LEN), 1))
x = layers.Conv2D(16, (3,3), padding='same', activation='relu')(inputs)
x = layers.BatchNormalization()(x)
x = layers.MaxPool2D((2,2))(x)

x = layers.Conv2D(32, (3,3), padding='same', activation='relu')(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPool2D((2,2))(x)

x = layers.Flatten()(x)
x = layers.Dense(64, activation='relu')(x)
x = layers.Dropout(0.3)(x)
outputs = layers.Dense(1, activation='sigmoid')(x)

model = models.Model(inputs, outputs)
model.compile(optimizer='adam',
              loss='binary_crossentropy',
              metrics=['accuracy'])

model.summary()


# ---------------- TRAIN ----------------
HISTORY = model.fit(
    X, y,
    epochs=40,
    batch_size=32,
    validation_split=0.2,
    callbacks=[
        tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True)
    ],
    verbose=2)

print("\nDetailed metrics on train+val:")
 
print(classification_report(y, (model.predict(X)>0.5).astype(int)))

# ---------------- EXPORT ----------------
os.makedirs(OUT_DIR, exist_ok=True)

# 1) TFLite (good for Barracuda)
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]   # 8-bit quant
tflite_model = converter.convert()
tflite_path = os.path.join(OUT_DIR, "click_cnn.tflite")
with open(tflite_path, "wb") as fp:
    fp.write(tflite_model)
print("✔ saved", tflite_path)

# 2) ONNX (if you prefer that pipeline)
onnx_path = os.path.join(OUT_DIR, "click_cnn.onnx")
model_proto, _ = tf2onnx.convert.from_keras(model, output_path=onnx_path,
                                            opset=13)
print("✔ saved", onnx_path)

# 3) Normalisation meta for runtime C# wrapper (mean/std per mel bin)
meta_path = os.path.join(OUT_DIR, "click_cnn_meta.json")
meta = dict(mean=float(X.mean()), std=float(X.std()),
            n_mels=N_MELS, hop=HOP_LEN, sr=SR,
            excerpt_sec=EXCERPT_SEC)
with open(meta_path, "w") as f:
    json.dump(meta, f, indent=2)
print("✔ saved", meta_path)

# Export ONNX as .bytes for Unity TextAsset
bytes_path = os.path.join(OUT_DIR, "click_cnn.bytes")
shutil.copy(onnx_path, bytes_path)
print("✔ saved", bytes_path)

# ═══════════════════════════════════════════════════════════════════════════════
# COMPREHENSIVE FREQUENCY SPECTRUM ANALYSIS FOR CLICK DETECTION
# ═══════════════════════════════════════════════════════════════════════════════


print("=" * 80)
print("FREQUENCY SPECTRUM ANALYSIS FOR CLICK DETECTION")
print("=" * 80)

# Extract frequency spectra from raw audio files for better analysis
def extract_spectrum(file_path):
    y, _ = librosa.load(file_path, sr=48000, mono=True)  # Use original 48kHz
    # Take first 1024 samples and apply window
    N = 1024
    hann = np.hanning(N)
    blk = np.zeros(N); blk[:min(len(y),N)] = y[:N] * hann
    mag = np.abs(np.fft.rfft(blk))
    freqs = np.fft.rfftfreq(N, 1/48000)
    return mag, freqs

# Collect spectra for all files
click_spectra, noise_spectra = [], []
click_files = glob.glob(os.path.join(CLICK_DIR, '*.wav'))
noise_files = glob.glob(os.path.join(NOISE_DIR, '*.wav'))

print(f"Analyzing {len(click_files)} click files and {len(noise_files)} noise files...")

for f in click_files:
    mag, freqs = extract_spectrum(f)
    click_spectra.append(mag)

for f in noise_files:
    mag, freqs = extract_spectrum(f)
    noise_spectra.append(mag)

click_spectra = np.array(click_spectra)
noise_spectra = np.array(noise_spectra)

# Calculate average spectra
avg_click = np.mean(click_spectra, axis=0)
std_click = np.std(click_spectra, axis=0)
avg_noise = np.mean(noise_spectra, axis=0) if len(noise_spectra) > 0 else np.zeros_like(avg_click)
std_noise = np.std(noise_spectra, axis=0) if len(noise_spectra) > 0 else np.zeros_like(avg_click)

# Plot comprehensive frequency analysis
plt.figure(figsize=(20, 15))

# 1. Average frequency spectra comparison
plt.subplot(3, 3, 1)
plt.semilogy(freqs, avg_click, 'b-', label=f'Clicks (n={len(click_files)})', linewidth=2)
plt.fill_between(freqs, avg_click-std_click, avg_click+std_click, alpha=0.3, color='blue')
if len(noise_spectra) > 0:
    plt.semilogy(freqs, avg_noise, 'r-', label=f'Noise (n={len(noise_files)})', linewidth=2)
    plt.fill_between(freqs, avg_noise-std_noise, avg_noise+std_noise, alpha=0.3, color='red')
plt.xlabel('Frequency (Hz)')
plt.ylabel('Magnitude')
plt.title('Average Frequency Spectra')
plt.legend()
plt.grid(True, alpha=0.3)
plt.xlim(0, 8000)

# 2. Energy distribution by frequency bands
plt.subplot(3, 3, 2)
bands = [(0, 600), (600, 1000), (1000, 2000), (2000, 4000), (4000, 8000)]
band_names = ['<600Hz', '600-1k', '1-2k', '2-4k', '4-8k']

click_band_energies = []
noise_band_energies = []

for low, high in bands:
    band_mask = (freqs >= low) & (freqs <= high)
    click_energy = np.mean([np.sum(spec[band_mask]**2) for spec in click_spectra])
    click_band_energies.append(click_energy)
    if len(noise_spectra) > 0:
        noise_energy = np.mean([np.sum(spec[band_mask]**2) for spec in noise_spectra])
        noise_band_energies.append(noise_energy)

x = np.arange(len(band_names))
width = 0.35
plt.bar(x - width/2, click_band_energies, width, label='Clicks', alpha=0.8)
if len(noise_band_energies) > 0:
    plt.bar(x + width/2, noise_band_energies, width, label='Noise', alpha=0.8)
plt.xlabel('Frequency Band')
plt.ylabel('Average Energy')
plt.title('Energy Distribution by Band')
plt.xticks(x, band_names)
plt.legend()
plt.yscale('log')

# 3. Individual click spectra (sample)
plt.subplot(3, 3, 3)
sample_indices = random.sample(range(len(click_spectra)), min(10, len(click_spectra)))
for i in sample_indices:
    plt.semilogy(freqs, click_spectra[i], alpha=0.6, linewidth=1)
plt.xlabel('Frequency (Hz)')
plt.ylabel('Magnitude')
plt.title('Individual Click Spectra (sample)')
plt.xlim(0, 8000)
plt.grid(True, alpha=0.3)

# 4. Statistical significance of frequency differences
plt.subplot(3, 3, 4)
if len(noise_spectra) > 0:
    # Perform t-test for each frequency bin
    p_values = []
    for i in range(len(freqs)):
        _, p = ttest_ind(click_spectra[:, i], noise_spectra[:, i])
        p_values.append(p)
    
    significant_mask = np.array(p_values) < 0.05
    plt.semilogy(freqs, p_values, 'k-', alpha=0.7)
    plt.axhline(y=0.05, color='r', linestyle='--', label='p=0.05')
    plt.fill_between(freqs, 0, 1, where=significant_mask, alpha=0.3, color='green', 
                    label='Significant difference')
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('p-value')
    plt.title('Statistical Significance by Frequency')
    plt.legend()
    plt.xlim(0, 8000)
    plt.ylim(1e-10, 1)

# 5. Spectral features comparison
plt.subplot(3, 3, 5)
# Calculate traditional spectral features
def calc_features(spectra, freqs):
    features = []
    for spec in spectra:
        # Hi/Lo ratio
        hi_mask = (freqs >= 1000) & (freqs <= 8000)
        lo_mask = freqs <= 600
        hi_energy = np.sum(spec[hi_mask]**2)
        lo_energy = np.sum(spec[lo_mask]**2) + 1e-9
        ratio = np.log10(hi_energy / lo_energy + 1e-9)
        
        # Spectral centroid
        centroid = np.sum(freqs * spec) / np.sum(spec)
        
        # Spectral flatness
        geometric_mean = np.exp(np.mean(np.log(spec + 1e-6)))
        arithmetic_mean = np.mean(spec)
        flatness = geometric_mean / arithmetic_mean
        
        # RMS
        rms = np.sqrt(np.mean(spec**2))
        
        features.append([ratio, centroid, flatness, rms])
    
    return np.array(features)

click_features = calc_features(click_spectra, freqs)
if len(noise_spectra) > 0:
    noise_features = calc_features(noise_spectra, freqs)
    
    feature_names = ['Hi/Lo Ratio (log10)', 'Centroid (Hz)', 'Flatness', 'RMS']
    
    for i, name in enumerate(feature_names):
        plt.subplot(3, 3, 5+i)
        plt.hist(click_features[:, i], bins=20, alpha=0.7, label='Clicks', density=True)
        plt.hist(noise_features[:, i], bins=20, alpha=0.7, label='Noise', density=True)
        plt.xlabel(name)
        plt.ylabel('Density')
        plt.title(f'{name} Distribution')
        plt.legend()
        plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()