In [1]:
import os
import json
import glob
import numpy as np
import soundfile as sf
import librosa
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, classification_report
from joblib import dump, load

# =========================
# CONFIG - FIXED VERSION
# =========================
DATASET_TRAIN_DIR = r"C:\Users\Jaiganesh\SoundGaurd\data\processed_data\train"
DATASET_TEST_DIR  = r"C:\Users\Jaiganesh\SoundGaurd\data\processed_data\test"
MODEL_DIR         = r"C:\Users\Jaiganesh\SoundGaurd\models_stage2"
os.makedirs(MODEL_DIR, exist_ok=True)

# Feature extraction parameters
SR = 22050
N_MFCC = 40
N_FFT = 1024
WIN_LENGTH = 512
HOP_LENGTH = 256
USE_DELTAS = True
POOL_STATS = ["mean", "std"]

# 4-class labels
CLASS_NAMES = ["non_threat", "glass_break", "scream", "gunshot"]
CLASS_MAPPING = {
    "non_threat": 0,
    "glass_break": 1, 
    "scream": 2,
    "gunshot": 3
}

# =========================
# UTILITIES - FIXED VERSION
# =========================
def list_audio_files_multiclass(root_dir):
    """List all audio files with 4-class labels - FIXED VERSION"""
    wavs = []
    
    if not os.path.isdir(root_dir):
        print(f"[ERROR] Root directory doesn't exist: {root_dir}")
        return wavs
    
    # Non-threat files (label = 0) - FIXED: avoid double counting
    non_threat_path = os.path.join(root_dir, "non_threat")
    if os.path.isdir(non_threat_path):
        # Collect all unique WAV files (both .wav and .WAV)
        all_non_threat_files = set()  # Use set to avoid duplicates
        for ext in ["*.wav", "*.WAV"]:
            pattern = os.path.join(non_threat_path, ext)
            all_non_threat_files.update(glob.glob(pattern))
        
        all_non_threat_files = list(all_non_threat_files)  # Convert back to list
        wavs.extend([(p, 0) for p in all_non_threat_files])
        print(f"[INFO] Found {len(all_non_threat_files)} non-threat files")
    else:
        print(f"[WARN] Non-threat directory missing: {non_threat_path}")
    
    # Threat files (labels = 1, 2, 3) - FIXED: avoid double counting
    threat_parent_path = os.path.join(root_dir, "threat")
    if os.path.isdir(threat_parent_path):
        for subfolder, label in [("glass_break", 1), ("scream", 2), ("gunshot", 3)]:
            subfolder_path = os.path.join(threat_parent_path, subfolder)
            if os.path.isdir(subfolder_path):
                # Collect all unique WAV files for this subfolder
                all_threat_files = set()  # Use set to avoid duplicates
                for ext in ["*.wav", "*.WAV"]:
                    pattern = os.path.join(subfolder_path, ext)
                    all_threat_files.update(glob.glob(pattern))
                
                all_threat_files = list(all_threat_files)  # Convert back to list
                wavs.extend([(p, label) for p in all_threat_files])
                print(f"[INFO] Found {len(all_threat_files)} {subfolder} files")
            else:
                print(f"[WARN] {subfolder} directory missing: {subfolder_path}")
    else:
        print(f"[WARN] Threat parent directory missing: {threat_parent_path}")
    
    # Final summary
    total_threats = len([w for w in wavs if w[1] in [1, 2, 3]])
    total_non_threats = len([w for w in wavs if w[1] == 0])
    print(f"[INFO] Summary - Threats: {total_threats}, Non-threats: {total_non_threats}")
    print(f"[INFO] Total files: {len(wavs)}")
    
    return wavs

def extract_mfcc_features(wav_path):
    """Extract MFCC + deltas features (same as Stage 1)"""
    try:
        y, sr = sf.read(wav_path)
        
        if y.ndim > 1:
            y = np.mean(y, axis=1)
        
        if sr != SR:
            y = librosa.resample(y, orig_sr=sr, target_sr=SR)
        
        mfcc = librosa.feature.mfcc(
            y=y, sr=SR, n_mfcc=N_MFCC, n_fft=N_FFT,
            hop_length=HOP_LENGTH, win_length=WIN_LENGTH
        )
        
        feats = [mfcc]
        
        if USE_DELTAS:
            delta = librosa.feature.delta(mfcc, order=1)
            delta2 = librosa.feature.delta(mfcc, order=2)
            feats.extend([delta, delta2])
        
        F = np.vstack(feats)
        
        pooled = []
        for stat in POOL_STATS:
            if stat == "mean":
                pooled.append(np.mean(F, axis=1))
            elif stat == "std":
                pooled.append(np.std(F, axis=1))
        
        pooled_vec = np.concatenate(pooled, axis=0)
        return pooled_vec.astype(np.float32)
        
    except Exception as e:
        print(f"[ERROR] Failed to process {wav_path}: {e}")
        return None

def build_multiclass_dataset(root_dir):
    """Build 4-class dataset"""
    items = list_audio_files_multiclass(root_dir)
    
    if len(items) == 0:
        raise RuntimeError(f"No audio files found in {root_dir}")
    
    X, y, paths = [], [], []
    failed_count = 0
    
    print(f"[INFO] Processing {len(items)} files...")
    for i, (path, label) in enumerate(items):
        if (i + 1) % 500 == 0:  # Progress indicator every 500 files
            print(f"[INFO] Processed {i + 1}/{len(items)} files...")
        
        features = extract_mfcc_features(path)
        if features is not None and np.all(np.isfinite(features)):
            X.append(features)
            y.append(label)
            paths.append(path)
        else:
            failed_count += 1
    
    if failed_count > 0:
        print(f"[WARN] Failed to process {failed_count} files")
    
    if len(X) == 0:
        raise RuntimeError(f"No valid features extracted from {len(items)} files")
    
    X = np.vstack(X)
    y = np.array(y, dtype=np.int64)
    
    print(f"[INFO] Successfully processed {len(X)} files")
    print(f"[INFO] Feature shape: {X.shape}")
    
    # Show class distribution
    for i, class_name in enumerate(CLASS_NAMES):
        count = np.sum(y == i)
        print(f"[INFO] Class {i} ({class_name}): {count} samples")
    
    return X, y, paths

def evaluate_multiclass(y_true, y_pred, title=""):
    """Evaluate multi-class model performance"""
    acc = accuracy_score(y_true, y_pred)
    prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="macro")
    cm = confusion_matrix(y_true, y_pred)
    
    print(f"\n=== {title} ===")
    print(f"Overall Accuracy: {acc:.4f}")
    print(f"Macro Precision: {prec:.4f}")
    print(f"Macro Recall: {rec:.4f}")
    print(f"Macro F1-Score: {f1:.4f}")
    
    print("\nConfusion Matrix:")
    print("Rows=True, Cols=Predicted")
    print("         ", end="")
    for name in CLASS_NAMES:
        print(f"{name:>12}", end="")
    print()
    
    for i, true_class in enumerate(CLASS_NAMES):
        print(f"{true_class:>12}: ", end="")
        for j in range(len(CLASS_NAMES)):
            print(f"{cm[i][j]:>8}", end="    ")
        print()
    
    print(f"\nDetailed Classification Report:")
    print(classification_report(y_true, y_pred, target_names=CLASS_NAMES))
    
    return {"accuracy": acc, "precision": prec, "recall": rec, "f1": f1}

# =========================
# MAIN TRAINING
# =========================
if __name__ == "__main__":
    print("🎯 STAGE 2: MULTI-CLASS THREAT CLASSIFICATION (FIXED VERSION)")
    print("=" * 75)
    
    # Check folder structure first
    print("📁 DATASET STRUCTURE CHECK:")
    total_train_files = 0
    total_test_files = 0
    
    for split, split_dir in [("TRAIN", DATASET_TRAIN_DIR), ("TEST", DATASET_TEST_DIR)]:
        print(f"\n{split} Directory: {split_dir}")
        split_total = 0
        
        if os.path.isdir(split_dir):
            # Check non-threat
            non_threat_path = os.path.join(split_dir, "non_threat")
            if os.path.isdir(non_threat_path):
                # Count unique files (both .wav and .WAV)
                wav_files = set(glob.glob(os.path.join(non_threat_path, "*.wav")))
                wav_files.update(glob.glob(os.path.join(non_threat_path, "*.WAV")))
                count = len(wav_files)
                print(f"  ✅ non_threat: {count} files")
                split_total += count
            else:
                print(f"  ❌ non_threat: MISSING")
            
            # Check threat subfolders
            threat_path = os.path.join(split_dir, "threat")
            if os.path.isdir(threat_path):
                for subfolder in ["glass_break", "scream", "gunshot"]:
                    subfolder_path = os.path.join(threat_path, subfolder)
                    if os.path.isdir(subfolder_path):
                        # Count unique files (both .wav and .WAV)
                        wav_files = set(glob.glob(os.path.join(subfolder_path, "*.wav")))
                        wav_files.update(glob.glob(os.path.join(subfolder_path, "*.WAV")))
                        count = len(wav_files)
                        print(f"  ✅ threat/{subfolder}: {count} files")
                        split_total += count
                    else:
                        print(f"  ❌ threat/{subfolder}: MISSING")
            else:
                print(f"  ❌ threat: MISSING")
        else:
            print(f"  ❌ Directory doesn't exist!")
        
        print(f"  📊 {split} Total: {split_total} files")
        if split == "TRAIN":
            total_train_files = split_total
        else:
            total_test_files = split_total
    
    print(f"\n📊 DATASET SUMMARY:")
    print(f"   Train: {total_train_files} files")
    print(f"   Test: {total_test_files} files")
    print(f"   Grand Total: {total_train_files + total_test_files} files")
    
    print("\n" + "=" * 75)
    print("🔊 BUILDING 4-CLASS TRAINING DATASET...")
    X_train_full, y_train_full, train_paths = build_multiclass_dataset(DATASET_TRAIN_DIR)
    
    # Validation split
    X_train, X_val, y_train, y_val = train_test_split(
        X_train_full, y_train_full, 
        test_size=0.15, 
        random_state=42, 
        stratify=y_train_full
    )
    
    print("\n" + "=" * 75)
    print("🤖 TRAINING MULTI-CLASS MODELS...")
    
    # Models for multi-class classification
    models = {
        "LogisticRegression": Pipeline([
            ("scaler", StandardScaler()),
            ("clf", LogisticRegression(max_iter=2000, C=1.0, random_state=42))
        ]),
        
        "RandomForest": Pipeline([
            ("scaler", StandardScaler()),
            ("clf", RandomForestClassifier(n_estimators=100, random_state=42))
        ]),
        
        "MLP": Pipeline([
            ("scaler", StandardScaler()),
            ("clf", MLPClassifier(hidden_layer_sizes=(256, 128), max_iter=1000, random_state=42))
        ]),
        
        "LinearSVM": Pipeline([
            ("scaler", StandardScaler()),
            ("clf", LinearSVC(C=1.0, random_state=42, max_iter=3000))
        ])
    }
    
    # Train and validate all models
    results = {}
    
    for name, model in models.items():
        print(f"\n🔄 Training {name}...")
        model.fit(X_train, y_train)
        y_pred = model.predict(X_val)
        results[name] = evaluate_multiclass(y_val, y_pred, f"Validation - {name}")
    
    # Select best model
    best_name = max(results, key=lambda x: results[x]["f1"])
    best_model = models[best_name]
    
    print(f"\n🏆 BEST MODEL: {best_name}")
    print(f"   Validation Macro F1: {results[best_name]['f1']:.4f}")
    
    # Retrain on full training set
    print(f"\n🔄 Retraining {best_name} on full training set...")
    best_model.fit(X_train_full, y_train_full)
    
    # Final test evaluation
    print("\n" + "=" * 75)
    print("🧪 FINAL 4-CLASS TEST EVALUATION...")
    X_test, y_test, test_paths = build_multiclass_dataset(DATASET_TEST_DIR)
    
    y_test_pred = best_model.predict(X_test)
    final_results = evaluate_multiclass(y_test, y_test_pred, "🎯 FINAL 4-CLASS RESULTS")
    
    # Save model and config
    model_filename = f"stage2_multiclass_{best_name.lower()}.joblib"
    model_path = os.path.join(MODEL_DIR, model_filename)
    dump(best_model, model_path)
    
    config = {
        "model_type": best_name,
        "model_path": model_path,
        "feature_params": {
            "sr": SR,
            "n_mfcc": N_MFCC,
            "n_fft": N_FFT,
            "win_length": WIN_LENGTH,
            "hop_length": HOP_LENGTH,
            "use_deltas": USE_DELTAS,
            "pool_stats": POOL_STATS
        },
        "results": final_results,
        "class_names": CLASS_NAMES,
        "class_mapping": CLASS_MAPPING,
        "dataset_info": {
            "train_files": len(X_train_full),
            "test_files": len(X_test),
            "total_files": len(X_train_full) + len(X_test)
        }
    }
    
    config_path = os.path.join(MODEL_DIR, "stage2_config.json")
    with open(config_path, "w") as f:
        json.dump(config, f, indent=2)
    
    print("\n" + "=" * 75)
    print("✅ STAGE 2 COMPLETE!")
    print(f"📁 Model saved: {model_path}")
    print(f"📁 Config saved: {config_path}")
    print(f"🎯 4-Class Accuracy: {final_results['accuracy']:.1%}")
    print(f"📊 Dataset Size: {len(X_train_full)} train + {len(X_test)} test = {len(X_train_full) + len(X_test)} total files")
    print(f"🎯 Your SoundGuard can now identify specific threat types!")

# Inference function for specific threat identification
def predict_threat_type(wav_path, model_path, config_path):
    """Predict specific threat type for a WAV file"""
    model = load(model_path)
    with open(config_path) as f:
        config = json.load(f)
    
    features = extract_mfcc_features(wav_path)
    if features is None:
        return None, None
    
    features = features.reshape(1, -1)
    prediction = model.predict(features)[0]
    
    # Get probability if available
    probabilities = None
    try:
        if hasattr(model.named_steps['clf'], 'predict_proba'):
            probabilities = model.named_steps['clf'].predict_proba(
                model.named_steps['scaler'].transform(features)
            )[0]
    except:
        pass
    
    predicted_class = config['class_names'][prediction]
    return predicted_class, probabilities

# Example usage:
# threat_type, probs = predict_threat_type("test.wav", "models_stage2/stage2_multiclass_logisticregression.joblib", "models_stage2/stage2_config.json")
# print(f"Predicted threat: {threat_type}")
# if probs is not None:
#     for i, class_name in enumerate(CLASS_NAMES):
#         print(f"  {class_name}: {probs[i]:.3f}")


🎯 STAGE 2: MULTI-CLASS THREAT CLASSIFICATION (FIXED VERSION)
📁 DATASET STRUCTURE CHECK:

TRAIN Directory: C:\Users\Jaiganesh\SoundGaurd\data\processed_data\train
  ✅ non_threat: 2100 files
  ✅ threat/glass_break: 700 files
  ✅ threat/scream: 700 files
  ✅ threat/gunshot: 700 files
  📊 TRAIN Total: 4200 files

TEST Directory: C:\Users\Jaiganesh\SoundGaurd\data\processed_data\test
  ✅ non_threat: 900 files
  ✅ threat/glass_break: 300 files
  ✅ threat/scream: 300 files
  ✅ threat/gunshot: 300 files
  📊 TEST Total: 1800 files

📊 DATASET SUMMARY:
   Train: 4200 files
   Test: 1800 files
   Grand Total: 6000 files

🔊 BUILDING 4-CLASS TRAINING DATASET...
[INFO] Found 2100 non-threat files
[INFO] Found 700 glass_break files
[INFO] Found 700 scream files
[INFO] Found 700 gunshot files
[INFO] Summary - Threats: 2100, Non-threats: 2100
[INFO] Total files: 4200
[INFO] Processing 4200 files...
[INFO] Processed 500/4200 files...
[INFO] Processed 1000/4200 files...
[INFO] Processed 1500/4200 files...