<a href="https://colab.research.google.com/github/AdityaVerma126/copd-asthma-classifier/blob/main/COPD_ASTHMA_CLASSIFIER_CODE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# üìò Project Title: *Lung Sound Classification with Deep Learning*

---

## üìå Objective

## üîß Setup & Installations

## üìÇ Data Loading and Preprocessing

## üéõÔ∏è Feature Extraction: MFCC

## üß™ Data Augmentation

## üèóÔ∏è Model Building: GRU

## ‚öôÔ∏è Training

## üìà Evaluation

## üìä Visualize Performance

## ‚úÖ Conclusion


## üìö Import Required Libraries


In [None]:
# =============================================
# üìö IMPORTING REQUIRED LIBRARIES
# =============================================

# Data manipulation
import pandas as pd
import numpy as np

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Audio processing
import librosa
import librosa.display

# File operations
import os
import shutil
from os import listdir
from os.path import isfile, join

# Model utilities
from sklearn.model_selection import train_test_split


In [None]:
!pip install openpyxl



## üì• Downloading Dataset from KaggleHub

We will use `kagglehub` to download the **lung sound dataset** hosted on Kaggle. This will help in streamlining data access and keeping the notebook modular.


In [None]:
# ============================================
# üì• DOWNLOADING DATASET FROM KAGGLEHUB
# ============================================

import kagglehub

# Download the latest version of the lung dataset
path = kagglehub.dataset_download("arashnic/lung-dataset")

print("‚úÖ Path to dataset files:", path)


Using Colab cache for faster access to the 'lung-dataset' dataset.
‚úÖ Path to dataset files: /kaggle/input/lung-dataset


## üìä Exploratory Data Analysis (EDA): Lung Disease Distribution

We begin by analyzing the distribution of different lung disease diagnoses based on the annotated dataset. Plotly is used for interactive visualization.


In [None]:
# ====================================================
# üìä LOAD & VISUALIZE LUNG DISEASE DISTRIBUTION
# ====================================================

import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

# üîπ Load diagnosis annotation file
diagnosis_data = pd.read_excel(
    r"/kaggle/input/lung-dataset/Data annotation.xlsx",
    usecols='B, E',
    names=['Sex', 'Disease']
)

# üîπ Preview data
diagnosis_data.head(4)

# üîπ Print disease label counts
print("üß¨ Disease Counts:\n")
print(diagnosis_data.Disease.value_counts())

# üîπ Count frequency of diseases
disease_counts = diagnosis_data.Disease.value_counts()

# üîπ Create an interactive bar chart
fig = px.bar(
    x=disease_counts.index,
    y=disease_counts.values,
    title='ü¶† Distribution of Lung Disease Diagnoses',
    labels={'x': 'Disease Type', 'y': 'Count'},
    color=disease_counts.values,
    color_continuous_scale='viridis'
)

# üîπ Customize layout
fig.update_layout(
    xaxis_tickangle=-90,
    showlegend=False,
    height=500,
    font=dict(size=12)
)

# üîπ Enhance hover tooltips
fig.update_traces(
    hovertemplate='<b>%{x}</b><br>Count: %{y}<extra></extra>'
)

# üîπ Show plot
fig.show()


üß¨ Disease Counts:

Disease
N                                 35
Asthma                            17
asthma                            15
heart failure                     15
COPD                               8
pneumonia                          5
Lung Fibrosis                      4
BRON                               3
Heart Failure                      3
Plueral Effusion                   2
Heart Failure + COPD               2
Heart Failure + Lung Fibrosis      1
Asthma and lung fibrosis           1
copd                               1
Name: count, dtype: int64


## üß† Custom Audio File Handler Class

This section defines a robust `AudioFileHandler` class that performs:

- üîç Recursive file discovery  
- üß™ Validation based on size and duration  
- üìä Label-based file organization  
- üìù Manifest creation  
- üìà Audio statistics summary  
- üóÇÔ∏è Compatibility functions for both simple and advanced use cases


In [None]:
import os
from os import listdir
from os.path import isfile, join, exists, getsize
from pathlib import Path
import pandas as pd
from collections import Counter, defaultdict
import logging
from typing import List, Dict, Tuple, Optional
import glob
import librosa
from tqdm import tqdm
import warnings

# Configure logging
logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s %(message)s',
                   datefmt='%H:%M:%S')
logger = logging.getLogger(__name__)

class AudioFileHandler:
    """Enhanced audio file handling with validation, organization, and analysis."""

    def __init__(self, base_path: str, supported_formats: List[str] = None):
        """
        Initialize the AudioFileHandler.

        Args:
            base_path: Base directory path containing audio files
            supported_formats: List of supported audio formats (default: common formats)
        """
        self.base_path = Path(base_path)
        self.supported_formats = supported_formats or ['.wav', '.mp3', '.flac', '.m4a', '.aac', '.ogg']
        self.file_info = {}
        self.stats = {
            'total_files': 0,
            'valid_files': 0,
            'invalid_files': 0,
            'by_format': Counter(),
            'by_label': Counter(),
            'file_sizes': [],
            'durations': []
        }

        logger.info(f"üîß Initialized AudioFileHandler")
        logger.info(f"   üìÅ Base path: {self.base_path}")
        logger.info(f"   üéµ Supported formats: {self.supported_formats}")

    def discover_audio_files(self,
                           recursive: bool = True,
                           validate_files: bool = True,
                           min_size_kb: float = 1.0,
                           max_size_mb: float = 100.0) -> List[str]:
        """
        Discover and validate audio files in the directory.

        Args:
            recursive: Search subdirectories recursively
            validate_files: Validate file integrity
            min_size_kb: Minimum file size in KB
            max_size_mb: Maximum file size in MB

        Returns:
            List of valid audio file paths
        """
        logger.info(f"üîç Discovering audio files...")

        # Check if base path exists
        if not self.base_path.exists():
            logger.error(f"‚ùå Base path does not exist: {self.base_path}")
            return []

        # Find all audio files
        all_files = []

        if recursive:
            # Recursive search
            for format_ext in self.supported_formats:
                pattern = f"**/*{format_ext}"
                files = list(self.base_path.glob(pattern))
                all_files.extend([str(f) for f in files])
        else:
            # Non-recursive search (original approach)
            try:
                filenames = [f for f in listdir(str(self.base_path))
                           if isfile(join(str(self.base_path), f))]

                # Filter by supported formats
                audio_files = [f for f in filenames
                             if any(f.lower().endswith(ext) for ext in self.supported_formats)]

                all_files = [join(str(self.base_path), f) for f in audio_files]

            except Exception as e:
                logger.error(f"‚ùå Error reading directory: {e}")
                return []

        logger.info(f"üìÅ Found {len(all_files)} potential audio files")

        if not all_files:
            logger.warning("‚ö†Ô∏è No audio files found!")
            return []

        # Validate and filter files
        valid_files = []

        if validate_files:
            logger.info("üîç Validating audio files...")
            valid_files = self._validate_audio_files(all_files, min_size_kb, max_size_mb)
        else:
            # Basic size filtering without audio validation
            for filepath in all_files:
                if self._check_file_size(filepath, min_size_kb, max_size_mb):
                    valid_files.append(filepath)

        # Update statistics
        self.stats['total_files'] = len(all_files)
        self.stats['valid_files'] = len(valid_files)
        self.stats['invalid_files'] = len(all_files) - len(valid_files)

        self._analyze_files(valid_files)
        self._print_discovery_summary()

        return sorted(valid_files)

    def _validate_audio_files(self, filepaths: List[str],
                            min_size_kb: float, max_size_mb: float) -> List[str]:
        """Validate audio files for integrity and properties."""
        valid_files = []

        for filepath in tqdm(filepaths, desc="Validating files"):
            try:
                # Check file size
                if not self._check_file_size(filepath, min_size_kb, max_size_mb):
                    continue

                # Try to load audio metadata (quick check)
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore")
                    duration = librosa.get_duration(path=filepath)

                if duration > 0:
                    self.file_info[filepath] = {
                        'size_mb': getsize(filepath) / (1024 * 1024),
                        'duration': duration,
                        'format': Path(filepath).suffix.lower()
                    }
                    valid_files.append(filepath)
                    self.stats['durations'].append(duration)

            except Exception as e:
                logger.debug(f"‚ùå Invalid audio file {os.path.basename(filepath)}: {e}")
                continue

        return valid_files

    def _check_file_size(self, filepath: str, min_size_kb: float, max_size_mb: float) -> bool:
        """Check if file size is within acceptable range."""
        try:
            size_bytes = getsize(filepath)
            size_kb = size_bytes / 1024
            size_mb = size_bytes / (1024 * 1024)

            self.stats['file_sizes'].append(size_mb)

            return min_size_kb <= size_kb <= (max_size_mb * 1024)

        except Exception:
            return False

    def _analyze_files(self, filepaths: List[str]):
        """Analyze file distribution by format and label."""
        for filepath in filepaths:
            # Count by format
            format_ext = Path(filepath).suffix.lower()
            self.stats['by_format'][format_ext] += 1

            # Count by label (extracted from filename)
            label = self._extract_label_from_filename(filepath)
            self.stats['by_label'][label] += 1

    def _extract_label_from_filename(self, filepath: str) -> str:
        """Extract label from filename or directory structure."""
        filename = os.path.basename(filepath).lower()
        parent_dir = os.path.basename(os.path.dirname(filepath)).lower()

        # Define label patterns
        label_patterns = {
            'Asthma': ['asthma'],
            'COPD': ['copd','COPD'],
           ######### 'Pneumonia': ['pneumonia'],
            'Healthy': ['healthy', 'normal', 'n', 'control']
        }

        # Check filename first
        for label, patterns in label_patterns.items():
            if any(pattern in filename for pattern in patterns):
                return label

        # Check parent directory
        for label, patterns in label_patterns.items():
            if any(pattern in parent_dir for pattern in patterns):
                return label

        return 'Unknown'

    def organize_by_labels(self, filepaths: List[str]) -> Dict[str, List[str]]:
        """Organize files by their extracted labels."""
        organized = defaultdict(list)

        for filepath in filepaths:
            label = self._extract_label_from_filename(filepath)
            organized[label].append(filepath)

        logger.info("üìä Files organized by labels:")
        for label, files in organized.items():
            logger.info(f"   {label}: {len(files)} files")

        return dict(organized)

    def get_file_statistics(self) -> Dict:
        """Get comprehensive file statistics."""
        stats = self.stats.copy()

        if self.stats['file_sizes']:
            stats['size_stats'] = {
                'mean_mb': sum(self.stats['file_sizes']) / len(self.stats['file_sizes']),
                'min_mb': min(self.stats['file_sizes']),
                'max_mb': max(self.stats['file_sizes']),
                'total_mb': sum(self.stats['file_sizes'])
            }

        if self.stats['durations']:
            stats['duration_stats'] = {
                'mean_sec': sum(self.stats['durations']) / len(self.stats['durations']),
                'min_sec': min(self.stats['durations']),
                'max_sec': max(self.stats['durations']),
                'total_sec': sum(self.stats['durations'])
            }

        return stats

    def create_file_manifest(self, filepaths: List[str], save_path: Optional[str] = None) -> pd.DataFrame:
        """Create a detailed manifest of all audio files."""
        manifest_data = []

        for filepath in filepaths:
            info = self.file_info.get(filepath, {})

            manifest_data.append({
                'filepath': filepath,
                'filename': os.path.basename(filepath),
                'directory': os.path.dirname(filepath),
                'label': self._extract_label_from_filename(filepath),
                'format': info.get('format', Path(filepath).suffix.lower()),
                'size_mb': info.get('size_mb', getsize(filepath) / (1024 * 1024)),
                'duration_sec': info.get('duration', None),
                'exists': os.path.exists(filepath)
            })

        df = pd.DataFrame(manifest_data)

        if save_path:
            df.to_csv(save_path, index=False)
            logger.info(f"üíæ File manifest saved to: {save_path}")

        return df

    def _print_discovery_summary(self):
        """Print file discovery summary."""
        logger.info("üìä File Discovery Summary:")
        logger.info(f"   üìÅ Total files found: {self.stats['total_files']}")
        logger.info(f"   ‚úÖ Valid files: {self.stats['valid_files']}")
        logger.info(f"   ‚ùå Invalid files: {self.stats['invalid_files']}")

        if self.stats['by_format']:
            logger.info("   üìÑ By format:")
            for format_ext, count in self.stats['by_format'].most_common():
                logger.info(f"      {format_ext}: {count} files")

        if self.stats['by_label']:
            logger.info("   üè∑Ô∏è By label:")
            for label, count in self.stats['by_label'].most_common():
                logger.info(f"      {label}: {count} files")

        if self.stats['file_sizes']:
            avg_size = sum(self.stats['file_sizes']) / len(self.stats['file_sizes'])
            total_size = sum(self.stats['file_sizes'])
            logger.info(f"   üíæ Average file size: {avg_size:.2f} MB")
            logger.info(f"   üíæ Total dataset size: {total_size:.2f} MB")

        if self.stats['durations']:
            avg_duration = sum(self.stats['durations']) / len(self.stats['durations'])
            total_duration = sum(self.stats['durations'])
            logger.info(f"   ‚è±Ô∏è Average duration: {avg_duration:.2f} seconds")
            logger.info(f"   ‚è±Ô∏è Total duration: {total_duration/60:.1f} minutes")

# Enhanced file discovery functions
def discover_audio_files(base_path: str,
                        recursive: bool = True,
                        validate_files: bool = True,
                        supported_formats: List[str] = None,
                        min_size_kb: float = 1.0,
                        max_size_mb: float = 100.0,
                        create_manifest: bool = False,
                        manifest_path: str = "audio_manifest.csv") -> List[str]:
    """
    Enhanced audio file discovery function.

    Args:
        base_path: Directory path containing audio files
        recursive: Search subdirectories recursively
        validate_files: Validate audio file integrity
        supported_formats: List of supported audio formats
        min_size_kb: Minimum file size in KB
        max_size_mb: Maximum file size in MB
        create_manifest: Create CSV manifest of files
        manifest_path: Path to save manifest CSV

    Returns:
        List of valid audio file paths
    """
    handler = AudioFileHandler(base_path, supported_formats)

    filepaths = handler.discover_audio_files(
        recursive=recursive,
        validate_files=validate_files,
        min_size_kb=min_size_kb,
        max_size_mb=max_size_mb
    )

    if create_manifest and filepaths:
        handler.create_file_manifest(filepaths, manifest_path)

    return filepaths

def get_organized_file_paths(base_path: str, **kwargs) -> Tuple[List[str], Dict[str, List[str]]]:
    """
    Get file paths organized by labels.

    Returns:
        Tuple of (all_filepaths, organized_by_label_dict)
    """
    handler = AudioFileHandler(base_path)
    filepaths = handler.discover_audio_files(**kwargs)
    organized = handler.organize_by_labels(filepaths)

    return filepaths, organized

# Simple function that maintains your original interface
def get_audio_filepaths(mypath: str,
                       recursive: bool = False,
                       validate: bool = True,
                       show_info: bool = True) -> List[str]:
    """
    Simple function that maintains compatibility with your original code.

    Args:
        mypath: Path to audio files directory
        recursive: Search subdirectories
        validate: Validate audio files
        show_info: Print file information

    Returns:
        List of audio file paths
    """
    handler = AudioFileHandler(mypath)

    filepaths = handler.discover_audio_files(
        recursive=recursive,
        validate_files=validate
    )

    if show_info and filepaths:
        print(f"\nüìÅ Found {len(filepaths)} audio files:")

        # Show first few files
        for i, filepath in enumerate(filepaths[:5]):
            print(f"   {i+1}. {os.path.basename(filepath)}")

        if len(filepaths) > 5:
            print(f"   ... and {len(filepaths) - 5} more files")

        # Show file distribution
        organized = handler.organize_by_labels(filepaths)
        print(f"\nüè∑Ô∏è Distribution by labels:")
        for label, files in organized.items():
            print(f"   {label}: {len(files)} files")

    return filepaths

# Example usage with your original structure improved
if __name__ == "__main__":
    # Your original approach - enhanced
    mypath = "/kaggle/input/lung-dataset/Audio Files"

    print("üîÑ Original approach (enhanced):")
    filepaths = get_audio_filepaths(mypath, validate=True, show_info=True)

    # Advanced usage
    print("\nüöÄ Advanced approach:")
    filepaths_advanced = discover_audio_files(
        base_path=mypath,
        recursive=True,          # Search subdirectories
        validate_files=True,     # Validate audio integrity
        min_size_kb=1.0,        # Minimum 1KB files
        max_size_mb=50.0,       # Maximum 50MB files
        create_manifest=True,    # Create CSV manifest
        manifest_path="lung_audio_manifest.csv"
    )

    # Organized by labels
    print("\nüìä Organized by labels:")
    all_files, organized_files = get_organized_file_paths(mypath)

    for label, files in organized_files.items():
        print(f"{label}: {len(files)} files")
        if files:
            print(f"   Sample: {os.path.basename(files[0])}")

üîÑ Original approach (enhanced):


Validating files:   0%|          | 0/336 [00:00<?, ?it/s]

## üõ† Audio Preprocessing: Frame Segmentation & Baseline Wander Removal

In this section, we:

- ‚öôÔ∏è Configure core parameters (frame length, overlap, sampling rate)  
- ‚úÇÔ∏è Segment raw lung sound signals into overlapping frames  
- üåÄ Remove low‚Äëfrequency baseline wander via DFT filtering  
- üöÄ Provide a parallelized preprocessing pipeline for entire datasets  


In [None]:
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê
# ü´Å LUNG SOUND PREPROCESSING PIPELINE
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê
# Advanced preprocessing system for lung sound analysis with frame segmentation
# and baseline wander removal using Discrete Fourier Transform (DFT)
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê

import numpy as np
import librosa
import os
from tqdm import tqdm
from typing import Dict, Tuple, List
import logging
from multiprocessing import Pool, cpu_count
import warnings

# Configure logging for monitoring preprocessing progress
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Suppress librosa warnings to keep output clean
warnings.filterwarnings("ignore", category=UserWarning)

# ===========================================
# üîπ CONFIGURATION PARAMETERS
# ===========================================
class Config:
    FRAME_LENGTH = 20  # seconds - Duration of each audio frame for analysis
    OVERLAP_RATIO = 0.4  # 40% overlap between consecutive frames
    SAMPLING_RATE = 22050  # Hz - Standard sampling rate for audio processing
    BASELINE_CUTOFF = 1  # Hz - Frequency threshold for baseline wander removal
    N_PROCESSES = max(1, cpu_count() - 1)  # Leave one core free for system stability


# ===========================================
# üîπ FRAME SEGMENTATION FUNCTION
# ===========================================
def segment_signal(
    signal: np.ndarray,
    sr: int,
    frame_length: float = Config.FRAME_LENGTH,
    overlap_ratio: float = Config.OVERLAP_RATIO
) -> np.ndarray:
    """
    Splits lung sound signal into frames of specified length with given overlap.

    Parameters:
        signal: Raw lung sound signal (1D NumPy array)
        sr: Sampling rate of the signal (in Hz)
        frame_length: Frame duration in seconds (default=20 sec)
        overlap_ratio: Percentage of overlap between frames (default=0.4)

    Returns:
        np.ndarray: Array of framed signals (n_frames, frame_size)

    Raises:
        ValueError: If input parameters are invalid
    """
    # Input validation - ensure all parameters are valid
    if len(signal) == 0:
        raise ValueError("Input signal is empty")
    if sr <= 0:
        raise ValueError("Sampling rate must be positive")
    if frame_length <= 0:
        raise ValueError("Frame length must be positive")
    if not 0 <= overlap_ratio < 1:
        raise ValueError("Overlap ratio must be in [0, 1)")

    # Convert frame duration from seconds to number of samples
    frame_size = int(sr * frame_length)  # Convert seconds to samples

    # Handle edge case where signal is shorter than desired frame length
    if frame_size > len(signal):
        logger.warning(f"Signal length ({len(signal)/sr:.2f}s) is shorter than frame length ({frame_length}s)")
        return np.array([signal])  # Return entire signal as single frame

    # Calculate step size based on overlap ratio (how much to advance for next frame)
    step_size = int(frame_size * (1 - overlap_ratio))  # Step between frames

    # Calculate total number of frames that can be extracted
    n_frames = max(1, (len(signal) - frame_size) // step_size + 1)

    # Vectorized frame creation using numpy's sliding window view for efficiency
    starts = np.arange(0, (n_frames - 1) * step_size + 1, step_size)
    frames = np.lib.stride_tricks.sliding_window_view(signal, frame_size)[starts]

    return frames


# ===========================================
# üîπ BASELINE WANDER REMOVAL USING DFT
# ===========================================
def remove_baseline_wander(
    signal: np.ndarray,
    sr: int,
    cutoff_freq: float = Config.BASELINE_CUTOFF
) -> np.ndarray:
    """
    Removes baseline wander noise (0-1 Hz) using Discrete Fourier Transform (DFT).

    Parameters:
        signal: Input frame signal (1D array)
        sr: Sampling rate (in Hz)
        cutoff_freq: Frequency below which DFT coefficients are removed (default=1 Hz)

    Returns:
        np.ndarray: Filtered signal after inverse DFT

    Raises:
        ValueError: If input parameters are invalid
    """
    # Input validation to ensure valid parameters
    if len(signal) == 0:
        raise ValueError("Input signal is empty")
    if sr <= 0:
        raise ValueError("Sampling rate must be positive")
    if cutoff_freq <= 0:
        raise ValueError("Cutoff frequency must be positive")

    # DFT processing: transform signal to frequency domain
    M = len(signal)  # DFT length
    freqs = np.fft.fftfreq(M, d=1/sr)  # Frequency bins corresponding to DFT coefficients
    dft_coeffs = np.fft.fft(signal)  # Compute forward DFT

    # Find indices of low-frequency components and set them to zero
    # This removes baseline wander by eliminating frequencies below cutoff
    dft_coeffs[np.abs(freqs) < cutoff_freq] = 0

    # Reconstruct signal using inverse DFT (take real part to avoid complex numbers)
    filtered_signal = np.fft.ifft(dft_coeffs).real

    return filtered_signal


# ===========================================
# üîπ PROCESS SINGLE FILE (FOR PARALLEL PROCESSING)
# ===========================================
def _process_single_file(args: Tuple[str, str]) -> Tuple[str, np.ndarray]:
    """
    Helper function for parallel processing of a single file.

    Parameters:
        args: Tuple of (filename, file_path)

    Returns:
        Tuple of (filename, processed_frames)
    """
    filename, file_path = args
    try:
        # Load the lung sound signal using librosa with specified sampling rate
        signal, sr = librosa.load(file_path, sr=Config.SAMPLING_RATE)

        # Segment signal into overlapping frames for analysis
        frames = segment_signal(signal, sr)

        # Apply baseline wander removal to each frame using DFT filtering
        processed_frames = np.array([remove_baseline_wander(frame, sr) for frame in frames])

        return filename, processed_frames
    except Exception as e:
        # Log any errors that occur during processing
        logger.error(f"Error processing {filename}: {str(e)}")
        return filename, None


# ===========================================
# üîπ PREPROCESSING FUNCTION FOR DATASET
# ===========================================
def preprocess_dataset(
    folder_path: str,
    parallel: bool = True
) -> Tuple[Dict[str, np.ndarray], int]:
    """
    Processes all .wav files in a folder:
    - Segments each audio file into frames
    - Removes baseline wander using DFT filtering

    Parameters:
        folder_path: Path to dataset folder containing .wav files
        parallel: Whether to use parallel processing (default=True)

    Returns:
        Tuple of (processed_data, sampling_rate) where:
        - processed_data: Dictionary with file names as keys and processed frames as values
        - sampling_rate: The sampling rate used for processing

    Raises:
        FileNotFoundError: If folder_path doesn't exist
    """
    # Check if the dataset folder exists
    if not os.path.exists(folder_path):
        raise FileNotFoundError(f"Dataset folder not found: {folder_path}")

    processed_data = {}  # Store preprocessed frames

    # Get all .wav files in the dataset folder
    wav_files = [f for f in os.listdir(folder_path) if f.endswith(".wav")]
    if not wav_files:
        logger.warning(f"No .wav files found in {folder_path}")
        return processed_data, Config.SAMPLING_RATE

    logger.info(f"üîç Found {len(wav_files)} .wav files. Processing...")

    # Prepare arguments for parallel processing (filename, full_path pairs)
    file_args = [(f, os.path.join(folder_path, f)) for f in wav_files]

    # Choose processing method based on parallel flag and number of files
    if parallel and len(wav_files) > 1:
        # Parallel processing using multiprocessing Pool
        with Pool(processes=min(Config.N_PROCESSES, len(wav_files))) as pool:
            results = list(tqdm(
                pool.imap(_process_single_file, file_args),
                total=len(wav_files),
                desc="Processing Files"
            ))
    else:
        # Sequential processing (useful for debugging or small datasets)
        results = []
        for args in tqdm(file_args, desc="Processing Files"):
            results.append(_process_single_file(args))

    # Collect results from processing and store successful ones
    for filename, frames in results:
        if frames is not None:
            processed_data[filename] = frames

    logger.info("‚úÖ Preprocessing completed!")
    return processed_data, Config.SAMPLING_RATE


# ===========================================
# ‚úÖ MAIN EXECUTION
# ===========================================
if __name__ == "__main__":
    try:
        # Set the path to your lung sound dataset
        dataset_path = r"/kaggle/input/lung-dataset/Audio Files"  # Change to your dataset path

        # Execute the preprocessing pipeline
        preprocessed_data, sampling_rate = preprocess_dataset(dataset_path)

        # Display results and statistics if processing was successful
        if preprocessed_data:
            # Example: Access preprocessed frames of a specific file
            example_filename = next(iter(preprocessed_data))
            example_frames = preprocessed_data[example_filename]

            # Log processing results and frame statistics
            logger.info(
                f"üîπ Processed frames for {example_filename}: "
                f"{example_frames.shape} (n_frames={example_frames.shape[0]}, "
                f"frame_size={example_frames.shape[1]})"
            )
            logger.info(f"üîπ Sampling rate: {sampling_rate} Hz")
        else:
            logger.warning("No valid data was processed")
    except Exception as e:
        logger.error(f"Error during preprocessing: {str(e)}")

# üéß Lung Sound Augmentation & Balancing Module
A complete Python module for audio data augmentation and class balancing, designed specifically for respiratory sound classification tasks. This module includes multiple augmentation strategies like pitch shifting, noise injection, reverb, and more ‚Äî all wrapped in a clean, reusable architecture.

---

## üß© Features Implemented:
- üîº Pitch Shift
- üå´Ô∏è Gaussian Noise
- üì¢ Volume Scaling
- ‚úÇÔ∏è Crop & Pad
- üì° Sine Interference
- üîä Simple Reverb
- ‚öñÔ∏è Dataset Balancing by Class

--

In [None]:
# ============================================================
# üéõÔ∏è DATA AUGMENTATION MODULE: NOISE, PITCH, REVERB & BALANCING
# ============================================================

import numpy as np
import librosa
import random
from datetime import datetime
from typing import Tuple, Dict, List

# ===========================================
# üìù LOGGING FUNCTION
# ===========================================
def log(msg):
    print(f"[{datetime.now().strftime('%H:%M:%S')}] {msg}")


# ===========================================
# üé® DATA AUGMENTATION FUNCTIONS
# ===========================================
def pitch_shift(signal: np.ndarray, sr: int, steps: int) -> np.ndarray:
    """Shift pitch by n_steps semitones."""
    try:
        if steps == 0:
            return signal
        shifted = librosa.effects.pitch_shift(signal, sr=sr, n_steps=steps)
        return librosa.util.fix_length(shifted, size=len(signal))
    except Exception as e:
        log(f"‚ö†Ô∏è Pitch shift error: {e}")
        return signal

def add_gaussian_noise(signal: np.ndarray, noise_level: float) -> np.ndarray:
    """Add Gaussian noise to the signal."""
    noise = np.random.normal(0, noise_level, signal.shape)
    return signal + noise

def random_volume_scaling(signal: np.ndarray, min_gain: float = 0.8, max_gain: float = 1.2) -> np.ndarray:
    """Scale volume randomly within given gain range."""
    try:
        gain = np.random.uniform(min_gain, max_gain)
        scaled = signal * gain
        max_val = np.max(np.abs(scaled))
        if max_val > 1.0:
            scaled = scaled / max_val * 0.95
        return scaled.astype(np.float32)
    except Exception as e:
        log(f"‚ö†Ô∏è Volume scaling error: {e}")
        return signal.copy()

def random_crop_and_pad(signal: np.ndarray, crop_ratio: float = 0.9) -> np.ndarray:
    """Randomly crop and pad the signal to original length."""
    try:
        if len(signal) < 10:
            return signal.copy()
        crop_len = int(len(signal) * crop_ratio)
        start = np.random.randint(0, len(signal) - crop_len + 1)
        cropped = signal[start:start + crop_len]
        padded = librosa.util.fix_length(cropped, size=len(signal))
        return padded.astype(np.float32)
    except Exception as e:
        log(f"‚ö†Ô∏è Crop-pad error: {e}")
        return signal.copy()

def add_sine_interference(signal: np.ndarray, sr: int, freq_range: Tuple[int, int] = (50, 300)) -> np.ndarray:
    """Overlay a faint sine wave at random frequency."""
    try:
        freq = np.random.uniform(*freq_range)
        amplitude = np.random.uniform(0.001, 0.01)
        t = np.arange(len(signal)) / sr
        sine_wave = amplitude * np.sin(2 * np.pi * freq * t)
        return (signal + sine_wave).astype(np.float32)
    except Exception as e:
        log(f"‚ö†Ô∏è Sine interference error: {e}")
        return signal.copy()

def add_simple_reverb(signal: np.ndarray, sr: int, room_size: float = 0.3, damping: float = 0.5) -> np.ndarray:
    """Add a simple delay-based reverb effect."""
    try:
        delay_samples = int(sr * room_size * 0.05)
        if delay_samples >= len(signal) or delay_samples == 0:
            return signal.copy()
        delayed = np.zeros_like(signal)
        delayed[delay_samples:] = signal[:-delay_samples] * (1 - damping)
        reverb_signal = signal + delayed * 0.2
        max_val = np.max(np.abs(reverb_signal))
        if max_val > 1.0:
            reverb_signal = reverb_signal / max_val * 0.95
        return reverb_signal.astype(np.float32)
    except Exception as e:
        log(f"‚ö†Ô∏è Reverb error: {e}")
        return signal.copy()


# ===========================================
# üß© AUGMENTOR CLASS
# ===========================================
class AudioAugmentor:
    """Combine multiple augmentation transforms."""
    def __init__(self, sr: int):
        self.sr = sr

    def augment(self, signal: np.ndarray) -> List[np.ndarray]:
        return [
            pitch_shift(signal, self.sr, steps=random.randint(-3, 3)),
            add_gaussian_noise(signal, noise_level=random.uniform(0.002, 0.01)),
            random_volume_scaling(signal, min_gain=0.85, max_gain=1.15),
            random_crop_and_pad(signal, crop_ratio=random.uniform(0.85, 0.95)),
            add_sine_interference(signal, self.sr, freq_range=(60, 250)),
            add_simple_reverb(signal, self.sr, room_size=random.uniform(0.2, 0.4), damping=0.5)
        ]


# ===========================================
# ‚öñÔ∏è BALANCING & AUGMENTING DATASET
# ===========================================
def balance_and_augment_dataset(
    preprocessed_data: Dict[str, np.ndarray],
    sr: int,
    target_counts: Dict[str, int],
    augment: bool = True
) -> Dict[str, np.ndarray]:
    """
    Balance classes by label and optionally augment to reach target counts.
    """
    class_frame_map: Dict[str, List[np.ndarray]] = {}
    result: Dict[str, np.ndarray] = {}
    augmentor = AudioAugmentor(sr)

    log(f"üîÑ Balancing and augmenting {len(preprocessed_data)} files...")

    # Group frames by label
    for filename, frames in preprocessed_data.items():
        label = filename.split("_")[0].lower()
        class_frame_map.setdefault(label, []).extend(frames)

    # Process each label
    for label, frames in class_frame_map.items():
        frames = [f for f in frames if isinstance(f, np.ndarray) and f.size > 0]
        total = len(frames)
        if total == 0:
            log(f"‚ö†Ô∏è No frames for label '{label}', skipping.")
            continue

        target = target_counts.get(label, total)
        random.shuffle(frames)

        # Select base frames
        selected = frames[:min(target, total)]
        augmented = []

        if augment:
            for frame in selected:
                augmented.extend(augmentor.augment(frame))
        else:
            augmented = selected.copy()

        # Trim or pad augmentation list to match target
        if len(augmented) > target:
            augmented = augmented[:target]
        elif len(augmented) < target:
            augmented.extend(random.choices(augmented, k=target - len(augmented)))

        result[label] = np.stack(augmented, dtype=np.float32)
        log(f"‚úÖ {label.capitalize()}: {total} ‚Üí {len(augmented)} samples")

    return result


# ===========================================
# üöÄ USAGE EXAMPLE
# ===========================================
if __name__ == "__main__":
    target_counts = {
        "asthma": 150,
        "copd":   135,
        "n":      150,
        # "pneumonia": 75
    }

    try:
        # Ensure `preprocessed_data` and `sampling_rate` are defined beforehand
        if 'preprocessed_data' not in globals() or 'sampling_rate' not in globals():
            raise NameError("Define `preprocessed_data` and `sampling_rate` before running.")

        balanced_augmented_data = balance_and_augment_dataset(
            preprocessed_data, sampling_rate, target_counts, augment=True
        )

        for label, data in balanced_augmented_data.items():
            log(f"üîπ {label} ‚Üí Final shape: {data.shape}")

    except NameError as ne:
        log(f"‚ùå Setup Error: {ne}")
    except Exception as e:
        log(f"‚ùå Runtime Error: {e}")


# üìà Lung Sound Signal Visualization (Waveform Plots)
This module visualizes preprocessed lung sound signals using **Plotly**, showcasing the first frame of each selected audio file. It's useful for:
- Verifying segmentation quality
- Comparing waveform characteristics across disease labels
- Visual inspection before feeding data into ML models

---

In [None]:

import plotly.graph_objects as go
from plotly.subplots import make_subplots
import random
import librosa
import numpy as np
import os

# --- Start of code to define preprocessed_data and sampling_rate ---
# This part is necessary to make the visualization code runnable independently
# assuming the preprocessing step from the preceding code has run.
# If running the preceding code block, this part can be removed.

# Placeholder data: Replace with actual data from your preprocessing step
# Example structure: {'filename1.wav': array_of_frames_1, 'filename2.wav': array_of_frames_2, ...}
preprocessed_data = {}
sampling_rate = Config.SAMPLING_RATE # Use the sampling rate from your config

# Assuming you have run the preprocessing step and have preprocessed_data defined
# Example of loading a few files for visualization if preprocessed_data is not available
# You would need to adjust the path and potentially the file names.
# This is just a placeholder/example.
mypath = "/kaggle/input/lung-dataset/Audio Files"
handler = AudioFileHandler(mypath)
all_filepaths = handler.discover_audio_files(recursive=True, validate_files=True)

# Select a few diverse files for visualization
selected_files_for_viz = {}
labels_to_find = ['copd', 'asthma', 'n', 'other']
found_labels = set()

for filepath in all_filepaths:
    label = handler._extract_label_from_filename(filepath).lower()
    if label in labels_to_find and label not in found_labels:
        try:
            # Load and preprocess the file manually for visualization purposes
            signal, sr = librosa.load(filepath, sr=Config.SAMPLING_RATE)
            frames = segment_signal(signal, sr)
            processed_frames = np.array([remove_baseline_wander(frame, sr) for frame in frames])
            if processed_frames.shape[0] > 0:
                 preprocessed_data[os.path.basename(filepath)] = processed_frames
                 selected_files_for_viz[label] = os.path.basename(filepath) # Store one file per label
                 found_labels.add(label)
                 if len(found_labels) == len(labels_to_find):
                     break # Stop once we have one file for each target label
        except Exception as e:
            print(f"Could not process {os.path.basename(filepath)} for visualization: {e}")

# If you don't have one sample per target label, add more files until you have at least 10 total
if len(preprocessed_data) < 10:
    random.shuffle(all_filepaths)
    files_added = len(preprocessed_data)
    for filepath in all_filepaths:
         if files_added >= 10:
             break
         filename = os.path.basename(filepath)
         if filename not in preprocessed_data:
             try:
                signal, sr = librosa.load(filepath, sr=Config.SAMPLING_RATE)
                frames = segment_signal(signal, sr)
                processed_frames = np.array([remove_baseline_wander(frame, sr) for frame in frames])
                if processed_frames.shape[0] > 0:
                    preprocessed_data[filename] = processed_frames
                    files_added += 1
             except Exception as e:
                 print(f"Could not process {os.path.basename(filepath)} for visualization: {e}")

# Ensure we have at least one file to visualize
if not preprocessed_data:
    print("Could not load any audio files for visualization.")
# --- End of code to define preprocessed_data and sampling_rate ---


# --- Visualization Code ---
if preprocessed_data:
    # Select up to 10 files for visualization
    viz_files = list(preprocessed_data.keys())[:10]
    num_plots = len(viz_files)

    if num_plots > 0:
        # Determine grid size for subplots
        rows = (num_plots + 1) // 2 if num_plots > 1 else 1
        cols = 2 if num_plots > 1 else 1

        fig = make_subplots(rows=rows, cols=cols,
                            subplot_titles=[f"{fname} ({handler._extract_label_from_filename(fname)})" for fname in viz_files],
                            shared_xaxes=True)

        for i, filename in enumerate(viz_files):
            frames = preprocessed_data[filename]
            # Use the first frame for plotting as an example
            signal_to_plot = frames[0] if frames.shape[0] > 0 else np.array([])

            if signal_to_plot.size > 0:
                time = np.linspace(0, len(signal_to_plot) / sampling_rate, len(signal_to_plot))

                # Add trace to subplot
                row_idx = i // cols + 1
                col_idx = i % cols + 1

                fig.add_trace(go.Scatter(x=time, y=signal_to_plot, mode='lines', name=filename),
                              row=row_idx, col=col_idx)

                # Update subplot titles and axis labels
                fig.update_xaxes(title_text="Time (s)", row=row_idx, col=col_idx)
                fig.update_yaxes(title_text="Amplitude", row=row_idx, col=col_idx)

        fig.update_layout(height=rows * 400, width=1000,
                          title_text="Sample Lung Sound Signals (First Frame)",
                          showlegend=False) # Hide individual trace legends

        fig.show()
    else:
        print("No valid audio frames available for visualization.")
else:
    print("Preprocessing data is empty. Cannot generate plots.")

In [None]:
from collections import Counter
from imblearn.over_sampling import SMOTE
from sklearn.preprocessing import LabelEncoder, StandardScaler
import joblib
import numpy as np

In [None]:
!pip install PyWavelets

# üîç Advanced Lung Sound Feature Extraction Pipeline (ASTHMA, COPD, NORMAL)

This notebook implements a **comprehensive and robust feature extraction pipeline** for lung sound classification involving **Asthma**, **COPD**, and **Normal** respiratory sounds.

---

## üß† Key Functionalities

### ‚úÖ Augmentation Techniques:
- `add_noise()`: Add Gaussian noise to improve model robustness.
- `shift()`: Time-shift audio to simulate real-world variance.
- `stretch()`: Time-stretching for temporal distortions.
- `pitch_shift()`: Modify pitch to increase dataset diversity.

---

### ‚úÖ Feature Extraction Modules:

| Feature Type                     | Description                                                                             |
|----------------------------------|-----------------------------------------------------------------------------------------|
| üéµ **Advanced MFCCs**            | 40 MFCCs + Delta + Delta-Delta + Statistical Moments (mean, std, skewness).             |
| üìâ **Fourier-Bessel Entropy**    | Energy distribution across frequency bands using spectral entropy.                      |
| üìä **Enhanced Mel-Spectrogram** | 128-bin Mel-spectrograms with mean, std, max, min ‚Äî suitable for 2D CNNs.              |
| üìà **Wavelet Features**         | Multi-resolution transient detection using `db4` wavelet decomposition.                |
| üîÅ **Sequence Features**        | Frame-wise MFCC temporal variation for attention/transformer models.                   |
| üéº **Spectral Features**        | Centroid, bandwidth, rolloff, flatness, ZCR, chroma, and tonnetz (auto-handled).       |
| üî£ **Fourier-Bessel Coeffs**    | Original signal decomposition based on Fourier-Bessel transforms.                      |

---

## üöÄ Pipeline Highlights

- ‚úÖ Supports `asthma`, `copd`, and `normal` only.
- ‚úÖ Error-handling for empty, invalid, or corrupt audio files.
- ‚úÖ Auto-parsing of disease label from filename.
- ‚úÖ Feature normalization using `StandardScaler`.
- ‚úÖ Augmented samples are included for better generalization.
- ‚úÖ Informative logging and final summary of class balance and feature breakdown.

---

## üõ†Ô∏è Usage Guide

To run the pipeline, make sure your `filepaths` list contains the full paths to your `.wav` files.

```python
X_data, y_data, scaler = run_enhanced_feature_extraction(filepaths)
joblib.dump(scaler, 'scaler.pkl')  # Save the scaler for consistent inference


In [None]:
import numpy as np
import librosa
import librosa.display
import matplotlib.pyplot as plt
import warnings
from scipy import stats
from scipy.signal import hilbert
import pywt
from sklearn.preprocessing import StandardScaler
import joblib

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

# Enhanced function to add noise
def add_noise(data, noise_level=0.005):
    """Add Gaussian noise to audio data with improved stability."""
    if len(data) == 0:
        return data
    noise = np.random.randn(len(data)) * noise_level
    noisy_data = data + noise
    # Normalize to prevent clipping
    return np.clip(noisy_data, -1.0, 1.0)

# Enhanced function to shift the audio
def shift(data, shift_max, sampling_rate):
    """Shift audio with improved parameter handling."""
    if len(data) == 0:
        return data
    # Ensure shift_max is reasonable relative to data length
    max_shift = min(shift_max, len(data) // 4)
    shift = np.random.randint(low=-max_shift, high=max_shift)
    return np.roll(data, shift)

# Enhanced function to stretch the audio
def stretch(data, rate=0.8):
    """Time stretch audio with error handling."""
    try:
        if len(data) == 0:
            return data
        # Ensure rate is within reasonable bounds
        rate = np.clip(rate, 0.5, 2.0)
        stretched = librosa.effects.time_stretch(data, rate=rate)
        return stretched
    except Exception as e:
        print(f"Warning: Time stretch failed, returning original data: {e}")
        return data

# Enhanced function to change pitch
def pitch_shift(data, sampling_rate, n_steps=2):
    """Pitch shift audio with improved error handling."""
    try:
        if len(data) == 0:
            return data
        # Ensure n_steps is within reasonable bounds
        n_steps = np.clip(n_steps, -12, 12)
        shifted = librosa.effects.pitch_shift(data, sr=sampling_rate, n_steps=n_steps)
        return shifted
    except Exception as e:
        print(f"Warning: Pitch shift failed, returning original data: {e}")
        return data

# NEW: Advanced MFCC Features with Delta and Delta-Delta
def extract_advanced_mfcc_features(data, sampling_rate, n_mfcc=40):
    """Extract MFCCs with delta and delta-delta features plus statistical moments."""
    try:
        if len(data) == 0:
            return np.zeros(n_mfcc * 3 + n_mfcc * 3)  # MFCC + Delta + Delta-Delta + Stats

        # Extract MFCCs
        mfccs = librosa.feature.mfcc(
            y=data,
            sr=sampling_rate,
            n_mfcc=n_mfcc,
            n_fft=2048,
            hop_length=512
        )

        if mfccs.shape[1] == 0:
            return np.zeros(n_mfcc * 3 + n_mfcc * 3)

        # Compute Delta (first derivative) and Delta-Delta (second derivative)
        delta_mfccs = librosa.feature.delta(mfccs)
        delta2_mfccs = librosa.feature.delta(mfccs, order=2)

        # Statistical moments across time for each coefficient
        mfcc_mean = np.mean(mfccs, axis=1)
        mfcc_std = np.std(mfccs, axis=1)
        mfcc_skew = stats.skew(mfccs, axis=1)

        delta_mean = np.mean(delta_mfccs, axis=1)
        delta_std = np.std(delta_mfccs, axis=1)
        delta_skew = stats.skew(delta_mfccs, axis=1)

        delta2_mean = np.mean(delta2_mfccs, axis=1)
        delta2_std = np.std(delta2_mfccs, axis=1)
        delta2_skew = stats.skew(delta2_mfccs, axis=1)

        # Combine all MFCC-based features
        advanced_mfcc_features = np.concatenate([
            mfcc_mean, mfcc_std, mfcc_skew,
            delta_mean, delta_std, delta_skew,
            delta2_mean, delta2_std, delta2_skew
        ])

        # Handle NaN or infinite values
        advanced_mfcc_features = np.nan_to_num(advanced_mfcc_features, nan=0.0, posinf=0.0, neginf=0.0)
        return advanced_mfcc_features

    except Exception as e:
        print(f"Warning: Advanced MFCC extraction failed: {e}")
        return np.zeros(n_mfcc * 9)  # 9 = 3 features * 3 statistics

# NEW: Fourier-Bessel Spectral Entropy (FBSE)
def extract_fbse_features(data, sampling_rate, n_bands=10):
    """Extract Fourier-Bessel Spectral Entropy features."""
    try:
        if len(data) == 0:
            return np.zeros(n_bands)

        # Compute power spectral density
        freqs, psd = librosa.power_to_db(np.abs(librosa.stft(data))**2), np.abs(librosa.stft(data))**2

        # Divide frequency range into bands
        freq_bands = np.linspace(0, sampling_rate//2, n_bands + 1)
        entropy_features = []

        for i in range(n_bands):
            # Get frequency band indices
            start_idx = int(freq_bands[i] * len(psd) / (sampling_rate//2))
            end_idx = int(freq_bands[i+1] * len(psd) / (sampling_rate//2))

            if end_idx > start_idx:
                band_psd = np.mean(psd[start_idx:end_idx], axis=0)
                # Normalize to create probability distribution
                band_psd_norm = band_psd / (np.sum(band_psd) + 1e-10)
                # Calculate entropy
                entropy = -np.sum(band_psd_norm * np.log(band_psd_norm + 1e-10))
                entropy_features.append(np.mean(entropy))
            else:
                entropy_features.append(0.0)

        fbse_features = np.array(entropy_features)
        fbse_features = np.nan_to_num(fbse_features, nan=0.0, posinf=0.0, neginf=0.0)
        return fbse_features

    except Exception as e:
        print(f"Warning: FBSE extraction failed: {e}")
        return np.zeros(n_bands)

# NEW: Enhanced Mel-Spectrogram with 2D features for CNN
def extract_enhanced_melspectrogram(data, sampling_rate, n_mels=128):
    """Extract enhanced Mel-spectrogram features suitable for 2D CNN processing."""
    try:
        if len(data) == 0:
            return np.zeros(n_mels * 4)  # Statistical features

        # Enhanced mel-spectrogram computation
        mel_spec = librosa.feature.melspectrogram(
            y=data,
            sr=sampling_rate,
            n_mels=n_mels,
            n_fft=2048,
            hop_length=512,
            fmax=sampling_rate//2
        )

        if mel_spec.size == 0:
            return np.zeros(n_mels * 4)

        mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)

        # Extract statistical features from 2D mel-spectrogram
        if mel_spec_db.shape[1] > 0:
            mel_mean = np.mean(mel_spec_db, axis=1)
            mel_std = np.std(mel_spec_db, axis=1)
            mel_max = np.max(mel_spec_db, axis=1)
            mel_min = np.min(mel_spec_db, axis=1)

            # Combine statistical features
            mel_features = np.concatenate([mel_mean, mel_std, mel_max, mel_min])
        else:
            mel_features = np.zeros(n_mels * 4)

        # Handle NaN or infinite values
        mel_features = np.nan_to_num(mel_features, nan=0.0, posinf=0.0, neginf=0.0)
        return mel_features

    except Exception as e:
        print(f"Warning: Enhanced Mel-spectrogram extraction failed: {e}")
        return np.zeros(n_mels * 4)

# NEW: Wavelet Features for transient detection
def extract_wavelet_features(data, wavelet='db4', levels=5):
    """Extract wavelet features for transient detection."""
    try:
        if len(data) == 0:
            return np.zeros(levels * 4)  # 4 stats per level

        # Perform wavelet decomposition
        coeffs = pywt.wavedec(data, wavelet, level=levels)

        wavelet_features = []
        for coeff in coeffs:
            if len(coeff) > 0:
                # Statistical features for each decomposition level
                wavelet_features.extend([
                    np.mean(np.abs(coeff)),  # Mean absolute value
                    np.std(coeff),           # Standard deviation
                    np.max(np.abs(coeff)),   # Maximum absolute value
                    np.sum(coeff**2)         # Energy
                ])
            else:
                wavelet_features.extend([0.0, 0.0, 0.0, 0.0])

        wavelet_features = np.array(wavelet_features)
        wavelet_features = np.nan_to_num(wavelet_features, nan=0.0, posinf=0.0, neginf=0.0)
        return wavelet_features

    except Exception as e:
        print(f"Warning: Wavelet feature extraction failed: {e}")
        return np.zeros(levels * 4)

# NEW: Sequence-based features for transformer/attention models
def extract_sequence_features(data, sampling_rate, frame_length=2048, hop_length=512):
    """Extract sequence-based features suitable for attention mechanisms."""
    try:
        if len(data) == 0:
            return np.zeros(26)  # Summary statistics

        # Extract frame-wise MFCCs for sequence modeling
        mfccs = librosa.feature.mfcc(
            y=data,
            sr=sampling_rate,
            n_mfcc=13,
            n_fft=frame_length,
            hop_length=hop_length
        )

        if mfccs.shape[1] == 0:
            return np.zeros(26)

        # Temporal dynamics features
        # 1. Frame-to-frame variation
        frame_variations = np.mean(np.abs(np.diff(mfccs, axis=1)), axis=1)

        # 2. Long-term average and variation
        long_term_mean = np.mean(mfccs, axis=1)

        # Combine sequence features
        sequence_features = np.concatenate([frame_variations, long_term_mean])
        sequence_features = np.nan_to_num(sequence_features, nan=0.0, posinf=0.0, neginf=0.0)
        return sequence_features

    except Exception as e:
        print(f"Warning: Sequence feature extraction failed: {e}")
        return np.zeros(26)

# Enhanced Fourier-Bessel Feature Extraction (Original)
def fourier_bessel_features(data, sampling_rate, n_coeff):
    """Enhanced Fourier-Bessel feature extraction with improved numerical stability."""
    if len(data) == 0:
        return np.zeros(n_coeff)

    t = np.arange(len(data)) / sampling_rate
    fb_coeff = np.zeros(n_coeff)

    # Normalize time for better numerical stability
    t_norm = t / np.max(t) if np.max(t) > 0 else t

    for i in range(n_coeff):
        j = i + 1
        # Enhanced computation with better numerical stability
        cosine_term = np.cos(2 * np.pi * j * t_norm)
        fb_coeff[i] = np.sum(data * cosine_term) / len(data)

    # Handle NaN or infinite values
    fb_coeff = np.nan_to_num(fb_coeff, nan=0.0, posinf=0.0, neginf=0.0)
    return fb_coeff
# IMPROVED: Spectral Features with better tonnetz handling
def extract_spectral_features(data, sampling_rate):
    """Extract spectral features with improved tonnetz handling for low sampling rates."""
    try:
        if len(data) == 0:
            return np.zeros(7)

        # Basic spectral features (these work fine with any sampling rate)
        spectral_centroid = librosa.feature.spectral_centroid(y=data, sr=sampling_rate)
        spectral_bandwidth = librosa.feature.spectral_bandwidth(y=data, sr=sampling_rate)
        spectral_rolloff = librosa.feature.spectral_rolloff(y=data, sr=sampling_rate)
        spectral_flatness = librosa.feature.spectral_flatness(y=data)
        zero_crossing_rate = librosa.feature.zero_crossing_rate(data)

        # Chroma features (work with any sampling rate)
        chroma = librosa.feature.chroma_stft(y=data, sr=sampling_rate)
        chroma_mean = np.mean(chroma)

        # IMPROVED: Tonnetz with proper frequency limit handling
        try:
            # Calculate safe frequency range for tonnetz
            nyquist_freq = sampling_rate / 2
            # Tonnetz typically needs frequencies up to ~4000Hz, but we must respect Nyquist
            max_safe_freq = min(4000, nyquist_freq * 0.95)  # Use 95% of Nyquist as safety margin

            if sampling_rate >= 8000:  # Safe threshold for full tonnetz
                tonnetz = librosa.feature.tonnetz(y=data, sr=sampling_rate)
                tonnetz_mean = np.mean(tonnetz)
            elif sampling_rate >= 4000:  # Limited tonnetz for medium sampling rates
                # Use chromagram-based approach for lower sampling rates
                chroma_cqt = librosa.feature.chroma_cqt(
                    y=data,
                    sr=sampling_rate,
                    fmin=librosa.note_to_hz('C1'),
                    n_chroma=12
                )
                # Approximate tonnetz using chroma features
                tonnetz_mean = np.mean(chroma_cqt) * 0.5  # Scale factor to approximate tonnetz range
            else:  # Very low sampling rates - skip tonnetz
                print(f"Info: Skipping tonnetz for sampling rate {sampling_rate}Hz (too low)")
                tonnetz_mean = 0.0

        except Exception as e:
            # Suppress the specific Nyquist frequency warning since we handle it
            if "Nyquist frequency" not in str(e):
                print(f"Warning: Tonnetz feature failed: {e}")
            tonnetz_mean = 0.0

        spectral_features = np.array([
            np.mean(spectral_centroid),
            np.mean(spectral_bandwidth),
            np.mean(spectral_rolloff),
            np.mean(spectral_flatness),
            np.mean(zero_crossing_rate),
            chroma_mean,
            tonnetz_mean
        ])

        spectral_features = np.nan_to_num(spectral_features, nan=0.0, posinf=0.0, neginf=0.0)
        return spectral_features

    except Exception as e:
        print(f"Warning: Spectral feature extraction failed: {e}")
        return np.zeros(7)


# Enhanced Feature Extraction Function with all new features
def feature_extraction(dir_):
    """Enhanced feature extraction with all advanced features for maximum accuracy."""
    X_Features = []
    y_Labels = []
    X_Sequences = []  # For transformer/attention models

    # Feature dimensions
    n_mfcc = 40
    fb_coeffs = 20
    n_mels = 128
    wavelet_levels = 5
    fbse_bands = 10

    # Statistics tracking
    processed_files = 0
    skipped_files = 0
    augmented_samples = 0

    print("üöÄ Starting ADVANCED feature extraction for ASTHMA, COPD, and NORMAL classes...")
    print("üìä Features being extracted:")
    print("   ‚Ä¢ Advanced MFCCs with Delta & Delta-Delta + Statistical Moments")
    print("   ‚Ä¢ Fourier-Bessel Spectral Entropy (FBSE)")
    print("   ‚Ä¢ Enhanced Mel-Spectrograms (128 bins) for 2D CNN")
    print("   ‚Ä¢ Wavelet Features for Transient Detection")
    print("   ‚Ä¢ Sequence Features for Attention Mechanisms")
    print("   ‚Ä¢ Comprehensive Spectral Features (with fixed tonnetz)")
    print("   ‚Ä¢ Original Fourier-Bessel Coefficients")
    print("   ‚Ä¢ Target Classes: ASTHMA, COPD, NORMAL")

    for soundDir in dir_:
        try:
            # ENHANCED disease name extraction from filename with lung fibrosis support
            try:
                filename = soundDir.split('/')[-1] if '/' in soundDir else soundDir.split('\\')[-1]
                parts = filename.split('_')
                if len(parts) < 2:
                    print(f"‚ö†Ô∏è  Invalid filename format: {filename}")
                    skipped_files += 1
                    continue

                disease_part = parts[1].split(',')[0].lower().strip()

                # Disease mapping for ASTHMA, COPD, and NORMAL only
                disease_mapping = {
                    'asthma': 'asthma',
                    'copd': 'copd',
                    'n': 'normal',
                    'c': 'normal',
                    'normal': 'normal'
                }

                # More flexible disease detection (ASTHMA, COPD, NORMAL only)
                if disease_part in disease_mapping:
                    disease = disease_mapping[disease_part]
                elif 'asthma' in disease_part:
                    disease = 'asthma'
                elif 'copd' in disease_part:
                    disease = 'copd'
                elif disease_part in ['n', 'normal', 'c']:
                    disease = 'normal'
                else:
                    print(f"‚ö†Ô∏è  Skipping unsupported disease label '{disease_part}' in: {filename}")
                    print(f"    Only processing: asthma, copd, normal")
                    skipped_files += 1
                    continue

            except Exception as e:
                print(f"‚ö†Ô∏è  Failed to parse filename {soundDir}: {e}")
                skipped_files += 1
                continue

            # Disease validation for ASTHMA, COPD, and NORMAL only
            valid_diseases = ["asthma", "copd", "normal"]
            if disease not in valid_diseases:
                print(f"‚ö†Ô∏è  Skipping invalid disease label '{disease}' in: {soundDir}")
                print(f"    Only processing: {valid_diseases}")
                skipped_files += 1
                continue

            # Enhanced audio loading with error handling
            try:
                data, sampling_rate = librosa.load(soundDir, sr=None)
                if len(data) == 0:
                    print(f"‚ö†Ô∏è  Empty audio file: {soundDir}")
                    skipped_files += 1
                    continue
            except Exception as e:
                print(f"‚ùå Failed to load audio file {soundDir}: {e}")
                skipped_files += 1
                continue

            # EXTRACT ALL ADVANCED FEATURES
            print(f"üîÑ Processing: {filename[:50]}... (SR: {sampling_rate}Hz)")

            # 1. Advanced MFCCs with Delta and Delta-Delta
            advanced_mfcc_features = extract_advanced_mfcc_features(data, sampling_rate, n_mfcc)

            # 2. Fourier-Bessel Spectral Entropy (FBSE)
            fbse_features = extract_fbse_features(data, sampling_rate, fbse_bands)

            # 3. Enhanced Mel-Spectrogram Features
            enhanced_mel_features = extract_enhanced_melspectrogram(data, sampling_rate, n_mels)

            # 4. Wavelet Features
            wavelet_features = extract_wavelet_features(data, 'db4', wavelet_levels)

            # 5. Sequence Features
            sequence_features = extract_sequence_features(data, sampling_rate)

            # 6. FIXED: Spectral Features (with proper tonnetz handling)
            spectral_features = extract_spectral_features(data, sampling_rate)

            # 7. Original Fourier-Bessel Features
            fb_features = fourier_bessel_features(data, sampling_rate, fb_coeffs)

            # COMBINE ALL FEATURES
            combined_features = np.concatenate([
                advanced_mfcc_features,  # Advanced MFCCs with deltas + stats
                fbse_features,           # Fourier-Bessel Spectral Entropy
                enhanced_mel_features,   # Enhanced Mel-Spectrogram
                wavelet_features,        # Wavelet features
                sequence_features,       # Sequence features
                spectral_features,       # Spectral features (fixed)
                fb_features             # Original FB features
            ])

            # Final validation
            combined_features = np.nan_to_num(combined_features, nan=0.0, posinf=0.0, neginf=0.0)

            # Append Original Features
            X_Features.append(combined_features)
            y_Labels.append(disease)

            # **ENHANCED AUGMENTATION with all feature types**
            augmentations = [
                (add_noise, 0.002),
                (shift, 1600),
                (stretch, 0.9),
                (pitch_shift, 2)
            ]

            aug_count = 0
            for aug_func, aug_param in augmentations:
                try:
                    # Enhanced augmentation with proper parameter passing
                    if aug_func == shift:
                        data_aug = aug_func(data, aug_param, sampling_rate)
                    elif aug_func == pitch_shift:
                        data_aug = aug_func(data, sampling_rate, aug_param)
                    else:
                        data_aug = aug_func(data, aug_param)

                    # Validate augmented data
                    if len(data_aug) == 0 or np.all(data_aug == 0):
                        continue

                    # Extract ALL features for augmented data
                    advanced_mfcc_aug = extract_advanced_mfcc_features(data_aug, sampling_rate, n_mfcc)
                    fbse_aug = extract_fbse_features(data_aug, sampling_rate, fbse_bands)
                    enhanced_mel_aug = extract_enhanced_melspectrogram(data_aug, sampling_rate, n_mels)
                    wavelet_aug = extract_wavelet_features(data_aug, 'db4', wavelet_levels)
                    sequence_aug = extract_sequence_features(data_aug, sampling_rate)
                    spectral_aug = extract_spectral_features(data_aug, sampling_rate)  # Fixed version
                    fb_aug = fourier_bessel_features(data_aug, sampling_rate, fb_coeffs)

                    # Combine all augmented features
                    combined_features_aug = np.concatenate([
                        advanced_mfcc_aug, fbse_aug, enhanced_mel_aug,
                        wavelet_aug, sequence_aug, spectral_aug, fb_aug
                    ])

                    combined_features_aug = np.nan_to_num(combined_features_aug, nan=0.0, posinf=0.0, neginf=0.0)

                    # Append Augmented Data
                    X_Features.append(combined_features_aug)
                    y_Labels.append(disease)
                    aug_count += 1

                except Exception as e:
                    print(f"Warning: Augmentation {aug_func.__name__} failed: {e}")
                    continue

            processed_files += 1
            augmented_samples += aug_count

            if processed_files % 25 == 0:  # Progress update every 25 files
                print(f"‚úÖ Processed {processed_files} files so far...")

        except Exception as e:
            print(f"‚ùå Error processing {soundDir}: {e}")
            skipped_files += 1
            continue

    # Enhanced final validation and conversion
    if len(X_Features) == 0:
        print("‚ùå No features extracted!")
        return np.array([]), np.array([])

    X_data = np.array(X_Features)
    y_data = np.array(y_Labels)

    # Feature normalization for better model performance
    scaler = StandardScaler()
    X_data_normalized = scaler.fit_transform(X_data)

    # Final statistics and validation
    print(f"\nüéØ ADVANCED Feature Extraction Summary (ASTHMA, COPD, NORMAL):")
    print(f"   ‚Ä¢ Successfully processed files: {processed_files}")
    print(f"   ‚Ä¢ Skipped files: {skipped_files}")
    print(f"   ‚Ä¢ Original samples: {processed_files}")
    print(f"   ‚Ä¢ Augmented samples: {augmented_samples}")
    print(f"   ‚Ä¢ Total samples: {len(X_Features)}")
    print(f"   ‚Ä¢ Feature dimensionality: {X_data.shape[1]}")
    print(f"   ‚Ä¢ Features normalized: ‚úÖ")
    print(f"   ‚Ä¢ Classes: ASTHMA, COPD, NORMAL")

    # Detailed feature breakdown
    feature_breakdown = {
        'Advanced MFCCs (with deltas & stats)': n_mfcc * 9,
        'FBSE Features': fbse_bands,
        'Enhanced Mel-Spectrogram': n_mels * 4,
        'Wavelet Features': wavelet_levels * 4 + 4,  # +4 for approximation coeffs
        'Sequence Features': 26,
        'Spectral Features (Fixed)': 7,
        'Original FB Features': fb_coeffs
    }

    print(f"\nüìä Feature Type Breakdown:")
    total_expected = 0
    for feature_type, count in feature_breakdown.items():
        print(f"   ‚Ä¢ {feature_type}: {count} features")
        total_expected += count
    print(f"   ‚Ä¢ Total Expected: {total_expected}")
    print(f"   ‚Ä¢ Actual Total: {X_data.shape[1]}")

    # Class distribution
    unique_labels, counts = np.unique(y_data, return_counts=True)
    print(f"\nüè∑Ô∏è  Class Distribution:")
    for label, count in zip(unique_labels, counts):
        percentage = (count / len(y_data)) * 100
        print(f"   ‚Ä¢ {label}: {count} samples ({percentage:.1f}%)")

    # Feature quality validation
    nan_count = np.sum(np.isnan(X_data))
    inf_count = np.sum(np.isinf(X_data))
    if nan_count > 0 or inf_count > 0:
        print(f"‚ö†Ô∏è  Found {nan_count} NaN and {inf_count} infinite values (cleaned)")

    print(f"\nüéâ ADVANCED feature extraction completed successfully!")
    print(f"‚úÖ Final Feature Matrix Shape: {X_data_normalized.shape}")
    print(f"‚úÖ Final Label Vector Shape: {y_data.shape}")
    print(f"üéØ Ready for high-accuracy model training!")

    return X_data_normalized, y_data, scaler

# Usage example with error handling
def run_enhanced_feature_extraction(filepaths):
    """Run the enhanced feature extraction with comprehensive error handling."""
    try:
        X_data, y_data, scaler = feature_extraction(filepaths)

        if len(X_data) == 0:
            print("‚ùå No features were extracted. Please check your file paths and audio files.")
            return None, None, None

        print(f"\nüî• FEATURE EXTRACTION COMPLETE!")
        print(f"üí™ Enhanced features ready for machine learning models:")
        print(f"   ‚Ä¢ Traditional ML: Use X_data directly")
        print(f"   ‚Ä¢ Deep Learning: Consider reshaping for CNN/RNN architectures")
        print(f"   ‚Ä¢ Transformer Models: Use sequence-based features")

        return X_data, y_data, scaler

    except NameError as e:
        print(f"‚ùå Missing required variables: {e}")
        print("Please ensure 'filepaths' variable is defined with your audio file paths.")
        return None, None, None
    except Exception as e:
        print(f"‚ùå Unexpected error during feature extraction: {e}")
        return None, None, None

# Example usage (uncomment when you have filepaths defined):
X_data, y_data, scaler = run_enhanced_feature_extraction(filepaths)
joblib.dump(scaler, 'scaler.pkl')

In [None]:
DISEASE_COLORS = {
    'normal': '#2E8B57',     # Sea Green
    'asthma': '#FF6B6B',     # Coral Red
    'copd': '#4ECDC4',       # Turquoise
    'pneumonia': '#FFE66D',  # Golden Yellow
    'bronchitis': '#A8E6CF'  # Light Green
}

In [None]:
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê
# üìä ENHANCED LUNG DISEASE FEATURE VISUALIZATION
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê
# Advanced visualization system for lung sound feature analysis across different diseases
# Fixed subplot compatibility issues and enhanced three-class comparison
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê

import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import numpy as np
import pandas as pd

# Disease-specific color palette for consistent visualization
DISEASE_COLORS = {
    'normal': '#2E8B57',     # Sea Green
    'asthma': '#FF6B6B',     # Coral Red
    'copd': '#4ECDC4',       # Turquoise
    'pneumonia': '#FFE66D',  # Golden Yellow
    'bronchitis': '#A8E6CF'  # Light Green
}

# Feature-specific visualization styles
FEATURE_STYLES = {
    'Advanced MFCCs': {'type': 'heatmap', 'colorscale': 'Viridis'},
    'FBSE': {'type': 'radar', 'colorscale': 'Plasma'},
    'Spectral Features': {'type': 'bar', 'colorscale': 'Cividis'},
    'Wavelet Features': {'type': 'line', 'colorscale': 'Turbo'},
    'Sequence Features': {'type': 'scatter', 'colorscale': 'Inferno'}
}

def create_enhanced_feature_visualization(preprocessed_data, sampling_rate, handler):
    """
    üé® Creates comprehensive multi-graph visualization of lung disease features

    Parameters:
        preprocessed_data: Dictionary of preprocessed audio frames
        sampling_rate: Audio sampling rate
        handler: Feature extraction handler object
    """

    if not preprocessed_data:
        print("‚ùå Preprocessing data is not available. Cannot generate feature plots.")
        return

    # ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
    # üìã DISEASE SAMPLE SELECTION (Focus on 3 classes)
    # ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
    diseases_to_plot = ['healthy', 'asthma', 'copd']
    sample_files = {}

    # Select representative samples for each disease
    for filename, frames in preprocessed_data.items():
        if frames.shape[0] > 0:
            label = handler._extract_label_from_filename(filename).lower()
            if label in diseases_to_plot and label not in sample_files:
                sample_files[label] = filename
            if len(sample_files) == len(diseases_to_plot):
                break

    if not sample_files:
        print("‚ùå Could not find sample files for the specified diseases.")
        return

    print(f"üìã Found samples for diseases: {list(sample_files.keys())}")

    # ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
    # üîß FEATURE EXTRACTION CONFIGURATION
    # ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
    features_config = [
        ('Advanced MFCCs', extract_advanced_mfcc_features, 360, 'heatmap'),
        ('FBSE', extract_fbse_features, 10, 'line'),
        ('Spectral Features', extract_spectral_features, 7, 'bar'),
        ('Wavelet Features', extract_wavelet_features, 24, 'line'),
        ('Sequence Features', extract_sequence_features, 26, 'scatter')
    ]

    # ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê
    # üé® CREATE MASTER VISUALIZATION DASHBOARD
    # ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê

    # Main comparison dashboard
    create_feature_comparison_dashboard(sample_files, preprocessed_data, sampling_rate, features_config)

    # Individual detailed visualizations for each feature type
    for feature_name, feature_func, expected_dim, viz_type in features_config:
        create_individual_feature_visualization(
            sample_files, preprocessed_data, sampling_rate,
            feature_name, feature_func, viz_type
        )

    # Statistical comparison charts
    create_statistical_comparison_charts(sample_files, preprocessed_data, sampling_rate, features_config)

    # Create comprehensive three-class comparison
    create_three_class_comparison(sample_files, preprocessed_data, sampling_rate, features_config)


def create_feature_comparison_dashboard(sample_files, preprocessed_data, sampling_rate, features_config):
    """üèÜ Creates comprehensive feature comparison dashboard (FIXED VERSION)"""

    num_diseases = len(sample_files)
    num_features = len(features_config)

    # Create subplot layout with compatible types
    fig = make_subplots(
        rows=num_diseases,
        cols=num_features,
        subplot_titles=[f"{disease.upper()} - {feature_name}"
                       for disease in sample_files.keys()
                       for feature_name, _, _, _ in features_config],
        horizontal_spacing=0.08,
        vertical_spacing=0.15
    )

    for i, (disease, filename) in enumerate(sample_files.items()):
        frames = preprocessed_data[filename]
        if frames.shape[0] > 0:
            signal_to_analyze = frames[0]  # Use first frame
            disease_color = DISEASE_COLORS.get(disease, '#636EFA')

            for j, (feature_name, feature_func, expected_dim, viz_type) in enumerate(features_config):
                try:
                    # Extract features based on function requirements
                    if feature_func == extract_wavelet_features:
                        feature_vector = feature_func(signal_to_analyze)
                    else:
                        feature_vector = feature_func(signal_to_analyze, sampling_rate)

                    if feature_vector.size > 0:
                        plot_data = prepare_plot_data(feature_vector)

                        # Create compatible visualizations for xy subplots
                        if 'mfcc' in feature_name.lower():
                            # Bar chart for MFCCs (compatible with xy subplot)
                            add_enhanced_bar_trace(fig, plot_data[:20], disease_color, f'{disease} MFCC', i+1, j+1)

                        elif 'spectral' in feature_name.lower():
                            # Regular bar chart for spectral features (compatible)
                            add_enhanced_bar_trace(fig, plot_data, disease_color, f'{disease} Spectral', i+1, j+1)

                        elif 'wavelet' in feature_name.lower():
                            # Line plot for wavelet features (compatible)
                            add_enhanced_line_trace(fig, plot_data, disease_color, f'{disease} Wavelet', i+1, j+1)

                        elif 'fbse' in feature_name.lower():
                            # Line plot for FBSE features
                            add_enhanced_line_trace(fig, plot_data, disease_color, f'{disease} FBSE', i+1, j+1)

                        else:
                            # Scatter plot for sequence features
                            add_enhanced_scatter_trace(fig, plot_data, disease_color, f'{disease} {feature_name}', i+1, j+1)

                except Exception as e:
                    print(f"‚ùå Error processing {feature_name} for {disease}: {e}")
                    # Add error placeholder
                    fig.add_trace(
                        go.Scatter(
                            x=[0], y=[0],
                            mode='text',
                            text=[f"Error processing<br>{feature_name}"],
                            textfont=dict(color='red', size=10),
                            showlegend=False
                        ),
                        row=i+1, col=j+1
                    )

    # Enhanced layout styling
    fig.update_layout(
        height=num_diseases * 350,
        width=num_features * 280,
        title={
            'text': "ü´Å Lung Disease Feature Analysis: Normal vs COPD vs Asthma",
            'x': 0.5,
            'xanchor': 'center',
            'font': {'size': 22, 'color': '#2C3E50'}
        },
        showlegend=False,
        plot_bgcolor='rgba(248, 249, 250, 0.8)',
        paper_bgcolor='white',
        font=dict(family="Arial, sans-serif", size=10)
    )

    # Update subplot axes
    for i in range(1, num_diseases + 1):
        for j in range(1, num_features + 1):
            fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='lightgray', row=i, col=j)
            fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='lightgray', row=i, col=j)

    fig.show()


def create_individual_feature_visualization(sample_files, preprocessed_data, sampling_rate, feature_name, feature_func, viz_type):
    """üéØ Creates detailed individual feature visualization for three classes"""

    fig = go.Figure()

    # Extract features for all diseases
    all_features = {}
    for disease, filename in sample_files.items():
        frames = preprocessed_data[filename]
        if frames.shape[0] > 0:
            signal_to_analyze = frames[0]

            try:
                if feature_func == extract_wavelet_features:
                    feature_vector = feature_func(signal_to_analyze)
                else:
                    feature_vector = feature_func(signal_to_analyze, sampling_rate)

                all_features[disease] = prepare_plot_data(feature_vector)
            except Exception as e:
                print(f"‚ùå Error extracting {feature_name} for {disease}: {e}")
                continue

    # Create comparative visualization
    if all_features:
        for disease, features in all_features.items():
            disease_color = DISEASE_COLORS.get(disease, '#636EFA')

            fig.add_trace(go.Scatter(
                y=features,
                mode='lines+markers',
                name=f'{disease.capitalize()}',
                line=dict(color=disease_color, width=3),
                marker=dict(size=6, opacity=0.8),
                hovertemplate=f'<b>{disease.capitalize()}</b><br>Index: %{{x}}<br>Value: %{{y:.3f}}<extra></extra>'
            ))

    fig.update_layout(
        title=f"üìà {feature_name} - Three-Class Comparison",
        xaxis_title="Feature Index",
        yaxis_title="Feature Value",
        height=500,
        width=800,
        plot_bgcolor='rgba(248, 249, 250, 0.8)',
        paper_bgcolor='white',
        font=dict(family="Arial, sans-serif", size=12),
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1
        )
    )

    fig.show()


def create_statistical_comparison_charts(sample_files, preprocessed_data, sampling_rate, features_config):
    """üìä Creates statistical comparison charts for three classes"""

    # Feature statistics comparison
    stats_fig = make_subplots(
        rows=2, cols=2,
        subplot_titles=['Feature Magnitudes Comparison', 'Feature Distributions', 'Disease Separation', 'Feature Statistics'],
        specs=[[{"type": "bar"}, {"type": "box"}],
               [{"type": "scatter"}, {"type": "bar"}]]
    )

    # Collect all feature statistics
    disease_stats = {}
    feature_means = {}

    for disease, filename in sample_files.items():
        frames = preprocessed_data[filename]
        if frames.shape[0] > 0:
            signal_to_analyze = frames[0]
            disease_features = []

            for feature_name, feature_func, _, _ in features_config:
                try:
                    if feature_func == extract_wavelet_features:
                        feature_vector = feature_func(signal_to_analyze)
                    else:
                        feature_vector = feature_func(signal_to_analyze, sampling_rate)

                    processed_features = prepare_plot_data(feature_vector)[:15]  # Take first 15 features
                    disease_features.extend(processed_features)
                except:
                    continue

            disease_stats[disease] = np.array(disease_features) if disease_features else np.array([0])
            feature_means[disease] = np.mean(disease_features) if disease_features else 0

    # Add statistical visualizations
    if disease_stats:
        diseases = list(disease_stats.keys())

        # 1. Feature magnitude comparison (first 10 features)
        for i, (disease, features) in enumerate(disease_stats.items()):
            stats_fig.add_trace(
                go.Bar(
                    x=[f'F{j+1}' for j in range(min(10, len(features)))],
                    y=features[:10],
                    name=f'{disease.capitalize()}',
                    marker_color=DISEASE_COLORS.get(disease, '#636EFA'),
                    opacity=0.8,
                    offsetgroup=i
                ),
                row=1, col=1
            )

        # 2. Feature distributions (box plot)
        for disease, features in disease_stats.items():
            stats_fig.add_trace(
                go.Box(
                    y=features,
                    name=f'{disease.capitalize()}',
                    marker_color=DISEASE_COLORS.get(disease, '#636EFA'),
                    boxpoints='outliers'
                ),
                row=1, col=2
            )

        # 3. Disease separation visualization
        x_coords = [np.mean(disease_stats[d]) for d in diseases]
        y_coords = [np.std(disease_stats[d]) for d in diseases]

        stats_fig.add_trace(
            go.Scatter(
                x=x_coords,
                y=y_coords,
                mode='markers+text',
                text=[d.capitalize() for d in diseases],
                textposition='top center',
                marker=dict(
                    size=25,
                    color=[DISEASE_COLORS.get(d, '#636EFA') for d in diseases],
                    opacity=0.8,
                    line=dict(width=3, color='white')
                ),
                name='Disease Clusters',
                showlegend=False
            ),
            row=2, col=1
        )

        # 4. Summary statistics
        means = [np.mean(disease_stats[d]) for d in diseases]
        stds = [np.std(disease_stats[d]) for d in diseases]

        stats_fig.add_trace(
            go.Bar(
                x=[f'{d.capitalize()} Mean' for d in diseases],
                y=means,
                name='Mean Values',
                marker_color=[DISEASE_COLORS.get(d, '#636EFA') for d in diseases],
                opacity=0.8
            ),
            row=2, col=2
        )

    stats_fig.update_layout(
        height=800,
        width=1200,
        title={
            'text': "üìà Statistical Analysis: Normal vs COPD vs Asthma",
            'x': 0.5,
            'xanchor': 'center',
            'font': {'size': 20, 'color': '#2C3E50'}
        },
        plot_bgcolor='rgba(248, 249, 250, 0.8)',
        paper_bgcolor='white',
        showlegend=True
    )

    # Update subplot labels
    stats_fig.update_xaxes(title_text="Features", row=1, col=1)
    stats_fig.update_yaxes(title_text="Magnitude", row=1, col=1)
    stats_fig.update_yaxes(title_text="Feature Values", row=1, col=2)
    stats_fig.update_xaxes(title_text="Mean Feature Value", row=2, col=1)
    stats_fig.update_yaxes(title_text="Standard Deviation", row=2, col=1)
    stats_fig.update_yaxes(title_text="Mean Value", row=2, col=2)

    stats_fig.show()


def create_three_class_comparison(sample_files, preprocessed_data, sampling_rate, features_config):
    """üéØ Creates comprehensive three-class comparison visualization"""

    # Create a comprehensive comparison figure
    fig = make_subplots(
        rows=2, cols=3,
        subplot_titles=['Feature Radar Comparison', 'Distribution Comparison', 'Feature Correlation',
                       'Classification Boundaries', 'Feature Importance', 'Summary Statistics'],
        specs=[[{"type": "scatterpolar"}, {"type": "violin"}, {"type": "heatmap"}],
               [{"type": "scatter"}, {"type": "bar"}, {"type": "table"}]]
    )

    # Extract comprehensive features for all diseases
    all_disease_features = {}
    feature_names = []

    for disease, filename in sample_files.items():
        frames = preprocessed_data[filename]
        if frames.shape[0] > 0:
            signal_to_analyze = frames[0]
            combined_features = []

            for i, (feature_name, feature_func, _, _) in enumerate(features_config):
                try:
                    if feature_func == extract_wavelet_features:
                        feature_vector = feature_func(signal_to_analyze)
                    else:
                        feature_vector = feature_func(signal_to_analyze, sampling_rate)

                    processed_features = prepare_plot_data(feature_vector)
                    # Take representative features from each type
                    if len(processed_features) > 0:
                        combined_features.append(np.mean(processed_features))  # Mean of each feature type
                        if disease == list(sample_files.keys())[0]:  # Only add names once
                            feature_names.append(feature_name)
                except Exception as e:
                    print(f"‚ùå Error in comprehensive analysis for {feature_name}: {e}")
                    continue

            all_disease_features[disease] = np.array(combined_features)

    if all_disease_features and len(feature_names) > 0:
        # 1. Radar chart comparison
        for disease, features in all_disease_features.items():
            fig.add_trace(go.Scatterpolar(
                r=features,
                theta=feature_names,
                fill='toself',
                name=f'{disease.capitalize()}',
                line_color=DISEASE_COLORS.get(disease, '#636EFA'),
                opacity=0.6
            ), row=1, col=1)

        # 2. Distribution comparison
        for disease, features in all_disease_features.items():
            fig.add_trace(go.Violin(
                y=features,
                name=f'{disease.capitalize()}',
                box_visible=True,
                meanline_visible=True,
                fillcolor=DISEASE_COLORS.get(disease, '#636EFA'),
                opacity=0.6
            ), row=1, col=2)

        # 3. Feature correlation matrix
        diseases = list(all_disease_features.keys())
        if len(diseases) >= 2:
            feature_matrix = np.array([all_disease_features[d] for d in diseases])
            correlation_matrix = np.corrcoef(feature_matrix)

            fig.add_trace(go.Heatmap(
                z=correlation_matrix,
                x=[d.capitalize() for d in diseases],
                y=[d.capitalize() for d in diseases],
                colorscale='RdBu',
                zmid=0,
                showscale=True
            ), row=1, col=3)

        # 4. Classification boundaries (PCA-like)
        means = [np.mean(all_disease_features[d]) for d in diseases]
        stds = [np.std(all_disease_features[d]) for d in diseases]

        fig.add_trace(go.Scatter(
            x=means,
            y=stds,
            mode='markers+text',
            text=[d.capitalize() for d in diseases],
            textposition='top center',
            marker=dict(
                size=30,
                color=[DISEASE_COLORS.get(d, '#636EFA') for d in diseases],
                opacity=0.8,
                line=dict(width=3, color='white')
            ),
            name='Disease Separation'
        ), row=2, col=1)

        # 5. Feature importance (variance)
        feature_importance = []
        for i in range(len(feature_names)):
            variance = np.var([all_disease_features[d][i] for d in diseases if i < len(all_disease_features[d])])
            feature_importance.append(variance)

        fig.add_trace(go.Bar(
            x=feature_names,
            y=feature_importance,
            marker_color='rgba(55, 128, 191, 0.8)',
            name='Feature Importance'
        ), row=2, col=2)

        # 6. Summary table
        table_data = []
        for disease in diseases:
            features = all_disease_features[disease]
            table_data.append([
                disease.capitalize(),
                f"{np.mean(features):.3f}",
                f"{np.std(features):.3f}",
                f"{np.max(features):.3f}",
                f"{np.min(features):.3f}"
            ])

        fig.add_trace(go.Table(
            header=dict(values=['Disease', 'Mean', 'Std Dev', 'Max', 'Min'],
                       fill_color='lightblue',
                       align='left'),
            cells=dict(values=list(zip(*table_data)),
                      fill_color='white',
                      align='left')
        ), row=2, col=3)

    fig.update_layout(
        height=1000,
        width=1400,
        title={
            'text': "üéØ Comprehensive Three-Class Analysis: Normal vs COPD vs Asthma",
            'x': 0.5,
            'xanchor': 'center',
            'font': {'size': 24, 'color': '#2C3E50'}
        },
        plot_bgcolor='rgba(248, 249, 250, 0.8)',
        paper_bgcolor='white',
        showlegend=True
    )

    fig.show()


# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê
# üõ†Ô∏è UTILITY FUNCTIONS
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê

def prepare_plot_data(feature_vector, max_points=30):
    """üîß Prepares feature data for plotting"""
    plot_data = feature_vector.flatten()
    plot_data = np.nan_to_num(plot_data, nan=0.0, posinf=0.0, neginf=0.0)

    if len(plot_data) > max_points:
        indices = np.linspace(0, len(plot_data) - 1, max_points, dtype=int)
        plot_data = plot_data[indices]

    return plot_data


def add_enhanced_bar_trace(fig, data, color, name, row, col):
    """üìä Adds enhanced bar trace with styling"""
    fig.add_trace(
        go.Bar(
            y=data,
            marker=dict(
                color=color,
                opacity=0.8,
                line=dict(color='white', width=1)
            ),
            name=name,
            showlegend=False,
            hovertemplate='<b>%{fullData.name}</b><br>Index: %{x}<br>Value: %{y:.3f}<extra></extra>'
        ),
        row=row, col=col
    )


def add_enhanced_line_trace(fig, data, color, name, row, col):
    """üìà Adds enhanced line trace with styling"""
    fig.add_trace(
        go.Scatter(
            y=data,
            mode='lines+markers',
            line=dict(color=color, width=2),
            marker=dict(size=4, opacity=0.8),
            name=name,
            showlegend=False,
            hovertemplate='<b>%{fullData.name}</b><br>Index: %{x}<br>Value: %{y:.3f}<extra></extra>'
        ),
        row=row, col=col
    )


def add_enhanced_scatter_trace(fig, data, color, name, row, col):
    """üéØ Adds enhanced scatter trace with styling"""
    fig.add_trace(
        go.Scatter(
            y=data,
            mode='markers',
            marker=dict(
                color=color,
                size=6,
                opacity=0.8,
                line=dict(color='white', width=1)
            ),
            name=name,
            showlegend=False,
            hovertemplate='<b>%{fullData.name}</b><br>Index: %{x}<br>Value: %{y:.3f}<extra></extra>'
        ),
        row=row, col=col
    )


# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê
# üöÄ MAIN EXECUTION
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê

# Execute the enhanced visualization
if 'preprocessed_data' in globals() and preprocessed_data:
    print("üé® Generating enhanced lung disease feature visualizations...")
    create_enhanced_feature_visualization(preprocessed_data, sampling_rate, handler)
    print("‚úÖ Visualization generation completed!")
else:
    print("‚ùå Preprocessing data is not available. Cannot generate feature plots.")
    print("üìã Please ensure the following variables are defined:")
    print("   - preprocessed_data: Dictionary of preprocessed audio frames")
    print("   - sampling_rate: Audio sampling rate")
    print("   - handler: Feature extraction handler object")
    print("   - Feature extraction functions: extract_advanced_mfcc_features, extract_fbse_features, etc.")

In [None]:
y_data

# üß™ SMOTE Class Balancing for Imbalanced Lung Sound Data

This section applies **SMOTE (Synthetic Minority Over-sampling Technique)** to balance the dataset for better model performance.

---

## üîß Required Imports

```python
from imblearn.over_sampling import SMOTE
from sklearn.preprocessing import LabelEncoder
from collections import Counter
import joblib


In [None]:


from imblearn.over_sampling import SMOTE
from sklearn.preprocessing import LabelEncoder

# Check if X_data and y_data exist
if 'X_data' not in locals() or 'y_data' not in locals():
    print("Error: X_data or y_data is not defined. Please run the feature extraction step first.")
else:
    try:
        print("üîÑ Applying SMOTE for class balancing...")

        # Encode labels to numerical values for SMOTE
        label_encoder = LabelEncoder()
        y_encoded = label_encoder.fit_transform(y_data)
        joblib.dump(label_encoder, 'label_encoder.pkl')
        # Check current class distribution to determine which classes need balancing
        original_counts = Counter(y_encoded)
        print("üìä Original Class Distribution (Encoded):", original_counts)

        # Set up SMOTE
        # Determine sampling strategy: oversample all minority classes
        # You can customize this if you only want to oversample specific classes
        sampling_strategy = 'auto' # Oversample all minority classes to make them equal to the majority class

        smote = SMOTE(sampling_strategy=sampling_strategy, random_state=42)

        # Apply SMOTE
        X_res, y_res_encoded = smote.fit_resample(X_data, y_encoded)

        # Decode the balanced labels back to original strings
        y_res = label_encoder.inverse_transform(y_res_encoded)

        print("‚úÖ SMOTE application complete!")
        print(f"üìä Original dataset shape: {X_data.shape}, {y_data.shape}")
        print(f"üìä Resampled dataset shape: {X_res.shape}, {y_res.shape}")

        # Update the variables with the balanced data
        X_data_balanced = X_res
        y_data_balanced = y_res

        # Print the new class distribution
        balanced_counts = Counter(y_data_balanced)
        print("üìä Balanced Class Distribution:", balanced_counts)

    except ValueError as ve:
        print(f"‚ùå Error applying SMOTE: {ve}")
        print("This might happen if a class has too few samples to be resampled (e.g., less than k_neighbors, default is 5).")
        print("Consider removing classes with very few samples before applying SMOTE or reducing the k_neighbors parameter if appropriate.")
    except Exception as e:
        print(f"‚ùå An unexpected error occurred during SMOTE application: {e}")



In [None]:

# Check if balanced data exists
if 'X_data_balanced' in locals() and 'y_data_balanced' in locals():
    print("\nüé® Visualizing balanced data distribution after SMOTE...")

    # Create a DataFrame for easier plotting
    df_balanced = pd.DataFrame(X_data_balanced)
    df_balanced['label'] = y_data_balanced

    # Count the occurrences of each label in the balanced data
    balanced_counts = Counter(y_data_balanced)
    labels = list(balanced_counts.keys())
    counts = list(balanced_counts.values())

    # Create a bar chart to visualize the class distribution
    fig_bar = px.bar(
        x=labels,
        y=counts,
        color=labels,  # Use labels for coloring the bars
        color_discrete_map={
            'normal': DISEASE_COLORS.get('normal', '#2E8B57'),
            'asthma': DISEASE_COLORS.get('asthma', '#FF6B6B'),
            'copd': DISEASE_COLORS.get('copd', '#4ECDC4'),
            # Add other classes if necessary, using DISEASE_COLORS or other colors
        },
        labels={'x': 'Disease Label', 'y': 'Number of Samples'},
        title='<b>Balanced Dataset Class Distribution After SMOTE</b>',
        template='plotly_white' # Use a clean template
    )

    # Enhance the layout
    fig_bar.update_layout(
        title_font_size=20,
        xaxis_title_font_size=14,
        yaxis_title_font_size=14,
        uniformtext_minsize=8,
        uniformtext_mode='hide',
        hovermode='x unified' # Group hover info
    )

    # Show the plot
    fig_bar.show()

    print("‚úÖ Visualization of balanced data completed!")

else:
    print("‚ùå Balanced data (X_data_balanced, y_data_balanced) not found.")
    print("üìã Please ensure SMOTE was applied successfully before attempting to visualize.")


In [None]:
y_data

In [None]:
y_data_balanced

In [None]:
from collections import Counter
from imblearn.over_sampling import SMOTE
from sklearn.preprocessing import LabelEncoder, StandardScaler
import joblib
import numpy as np

In [None]:
x_mfccs=X_data_balanced
y_mfccs = y_data_balanced

In [None]:
x_mfccs

In [None]:
y_mfccs

# üéß Interactive MFCC Visualization using Plotly

This function generates a **dynamic MFCC heatmap** using `plotly.express`, ideal for analyzing and visually comparing lung sound signals.

---

## üì¶ Required Imports

```python
import os
import librosa
import numpy as np
import plotly.express as px
import plotly.graph_objects as go


In [None]:


import plotly.express as px
import plotly.graph_objects as go

def plot_mfcc_advanced(filepath):
    """
    Plots the MFCC features of the audio file using Plotly for interactive visualization.
    Args:
        filepath (str): The path to the audio file.
    """
    try:
        # Load audio
        audio, sr = librosa.load(filepath, sr=None) # Use native sampling rate
        if len(audio) == 0:
            print(f"‚ö†Ô∏è  Audio file is empty: {filepath}")
            return

        # Extract MFCCs
        # Using parameters that might be more robust for various audio lengths
        n_mfcc = 40
        n_fft = 2048
        hop_length = 512
        n_mels = 128

        # Handle short audio files
        if len(audio) < n_fft:
            print(f"‚ö†Ô∏è  Audio file is too short for MFCC extraction: {filepath}")
            # Pad the audio if it's too short
            audio = np.pad(audio, (0, n_fft - len(audio)), 'constant')


        mfccs = librosa.feature.mfcc(
            y=audio,
            sr=sr,
            n_mfcc=n_mfcc,
            n_fft=n_fft,
            hop_length=hop_length,
            n_mels=n_mels,
            fmax=sr//2
        )

        # Handle potential empty MFCC array for very short sounds even after padding
        if mfccs.shape[1] == 0:
             print(f"‚ö†Ô∏è  MFCC extraction resulted in an empty array for: {filepath}")
             return

        # Convert MFCCs to dB scale for better visualization
        mfccs_db = librosa.power_to_db(mfccs, ref=np.max)

        # Create time and MFCC coefficient index labels
        time_axis = librosa.times_like(mfccs_db, sr=sr, hop_length=hop_length)
        mfcc_coeffs = [f'MFCC {i+1}' for i in range(n_mfcc)]

        # Create interactive heatmap using Plotly
        fig = px.imshow(mfccs_db,
                        aspect="auto",
                        x=time_axis,
                        y=mfcc_coeffs,
                        labels=dict(x="Time (s)", y="MFCC Coefficient", color="Amplitude (dB)"),
                        title=f'MFCC Heatmap: {os.path.basename(filepath)}',
                        color_continuous_scale='Viridis') # Choose a pleasant color scale

        # Update layout for better readability
        fig.update_layout(
            title_x=0.5,
            yaxis_title="MFCC Coefficient Index",
            xaxis_title="Time (s)",
            hovermode='closest' # Show tooltip on hover
        )

        # Show plot
        fig.show()

    except FileNotFoundError:
        print(f"‚ùå File not found: {filepath}")
    except Exception as e:
        print(f"‚ùå An error occurred during MFCC plotting for {filepath}: {e}")

# Example usage:
# Assuming filepaths is a list of audio file paths from previous code
if 'filepaths' in locals() and filepaths:
    # Select one file to plot
    example_file_to_plot = filepaths[0] # Replace with the path to your desired file
    print(f"Generating Plotly MFCC heatmap for: {os.path.basename(example_file_to_plot)}")
    plot_mfcc_advanced(example_file_to_plot)
else:
    print("‚ö†Ô∏è  'filepaths' variable not found or is empty. Cannot plot example MFCC.")
    print("Please ensure the file discovery step has been successfully executed.")


In [None]:
# prompt: for above plot rest other featue seprately if possible consider waveforms to show the feature or heatmaps

# Function to plot Spectral Centroid
def plot_spectral_centroid(filepath):
    """Plots the Spectral Centroid of the audio file."""
    try:
        audio, sr = librosa.load(filepath, sr=None)
        if len(audio) == 0:
            print(f"‚ö†Ô∏è  Audio file is empty: {filepath}")
            return

        spectral_centroids = librosa.feature.spectral_centroid(y=audio, sr=sr)[0]
        time_axis = librosa.times_like(spectral_centroids, sr=sr)

        fig = px.line(x=time_axis, y=spectral_centroids,
                      title=f'Spectral Centroid: {os.path.basename(filepath)}',
                      labels={'x': 'Time (s)', 'y': 'Spectral Centroid (Hz)'})

        fig.update_layout(title_x=0.5)
        fig.show()

    except FileNotFoundError:
        print(f"‚ùå File not found: {filepath}")
    except Exception as e:
        print(f"‚ùå An error occurred during Spectral Centroid plotting for {filepath}: {e}")


# Function to plot Zero Crossing Rate
def plot_zero_crossing_rate(filepath):
    """Plots the Zero Crossing Rate of the audio file."""
    try:
        audio, sr = librosa.load(filepath, sr=None)
        if len(audio) == 0:
            print(f"‚ö†Ô∏è  Audio file is empty: {filepath}")
            return

        zero_crossings = librosa.feature.zero_crossing_rate(audio)[0]
        time_axis = librosa.times_like(zero_crossings, sr=sr)

        fig = px.line(x=time_axis, y=zero_crossings,
                      title=f'Zero Crossing Rate: {os.path.basename(filepath)}',
                      labels={'x': 'Time (s)', 'y': 'Zero Crossing Rate'})

        fig.update_layout(title_x=0.5)
        fig.show()

    except FileNotFoundError:
        print(f"‚ùå File not found: {filepath}")
    except Exception as e:
        print(f"‚ùå An error occurred during Zero Crossing Rate plotting for {filepath}: {e}")

# Function to plot Chroma Features (Heatmap)
def plot_chroma_features(filepath):
    """Plots the Chroma Features of the audio file as a heatmap."""
    try:
        audio, sr = librosa.load(filepath, sr=None)
        if len(audio) == 0:
            print(f"‚ö†Ô∏è  Audio file is empty: {filepath}")
            return

        # Ensure sampling rate is sufficient for chroma
        if sr < 8000:
             print(f"‚ö†Ô∏è  Sampling rate {sr}Hz is too low for meaningful Chroma features. Skipping: {filepath}")
             return

        chromagram = librosa.feature.chroma_stft(y=audio, sr=sr)
        time_axis = librosa.times_like(chromagram, sr=sr)
        chroma_notes = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']

        fig = px.imshow(chromagram,
                        aspect="auto",
                        x=time_axis,
                        y=chroma_notes,
                        labels=dict(x="Time (s)", y="Chroma Note", color="Intensity"),
                        title=f'Chroma Features: {os.path.basename(filepath)}',
                        color_continuous_scale='Plasma')

        fig.update_layout(title_x=0.5, yaxis_title="Pitch Class")
        fig.show()

    except FileNotFoundError:
        print(f"‚ùå File not found: {filepath}")
    except Exception as e:
        print(f"‚ùå An error occurred during Chroma plotting for {filepath}: {e}")

# Function to plot Mel-Spectrogram
def plot_mel_spectrogram(filepath):
    """Plots the Mel-Spectrogram of the audio file as a heatmap."""
    try:
        audio, sr = librosa.load(filepath, sr=None)
        if len(audio) == 0:
            print(f"‚ö†Ô∏è  Audio file is empty: {filepath}")
            return

        n_fft = 2048
        hop_length = 512

        # Handle short audio files
        if len(audio) < n_fft:
            print(f"‚ö†Ô∏è  Audio file is too short for Mel-Spectrogram: {filepath}")
            audio = np.pad(audio, (0, n_fft - len(audio)), 'constant')

        mel_spectrogram = librosa.feature.melspectrogram(
            y=audio,
            sr=sr,
            n_fft=n_fft,
            hop_length=hop_length
        )

        if mel_spectrogram.size == 0:
             print(f"‚ö†Ô∏è  Mel-Spectrogram extraction resulted in an empty array for: {filepath}")
             return


        mel_spectrogram_db = librosa.power_to_db(mel_spectrogram, ref=np.max)
        time_axis = librosa.times_like(mel_spectrogram_db, sr=sr, hop_length=hop_length)
        freq_axis = librosa.mel_frequencies(n_mels=mel_spectrogram.shape[0], fmin=0, fmax=sr/2)

        fig = px.imshow(mel_spectrogram_db,
                        aspect="auto",
                        x=time_axis,
                        y=freq_axis,
                        labels=dict(x="Time (s)", y="Mel Frequency (Hz)", color="Amplitude (dB)"),
                        title=f'Mel-Spectrogram: {os.path.basename(filepath)}',
                        color_continuous_scale='Jet') # Another common color scale

        fig.update_layout(title_x=0.5, yaxis_title="Frequency (Mel)")
        fig.show()

    except FileNotFoundError:
        print(f"‚ùå File not found: {filepath}")
    except Exception as e:
        print(f"‚ùå An error occurred during Mel-Spectrogram plotting for {filepath}: {e}")

# Example usage:
# Assuming filepaths is a list of audio file paths from previous code
if 'filepaths' in locals() and filepaths:
    # Select one file to plot
    example_file_to_plot = filepaths[0] # Replace with the path to your desired file
    print(f"\nGenerating visualizations for other features for: {os.path.basename(example_file_to_plot)}")

    plot_spectral_centroid(example_file_to_plot)
    plot_zero_crossing_rate(example_file_to_plot)
    plot_chroma_features(example_file_to_plot) # May be skipped if SR is too low
    plot_mel_spectrogram(example_file_to_plot)

else:
    print("\n‚ö†Ô∏è  'filepaths' variable not found or is empty. Cannot plot example features.")
    print("Please ensure the file discovery step has been successfully executed.")


In [None]:
# prompt: plot the above graphs for one model of each asthma as well as copd

# Select one file for each of the target diseases (asthma, copd, normal)
# Ensure these files exist in your filepaths list and are correctly labeled

asthma_file = None
copd_file = None
normal_file = None

# Iterate through filepaths to find one sample for each class
if 'filepaths' in locals() and filepaths:
    for f in filepaths:
        filename = os.path.basename(f).lower()
        if 'asthma' in filename and asthma_file is None:
            asthma_file = f
        elif 'copd' in filename and copd_file is None:
            copd_file = f
        elif ('_n_' in filename or '_normal_' in filename or '_c_' in filename) and normal_file is None:
             normal_file = f # Assuming '_n_' or '_normal_' or '_c_' denotes normal
        if asthma_file and copd_file and normal_file:
            break

    # Check if samples were found
    if asthma_file:
        print(f"Generating plots for Asthma sample: {os.path.basename(asthma_file)}")
        plot_mfcc_advanced(asthma_file)
        plot_spectral_centroid(asthma_file)
        plot_zero_crossing_rate(asthma_file)
        plot_chroma_features(asthma_file)
        plot_mel_spectrogram(asthma_file)
    else:
        print("‚ùå Could not find a sample file for Asthma.")

    if copd_file:
        print(f"\nGenerating plots for COPD sample: {os.path.basename(copd_file)}")
        plot_mfcc_advanced(copd_file)
        plot_spectral_centroid(copd_file)
        plot_zero_crossing_rate(copd_file)
        plot_chroma_features(copd_file)
        plot_mel_spectrogram(copd_file)
    else:
        print("‚ùå Could not find a sample file for COPD.")

    if normal_file:
        print(f"\nGenerating plots for Normal sample: {os.path.basename(normal_file)}")
        plot_mfcc_advanced(normal_file)
        plot_spectral_centroid(normal_file)
        plot_zero_crossing_rate(normal_file)
        plot_chroma_features(normal_file)
        plot_mel_spectrogram(normal_file)
    else:
         print("‚ùå Could not find a sample file for Normal.")

else:
    print("‚ö†Ô∏è  'filepaths' variable not found or is empty. Cannot plot features.")
    print("Please ensure the file discovery step has been successfully executed.")


In [None]:
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots

def plot_augmented_data_distribution(labels):
    """
    Plots the count of each disease in the dataset with enhanced visualization using Plotly.

    Args:
        labels (array-like): List or array of disease labels.

    Returns:
        dict: A dictionary containing the count of each unique disease.
    """
    unique_labels, counts = np.unique(labels, return_counts=True)
    data_count = dict(zip(unique_labels, counts))

    # Map labels for display: 'n' to 'Healthy', others remain as disease names
    display_labels = []
    for label in unique_labels:
        if label == 'n':
            display_labels.append('Healthy')
        else:
            display_labels.append(str(label))

    # Create professional color palette with gradient and complementary colors
    if len(unique_labels) <= 10:
        # Professional color palette for up to 10 categories
        professional_colors = [
            '#2E86AB',  # Ocean Blue
            '#A23B72',  # Deep Rose
            '#F18F01',  # Amber Orange
            '#C73E1D',  # Crimson Red
            '#6A994E',  # Forest Green
            '#7209B7',  # Royal Purple
            '#F77F00',  # Burnt Orange
            '#FCBF49',  # Golden Yellow
            '#003566',  # Navy Blue
            '#06FFA5'   # Mint Green
        ]
        colors = professional_colors[:len(unique_labels)]
    else:
        # For more than 10 categories, use a smooth gradient
        colors = px.colors.sample_colorscale(
            'viridis',
            [i/(len(unique_labels)-1) for i in range(len(unique_labels))]
        )

    # Create the main bar chart
    fig = go.Figure()

    # Add bar chart with enhanced styling
    fig.add_trace(go.Bar(
        x=display_labels,
        y=counts,
        text=[f'{count}' for count in counts],
        textposition='outside',
        textfont=dict(size=12, color='black', family='Arial Black'),
        marker=dict(
            color=colors,
            line=dict(color='rgba(255, 255, 255, 0.8)', width=2.5),
            opacity=0.85,
            # Add subtle gradient effect
            pattern=dict(
                shape="",
                bgcolor="rgba(255, 255, 255, 0.1)"
            )
        ),
        hovertemplate='<b>%{x}</b><br>' +
                      'Count: %{y}<br>' +
                      '<extra></extra>',
        name='Disease Count'
    ))

    # Calculate statistics
    total_samples = sum(counts)
    unique_classes = len(unique_labels)

    # Update layout with enhanced styling
    fig.update_layout(
        title=dict(
            text="Distribution of Diseases in Augmented Data",
            x=0.5,
            font=dict(size=18, color='#2c3e50', family='Arial Black')
        ),
        xaxis=dict(
            title=dict(
                text="Diseases",
                font=dict(size=14, color='#2c3e50', family='Arial Black')
            ),
            tickfont=dict(size=12, color='#34495e'),
            tickangle=45,
            showgrid=False,
            showline=True,
            linewidth=2,
            linecolor='#bdc3c7'
        ),
        yaxis=dict(
            title=dict(
                text="Count",
                font=dict(size=14, color='#2c3e50', family='Arial Black')
            ),
            tickfont=dict(size=12, color='#34495e'),
            showgrid=True,
            gridwidth=1,
            gridcolor='rgba(189, 195, 199, 0.3)',
            showline=True,
            linewidth=2,
            linecolor='#bdc3c7'
        ),
        plot_bgcolor='rgba(248, 249, 250, 0.95)',
        paper_bgcolor='#FEFEFE',
        font=dict(family='Arial'),
        showlegend=False,
        margin=dict(l=80, r=80, t=100, b=120),
        height=600,
        width=1000
    )

    # Add statistics annotation with professional styling
    fig.add_annotation(
        text=f"<b>Statistics</b><br>Total Samples: {total_samples}<br>Unique Classes: {unique_classes}",
        xref="paper", yref="paper",
        x=0.02, y=0.98,
        xanchor="left", yanchor="top",
        showarrow=False,
        bgcolor="rgba(46, 134, 171, 0.15)",
        bordercolor="#2E86AB",
        borderwidth=2,
        borderpad=12,
        font=dict(size=11, color='#2c3e50', family='Arial Bold')
    )

    # Add hover effects and interactivity
    fig.update_traces(
        marker=dict(
            line=dict(width=2),
        ),
        selector=dict(type="bar")
    )

    # Show the interactive plot
    fig.show()

    return data_count

# Example usage:
# labels = np.array(['disease1', 'n', 'disease2', 'n', 'disease1'])  # Your dataset labels
# disease_counts = plot_augmented_data_distribution(labels)
# print(disease_counts)

In [None]:
 plot_augmented_data_distribution(y_mfccs)

# üß¨ One-Hot Encoding for Disease Labels (Normal, Asthma, COPD)

This function performs **efficient, vectorized one-hot encoding** of lung disease labels:
- `'normal'` ‚Üí `[1, 0, 0]`
- `'asthma'` ‚Üí `[0, 1, 0]`
- `'copd'` ‚Üí `[0, 0, 1]`

---

In [None]:
import numpy as np

def encode_disease_labels(y_data_balanced):
    """
    Efficient one-hot encoding for disease labels using vectorized operations.

    Args:
        y_data_balanced (array-like): Input labels to encode
                                    Expected labels: ['normal', 'asthma', 'copd']

    Returns:
        numpy.ndarray: One-hot encoded labels (n_samples, 3)
                      [1,0,0] for 'normal'
                      [0,1,0] for 'asthma'
                      [0,0,1] for 'copd'
    """
    # Convert to numpy array and flatten
    y_flat = np.array(y_data_balanced).flatten()
    n_samples = len(y_flat)

    # Pre-allocate output array (3 classes, not 4)
    Y_data = np.zeros((n_samples, 3), dtype=np.float64)

    # Vectorized encoding using boolean indexing
    Y_data[y_flat == 'normal', 0] = 1   # [1,0,0]
    Y_data[y_flat == 'asthma', 1] = 1   # [0,1,0]
    Y_data[y_flat == 'copd', 2] = 1     # [0,0,1]

    return Y_data

# Example usage:
# y_data_balanced = ['normal', 'asthma', 'copd', 'normal', 'asthma']
Y_data = encode_disease_labels(y_data_balanced)



In [None]:
X_data_balanced.shape, Y_data.shape

In [None]:
X_data=X_data_balanced

In [None]:
Y_data   #=> normal  , asthma , copd

# üîÑ GRU-Ready Data Split & Preparation Pipeline

This function handles the complete pipeline for preparing your dataset to train a GRU-based deep learning model. It includes:
- Validated stratified splits into **Train / Validation / Test**
- Proper input reshaping to meet GRU input shape requirements
- Optional label reshaping for time-series regression (optional)

---

## üì¶ Required Imports

```python
import numpy as np
from sklearn.model_selection import train_test_split


In [None]:
import numpy as np
from sklearn.model_selection import train_test_split

def split_and_prepare_for_gru(X_data, Y_data, train_ratio=0.75, val_ratio=0.175, test_ratio=0.075,
                             random_state=10, reshape_labels=False):
    """
    Complete pipeline: Split data and prepare for GRU training in one function.

    Args:
        X_data: Feature data
        Y_data: Label data
        train_ratio: Training set ratio (default: 0.75)
        val_ratio: Validation set ratio (default: 0.175)
        test_ratio: Test set ratio (default: 0.075)
        random_state: Random seed for reproducibility
        reshape_labels: Whether to add time dimension to labels (False for classification)

    Returns:
        tuple: (x_train_gru, x_val_gru, x_test_gru, y_train_gru, y_val_gru, y_test_gru)
    """

    # Step 1: Data Splitting
    print("Step 1: Splitting data...")

    # Validate ratios
    total_ratio = train_ratio + val_ratio + test_ratio
    if not np.isclose(total_ratio, 1.0):
        raise ValueError(f"Ratios must sum to 1.0, got {total_ratio}")

    # First split: separate training from (validation + test)
    val_test_ratio = val_ratio + test_ratio
    X_train, X_temp, y_train, y_temp = train_test_split(
        X_data, Y_data,
        test_size=val_test_ratio,
        random_state=random_state,
        stratify=np.argmax(Y_data, axis=1) if Y_data.ndim > 1 else Y_data
    )

    # Second split: separate validation from test
    test_ratio_adjusted = test_ratio / val_test_ratio
    X_val, X_test, y_val, y_test = train_test_split(
        X_temp, y_temp,
        test_size=test_ratio_adjusted,
        random_state=random_state,
        stratify=np.argmax(y_temp, axis=1) if y_temp.ndim > 1 else y_temp
    )

    # Print split statistics
    total_samples = len(X_train) + len(X_val) + len(X_test)
    print(f"\nData Split Statistics:")
    print("=" * 40)
    print(f"Training:   {len(X_train)/total_samples:.1%} ({len(X_train):,} samples)")
    print(f"Validation: {len(X_val)/total_samples:.1%} ({len(X_val):,} samples)")
    print(f"Testing:    {len(X_test)/total_samples:.1%} ({len(X_test):,} samples)")
    print(f"Total:      {total_samples:,} samples")

    print("\nOriginal shapes after splitting:")
    print(f"Features: Train={X_train.shape}, Val={X_val.shape}, Test={X_test.shape}")
    print(f"Labels:   Train={y_train.shape}, Val={y_val.shape}, Test={y_test.shape}")

    # Step 2: GRU Data Preparation
    print("\nStep 2: Preparing data for GRU...")

    # Reshape features for GRU (add time dimension)
    # GRU expects: (batch_size, timesteps, features)
    x_train_gru = np.expand_dims(X_train, axis=1)
    x_val_gru = np.expand_dims(X_val, axis=1)
    x_test_gru = np.expand_dims(X_test, axis=1)

    # Handle labels based on reshape_labels parameter
    if reshape_labels:
        y_train_gru = np.expand_dims(y_train, axis=1)
        y_val_gru = np.expand_dims(y_val, axis=1)
        y_test_gru = np.expand_dims(y_test, axis=1)
    else:
        # Keep labels as-is for standard classification
        y_train_gru = y_train
        y_val_gru = y_val
        y_test_gru = y_test

    print("\nFinal GRU-ready shapes:")
    print(f"Features: Train={x_train_gru.shape}, Val={x_val_gru.shape}, Test={x_test_gru.shape}")
    print(f"Labels:   Train={y_train_gru.shape}, Val={y_val_gru.shape}, Test={y_test_gru.shape}")

    # Validation checks
    assert x_train_gru.shape[0] == y_train_gru.shape[0], "Mismatch in training samples"
    assert x_val_gru.shape[0] == y_val_gru.shape[0], "Mismatch in validation samples"
    assert x_test_gru.shape[0] == y_test_gru.shape[0], "Mismatch in test samples"

    # Class distribution (if labels are one-hot encoded)
    if y_train.ndim > 1 and y_train.shape[1] > 1:
        class_names = ['COPD', 'Asthma', 'Healthy']
        print("\nClass Distribution:")
        for i, class_name in enumerate(class_names):
            train_count = np.sum(y_train[:, i])
            val_count = np.sum(y_val[:, i])
            test_count = np.sum(y_test[:, i])
            total_class = train_count + val_count + test_count
            print(f"  {class_name:>10}: Train={train_count:>3.0f} | Val={val_count:>3.0f} | Test={test_count:>3.0f} | Total={total_class:>3.0f}")

    print("\n‚úì Data splitting and GRU preparation completed successfully!")

    return x_train_gru, x_val_gru, x_test_gru, y_train_gru, y_val_gru, y_test_gru

# Alternative: Enhanced version of original approach (matching your exact ratios)
def split_and_prepare_original_enhanced(X_data, Y_data, random_state=10):
    """
    Enhanced version matching your original splitting ratios exactly.
    """
    print("Using original enhanced splitting approach...")

    # First split: 82.5% train+test, 17.5% validation
    X_temp, X_val, y_temp, y_val = train_test_split(
        X_data, Y_data,
        test_size=0.175,
        random_state=random_state,
        stratify=np.argmax(Y_data, axis=1) if Y_data.ndim > 1 else Y_data
    )

    # Second split: ~9.1% of total for test (7.5% of remaining 82.5%)
    X_train, X_test, y_train, y_test = train_test_split(
        X_temp, y_temp,
        test_size=0.075/0.825,  # Adjusted ratio
        random_state=random_state,
        stratify=np.argmax(y_temp, axis=1) if y_temp.ndim > 1 else y_temp
    )

    # Prepare for GRU
    x_train_gru = np.expand_dims(X_train, axis=1)
    x_val_gru = np.expand_dims(X_val, axis=1)
    x_test_gru = np.expand_dims(X_test, axis=1)

    y_train_gru = y_train
    y_val_gru = y_val
    y_test_gru = y_test

    # Print results
    total_samples = len(X_train) + len(X_val) + len(X_test)
    print(f"\nSplit ratios achieved:")
    print(f"Training:   {len(X_train)/total_samples:.1%} ({len(X_train)} samples)")
    print(f"Validation: {len(X_val)/total_samples:.1%} ({len(X_val)} samples)")
    print(f"Testing:    {len(X_test)/total_samples:.1%} ({len(X_test)} samples)")

    print(f"\nGRU-ready shapes:")
    print(f"x_train_gru: {x_train_gru.shape}")
    print(f"x_val_gru:   {x_val_gru.shape}")
    print(f"x_test_gru:  {x_test_gru.shape}")
    print(f"y_train_gru: {y_train_gru.shape}")
    print(f"y_val_gru:   {y_val_gru.shape}")
    print(f"y_test_gru:  {y_test_gru.shape}")

    return x_train_gru, x_val_gru, x_test_gru, y_train_gru, y_val_gru, y_test_gru

# Usage Examples:

# Method 1: Complete pipeline with custom ratios (recommended)
x_train_gru, x_val_gru, x_test_gru, y_train_gru, y_val_gru, y_test_gru = split_and_prepare_for_gru(
    X_data, Y_data,
    train_ratio=0.75,
    val_ratio=0.175,
    test_ratio=0.075,
    random_state=10,
    reshape_labels=False  # Set to False for classification tasks
)

# Method 2: Original enhanced approach
# x_train_gru, x_val_gru, x_test_gru, y_train_gru, y_val_gru, y_test_gru = split_and_prepare_original_enhanced(
#     X_data, Y_data, random_state=10
# )

# Your data is now ready for GRU training!
print(f"\nüéâ Final shapes ready for GRU model:")
print(f"Features: {x_train_gru.shape}, {x_val_gru.shape}, {x_test_gru.shape}")
print(f"Labels:   {y_train_gru.shape}, {y_val_gru.shape}, {y_test_gru.shape}")

In [None]:
x_train_gru

In [None]:
y_train_gru

In [None]:
!pip install keras-tuner --upgrade --quiet


In [None]:
pip install tensorflow tensorflow-addons keras-tuner

# **MODEL 1**
# ü´Å Lung Sound Classification Training Pipeline ‚Äì ‚öôÔ∏è FIXED & Optimized

This pipeline builds a **deep neural network** optimized for **lung disease detection** (Asthma, COPD, Healthy) using **GRU**, **Conv1D**, and **advanced regularization techniques**. Designed to be used with **pre-split data** in the format:

- **Input Shape**: `(samples, 1, 959)`
- **Labels**: One-hot encoded for 3 classes (`[Healthy, Asthma, COPD]`)

---

## ‚úÖ Key Features

- üí° **Hybrid Conv1D + Bi-GRU** architecture for local + temporal pattern learning  
- üß† Customizable architecture via config dictionary  
- üß™ Advanced regularization: `Dropout`, `L1/L2`, and `BatchNorm`  
- üìà Training callbacks: `EarlyStopping`, `ReduceLROnPlateau`, `LearningRateScheduler`  
- üéØ Custom metrics including **weighted F1 Score** and **Precision/Recall**  
- üìä Visual training history + Confusion Matrix  
- üì¶ Easily pluggable into your data pipeline

---

## üîß Model Architecture Highlights

- `Conv1D ‚Üí Bi-GRU(128) ‚Üí Bi-GRU(64) ‚Üí Dense(256 ‚Üí 128) ‚Üí Output(Softmax)`
- Support for:
  - **Bidirectional RNNs**
  - **Multiple pooling strategies**: `avg`, `max`, `both`
  - **Dynamic learning rate scheduling**
- Input: `1 √ó 959` (time √ó features)

---

## üöÄ How to Use

### 1. Load your GRU-ready data

```python
# Expected shapes
x_train_gru.shape  # (e.g., (1181, 1, 959))
y_train_gru.shape  # (1181, 3)
x_val_gru.shape    # (275, 1, 959)
y_val_gru.shape    # (275, 3)
x_test_gru.shape   # (119, 1, 959)
y_test_gru.shape   # (119, 3)


In [None]:
# Lung Sound Classification Training Pipeline - FIXED VERSION
# Ready to use with your pre-split data

import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, Dense, Dropout, BatchNormalization, GRU,
    Bidirectional, GlobalAveragePooling1D, GlobalMaxPooling1D,
    Concatenate, Conv1D, LeakyReLU
)
from tensorflow.keras.regularizers import l2, l1_l2
from tensorflow.keras.optimizers import AdamW
from tensorflow.keras.callbacks import (
    EarlyStopping, ReduceLROnPlateau, ModelCheckpoint,
    CSVLogger, LearningRateScheduler
)
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report, confusion_matrix, f1_score
import matplotlib.pyplot as plt
import seaborn as sns

class OptimizedLungSoundClassifier:
    """
    Optimized Neural Network for Lung Sound Classification
    Designed for your specific dataset: (samples, 1, 959) -> 3 classes
    """

    def __init__(self, input_shape=(1, 959), num_classes=3):
        self.input_shape = input_shape
        self.num_classes = num_classes
        self.model = None
        self.history = None
        self.class_weights = None

    def create_model(self, config=None):
        """Create optimized model architecture"""
        if config is None:
            config = {
                'conv_filters': 64,
                'gru_units_1': 128,
                'gru_units_2': 64,
                'dense_units_1': 256,
                'dense_units_2': 128,
                'dropout_rate': 0.4,
                'l2_reg': 0.001,
                'use_bidirectional': True,
                'use_conv1d': True,
                'pooling_strategy': 'both'
            }

        # Input layer
        inputs = Input(shape=self.input_shape, name='lung_sound_input')
        x = inputs

        # Optional Conv1D for local pattern extraction
        if config['use_conv1d']:
            x = Conv1D(
                filters=config['conv_filters'],
                kernel_size=5,
                padding='same',
                kernel_regularizer=l2(config['l2_reg'])
            )(x)
            x = BatchNormalization()(x)
            x = LeakyReLU(alpha=0.1)(x)
            x = Dropout(config['dropout_rate'] * 0.5)(x)

        # First RNN layer
        if config['use_bidirectional']:
            x = Bidirectional(
                GRU(
                    config['gru_units_1'],
                    return_sequences=True,
                    kernel_regularizer=l2(config['l2_reg']),
                    recurrent_regularizer=l2(config['l2_reg'] * 0.5),
                    dropout=config['dropout_rate'] * 0.3,
                    recurrent_dropout=config['dropout_rate'] * 0.3
                ),
                name='bi_gru_1'
            )(x)
        else:
            x = GRU(
                config['gru_units_1'] * 2,
                return_sequences=True,
                kernel_regularizer=l2(config['l2_reg']),
                recurrent_regularizer=l2(config['l2_reg'] * 0.5),
                dropout=config['dropout_rate'] * 0.3,
                recurrent_dropout=config['dropout_rate'] * 0.3,
                name='gru_1'
            )(x)

        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.1)(x)

        # Second RNN layer
        if config['use_bidirectional']:
            x = Bidirectional(
                GRU(
                    config['gru_units_2'],
                    return_sequences=True,
                    kernel_regularizer=l2(config['l2_reg']),
                    recurrent_regularizer=l2(config['l2_reg'] * 0.5),
                    dropout=config['dropout_rate'] * 0.3,
                    recurrent_dropout=config['dropout_rate'] * 0.3
                ),
                name='bi_gru_2'
            )(x)
        else:
            x = GRU(
                config['gru_units_2'] * 2,
                return_sequences=True,
                kernel_regularizer=l2(config['l2_reg']),
                recurrent_regularizer=l2(config['l2_reg'] * 0.5),
                dropout=config['dropout_rate'] * 0.3,
                recurrent_dropout=config['dropout_rate'] * 0.3,
                name='gru_2'
            )(x)

        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.1)(x)

        # Global pooling
        if config['pooling_strategy'] == 'both':
            avg_pool = GlobalAveragePooling1D()(x)
            max_pool = GlobalMaxPooling1D()(x)
            x = Concatenate()([avg_pool, max_pool])
        elif config['pooling_strategy'] == 'avg':
            x = GlobalAveragePooling1D()(x)
        else:
            x = GlobalMaxPooling1D()(x)

        # Dense layers
        x = Dense(
            config['dense_units_1'],
            kernel_regularizer=l1_l2(l1=config['l2_reg']*0.5, l2=config['l2_reg'])
        )(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.1)(x)
        x = Dropout(config['dropout_rate'])(x)

        x = Dense(
            config['dense_units_2'],
            kernel_regularizer=l1_l2(l1=config['l2_reg']*0.5, l2=config['l2_reg'])
        )(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.1)(x)
        x = Dropout(config['dropout_rate'])(x)

        # Output layer
        outputs = Dense(
            self.num_classes,
            activation='softmax',
            kernel_regularizer=l2(config['l2_reg']),
            name='classification_output'
        )(x)

        model = Model(inputs=inputs, outputs=outputs, name='LungSoundClassifier')
        return model

    def compute_class_weights(self, y_train):
        """Compute class weights for balanced training"""
        y_indices = np.argmax(y_train, axis=1)
        classes = np.unique(y_indices)
        class_weights = compute_class_weight('balanced', classes=classes, y=y_indices)
        self.class_weights = dict(zip(classes, class_weights))

        print(f"üìä Class weights computed:")
        class_names = ['Healthy', 'Asthma', 'COPD']
        for i, weight in self.class_weights.items():
            print(f"   ‚Ä¢ {class_names[i]}: {weight:.3f}")

        return self.class_weights

    def create_callbacks(self, model_name='lung_sound_model'):
        """Create training callbacks"""

        def scheduler(epoch, lr):
            if epoch < 20:
                return lr
            elif epoch < 50:
                return lr * 0.5
            elif epoch < 80:
                return lr * 0.25
            else:
                return lr * 0.1

        callbacks = [
            EarlyStopping(
                monitor='val_loss',
                patience=25,
                restore_best_weights=True,
                verbose=1,
                min_delta=0.0001
            ),
            ReduceLROnPlateau(
                monitor='val_loss',
                factor=0.5,
                patience=8,
                min_lr=1e-7,
                verbose=1
            ),
            ModelCheckpoint(
                filepath=f'{model_name}_best.keras',
                monitor='val_accuracy',
                save_best_only=True,
                save_weights_only=False,
                verbose=1,
                mode='max'
            ),
            CSVLogger(f'{model_name}_training_log.csv', append=True),
            LearningRateScheduler(scheduler, verbose=0)
        ]

        return callbacks

    def compile_model(self, model, learning_rate=0.001):
        """Compile model with optimizer and metrics - FIXED VERSION"""

        # Fixed F1 Score metric using Keras backend operations
        def f1_score_metric(y_true, y_pred):
            """F1 score metric that works with TensorFlow's computation graph"""
            # Convert predictions to class indices
            y_pred_classes = tf.argmax(y_pred, axis=1)
            y_true_classes = tf.argmax(y_true, axis=1)

            # Calculate confusion matrix components
            tp = tf.reduce_sum(tf.cast(
                tf.logical_and(
                    tf.equal(y_true_classes, y_pred_classes),
                    tf.equal(y_true_classes, 1)  # Assuming class 1 for binary-like F1
                ), tf.float32))

            fp = tf.reduce_sum(tf.cast(
                tf.logical_and(
                    tf.not_equal(y_true_classes, y_pred_classes),
                    tf.equal(y_pred_classes, 1)
                ), tf.float32))

            fn = tf.reduce_sum(tf.cast(
                tf.logical_and(
                    tf.not_equal(y_true_classes, y_pred_classes),
                    tf.equal(y_true_classes, 1)
                ), tf.float32))

            # Calculate precision and recall
            precision = tp / (tp + fp + tf.keras.backend.epsilon())
            recall = tp / (tp + fn + tf.keras.backend.epsilon())

            # Calculate F1 score
            f1 = 2 * precision * recall / (precision + recall + tf.keras.backend.epsilon())

            return f1

        # Simpler categorical accuracy as alternative
        def weighted_categorical_accuracy(y_true, y_pred):
            """Alternative metric that's more stable"""
            return tf.keras.metrics.categorical_accuracy(y_true, y_pred)

        optimizer = AdamW(
            learning_rate=learning_rate,
            weight_decay=0.01,
            clipnorm=1.0
        )

        # Use standard metrics that are guaranteed to work
        model.compile(
            optimizer=optimizer,
            loss='categorical_crossentropy',
            metrics=[
                'accuracy',
                tf.keras.metrics.Precision(name='precision'),
                tf.keras.metrics.Recall(name='recall'),
                weighted_categorical_accuracy
            ]
        )

        return model

    def train(self, x_train, y_train, x_val, y_val,
              epochs=100, batch_size=32, config=None):
        """Train the model"""

        print("üöÄ Starting training...")
        print(f"   ‚Ä¢ Training samples: {x_train.shape[0]}")
        print(f"   ‚Ä¢ Validation samples: {x_val.shape[0]}")
        print(f"   ‚Ä¢ Input shape: {x_train.shape[1:]}")

        # Create and compile model
        self.model = self.create_model(config)
        self.model = self.compile_model(self.model)

        # Compute class weights
        class_weights = self.compute_class_weights(y_train)

        # Create callbacks
        callbacks = self.create_callbacks()

        # Print model summary
        print(f"\nüèóÔ∏è Model Architecture:")
        print(f"   ‚Ä¢ Total parameters: {self.model.count_params():,}")
        self.model.summary()

        # Train model
        self.history = self.model.fit(
            x_train, y_train,
            validation_data=(x_val, y_val),
            epochs=epochs,
            batch_size=batch_size,
            callbacks=callbacks,
            class_weight=class_weights,
            verbose=1,
            shuffle=True
        )

        print("‚úÖ Training completed!")
        return self.history

    def evaluate_model(self, x_test, y_test):
        """Evaluate model performance"""

        if self.model is None:
            print("‚ùå Model not trained yet!")
            return

        # Predictions
        y_pred_proba = self.model.predict(x_test, verbose=0)
        y_pred = np.argmax(y_pred_proba, axis=1)
        y_true = np.argmax(y_test, axis=1)

        # Metrics
        evaluation_results = self.model.evaluate(x_test, y_test, verbose=0)
        test_loss = evaluation_results[0]
        test_acc = evaluation_results[1]
        test_prec = evaluation_results[2] if len(evaluation_results) > 2 else 0
        test_rec = evaluation_results[3] if len(evaluation_results) > 3 else 0

        # Calculate F1 score manually using sklearn
        test_f1 = f1_score(y_true, y_pred, average='weighted')

        print(f"\nüìä Test Set Performance:")
        print(f"   ‚Ä¢ Test Loss: {test_loss:.4f}")
        print(f"   ‚Ä¢ Test Accuracy: {test_acc:.4f}")
        print(f"   ‚Ä¢ Test Precision: {test_prec:.4f}")
        print(f"   ‚Ä¢ Test Recall: {test_rec:.4f}")
        print(f"   ‚Ä¢ Test F1-Score: {test_f1:.4f}")

        # Classification report
        class_names = ['Healthy', 'Asthma', 'COPD']
        print(f"\nüìã Detailed Classification Report:")
        print(classification_report(y_true, y_pred, target_names=class_names))

        # Confusion matrix
        cm = confusion_matrix(y_true, y_pred)

        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                   xticklabels=class_names, yticklabels=class_names)
        plt.title('Confusion Matrix - Test Set')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.tight_layout()
        plt.show()

        return {
            'accuracy': test_acc,
            'precision': test_prec,
            'recall': test_rec,
            'f1_score': test_f1,
            'predictions': y_pred_proba,
            'confusion_matrix': cm
        }

    def plot_training_history(self):
        """Plot training history"""

        if self.history is None:
            print("‚ùå No training history available!")
            return

        fig, axes = plt.subplots(2, 2, figsize=(15, 10))

        # Accuracy
        axes[0, 0].plot(self.history.history['accuracy'], label='Train Accuracy', color='blue')
        axes[0, 0].plot(self.history.history['val_accuracy'], label='Val Accuracy', color='orange')
        axes[0, 0].set_title('Model Accuracy')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Accuracy')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)

        # Loss
        axes[0, 1].plot(self.history.history['loss'], label='Train Loss', color='blue')
        axes[0, 1].plot(self.history.history['val_loss'], label='Val Loss', color='orange')
        axes[0, 1].set_title('Model Loss')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Loss')
        axes[0, 1].legend()
        axes[0, 1].grid(True, alpha=0.3)

        # Precision
        if 'precision' in self.history.history:
            axes[1, 0].plot(self.history.history['precision'], label='Train Precision', color='blue')
            axes[1, 0].plot(self.history.history['val_precision'], label='Val Precision', color='orange')
            axes[1, 0].set_title('Precision')
            axes[1, 0].set_xlabel('Epoch')
            axes[1, 0].set_ylabel('Precision')
            axes[1, 0].legend()
            axes[1, 0].grid(True, alpha=0.3)

        # Learning Rate
        if 'lr' in self.history.history:
            axes[1, 1].plot(self.history.history['lr'], label='Learning Rate', color='green')
            axes[1, 1].set_title('Learning Rate Schedule')
            axes[1, 1].set_xlabel('Epoch')
            axes[1, 1].set_ylabel('Learning Rate')
            axes[1, 1].legend()
            axes[1, 1].grid(True, alpha=0.3)
            axes[1, 1].set_yscale('log')
        else:
            # Plot recall if learning rate not available
            if 'recall' in self.history.history:
                axes[1, 1].plot(self.history.history['recall'], label='Train Recall', color='blue')
                axes[1, 1].plot(self.history.history['val_recall'], label='Val Recall', color='orange')
                axes[1, 1].set_title('Recall')
                axes[1, 1].set_xlabel('Epoch')
                axes[1, 1].set_ylabel('Recall')
                axes[1, 1].legend()
                axes[1, 1].grid(True, alpha=0.3)

        plt.tight_layout()
        plt.show()


# Training Pipeline - Ready to use with your data
def train_lung_sound_classifier(x_train_gru, y_train_gru, x_val_gru, y_val_gru, x_test_gru, y_test_gru):
    """
    Complete training pipeline for your lung sound data

    Parameters:
    - x_train_gru: (1181, 1, 959) - Training features
    - y_train_gru: (1181, 3) - Training labels (one-hot encoded)
    - x_val_gru: (275, 1, 959) - Validation features
    - y_val_gru: (275, 3) - Validation labels
    - x_test_gru: (119, 1, 959) - Test features
    - y_test_gru: (119, 3) - Test labels
    """

    print("üéØ Lung Sound Classification Training Pipeline - FIXED VERSION")
    print("=" * 65)

    # Verify data shapes
    print(f"üìä Data shapes:")
    print(f"   ‚Ä¢ Train: {x_train_gru.shape} -> {y_train_gru.shape}")
    print(f"   ‚Ä¢ Val:   {x_val_gru.shape} -> {y_val_gru.shape}")
    print(f"   ‚Ä¢ Test:  {x_test_gru.shape} -> {y_test_gru.shape}")

    # Initialize classifier
    classifier = OptimizedLungSoundClassifier(input_shape=(1, 959), num_classes=3)

    # Configuration for your dataset
    config = {
        'conv_filters': 64,
        'gru_units_1': 96,        # Reduced for better generalization
        'gru_units_2': 48,        # Reduced for better generalization
        'dense_units_1': 128,     # Reduced to prevent overfitting
        'dense_units_2': 64,      # Reduced to prevent overfitting
        'dropout_rate': 0.5,      # Higher dropout for regularization
        'l2_reg': 0.01,           # Strong L2 regularization
        'use_bidirectional': True,
        'use_conv1d': True,
        'pooling_strategy': 'both'
    }

    # Train the model
    print(f"\nüöÄ Starting training with optimized configuration...")
    history = classifier.train(
        x_train_gru, y_train_gru,
        x_val_gru, y_val_gru,
        epochs=150,     # Sufficient epochs with early stopping
        batch_size=16,  # Smaller batch size for stable training
        config=config
    )

    # Evaluate on test set
    print(f"\nüß™ Evaluating on test set...")
    test_results = classifier.evaluate_model(x_test_gru, y_test_gru)

    # Plot training history
    print(f"\nüìà Plotting training history...")
    classifier.plot_training_history()

    return classifier, history, test_results


# Simple usage example with placeholder variables
# Replace these with your actual variable names
def example_usage():
    """
    Example of how to use the fixed training pipeline
    """
    print("üìù Example Usage:")
    print("1. Make sure your data variables are loaded:")
    print("   - x_train_gru, y_train_gru (training data)")
    print("   - x_val_gru, y_val_gru (validation data)")
    print("   - x_test_gru, y_test_gru (test data)")
    print("\n2. Run the training pipeline:")
    print("   classifier, history, results = train_lung_sound_classifier(")
    print("       x_train_gru, y_train_gru,")
    print("       x_val_gru, y_val_gru,")
    print("       x_test_gru, y_test_gru")
    print("   )")
    print("\n3. Check results:")
    print("   print(f'Final Test Accuracy: {results[\"accuracy\"]:.4f}')")
    print("   print(f'Final Test F1-Score: {results[\"f1_score\"]:.4f}')")

if __name__ == "__main__":
    example_usage()

    # Uncomment the lines below when you have your data ready:
    model, history, results = train_lung_sound_classifier(
        x_train_gru, y_train_gru,
        x_val_gru, y_val_gru,
        x_test_gru, y_test_gru
    )

    # Expected performance: 85-95% accuracy with good generalization

# üìä ROC & AUC Analysis ‚Äì Lung Sound Classification

This module provides **visual diagnostic tools** to evaluate your trained classifier‚Äôs performance using **ROC Curves** and **AUC scores** for each class.

---

## ‚úÖ Features

- üéØ **Main ROC Curve** (One-vs-Rest) with Micro & Macro averages  
- üî¨ **Detailed Subplot Analysis** (1-vs-Rest per class + Combined View)  
- üìà **AUC Comparison Bar Chart**  
- üì¶ Integrated function for complete analysis in 1 line  

---

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, roc_auc_score
from sklearn.preprocessing import label_binarize
from itertools import cycle
import seaborn as sns

def plot_roc_curves(classifier, x_test_gru, y_test_gru, class_names=None):
    """
    Plot ROC curves for multi-class classification using OptimizedLungSoundClassifier

    Parameters:
    - classifier: Trained OptimizedLungSoundClassifier object
    - x_test_gru: Test features (119, 1, 959)
    - y_test_gru: Test labels one-hot encoded (119, 3)
    - class_names: List of class names (default: ["Normal", "Asthma", "COPD"])
    """

    if class_names is None:
        class_names = ["Normal", "Asthma", "COPD"]

    n_classes = len(class_names)

    print("üéØ Generating ROC Curves...")
    print(f"   ‚Ä¢ Test samples: {x_test_gru.shape[0]}")
    print(f"   ‚Ä¢ Classes: {class_names}")

    # Get prediction probabilities using the correct method
    if hasattr(classifier, 'model') and classifier.model is not None:
        y_score = classifier.model.predict(x_test_gru, verbose=0)
    else:
        raise ValueError("Classifier model not found. Make sure the model is trained.")

    y_true = y_test_gru  # Already one-hot encoded

    print(f"   ‚Ä¢ Prediction shape: {y_score.shape}")
    print(f"   ‚Ä¢ True labels shape: {y_true.shape}")

    # Compute ROC curve and ROC area for each class
    fpr = dict()
    tpr = dict()
    roc_auc = dict()

    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(y_true[:, i], y_score[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    # Compute micro-average ROC curve and ROC area
    fpr["micro"], tpr["micro"], _ = roc_curve(y_true.ravel(), y_score.ravel())
    roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

    # Compute macro-average ROC curve and ROC area
    # First aggregate all false positive rates
    all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))

    # Then interpolate all ROC curves at this points
    mean_tpr = np.zeros_like(all_fpr)
    for i in range(n_classes):
        mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])

    # Finally average it and compute AUC
    mean_tpr /= n_classes

    fpr["macro"] = all_fpr
    tpr["macro"] = mean_tpr
    roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

    # Plot all ROC curves
    plt.figure(figsize=(12, 10))

    # Define colors for each class
    colors = cycle(['aqua', 'darkorange', 'cornflowerblue', 'red', 'green', 'purple'])

    # Plot ROC curve for each class
    for i, color in zip(range(n_classes), colors):
        plt.plot(fpr[i], tpr[i], color=color, lw=2,
                 label=f'{class_names[i]} (AUC = {roc_auc[i]:.3f})')

    # Plot micro-average ROC curve
    plt.plot(fpr["micro"], tpr["micro"],
             label=f'Micro-average (AUC = {roc_auc["micro"]:.3f})',
             color='deeppink', linestyle=':', linewidth=3)

    # Plot macro-average ROC curve
    plt.plot(fpr["macro"], tpr["macro"],
             label=f'Macro-average (AUC = {roc_auc["macro"]:.3f})',
             color='navy', linestyle=':', linewidth=3)

    # Plot diagonal (random classifier)
    plt.plot([0, 1], [0, 1], 'k--', lw=2, alpha=0.8, label='Random Classifier')

    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate', fontsize=14)
    plt.ylabel('True Positive Rate', fontsize=14)
    plt.title('ROC Curves - Lung Sound Classification\n(One-vs-Rest)', fontsize=16, fontweight='bold')
    plt.legend(loc="lower right", fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

    # Print AUC scores
    print(f"\nüìä AUC Scores Summary:")
    print("=" * 40)
    for i, class_name in enumerate(class_names):
        performance = "Excellent" if roc_auc[i] > 0.9 else "Good" if roc_auc[i] > 0.8 else "Fair" if roc_auc[i] > 0.7 else "Poor"
        print(f"   ‚Ä¢ {class_name:8}: {roc_auc[i]:.3f} ({performance})")

    print(f"   ‚Ä¢ {'Micro-avg':8}: {roc_auc['micro']:.3f}")
    print(f"   ‚Ä¢ {'Macro-avg':8}: {roc_auc['macro']:.3f}")

    # Overall AUC using sklearn's built-in function (alternative calculation)
    try:
        overall_auc = roc_auc_score(y_true, y_score, multi_class='ovr', average='weighted')
        print(f"   ‚Ä¢ {'Weighted':8}: {overall_auc:.3f}")
    except Exception as e:
        print(f"   ‚Ä¢ Weighted calculation failed: {e}")

    return fpr, tpr, roc_auc


def plot_detailed_roc_analysis(classifier, x_test_gru, y_test_gru, class_names=None):
    """
    Create a detailed ROC analysis with individual plots for each class
    """

    if class_names is None:
        class_names = ["Normal", "Asthma", "COPD"]

    n_classes = len(class_names)

    # Get predictions
    y_score = classifier.model.predict(x_test_gru, verbose=0)
    y_true = y_test_gru

    # Create subplots
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle('Detailed ROC Analysis - Lung Sound Classification', fontsize=16, fontweight='bold')

    # Individual ROC curves
    colors = ['blue', 'red', 'green']

    for i in range(n_classes):
        row = i // 2
        col = i % 2

        fpr, tpr, _ = roc_curve(y_true[:, i], y_score[:, i])
        roc_auc = auc(fpr, tpr)

        axes[row, col].plot(fpr, tpr, color=colors[i], lw=3,
                           label=f'{class_names[i]} (AUC = {roc_auc:.3f})')
        axes[row, col].plot([0, 1], [0, 1], 'k--', lw=2, alpha=0.6)
        axes[row, col].set_xlim([0.0, 1.0])
        axes[row, col].set_ylim([0.0, 1.05])
        axes[row, col].set_xlabel('False Positive Rate', fontsize=12)
        axes[row, col].set_ylabel('True Positive Rate', fontsize=12)
        axes[row, col].set_title(f'{class_names[i]} vs Rest', fontsize=14, fontweight='bold')
        axes[row, col].legend(loc="lower right", fontsize=11)
        axes[row, col].grid(True, alpha=0.3)

        # Add text box with additional metrics
        axes[row, col].text(0.6, 0.2, f'AUC: {roc_auc:.3f}\nSamples: {np.sum(y_true[:, i])}',
                           bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue", alpha=0.7),
                           fontsize=10)

    # Combined plot in the last subplot
    for i in range(n_classes):
        fpr, tpr, _ = roc_curve(y_true[:, i], y_score[:, i])
        roc_auc = auc(fpr, tpr)
        axes[1, 1].plot(fpr, tpr, color=colors[i], lw=2,
                       label=f'{class_names[i]} (AUC = {roc_auc:.3f})')

    axes[1, 1].plot([0, 1], [0, 1], 'k--', lw=2, alpha=0.6, label='Random')
    axes[1, 1].set_xlim([0.0, 1.0])
    axes[1, 1].set_ylim([0.0, 1.05])
    axes[1, 1].set_xlabel('False Positive Rate', fontsize=12)
    axes[1, 1].set_ylabel('True Positive Rate', fontsize=12)
    axes[1, 1].set_title('All Classes Combined', fontsize=14, fontweight='bold')
    axes[1, 1].legend(loc="lower right", fontsize=10)
    axes[1, 1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()


def plot_auc_comparison(classifier, x_test_gru, y_test_gru, class_names=None):
    """
    Create a bar plot comparing AUC scores across classes
    """

    if class_names is None:
        class_names = ["Normal", "Asthma", "COPD"]

    # Get predictions and calculate AUC for each class
    y_score = classifier.model.predict(x_test_gru, verbose=0)
    y_true = y_test_gru

    auc_scores = []
    for i in range(len(class_names)):
        fpr, tpr, _ = roc_curve(y_true[:, i], y_score[:, i])
        auc_scores.append(auc(fpr, tpr))

    # Create bar plot
    plt.figure(figsize=(10, 6))
    bars = plt.bar(class_names, auc_scores,
                   color=['skyblue', 'lightcoral', 'lightgreen'],
                   alpha=0.8, edgecolor='black', linewidth=1.5)

    # Add value labels on bars
    for bar, score in zip(bars, auc_scores):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                f'{score:.3f}', ha='center', va='bottom', fontweight='bold', fontsize=12)

    plt.ylim(0, 1.1)
    plt.ylabel('AUC Score', fontsize=14)
    plt.xlabel('Class', fontsize=14)
    plt.title('AUC Score Comparison Across Classes', fontsize=16, fontweight='bold')
    plt.grid(True, alpha=0.3, axis='y')

    # Add horizontal line at 0.5 (random performance)
    plt.axhline(y=0.5, color='red', linestyle='--', alpha=0.7, label='Random Performance')
    plt.axhline(y=0.8, color='orange', linestyle='--', alpha=0.7, label='Good Performance')
    plt.axhline(y=0.9, color='green', linestyle='--', alpha=0.7, label='Excellent Performance')

    plt.legend()
    plt.tight_layout()
    plt.show()

    return auc_scores


# Complete ROC Analysis Function
def complete_roc_analysis(classifier, x_test_gru, y_test_gru, class_names=None):
    """
    Run complete ROC analysis with all visualizations
    """

    print("üéØ Starting Complete ROC Analysis...")
    print("=" * 50)

    # 1. Main ROC curves plot
    print("\n1Ô∏è‚É£ Generating main ROC curves...")
    fpr, tpr, roc_auc = plot_roc_curves(classifier, x_test_gru, y_test_gru, class_names)

    # 2. Detailed analysis
    print("\n2Ô∏è‚É£ Generating detailed ROC analysis...")
    plot_detailed_roc_analysis(classifier, x_test_gru, y_test_gru, class_names)

    # 3. AUC comparison
    print("\n3Ô∏è‚É£ Generating AUC comparison...")
    auc_scores = plot_auc_comparison(classifier, x_test_gru, y_test_gru, class_names)

    print("\n‚úÖ ROC Analysis Complete!")

    return fpr, tpr, roc_auc, auc_scores


# Example usage:

# After training your model:
# classifier, history, results = train_lung_sound_classifier(
#     x_train_gru, y_train_gru,
#     x_val_gru, y_val_gru,
#     x_test_gru, y_test_gru
# )

# Generate ROC curves and AUC analysis:
fpr, tpr, roc_auc, auc_scores = complete_roc_analysis(
    model, x_test_gru, y_test_gru,
    class_names=["Normal", "Asthma", "COPD"]
)

# Or just the main ROC plot:
plot_roc_curves(model, x_test_gru, y_test_gru)


# ü´Å Lung Sound Classification System ‚Äì For Unseen Audio Data

A complete prediction pipeline using a pre-trained deep learning model for **asthma, COPD, and normal** lung sound detection from audio files.

---

## ‚úÖ Key Components

- Loads trained Keras model (`.keras`)
- Loads `StandardScaler` for feature normalization
- Extracts advanced, multi-modal audio features
- Supports both **single** and **batch predictions**
- Gives **confidence levels**, **visual summaries**, and **class probabilities**

---

## ‚öôÔ∏è Class: `LungSoundPredictor`

```python
predictor = LungSoundPredictor(
    model_path='lung_sound_model_best.keras',
    scaler_path='scaler.pkl'
)


In [None]:
import numpy as np
import librosa
import joblib
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.utils import register_keras_serializable
import warnings
from scipy import stats
from scipy.signal import hilbert
import pywt
from sklearn.preprocessing import StandardScaler
import os

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

# Define the custom metric that was used during training
@register_keras_serializable()
def weighted_categorical_accuracy(y_true, y_pred):
    """Custom weighted categorical accuracy metric"""
    return tf.keras.metrics.categorical_accuracy(y_true, y_pred)

class LungSoundPredictor:
    """
    Complete Lung Sound Classification System for Unseen Data
    Supports: ASTHMA, COPD, NORMAL classification
    """

    def __init__(self, model_path='lung_sound_model_best.keras', scaler_path='scaler.pkl'):
        """
        Initialize the predictor with trained model and scaler

        Parameters:
        - model_path: Path to saved Keras model
        - scaler_path: Path to saved StandardScaler
        """
        self.model_path = model_path
        self.scaler_path = scaler_path
        self.model = None
        self.scaler = None
        self.class_names = ['Normal', 'Asthma', 'COPD']

        # Load model and scaler
        self.load_model_and_scaler()

    def load_model_and_scaler(self):
        """Load the trained model and feature scaler"""
        try:
            # Load the trained model with custom objects
            if os.path.exists(self.model_path):
                custom_objects = {
                    'weighted_categorical_accuracy': weighted_categorical_accuracy
                }
                self.model = load_model(self.model_path, custom_objects=custom_objects)
                print(f"‚úÖ Model loaded successfully from: {self.model_path}")
            else:
                print(f"‚ùå Model file not found: {self.model_path}")
                print("Please ensure you have trained and saved the model first.")
                return False

            # Load the scaler
            if os.path.exists(self.scaler_path):
                self.scaler = joblib.load(self.scaler_path)
                print(f"‚úÖ Scaler loaded successfully from: {self.scaler_path}")
            else:
                print(f"‚ùå Scaler file not found: {self.scaler_path}")
                print("Please ensure you have the scaler.pkl file from training.")
                return False

            return True

        except Exception as e:
            print(f"‚ùå Error loading model or scaler: {e}")

            # Alternative loading method - compile=False
            try:
                print("üîÑ Trying alternative loading method...")
                self.model = load_model(self.model_path, compile=False)

                # Recompile the model with standard metrics
                self.model.compile(
                    optimizer='adam',
                    loss='categorical_crossentropy',
                    metrics=['accuracy']
                )
                print(f"‚úÖ Model loaded successfully with alternative method")

                # Load scaler
                if os.path.exists(self.scaler_path):
                    self.scaler = joblib.load(self.scaler_path)
                    print(f"‚úÖ Scaler loaded successfully from: {self.scaler_path}")
                    return True
                else:
                    print(f"‚ùå Scaler file not found: {self.scaler_path}")
                    return False

            except Exception as e2:
                print(f"‚ùå Alternative loading method also failed: {e2}")
                return False

    # Feature extraction functions (same as your training code)
    def add_noise(self, data, noise_level=0.005):
        """Add Gaussian noise to audio data with improved stability."""
        if len(data) == 0:
            return data
        noise = np.random.randn(len(data)) * noise_level
        noisy_data = data + noise
        return np.clip(noisy_data, -1.0, 1.0)

    def extract_advanced_mfcc_features(self, data, sampling_rate, n_mfcc=40):
        """Extract MFCCs with delta and delta-delta features plus statistical moments."""
        try:
            if len(data) == 0:
                return np.zeros(n_mfcc * 9)

            # Extract MFCCs
            mfccs = librosa.feature.mfcc(
                y=data,
                sr=sampling_rate,
                n_mfcc=n_mfcc,
                n_fft=2048,
                hop_length=512
            )

            if mfccs.shape[1] == 0:
                return np.zeros(n_mfcc * 9)

            # Compute Delta and Delta-Delta
            delta_mfccs = librosa.feature.delta(mfccs)
            delta2_mfccs = librosa.feature.delta(mfccs, order=2)

            # Statistical moments
            mfcc_mean = np.mean(mfccs, axis=1)
            mfcc_std = np.std(mfccs, axis=1)
            mfcc_skew = stats.skew(mfccs, axis=1)

            delta_mean = np.mean(delta_mfccs, axis=1)
            delta_std = np.std(delta_mfccs, axis=1)
            delta_skew = stats.skew(delta_mfccs, axis=1)

            delta2_mean = np.mean(delta2_mfccs, axis=1)
            delta2_std = np.std(delta2_mfccs, axis=1)
            delta2_skew = stats.skew(delta2_mfccs, axis=1)

            # Combine all MFCC-based features
            advanced_mfcc_features = np.concatenate([
                mfcc_mean, mfcc_std, mfcc_skew,
                delta_mean, delta_std, delta_skew,
                delta2_mean, delta2_std, delta2_skew
            ])

            return np.nan_to_num(advanced_mfcc_features, nan=0.0, posinf=0.0, neginf=0.0)

        except Exception as e:
            print(f"Warning: Advanced MFCC extraction failed: {e}")
            return np.zeros(n_mfcc * 9)

    def extract_fbse_features(self, data, sampling_rate, n_bands=10):
        """Extract Fourier-Bessel Spectral Entropy features."""
        try:
            if len(data) == 0:
                return np.zeros(n_bands)

            # Compute power spectral density
            stft = librosa.stft(data)
            psd = np.abs(stft)**2

            # Divide frequency range into bands
            freq_bands = np.linspace(0, sampling_rate//2, n_bands + 1)
            entropy_features = []

            for i in range(n_bands):
                start_idx = int(freq_bands[i] * len(psd) / (sampling_rate//2))
                end_idx = int(freq_bands[i+1] * len(psd) / (sampling_rate//2))

                if end_idx > start_idx:
                    band_psd = np.mean(psd[start_idx:end_idx], axis=0)
                    band_psd_norm = band_psd / (np.sum(band_psd) + 1e-10)
                    entropy = -np.sum(band_psd_norm * np.log(band_psd_norm + 1e-10))
                    entropy_features.append(np.mean(entropy))
                else:
                    entropy_features.append(0.0)

            return np.nan_to_num(np.array(entropy_features), nan=0.0, posinf=0.0, neginf=0.0)

        except Exception as e:
            print(f"Warning: FBSE extraction failed: {e}")
            return np.zeros(n_bands)

    def extract_enhanced_melspectrogram(self, data, sampling_rate, n_mels=128):
        """Extract enhanced Mel-spectrogram features."""
        try:
            if len(data) == 0:
                return np.zeros(n_mels * 4)

            mel_spec = librosa.feature.melspectrogram(
                y=data,
                sr=sampling_rate,
                n_mels=n_mels,
                n_fft=2048,
                hop_length=512,
                fmax=sampling_rate//2
            )

            if mel_spec.size == 0:
                return np.zeros(n_mels * 4)

            mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)

            if mel_spec_db.shape[1] > 0:
                mel_mean = np.mean(mel_spec_db, axis=1)
                mel_std = np.std(mel_spec_db, axis=1)
                mel_max = np.max(mel_spec_db, axis=1)
                mel_min = np.min(mel_spec_db, axis=1)
                mel_features = np.concatenate([mel_mean, mel_std, mel_max, mel_min])
            else:
                mel_features = np.zeros(n_mels * 4)

            return np.nan_to_num(mel_features, nan=0.0, posinf=0.0, neginf=0.0)

        except Exception as e:
            print(f"Warning: Enhanced Mel-spectrogram extraction failed: {e}")
            return np.zeros(n_mels * 4)

    def extract_wavelet_features(self, data, wavelet='db4', levels=5):
        """Extract wavelet features for transient detection."""
        try:
            if len(data) == 0:
                return np.zeros(levels * 4 + 4)

            coeffs = pywt.wavedec(data, wavelet, level=levels)
            wavelet_features = []

            for coeff in coeffs:
                if len(coeff) > 0:
                    wavelet_features.extend([
                        np.mean(np.abs(coeff)),
                        np.std(coeff),
                        np.max(np.abs(coeff)),
                        np.sum(coeff**2)
                    ])
                else:
                    wavelet_features.extend([0.0, 0.0, 0.0, 0.0])

            return np.nan_to_num(np.array(wavelet_features), nan=0.0, posinf=0.0, neginf=0.0)

        except Exception as e:
            print(f"Warning: Wavelet feature extraction failed: {e}")
            return np.zeros(levels * 4 + 4)

    def extract_sequence_features(self, data, sampling_rate, frame_length=2048, hop_length=512):
        """Extract sequence-based features."""
        try:
            if len(data) == 0:
                return np.zeros(26)

            mfccs = librosa.feature.mfcc(
                y=data,
                sr=sampling_rate,
                n_mfcc=13,
                n_fft=frame_length,
                hop_length=hop_length
            )

            if mfccs.shape[1] == 0:
                return np.zeros(26)

            frame_variations = np.mean(np.abs(np.diff(mfccs, axis=1)), axis=1)
            long_term_mean = np.mean(mfccs, axis=1)
            sequence_features = np.concatenate([frame_variations, long_term_mean])

            return np.nan_to_num(sequence_features, nan=0.0, posinf=0.0, neginf=0.0)

        except Exception as e:
            print(f"Warning: Sequence feature extraction failed: {e}")
            return np.zeros(26)

    def extract_spectral_features(self, data, sampling_rate):
        """Extract spectral features with improved tonnetz handling."""
        try:
            if len(data) == 0:
                return np.zeros(7)

            spectral_centroid = librosa.feature.spectral_centroid(y=data, sr=sampling_rate)
            spectral_bandwidth = librosa.feature.spectral_bandwidth(y=data, sr=sampling_rate)
            spectral_rolloff = librosa.feature.spectral_rolloff(y=data, sr=sampling_rate)
            spectral_flatness = librosa.feature.spectral_flatness(y=data)
            zero_crossing_rate = librosa.feature.zero_crossing_rate(data)
            chroma = librosa.feature.chroma_stft(y=data, sr=sampling_rate)
            chroma_mean = np.mean(chroma)

            # Improved tonnetz handling
            try:
                if sampling_rate >= 8000:
                    tonnetz = librosa.feature.tonnetz(y=data, sr=sampling_rate)
                    tonnetz_mean = np.mean(tonnetz)
                elif sampling_rate >= 4000:
                    chroma_cqt = librosa.feature.chroma_cqt(
                        y=data,
                        sr=sampling_rate,
                        fmin=librosa.note_to_hz('C1'),
                        n_chroma=12
                    )
                    tonnetz_mean = np.mean(chroma_cqt) * 0.5
                else:
                    tonnetz_mean = 0.0
            except:
                tonnetz_mean = 0.0

            spectral_features = np.array([
                np.mean(spectral_centroid),
                np.mean(spectral_bandwidth),
                np.mean(spectral_rolloff),
                np.mean(spectral_flatness),
                np.mean(zero_crossing_rate),
                chroma_mean,
                tonnetz_mean
            ])

            return np.nan_to_num(spectral_features, nan=0.0, posinf=0.0, neginf=0.0)

        except Exception as e:
            print(f"Warning: Spectral feature extraction failed: {e}")
            return np.zeros(7)

    def fourier_bessel_features(self, data, sampling_rate, n_coeff=20):
        """Enhanced Fourier-Bessel feature extraction."""
        if len(data) == 0:
            return np.zeros(n_coeff)

        t = np.arange(len(data)) / sampling_rate
        fb_coeff = np.zeros(n_coeff)
        t_norm = t / np.max(t) if np.max(t) > 0 else t

        for i in range(n_coeff):
            j = i + 1
            cosine_term = np.cos(2 * np.pi * j * t_norm)
            fb_coeff[i] = np.sum(data * cosine_term) / len(data)

        return np.nan_to_num(fb_coeff, nan=0.0, posinf=0.0, neginf=0.0)

    def extract_features_from_audio(self, audio_file_path):
        """
        Extract all features from a single audio file

        Parameters:
        - audio_file_path: Path to the audio file

        Returns:
        - features: Normalized feature vector ready for prediction
        """
        try:
            # Load audio file
            data, sampling_rate = librosa.load(audio_file_path, sr=None)

            if len(data) == 0:
                print(f"‚ùå Empty audio file: {audio_file_path}")
                return None

            print(f"üéµ Processing: {os.path.basename(audio_file_path)}")
            print(f"   ‚Ä¢ Duration: {len(data)/sampling_rate:.2f}s")
            print(f"   ‚Ä¢ Sampling Rate: {sampling_rate}Hz")

            # Extract all features (same as training)
            n_mfcc = 40
            fb_coeffs = 20
            n_mels = 128
            wavelet_levels = 5
            fbse_bands = 10

            # Extract each feature type
            advanced_mfcc_features = self.extract_advanced_mfcc_features(data, sampling_rate, n_mfcc)
            fbse_features = self.extract_fbse_features(data, sampling_rate, fbse_bands)
            enhanced_mel_features = self.extract_enhanced_melspectrogram(data, sampling_rate, n_mels)
            wavelet_features = self.extract_wavelet_features(data, 'db4', wavelet_levels)
            sequence_features = self.extract_sequence_features(data, sampling_rate)
            spectral_features = self.extract_spectral_features(data, sampling_rate)
            fb_features = self.fourier_bessel_features(data, sampling_rate, fb_coeffs)

            # Combine all features
            combined_features = np.concatenate([
                advanced_mfcc_features,
                fbse_features,
                enhanced_mel_features,
                wavelet_features,
                sequence_features,
                spectral_features,
                fb_features
            ])

            # Final validation
            combined_features = np.nan_to_num(combined_features, nan=0.0, posinf=0.0, neginf=0.0)

            print(f"   ‚Ä¢ Features extracted: {len(combined_features)} dimensions")

            # Normalize using training scaler
            if self.scaler is not None:
                features_normalized = self.scaler.transform(combined_features.reshape(1, -1))
                return features_normalized[0]
            else:
                print("‚ùå Scaler not loaded. Cannot normalize features.")
                return None

        except Exception as e:
            print(f"‚ùå Error extracting features from {audio_file_path}: {e}")
            return None

    def predict_single_file(self, audio_file_path, show_confidence=True):
        """
        Predict lung condition for a single audio file

        Parameters:
        - audio_file_path: Path to the audio file
        - show_confidence: Whether to show confidence scores

        Returns:
        - prediction_result: Dictionary with prediction details
        """
        if self.model is None or self.scaler is None:
            print("‚ùå Model or scaler not loaded properly.")
            return None

        # Extract features
        features = self.extract_features_from_audio(audio_file_path)
        if features is None:
            return None

        try:
            # Reshape for model input (1, 1, feature_size) format
            features_reshaped = features.reshape(1, 1, -1)

            print(f"   ‚Ä¢ Input shape: {features_reshaped.shape}")

            # Make prediction
            prediction_proba = self.model.predict(features_reshaped, verbose=0)
            predicted_class_idx = np.argmax(prediction_proba[0])
            predicted_class = self.class_names[predicted_class_idx]
            confidence = prediction_proba[0][predicted_class_idx]

            # Prepare result
            result = {
                'file': os.path.basename(audio_file_path),
                'predicted_class': predicted_class,
                'confidence': confidence,
                'all_probabilities': {
                    self.class_names[i]: prediction_proba[0][i]
                    for i in range(len(self.class_names))
                }
            }

            # Display results
            print(f"\nüéØ Prediction Results:")
            print(f"   ‚Ä¢ File: {result['file']}")
            print(f"   ‚Ä¢ Predicted Class: {predicted_class}")
            print(f"   ‚Ä¢ Confidence: {confidence:.3f} ({confidence*100:.1f}%)")

            if show_confidence:
                print(f"   ‚Ä¢ Detailed Probabilities:")
                for class_name, prob in result['all_probabilities'].items():
                    print(f"     - {class_name}: {prob:.3f} ({prob*100:.1f}%)")

            # Confidence interpretation
            if confidence >= 0.8:
                confidence_level = "High"
                emoji = "üü¢"
            elif confidence >= 0.6:
                confidence_level = "Medium"
                emoji = "üü°"
            else:
                confidence_level = "Low"
                emoji = "üî¥"

            print(f"   ‚Ä¢ Confidence Level: {emoji} {confidence_level}")

            return result

        except Exception as e:
            print(f"‚ùå Error during prediction: {e}")
            return None

    def predict_multiple_files(self, audio_files_list):
        """
        Predict lung conditions for multiple audio files

        Parameters:
        - audio_files_list: List of audio file paths

        Returns:
        - results: List of prediction results
        """
        results = []

        print(f"üîÑ Processing {len(audio_files_list)} audio files...")
        print("=" * 50)

        for i, audio_file in enumerate(audio_files_list):
            print(f"\n[{i+1}/{len(audio_files_list)}] Processing: {os.path.basename(audio_file)}")

            result = self.predict_single_file(audio_file, show_confidence=False)
            if result is not None:
                results.append(result)

        # Summary
        if results:
            print(f"\nüìä Summary of {len(results)} successful predictions:")
            class_counts = {}
            for result in results:
                pred_class = result['predicted_class']
                class_counts[pred_class] = class_counts.get(pred_class, 0) + 1

            for class_name, count in class_counts.items():
                percentage = (count / len(results)) * 100
                print(f"   ‚Ä¢ {class_name}: {count} files ({percentage:.1f}%)")

        return results

    def get_model_info(self):
        """Display information about the loaded model"""
        if self.model is not None:
            print("üè• Model Information:")
            print(f"   ‚Ä¢ Input shape: {self.model.input_shape}")
            print(f"   ‚Ä¢ Output shape: {self.model.output_shape}")
            print(f"   ‚Ä¢ Total parameters: {self.model.count_params():,}")
            print(f"   ‚Ä¢ Classes: {self.class_names}")
            if self.scaler is not None:
                print(f"   ‚Ä¢ Feature dimensions: {len(self.scaler.mean_)}")
        else:
            print("‚ùå No model loaded")


# Simple usage functions
def predict_single_audio(audio_file_path, model_path='/content/lung_sound_model_best.keras', scaler_path='/content/scaler.pkl'):
    """
    Simple function to predict a single audio file

    Parameters:
    - audio_file_path: Path to your audio file
    - model_path: Path to saved model (default: 'lung_sound_model_best.keras')
    - scaler_path: Path to saved scaler (default: 'scaler.pkl')

    Returns:
    - prediction_result: Dictionary with prediction details
    """
    predictor = LungSoundPredictor(model_path, scaler_path)
    return predictor.predict_single_file(audio_file_path)

def predict_audio_directory(directory_path, model_path='/content/lung_sound_model_best.keras', scaler_path='scaler.pkl'):
    """
    Simple function to predict all audio files in a directory

    Parameters:
    - directory_path: Directory containing audio files
    - model_path: Path to saved model
    - scaler_path: Path to saved scaler

    Returns:
    - results: List of prediction results
    """
    import glob

    # Find all audio files in directory
    audio_extensions = ['*.wav', '*.mp3', '*.flac', '*.m4a']
    audio_files = []

    for ext in audio_extensions:
        audio_files.extend(glob.glob(os.path.join(directory_path, ext)))
        audio_files.extend(glob.glob(os.path.join(directory_path, ext.upper())))

    if not audio_files:
        print(f"‚ùå No audio files found in: {directory_path}")
        return []

    print(f"üìÅ Found {len(audio_files)} audio files in: {directory_path}")

    predictor = LungSoundPredictor(model_path, scaler_path)
    return predictor.predict_multiple_files(audio_files)


# Example usage
if __name__ == "__main__":
    print("ü´Å Lung Sound Classification for Unseen Data")
    print("=" * 50)

    # Example 1: Predict a single file
    print("\nüìñ Example 1: Predict single file")
    result = predict_single_audio('/content/BP50_N,N,P R L ,27,M.wav')

    # Example 2: Predict multiple files
    print("\nüìñ Example 2: Predict all files in directory")
    print("results = predict_audio_directory('path/to/your/audio/directory')")

    # Example 3: Using the class directly
    print("\nüìñ Example 3: Using the class directly")
    print("predictor = LungSoundPredictor()")
    print("predictor.get_model_info()  # Show model details")
    print("result = predictor.predict_single_file('audio_file.wav')")

    print("\n‚úÖ Ready to use! Make sure you have:")
    print("   1. Your trained model file (lung_sound_model_best.keras)")
    print("   2. Your scaler file (scaler.pkl)")
    print("   3. Audio files to classify")

    # Test if files exist
    print("\nüîç Checking for required files...")
    if os.path.exists('lung_sound_model_best.keras'):
        print("   ‚úÖ Model file found")
    else:
        print("   ‚ùå Model file 'lung_sound_model_best.keras' not found")

    if os.path.exists('scaler.pkl'):
        print("   ‚úÖ Scaler file found")
    else:
        print("   ‚ùå Scaler file 'scaler.pkl' not found")

# ans ==> /content/142_1b1_Pl_mc_LittC2SE.wav

## **Model 2**
# ü©∫ Improved Lung Sound Classification - Focus on Generalization

> **Goal:** Build a robust, generalizable model for classifying lung sounds into `Healthy`, `Asthma`, and `COPD` categories using deep learning and ensemble techniques.

---

## üöÄ Highlights

- ‚úÖ **Simplified models** to reduce overfitting
- üîÅ **Data augmentation** to boost generalization
- üìä **Attention mechanism** for better temporal understanding
- üß† **Ensemble support** for robust prediction
- üõ†Ô∏è Conservative training with early stopping and LR reduction

---

## üèóÔ∏è Class: `ImprovedLungSoundClassifier`

### üîß Initialization
```python
classifier = ImprovedLungSoundClassifier(input_shape=(1, 959), num_classes=3)


In [None]:
# Improved Lung Sound Classification - Focus on Generalization
# Addresses overfitting issues for better unseen data performance

import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, Dense, Dropout, BatchNormalization, GRU, LSTM,
    Bidirectional, GlobalAveragePooling1D, GlobalMaxPooling1D,
    Concatenate, Conv1D, LeakyReLU, SpatialDropout1D,
    MultiHeadAttention, LayerNormalization
)
from tensorflow.keras.regularizers import l2
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import (
    EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
)
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report, confusion_matrix, f1_score
import matplotlib.pyplot as plt
import seaborn as sns

class ImprovedLungSoundClassifier:
    """
    Improved Neural Network for Lung Sound Classification
    Focus: Better generalization on unseen data
    """

    def __init__(self, input_shape=(1, 959), num_classes=3):
        self.input_shape = input_shape
        self.num_classes = num_classes
        self.model = None
        self.history = None
        self.class_weights = None

    def create_simple_model(self):
        """Create a simpler, more generalizable model"""
        inputs = Input(shape=self.input_shape, name='lung_sound_input')
        x = inputs

        # Simple feature extraction with lighter regularization
        x = Conv1D(filters=32, kernel_size=7, padding='same')(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.1)(x)
        x = SpatialDropout1D(0.2)(x)

        # Single RNN layer to reduce complexity
        x = Bidirectional(
            GRU(64, return_sequences=True, dropout=0.3, recurrent_dropout=0.3)
        )(x)
        x = BatchNormalization()(x)

        # Global pooling
        avg_pool = GlobalAveragePooling1D()(x)
        max_pool = GlobalMaxPooling1D()(x)
        x = Concatenate()([avg_pool, max_pool])

        # Simpler dense layers
        x = Dense(64, kernel_regularizer=l2(0.001))(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.1)(x)
        x = Dropout(0.5)(x)

        # Output layer
        outputs = Dense(self.num_classes, activation='softmax', name='output')(x)

        model = Model(inputs=inputs, outputs=outputs, name='SimpleLungClassifier')
        return model

    def create_attention_model(self):
        """Create attention-based model for better feature learning"""
        inputs = Input(shape=self.input_shape, name='lung_sound_input')
        x = inputs

        # Light conv preprocessing
        x = Conv1D(filters=32, kernel_size=5, padding='same')(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.1)(x)
        x = SpatialDropout1D(0.2)(x)

        # RNN processing
        x = Bidirectional(GRU(48, return_sequences=True, dropout=0.2, recurrent_dropout=0.2))(x)
        x = LayerNormalization()(x)

        # Self-attention mechanism
        attention = MultiHeadAttention(
            num_heads=4,
            key_dim=48,
            dropout=0.2
        )(x, x)
        x = LayerNormalization()(x + attention)  # Residual connection

        # Global pooling
        x = GlobalAveragePooling1D()(x)

        # Classification head
        x = Dense(32, kernel_regularizer=l2(0.001))(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.1)(x)
        x = Dropout(0.4)(x)

        outputs = Dense(self.num_classes, activation='softmax', name='output')(x)

        model = Model(inputs=inputs, outputs=outputs, name='AttentionLungClassifier')
        return model

    def create_ensemble_ready_model(self, model_variant='simple'):
        """Create different model variants for ensemble"""
        if model_variant == 'simple':
            return self.create_simple_model()
        elif model_variant == 'attention':
            return self.create_attention_model()
        elif model_variant == 'lstm':
            return self.create_lstm_model()
        else:
            return self.create_simple_model()

    def create_lstm_model(self):
        """LSTM variant for ensemble diversity"""
        inputs = Input(shape=self.input_shape, name='lung_sound_input')
        x = inputs

        # Feature extraction
        x = Conv1D(filters=32, kernel_size=3, padding='same')(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.1)(x)
        x = SpatialDropout1D(0.2)(x)

        # LSTM layers
        x = Bidirectional(LSTM(48, return_sequences=True, dropout=0.3, recurrent_dropout=0.3))(x)
        x = BatchNormalization()(x)

        # Pooling
        x = GlobalAveragePooling1D()(x)

        # Classification
        x = Dense(48, kernel_regularizer=l2(0.001))(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.1)(x)
        x = Dropout(0.4)(x)

        outputs = Dense(self.num_classes, activation='softmax', name='output')(x)

        model = Model(inputs=inputs, outputs=outputs, name='LSTMLungClassifier')
        return model

    def compile_model(self, model, learning_rate=0.0005):
        """Compile with conservative settings for better generalization"""
        optimizer = Adam(
            learning_rate=learning_rate,
            clipnorm=1.0
        )

        model.compile(
            optimizer=optimizer,
            loss='categorical_crossentropy',
            metrics=['accuracy', 'precision', 'recall']
        )

        return model

    def compute_class_weights(self, y_train):
        """Compute balanced class weights"""
        y_indices = np.argmax(y_train, axis=1)
        classes = np.unique(y_indices)
        class_weights = compute_class_weight('balanced', classes=classes, y=y_indices)
        self.class_weights = dict(zip(classes, class_weights))

        print(f"üìä Class weights:")
        class_names = ['Healthy', 'Asthma', 'COPD']
        for i, weight in self.class_weights.items():
            print(f"   ‚Ä¢ {class_names[i]}: {weight:.3f}")

        return self.class_weights

    def create_callbacks(self, model_name='improved_lung_model'):
        """Conservative callbacks for better generalization"""
        callbacks = [
            EarlyStopping(
                monitor='val_loss',
                patience=15,  # Shorter patience to prevent overfitting
                restore_best_weights=True,
                verbose=1,
                min_delta=0.001
            ),
            ReduceLROnPlateau(
                monitor='val_loss',
                factor=0.5,
                patience=5,  # Reduce LR more aggressively
                min_lr=1e-6,
                verbose=1
            ),
            ModelCheckpoint(
                filepath=f'{model_name}_best.keras',
                monitor='val_accuracy',
                save_best_only=True,
                verbose=1,
                mode='max'
            )
        ]

        return callbacks

    def train_with_data_augmentation(self, x_train, y_train, x_val, y_val,
                                   model_type='simple', epochs=80, batch_size=32):
        """Train with data augmentation for better generalization"""

        print(f"üöÄ Training {model_type} model with data augmentation...")
        print(f"   ‚Ä¢ Training samples: {x_train.shape[0]}")
        print(f"   ‚Ä¢ Validation samples: {x_val.shape[0]}")

        # Data augmentation
        x_train_aug, y_train_aug = self.augment_data(x_train, y_train)
        print(f"   ‚Ä¢ Augmented training samples: {x_train_aug.shape[0]}")

        # Create model
        self.model = self.create_ensemble_ready_model(model_type)
        self.model = self.compile_model(self.model)

        # Compute class weights
        class_weights = self.compute_class_weights(y_train_aug)

        # Callbacks
        callbacks = self.create_callbacks(f'{model_type}_lung_model')

        print(f"\nüèóÔ∏è Model Architecture ({model_type}):")
        print(f"   ‚Ä¢ Total parameters: {self.model.count_params():,}")

        # Train
        self.history = self.model.fit(
            x_train_aug, y_train_aug,
            validation_data=(x_val, y_val),
            epochs=epochs,
            batch_size=batch_size,
            callbacks=callbacks,
            class_weight=class_weights,
            verbose=1,
            shuffle=True
        )

        print("‚úÖ Training completed!")
        return self.history

    def augment_data(self, x_data, y_data, augment_factor=0.5):
        """Simple data augmentation techniques"""
        augmented_x = []
        augmented_y = []

        # Original data
        augmented_x.append(x_data)
        augmented_y.append(y_data)

        n_augment = int(len(x_data) * augment_factor)
        indices = np.random.choice(len(x_data), n_augment, replace=True)

        for idx in indices:
            sample = x_data[idx].copy()
            label = y_data[idx].copy()

            # Random noise addition (5% of signal std)
            noise_level = 0.05 * np.std(sample)
            sample += np.random.normal(0, noise_level, sample.shape)

            # Random scaling (¬±10%)
            scale_factor = np.random.uniform(0.9, 1.1)
            sample *= scale_factor

            augmented_x.append(sample[np.newaxis, :])
            augmented_y.append(label[np.newaxis, :])

        return np.vstack(augmented_x), np.vstack(augmented_y)

    def evaluate_model(self, x_test, y_test):
        """Comprehensive evaluation"""
        if self.model is None:
            print("‚ùå Model not trained yet!")
            return

        # Predictions
        y_pred_proba = self.model.predict(x_test, verbose=0)
        y_pred = np.argmax(y_pred_proba, axis=1)
        y_true = np.argmax(y_test, axis=1)

        # Metrics
        test_loss, test_acc, test_prec, test_rec = self.model.evaluate(x_test, y_test, verbose=0)
        test_f1 = f1_score(y_true, y_pred, average='weighted')

        print(f"\nüìä Test Performance:")
        print(f"   ‚Ä¢ Accuracy: {test_acc:.4f}")
        print(f"   ‚Ä¢ Precision: {test_prec:.4f}")
        print(f"   ‚Ä¢ Recall: {test_rec:.4f}")
        print(f"   ‚Ä¢ F1-Score: {test_f1:.4f}")

        # Per-class metrics
        class_names = ['Healthy', 'Asthma', 'COPD']
        print(f"\nüìã Classification Report:")
        print(classification_report(y_true, y_pred, target_names=class_names))

        # Confusion matrix
        cm = confusion_matrix(y_true, y_pred)
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                   xticklabels=class_names, yticklabels=class_names)
        plt.title('Confusion Matrix')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.tight_layout()
        plt.show()

        return {
            'accuracy': test_acc,
            'precision': test_prec,
            'recall': test_rec,
            'f1_score': test_f1,
            'predictions': y_pred_proba,
            'confusion_matrix': cm
        }

    def plot_training_history(self):
        """Plot training curves to check for overfitting"""
        if self.history is None:
            print("‚ùå No training history!")
            return

        fig, axes = plt.subplots(2, 2, figsize=(12, 8))

        # Accuracy
        axes[0, 0].plot(self.history.history['accuracy'], label='Train', color='blue')
        axes[0, 0].plot(self.history.history['val_accuracy'], label='Val', color='orange')
        axes[0, 0].set_title('Accuracy')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)

        # Loss
        axes[0, 1].plot(self.history.history['loss'], label='Train', color='blue')
        axes[0, 1].plot(self.history.history['val_loss'], label='Val', color='orange')
        axes[0, 1].set_title('Loss')
        axes[0, 1].legend()
        axes[0, 1].grid(True, alpha=0.3)

        # Precision
        axes[1, 0].plot(self.history.history['precision'], label='Train', color='blue')
        axes[1, 0].plot(self.history.history['val_precision'], label='Val', color='orange')
        axes[1, 0].set_title('Precision')
        axes[1, 0].legend()
        axes[1, 0].grid(True, alpha=0.3)

        # Recall
        axes[1, 1].plot(self.history.history['recall'], label='Train', color='blue')
        axes[1, 1].plot(self.history.history['val_recall'], label='Val', color='orange')
        axes[1, 1].set_title('Recall')
        axes[1, 1].legend()
        axes[1, 1].grid(True, alpha=0.3)

        plt.tight_layout()
        plt.show()


class EnsembleLungClassifier:
    """Ensemble approach for robust predictions"""

    def __init__(self, input_shape=(1, 959), num_classes=3):
        self.input_shape = input_shape
        self.num_classes = num_classes
        self.models = []
        self.model_types = ['simple', 'attention', 'lstm']

    def train_ensemble(self, x_train, y_train, x_val, y_val, epochs=60):
        """Train ensemble of diverse models"""
        print("üéØ Training Ensemble of Models...")

        for model_type in self.model_types:
            print(f"\nüîÑ Training {model_type} model...")

            classifier = ImprovedLungSoundClassifier(self.input_shape, self.num_classes)
            history = classifier.train_with_data_augmentation(
                x_train, y_train, x_val, y_val,
                model_type=model_type,
                epochs=epochs,
                batch_size=32
            )

            self.models.append(classifier.model)
            print(f"‚úÖ {model_type} model trained!")

        return self.models

    def predict_ensemble(self, x_test):
        """Make ensemble predictions"""
        if not self.models:
            print("‚ùå No models trained!")
            return None

        predictions = []
        for model in self.models:
            pred = model.predict(x_test, verbose=0)
            predictions.append(pred)

        # Average predictions
        ensemble_pred = np.mean(predictions, axis=0)
        return ensemble_pred

    def evaluate_ensemble(self, x_test, y_test):
        """Evaluate ensemble performance"""
        ensemble_pred = self.predict_ensemble(x_test)
        if ensemble_pred is None:
            return None

        y_pred = np.argmax(ensemble_pred, axis=1)
        y_true = np.argmax(y_test, axis=1)

        # Calculate metrics
        from sklearn.metrics import accuracy_score, precision_score, recall_score

        accuracy = accuracy_score(y_true, y_pred)
        precision = precision_score(y_true, y_pred, average='weighted')
        recall = recall_score(y_true, y_pred, average='weighted')
        f1 = f1_score(y_true, y_pred, average='weighted')

        print(f"\nüéØ Ensemble Performance:")
        print(f"   ‚Ä¢ Accuracy: {accuracy:.4f}")
        print(f"   ‚Ä¢ Precision: {precision:.4f}")
        print(f"   ‚Ä¢ Recall: {recall:.4f}")
        print(f"   ‚Ä¢ F1-Score: {f1:.4f}")

        return {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1_score': f1,
            'predictions': ensemble_pred
        }


# Improved Training Pipeline
def train_improved_lung_classifier(x_train, y_train, x_val, y_val, x_test, y_test,
                                 use_ensemble=False):
    """
    Improved training pipeline focused on generalization

    Key improvements:
    1. Simpler architectures to reduce overfitting
    2. Data augmentation for better generalization
    3. Conservative training settings
    4. Ensemble option for robust predictions
    """

    print("üéØ Improved Lung Sound Classification Pipeline")
    print("Focus: Better generalization on unseen data")
    print("=" * 60)

    if use_ensemble:
        # Train ensemble
        ensemble = EnsembleLungClassifier()
        models = ensemble.train_ensemble(x_train, y_train, x_val, y_val)
        results = ensemble.evaluate_ensemble(x_test, y_test)
        return ensemble, results
    else:
        # Train single improved model
        classifier = ImprovedLungSoundClassifier()

        # Try simple model first
        history = classifier.train_with_data_augmentation(
            x_train, y_train, x_val, y_val,
            model_type='simple',
            epochs=150,
            batch_size=32
        )

        # Evaluate
        results = classifier.evaluate_model(x_test, y_test)

        # Plot training curves
        classifier.plot_training_history()

        return classifier, results


# Usage instructions
def usage_example():
    """How to use the improved classifier"""
    print("\nüìù Usage Example:")
    print("# For single improved model:")
    # classifier, results = train_improved_lung_classifier(
    #   x_train_gru, y_train_gru, x_val_gru, y_val_gru, x_test_gru, y_test_gru)
    print(")")
    print("\n# For ensemble approach (better but slower):")
    # classifier, results = train_improved_lung_classifier(
    # x_train_gru, y_train_gru, x_val_gru, y_val_gru, x_test_gru, y_test_gru,)
    print("    use_ensemble=True")
    print(")")

    print("\nüí° Key Improvements:")
    print("‚Ä¢ Simpler architecture to prevent overfitting")
    print("‚Ä¢ Data augmentation for better generalization")
    print("‚Ä¢ Conservative training with early stopping")
    print("‚Ä¢ Ensemble option for robust predictions")
    print("‚Ä¢ Better regularization strategies")

if __name__ == "__main__":
    # usage_example()
  classifier, results = train_improved_lung_classifier(
      x_train_gru, y_train_gru, x_val_gru, y_val_gru, x_test_gru, y_test_gru)

#### aoc roc curve

In [None]:
# AUC and ROC Curve Analysis for Lung Sound Classification
# Comprehensive evaluation metrics for multi-class classification

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import (
    roc_curve, auc, roc_auc_score,
    precision_recall_curve, average_precision_score
)
from sklearn.preprocessing import label_binarize
from itertools import cycle
import seaborn as sns

class ModelEvaluationMetrics:
    """
    Comprehensive evaluation metrics including ROC curves and AUC scores
    for multi-class lung sound classification
    """

    def __init__(self, class_names=['Healthy', 'Asthma', 'COPD']):
        self.class_names = class_names
        self.n_classes = len(class_names)
        self.colors = cycle(['blue', 'red', 'green', 'orange', 'purple'])

    def compute_roc_auc(self, y_true, y_pred_proba, plot=True):
        """
        Compute ROC curves and AUC scores for multi-class classification

        Parameters:
        -----------
        y_true : array-like, shape = [n_samples]
            True class labels (integer encoded)
        y_pred_proba : array-like, shape = [n_samples, n_classes]
            Predicted class probabilities
        plot : bool, default=True
            Whether to plot ROC curves

        Returns:
        --------
        dict : Dictionary containing AUC scores and ROC data
        """

        # Binarize the output for multi-class ROC
        y_true_bin = label_binarize(y_true, classes=range(self.n_classes))

        # For binary classification, label_binarize returns 1D array
        if self.n_classes == 2:
            y_true_bin = np.column_stack([1 - y_true_bin, y_true_bin])

        # Compute ROC curve and ROC area for each class
        fpr = dict()
        tpr = dict()
        roc_auc = dict()

        for i in range(self.n_classes):
            fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_pred_proba[:, i])
            roc_auc[i] = auc(fpr[i], tpr[i])

        # Compute micro-average ROC curve and ROC area
        fpr["micro"], tpr["micro"], _ = roc_curve(y_true_bin.ravel(), y_pred_proba.ravel())
        roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

        # Compute macro-average ROC curve and ROC area
        # First aggregate all false positive rates
        all_fpr = np.unique(np.concatenate([fpr[i] for i in range(self.n_classes)]))

        # Then interpolate all ROC curves at this points
        mean_tpr = np.zeros_like(all_fpr)
        for i in range(self.n_classes):
            mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])

        # Finally average it and compute AUC
        mean_tpr /= self.n_classes

        fpr["macro"] = all_fpr
        tpr["macro"] = mean_tpr
        roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

        # Plot ROC curves
        if plot:
            self.plot_roc_curves(fpr, tpr, roc_auc)

        # Print AUC scores
        print("\nüìä ROC AUC Scores:")
        print(f"   ‚Ä¢ Micro-average AUC: {roc_auc['micro']:.4f}")
        print(f"   ‚Ä¢ Macro-average AUC: {roc_auc['macro']:.4f}")
        print("\n   Per-class AUC:")
        for i, class_name in enumerate(self.class_names):
            print(f"   ‚Ä¢ {class_name}: {roc_auc[i]:.4f}")

        return {
            'fpr': fpr,
            'tpr': tpr,
            'roc_auc': roc_auc,
            'micro_auc': roc_auc['micro'],
            'macro_auc': roc_auc['macro']
        }

    def plot_roc_curves(self, fpr, tpr, roc_auc):
        """Plot ROC curves for multi-class classification"""

        plt.figure(figsize=(12, 8))

        # Plot ROC curve for each class
        colors = cycle(['blue', 'red', 'green', 'orange', 'purple'])
        for i, color in zip(range(self.n_classes), colors):
            plt.plot(fpr[i], tpr[i], color=color, lw=2,
                    label=f'{self.class_names[i]} (AUC = {roc_auc[i]:.3f})')

        # Plot micro-average ROC curve
        plt.plot(fpr["micro"], tpr["micro"],
                label=f'Micro-average (AUC = {roc_auc["micro"]:.3f})',
                color='deeppink', linestyle=':', linewidth=3)

        # Plot macro-average ROC curve
        plt.plot(fpr["macro"], tpr["macro"],
                label=f'Macro-average (AUC = {roc_auc["macro"]:.3f})',
                color='navy', linestyle=':', linewidth=3)

        # Plot random classifier line
        plt.plot([0, 1], [0, 1], 'k--', lw=2, label='Random Classifier')

        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate', fontsize=12)
        plt.ylabel('True Positive Rate', fontsize=12)
        plt.title('ROC Curves - Lung Sound Classification', fontsize=14, fontweight='bold')
        plt.legend(loc="lower right", fontsize=10)
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.show()

    def compute_precision_recall_auc(self, y_true, y_pred_proba, plot=True):
        """
        Compute Precision-Recall curves and AUC scores

        Parameters:
        -----------
        y_true : array-like, shape = [n_samples]
            True class labels (integer encoded)
        y_pred_proba : array-like, shape = [n_samples, n_classes]
            Predicted class probabilities
        plot : bool, default=True
            Whether to plot PR curves

        Returns:
        --------
        dict : Dictionary containing PR AUC scores and curve data
        """

        # Binarize the output for multi-class PR curves
        y_true_bin = label_binarize(y_true, classes=range(self.n_classes))

        # For binary classification, label_binarize returns 1D array
        if self.n_classes == 2:
            y_true_bin = np.column_stack([1 - y_true_bin, y_true_bin])

        # Compute Precision-Recall curve and average precision for each class
        precision = dict()
        recall = dict()
        pr_auc = dict()

        for i in range(self.n_classes):
            precision[i], recall[i], _ = precision_recall_curve(y_true_bin[:, i], y_pred_proba[:, i])
            pr_auc[i] = average_precision_score(y_true_bin[:, i], y_pred_proba[:, i])

        # Compute micro-average precision-recall curve
        precision["micro"], recall["micro"], _ = precision_recall_curve(
            y_true_bin.ravel(), y_pred_proba.ravel())
        pr_auc["micro"] = average_precision_score(y_true_bin, y_pred_proba, average="micro")

        # Compute macro-average
        pr_auc["macro"] = average_precision_score(y_true_bin, y_pred_proba, average="macro")

        # Plot PR curves
        if plot:
            self.plot_precision_recall_curves(precision, recall, pr_auc)

        # Print PR AUC scores
        print("\nüìä Precision-Recall AUC Scores:")
        print(f"   ‚Ä¢ Micro-average PR-AUC: {pr_auc['micro']:.4f}")
        print(f"   ‚Ä¢ Macro-average PR-AUC: {pr_auc['macro']:.4f}")
        print("\n   Per-class PR-AUC:")
        for i, class_name in enumerate(self.class_names):
            print(f"   ‚Ä¢ {class_name}: {pr_auc[i]:.4f}")

        return {
            'precision': precision,
            'recall': recall,
            'pr_auc': pr_auc,
            'micro_pr_auc': pr_auc['micro'],
            'macro_pr_auc': pr_auc['macro']
        }

    def plot_precision_recall_curves(self, precision, recall, pr_auc):
        """Plot Precision-Recall curves for multi-class classification"""

        plt.figure(figsize=(12, 8))

        # Plot PR curve for each class
        colors = cycle(['blue', 'red', 'green', 'orange', 'purple'])
        for i, color in zip(range(self.n_classes), colors):
            plt.plot(recall[i], precision[i], color=color, lw=2,
                    label=f'{self.class_names[i]} (AP = {pr_auc[i]:.3f})')

        # Plot micro-average PR curve
        plt.plot(recall["micro"], precision["micro"],
                label=f'Micro-average (AP = {pr_auc["micro"]:.3f})',
                color='deeppink', linestyle=':', linewidth=3)

        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('Recall', fontsize=12)
        plt.ylabel('Precision', fontsize=12)
        plt.title('Precision-Recall Curves - Lung Sound Classification', fontsize=14, fontweight='bold')
        plt.legend(loc="lower left", fontsize=10)
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.show()

    def comprehensive_evaluation(self, y_true, y_pred_proba):
        """
        Comprehensive evaluation including both ROC and PR curves

        Parameters:
        -----------
        y_true : array-like, shape = [n_samples]
            True class labels (integer encoded)
        y_pred_proba : array-like, shape = [n_samples, n_classes]
            Predicted class probabilities

        Returns:
        --------
        dict : Dictionary containing all evaluation metrics
        """

        print("üéØ Comprehensive Model Evaluation")
        print("=" * 50)

        # ROC Analysis
        roc_results = self.compute_roc_auc(y_true, y_pred_proba, plot=True)

        # Precision-Recall Analysis
        pr_results = self.compute_precision_recall_auc(y_true, y_pred_proba, plot=True)

        # Combined results
        results = {
            'roc_auc': roc_results,
            'pr_auc': pr_results,
            'summary': {
                'micro_roc_auc': roc_results['micro_auc'],
                'macro_roc_auc': roc_results['macro_auc'],
                'micro_pr_auc': pr_results['micro_pr_auc'],
                'macro_pr_auc': pr_results['macro_pr_auc']
            }
        }

        return results

    def plot_combined_metrics(self, y_true, y_pred_proba):
        """Plot ROC and PR curves side by side"""

        # Compute metrics
        roc_results = self.compute_roc_auc(y_true, y_pred_proba, plot=False)
        pr_results = self.compute_precision_recall_auc(y_true, y_pred_proba, plot=False)

        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

        # ROC Curves
        colors = cycle(['blue', 'red', 'green', 'orange', 'purple'])
        for i, color in zip(range(self.n_classes), colors):
            ax1.plot(roc_results['fpr'][i], roc_results['tpr'][i], color=color, lw=2,
                    label=f'{self.class_names[i]} (AUC = {roc_results["roc_auc"][i]:.3f})')

        ax1.plot(roc_results['fpr']["micro"], roc_results['tpr']["micro"],
                label=f'Micro-avg (AUC = {roc_results["roc_auc"]["micro"]:.3f})',
                color='deeppink', linestyle=':', linewidth=3)

        ax1.plot([0, 1], [0, 1], 'k--', lw=2, label='Random')
        ax1.set_xlim([0.0, 1.0])
        ax1.set_ylim([0.0, 1.05])
        ax1.set_xlabel('False Positive Rate')
        ax1.set_ylabel('True Positive Rate')
        ax1.set_title('ROC Curves')
        ax1.legend(loc="lower right", fontsize=9)
        ax1.grid(True, alpha=0.3)

        # PR Curves
        colors = cycle(['blue', 'red', 'green', 'orange', 'purple'])
        for i, color in zip(range(self.n_classes), colors):
            ax2.plot(pr_results['recall'][i], pr_results['precision'][i], color=color, lw=2,
                    label=f'{self.class_names[i]} (AP = {pr_results["pr_auc"][i]:.3f})')

        ax2.plot(pr_results['recall']["micro"], pr_results['precision']["micro"],
                label=f'Micro-avg (AP = {pr_results["pr_auc"]["micro"]:.3f})',
                color='deeppink', linestyle=':', linewidth=3)

        ax2.set_xlim([0.0, 1.0])
        ax2.set_ylim([0.0, 1.05])
        ax2.set_xlabel('Recall')
        ax2.set_ylabel('Precision')
        ax2.set_title('Precision-Recall Curves')
        ax2.legend(loc="lower left", fontsize=9)
        ax2.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.show()


def evaluate_lung_classifier_with_curves(classifier, x_test, y_test):
    """
    Evaluate the trained lung classifier with ROC and PR curves

    Parameters:
    -----------
    classifier : ImprovedLungSoundClassifier
        Trained classifier object
    x_test : array-like
        Test features
    y_test : array-like
        Test labels (one-hot encoded)

    Returns:
    --------
    dict : Comprehensive evaluation results
    """

    if classifier.model is None:
        print("‚ùå Model not trained yet!")
        return None

    # Get predictions
    y_pred_proba = classifier.model.predict(x_test, verbose=0)
    y_true = np.argmax(y_test, axis=1)

    # Initialize evaluator
    evaluator = ModelEvaluationMetrics(class_names=['Healthy', 'Asthma', 'COPD'])

    # Comprehensive evaluation
    results = evaluator.comprehensive_evaluation(y_true, y_pred_proba)

    # Combined plot
    print("\nüìä Combined ROC and PR Curves:")
    evaluator.plot_combined_metrics(y_true, y_pred_proba)

    return results


def evaluate_ensemble_with_curves(ensemble_classifier, x_test, y_test):
    """
    Evaluate ensemble classifier with ROC and PR curves

    Parameters:
    -----------
    ensemble_classifier : EnsembleLungClassifier
        Trained ensemble classifier
    x_test : array-like
        Test features
    y_test : array-like
        Test labels (one-hot encoded)

    Returns:
    --------
    dict : Comprehensive evaluation results
    """

    # Get ensemble predictions
    y_pred_proba = ensemble_classifier.predict_ensemble(x_test)
    if y_pred_proba is None:
        print("‚ùå Ensemble not trained yet!")
        return None

    y_true = np.argmax(y_test, axis=1)

    # Initialize evaluator
    evaluator = ModelEvaluationMetrics(class_names=['Healthy', 'Asthma', 'COPD'])

    # Comprehensive evaluation
    results = evaluator.comprehensive_evaluation(y_true, y_pred_proba)

    # Combined plot
    print("\nüìä Ensemble - Combined ROC and PR Curves:")
    evaluator.plot_combined_metrics(y_true, y_pred_proba)

    return results


def compare_models_curves(models_dict, x_test, y_test):
    """
    Compare multiple models using ROC curves

    Parameters:
    -----------
    models_dict : dict
        Dictionary of {'model_name': model} pairs
    x_test : array-like
        Test features
    y_test : array-like
        Test labels (one-hot encoded)
    """

    plt.figure(figsize=(12, 8))

    y_true = np.argmax(y_test, axis=1)
    colors = cycle(['blue', 'red', 'green', 'orange', 'purple', 'brown'])

    for (model_name, model), color in zip(models_dict.items(), colors):
        # Get predictions
        if hasattr(model, 'predict_ensemble'):
            y_pred_proba = model.predict_ensemble(x_test)
        else:
            y_pred_proba = model.predict(x_test, verbose=0)

        # Compute micro-average AUC
        y_true_bin = label_binarize(y_true, classes=range(3))
        micro_auc = roc_auc_score(y_true_bin, y_pred_proba, average='micro')

        # Compute micro-average ROC curve
        fpr_micro, tpr_micro, _ = roc_curve(y_true_bin.ravel(), y_pred_proba.ravel())

        plt.plot(fpr_micro, tpr_micro, color=color, lw=2,
                label=f'{model_name} (AUC = {micro_auc:.3f})')

    plt.plot([0, 1], [0, 1], 'k--', lw=2, label='Random Classifier')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate', fontsize=12)
    plt.ylabel('True Positive Rate', fontsize=12)
    plt.title('Model Comparison - ROC Curves (Micro-Average)', fontsize=14, fontweight='bold')
    plt.legend(loc="lower right", fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()


# Usage Examples
def usage_examples():
    """
    Example usage of the ROC and AUC evaluation functions
    """

    print("\nüìù Usage Examples:")
    print("=" * 50)

    print("\n1. Single Model Evaluation:")
    results = evaluate_lung_classifier_with_curves(classifier, x_test_gru, y_test_gru)


    print("\n2. Ensemble Model Evaluation:")
    print("```python")
    print("# After training ensemble")
    print("results = evaluate_ensemble_with_curves(ensemble, x_test, y_test)")
    print("```")

    print("\n3. Model Comparison:")
    print("```python")
    print("models_dict = {")
    print("    'Simple Model': classifier1.model,")
    print("    'Attention Model': classifier2.model,")
    print("    'Ensemble': ensemble_classifier")
    print("}")
    print("compare_models_curves(models_dict, x_test, y_test)")
    print("```")

    print("\n4. Custom Evaluation:")
    print("```python")
    print("evaluator = ModelEvaluationMetrics()")
    print("results = evaluator.comprehensive_evaluation(y_true, y_pred_proba)")
    print("```")

    print("\nüí° Key Metrics Explained:")
    print("‚Ä¢ ROC AUC: Area under ROC curve (0.5 = random, 1.0 = perfect)")
    print("‚Ä¢ PR AUC: Area under Precision-Recall curve (accounts for class imbalance)")
    print("‚Ä¢ Micro-average: Global metric across all classes")
    print("‚Ä¢ Macro-average: Average of per-class metrics")
    print("‚Ä¢ Higher values indicate better performance")

if __name__ == "__main__":
    usage_examples()

# **unseen data prediction**
# ü´Å Lung Sound Classification (Normal | Asthma | COPD)

 performs lung sound classification using a pre-trained deep learning model.

---

In [None]:
import numpy as np
import librosa
import joblib
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.utils import register_keras_serializable
import warnings
from scipy import stats
from scipy.signal import hilbert
import pywt
from sklearn.preprocessing import StandardScaler
import os

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

# Define the custom metric that was used during training
@register_keras_serializable()
def weighted_categorical_accuracy(y_true, y_pred):
    """Custom weighted categorical accuracy metric"""
    return tf.keras.metrics.categorical_accuracy(y_true, y_pred)

class LungSoundPredictor:
    """
    Complete Lung Sound Classification System for Unseen Data
    Supports: ASTHMA, COPD, NORMAL classification
    """

    def __init__(self, model_path='lung_sound_model_best.keras', scaler_path='scaler.pkl'):
        """
        Initialize the predictor with trained model and scaler

        Parameters:
        - model_path: Path to saved Keras model
        - scaler_path: Path to saved StandardScaler
        """
        self.model_path = model_path
        self.scaler_path = scaler_path
        self.model = None
        self.scaler = None
        self.class_names = ['Normal', 'Asthma', 'COPD']

        # Load model and scaler
        self.load_model_and_scaler()

    def load_model_and_scaler(self):
        """Load the trained model and feature scaler"""
        try:
            # Load the trained model with custom objects
            if os.path.exists(self.model_path):
                custom_objects = {
                    'weighted_categorical_accuracy': weighted_categorical_accuracy
                }
                self.model = load_model(self.model_path, custom_objects=custom_objects)
                print(f"‚úÖ Model loaded successfully from: {self.model_path}")
            else:
                print(f"‚ùå Model file not found: {self.model_path}")
                print("Please ensure you have trained and saved the model first.")
                return False

            # Load the scaler
            if os.path.exists(self.scaler_path):
                self.scaler = joblib.load(self.scaler_path)
                print(f"‚úÖ Scaler loaded successfully from: {self.scaler_path}")
            else:
                print(f"‚ùå Scaler file not found: {self.scaler_path}")
                print("Please ensure you have the scaler.pkl file from training.")
                return False

            return True

        except Exception as e:
            print(f"‚ùå Error loading model or scaler: {e}")

            # Alternative loading method - compile=False
            try:
                print("üîÑ Trying alternative loading method...")
                self.model = load_model(self.model_path, compile=False)

                # Recompile the model with standard metrics
                self.model.compile(
                    optimizer='adam',
                    loss='categorical_crossentropy',
                    metrics=['accuracy']
                )
                print(f"‚úÖ Model loaded successfully with alternative method")

                # Load scaler
                if os.path.exists(self.scaler_path):
                    self.scaler = joblib.load(self.scaler_path)
                    print(f"‚úÖ Scaler loaded successfully from: {self.scaler_path}")
                    return True
                else:
                    print(f"‚ùå Scaler file not found: {self.scaler_path}")
                    return False

            except Exception as e2:
                print(f"‚ùå Alternative loading method also failed: {e2}")
                return False

    # Feature extraction functions (same as your training code)
    def add_noise(self, data, noise_level=0.005):
        """Add Gaussian noise to audio data with improved stability."""
        if len(data) == 0:
            return data
        noise = np.random.randn(len(data)) * noise_level
        noisy_data = data + noise
        return np.clip(noisy_data, -1.0, 1.0)

    def extract_advanced_mfcc_features(self, data, sampling_rate, n_mfcc=40):
        """Extract MFCCs with delta and delta-delta features plus statistical moments."""
        try:
            if len(data) == 0:
                return np.zeros(n_mfcc * 9)

            # Extract MFCCs
            mfccs = librosa.feature.mfcc(
                y=data,
                sr=sampling_rate,
                n_mfcc=n_mfcc,
                n_fft=2048,
                hop_length=512
            )

            if mfccs.shape[1] == 0:
                return np.zeros(n_mfcc * 9)

            # Compute Delta and Delta-Delta
            delta_mfccs = librosa.feature.delta(mfccs)
            delta2_mfccs = librosa.feature.delta(mfccs, order=2)

            # Statistical moments
            mfcc_mean = np.mean(mfccs, axis=1)
            mfcc_std = np.std(mfccs, axis=1)
            mfcc_skew = stats.skew(mfccs, axis=1)

            delta_mean = np.mean(delta_mfccs, axis=1)
            delta_std = np.std(delta_mfccs, axis=1)
            delta_skew = stats.skew(delta_mfccs, axis=1)

            delta2_mean = np.mean(delta2_mfccs, axis=1)
            delta2_std = np.std(delta2_mfccs, axis=1)
            delta2_skew = stats.skew(delta2_mfccs, axis=1)

            # Combine all MFCC-based features
            advanced_mfcc_features = np.concatenate([
                mfcc_mean, mfcc_std, mfcc_skew,
                delta_mean, delta_std, delta_skew,
                delta2_mean, delta2_std, delta2_skew
            ])

            return np.nan_to_num(advanced_mfcc_features, nan=0.0, posinf=0.0, neginf=0.0)

        except Exception as e:
            print(f"Warning: Advanced MFCC extraction failed: {e}")
            return np.zeros(n_mfcc * 9)

    def extract_fbse_features(self, data, sampling_rate, n_bands=10):
        """Extract Fourier-Bessel Spectral Entropy features."""
        try:
            if len(data) == 0:
                return np.zeros(n_bands)

            # Compute power spectral density
            stft = librosa.stft(data)
            psd = np.abs(stft)**2

            # Divide frequency range into bands
            freq_bands = np.linspace(0, sampling_rate//2, n_bands + 1)
            entropy_features = []

            for i in range(n_bands):
                start_idx = int(freq_bands[i] * len(psd) / (sampling_rate//2))
                end_idx = int(freq_bands[i+1] * len(psd) / (sampling_rate//2))

                if end_idx > start_idx:
                    band_psd = np.mean(psd[start_idx:end_idx], axis=0)
                    band_psd_norm = band_psd / (np.sum(band_psd) + 1e-10)
                    entropy = -np.sum(band_psd_norm * np.log(band_psd_norm + 1e-10))
                    entropy_features.append(np.mean(entropy))
                else:
                    entropy_features.append(0.0)

            return np.nan_to_num(np.array(entropy_features), nan=0.0, posinf=0.0, neginf=0.0)

        except Exception as e:
            print(f"Warning: FBSE extraction failed: {e}")
            return np.zeros(n_bands)

    def extract_enhanced_melspectrogram(self, data, sampling_rate, n_mels=128):
        """Extract enhanced Mel-spectrogram features."""
        try:
            if len(data) == 0:
                return np.zeros(n_mels * 4)

            mel_spec = librosa.feature.melspectrogram(
                y=data,
                sr=sampling_rate,
                n_mels=n_mels,
                n_fft=2048,
                hop_length=512,
                fmax=sampling_rate//2
            )

            if mel_spec.size == 0:
                return np.zeros(n_mels * 4)

            mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)

            if mel_spec_db.shape[1] > 0:
                mel_mean = np.mean(mel_spec_db, axis=1)
                mel_std = np.std(mel_spec_db, axis=1)
                mel_max = np.max(mel_spec_db, axis=1)
                mel_min = np.min(mel_spec_db, axis=1)
                mel_features = np.concatenate([mel_mean, mel_std, mel_max, mel_min])
            else:
                mel_features = np.zeros(n_mels * 4)

            return np.nan_to_num(mel_features, nan=0.0, posinf=0.0, neginf=0.0)

        except Exception as e:
            print(f"Warning: Enhanced Mel-spectrogram extraction failed: {e}")
            return np.zeros(n_mels * 4)

    def extract_wavelet_features(self, data, wavelet='db4', levels=5):
        """Extract wavelet features for transient detection."""
        try:
            if len(data) == 0:
                return np.zeros(levels * 4 + 4)

            coeffs = pywt.wavedec(data, wavelet, level=levels)
            wavelet_features = []

            for coeff in coeffs:
                if len(coeff) > 0:
                    wavelet_features.extend([
                        np.mean(np.abs(coeff)),
                        np.std(coeff),
                        np.max(np.abs(coeff)),
                        np.sum(coeff**2)
                    ])
                else:
                    wavelet_features.extend([0.0, 0.0, 0.0, 0.0])

            return np.nan_to_num(np.array(wavelet_features), nan=0.0, posinf=0.0, neginf=0.0)

        except Exception as e:
            print(f"Warning: Wavelet feature extraction failed: {e}")
            return np.zeros(levels * 4 + 4)

    def extract_sequence_features(self, data, sampling_rate, frame_length=2048, hop_length=512):
        """Extract sequence-based features."""
        try:
            if len(data) == 0:
                return np.zeros(26)

            mfccs = librosa.feature.mfcc(
                y=data,
                sr=sampling_rate,
                n_mfcc=13,
                n_fft=frame_length,
                hop_length=hop_length
            )

            if mfccs.shape[1] == 0:
                return np.zeros(26)

            frame_variations = np.mean(np.abs(np.diff(mfccs, axis=1)), axis=1)
            long_term_mean = np.mean(mfccs, axis=1)
            sequence_features = np.concatenate([frame_variations, long_term_mean])

            return np.nan_to_num(sequence_features, nan=0.0, posinf=0.0, neginf=0.0)

        except Exception as e:
            print(f"Warning: Sequence feature extraction failed: {e}")
            return np.zeros(26)

    def extract_spectral_features(self, data, sampling_rate):
        """Extract spectral features with improved tonnetz handling."""
        try:
            if len(data) == 0:
                return np.zeros(7)

            spectral_centroid = librosa.feature.spectral_centroid(y=data, sr=sampling_rate)
            spectral_bandwidth = librosa.feature.spectral_bandwidth(y=data, sr=sampling_rate)
            spectral_rolloff = librosa.feature.spectral_rolloff(y=data, sr=sampling_rate)
            spectral_flatness = librosa.feature.spectral_flatness(y=data)
            zero_crossing_rate = librosa.feature.zero_crossing_rate(data)
            chroma = librosa.feature.chroma_stft(y=data, sr=sampling_rate)
            chroma_mean = np.mean(chroma)

            # Improved tonnetz handling
            try:
                if sampling_rate >= 8000:
                    tonnetz = librosa.feature.tonnetz(y=data, sr=sampling_rate)
                    tonnetz_mean = np.mean(tonnetz)
                elif sampling_rate >= 4000:
                    chroma_cqt = librosa.feature.chroma_cqt(
                        y=data,
                        sr=sampling_rate,
                        fmin=librosa.note_to_hz('C1'),
                        n_chroma=12
                    )
                    tonnetz_mean = np.mean(chroma_cqt) * 0.5
                else:
                    tonnetz_mean = 0.0
            except:
                tonnetz_mean = 0.0

            spectral_features = np.array([
                np.mean(spectral_centroid),
                np.mean(spectral_bandwidth),
                np.mean(spectral_rolloff),
                np.mean(spectral_flatness),
                np.mean(zero_crossing_rate),
                chroma_mean,
                tonnetz_mean
            ])

            return np.nan_to_num(spectral_features, nan=0.0, posinf=0.0, neginf=0.0)

        except Exception as e:
            print(f"Warning: Spectral feature extraction failed: {e}")
            return np.zeros(7)

    def fourier_bessel_features(self, data, sampling_rate, n_coeff=20):
        """Enhanced Fourier-Bessel feature extraction."""
        if len(data) == 0:
            return np.zeros(n_coeff)

        t = np.arange(len(data)) / sampling_rate
        fb_coeff = np.zeros(n_coeff)
        t_norm = t / np.max(t) if np.max(t) > 0 else t

        for i in range(n_coeff):
            j = i + 1
            cosine_term = np.cos(2 * np.pi * j * t_norm)
            fb_coeff[i] = np.sum(data * cosine_term) / len(data)

        return np.nan_to_num(fb_coeff, nan=0.0, posinf=0.0, neginf=0.0)

    def extract_features_from_audio(self, audio_file_path):
        """
        Extract all features from a single audio file

        Parameters:
        - audio_file_path: Path to the audio file

        Returns:
        - features: Normalized feature vector ready for prediction
        """
        try:
            # Load audio file
            data, sampling_rate = librosa.load(audio_file_path, sr=None)

            if len(data) == 0:
                print(f"‚ùå Empty audio file: {audio_file_path}")
                return None

            print(f"üéµ Processing: {os.path.basename(audio_file_path)}")
            print(f"   ‚Ä¢ Duration: {len(data)/sampling_rate:.2f}s")
            print(f"   ‚Ä¢ Sampling Rate: {sampling_rate}Hz")

            # Extract all features (same as training)
            n_mfcc = 40
            fb_coeffs = 20
            n_mels = 128
            wavelet_levels = 5
            fbse_bands = 10

            # Extract each feature type
            advanced_mfcc_features = self.extract_advanced_mfcc_features(data, sampling_rate, n_mfcc)
            fbse_features = self.extract_fbse_features(data, sampling_rate, fbse_bands)
            enhanced_mel_features = self.extract_enhanced_melspectrogram(data, sampling_rate, n_mels)
            wavelet_features = self.extract_wavelet_features(data, 'db4', wavelet_levels)
            sequence_features = self.extract_sequence_features(data, sampling_rate)
            spectral_features = self.extract_spectral_features(data, sampling_rate)
            fb_features = self.fourier_bessel_features(data, sampling_rate, fb_coeffs)

            # Combine all features
            combined_features = np.concatenate([
                advanced_mfcc_features,
                fbse_features,
                enhanced_mel_features,
                wavelet_features,
                sequence_features,
                spectral_features,
                fb_features
            ])

            # Final validation
            combined_features = np.nan_to_num(combined_features, nan=0.0, posinf=0.0, neginf=0.0)

            print(f"   ‚Ä¢ Features extracted: {len(combined_features)} dimensions")

            # Normalize using training scaler
            if self.scaler is not None:
                features_normalized = self.scaler.transform(combined_features.reshape(1, -1))
                return features_normalized[0]
            else:
                print("‚ùå Scaler not loaded. Cannot normalize features.")
                return None

        except Exception as e:
            print(f"‚ùå Error extracting features from {audio_file_path}: {e}")
            return None

    def predict_single_file(self, audio_file_path, show_confidence=True):
        """
        Predict lung condition for a single audio file

        Parameters:
        - audio_file_path: Path to the audio file
        - show_confidence: Whether to show confidence scores

        Returns:
        - prediction_result: Dictionary with prediction details
        """
        if self.model is None or self.scaler is None:
            print("‚ùå Model or scaler not loaded properly.")
            return None

        # Extract features
        features = self.extract_features_from_audio(audio_file_path)
        if features is None:
            return None

        try:
            # Reshape for model input (1, 1, feature_size) format
            features_reshaped = features.reshape(1, 1, -1)

            print(f"   ‚Ä¢ Input shape: {features_reshaped.shape}")

            # Make prediction
            prediction_proba = self.model.predict(features_reshaped, verbose=0)
            predicted_class_idx = np.argmax(prediction_proba[0])
            predicted_class = self.class_names[predicted_class_idx]
            confidence = prediction_proba[0][predicted_class_idx]

            # Prepare result
            result = {
                'file': os.path.basename(audio_file_path),
                'predicted_class': predicted_class,
                'confidence': confidence,
                'all_probabilities': {
                    self.class_names[i]: prediction_proba[0][i]
                    for i in range(len(self.class_names))
                }
            }

            # Display results
            print(f"\nüéØ Prediction Results:")
            print(f"   ‚Ä¢ File: {result['file']}")
            print(f"   ‚Ä¢ Predicted Class: {predicted_class}")
            print(f"   ‚Ä¢ Confidence: {confidence:.3f} ({confidence*100:.1f}%)")

            if show_confidence:
                print(f"   ‚Ä¢ Detailed Probabilities:")
                for class_name, prob in result['all_probabilities'].items():
                    print(f"     - {class_name}: {prob:.3f} ({prob*100:.1f}%)")

            # Confidence interpretation
            if confidence >= 0.8:
                confidence_level = "High"
                emoji = "üü¢"
            elif confidence >= 0.6:
                confidence_level = "Medium"
                emoji = "üü°"
            else:
                confidence_level = "Low"
                emoji = "üî¥"

            print(f"   ‚Ä¢ Confidence Level: {emoji} {confidence_level}")

            return result

        except Exception as e:
            print(f"‚ùå Error during prediction: {e}")
            return None

    def predict_multiple_files(self, audio_files_list):
        """
        Predict lung conditions for multiple audio files

        Parameters:
        - audio_files_list: List of audio file paths

        Returns:
        - results: List of prediction results
        """
        results = []

        print(f"üîÑ Processing {len(audio_files_list)} audio files...")
        print("=" * 50)

        for i, audio_file in enumerate(audio_files_list):
            print(f"\n[{i+1}/{len(audio_files_list)}] Processing: {os.path.basename(audio_file)}")

            result = self.predict_single_file(audio_file, show_confidence=False)
            if result is not None:
                results.append(result)

        # Summary
        if results:
            print(f"\nüìä Summary of {len(results)} successful predictions:")
            class_counts = {}
            for result in results:
                pred_class = result['predicted_class']
                class_counts[pred_class] = class_counts.get(pred_class, 0) + 1

            for class_name, count in class_counts.items():
                percentage = (count / len(results)) * 100
                print(f"   ‚Ä¢ {class_name}: {count} files ({percentage:.1f}%)")

        return results

    def get_model_info(self):
        """Display information about the loaded model"""
        if self.model is not None:
            print("üè• Model Information:")
            print(f"   ‚Ä¢ Input shape: {self.model.input_shape}")
            print(f"   ‚Ä¢ Output shape: {self.model.output_shape}")
            print(f"   ‚Ä¢ Total parameters: {self.model.count_params():,}")
            print(f"   ‚Ä¢ Classes: {self.class_names}")
            if self.scaler is not None:
                print(f"   ‚Ä¢ Feature dimensions: {len(self.scaler.mean_)}")
        else:
            print("‚ùå No model loaded")


# Simple usage functions
def predict_single_audio(audio_file_path, model_path='/content/simple_lung_model_best.keras', scaler_path='/content/scaler.pkl'):
    """
    Simple function to predict a single audio file

    Parameters:
    - audio_file_path: Path to your audio file
    - model_path: Path to saved model (default: 'lung_sound_model_best.keras')
    - scaler_path: Path to saved scaler (default: 'scaler.pkl')

    Returns:
    - prediction_result: Dictionary with prediction details
    """
    predictor = LungSoundPredictor(model_path, scaler_path)
    return predictor.predict_single_file(audio_file_path)

def predict_audio_directory(directory_path, model_path='/content/lung_sound_model_best.keras', scaler_path='scaler.pkl'):
    """
    Simple function to predict all audio files in a directory

    Parameters:
    - directory_path: Directory containing audio files
    - model_path: Path to saved model
    - scaler_path: Path to saved scaler

    Returns:
    - results: List of prediction results
    """
    import glob

    # Find all audio files in directory
    audio_extensions = ['*.wav', '*.mp3', '*.flac', '*.m4a']
    audio_files = []

    for ext in audio_extensions:
        audio_files.extend(glob.glob(os.path.join(directory_path, ext)))
        audio_files.extend(glob.glob(os.path.join(directory_path, ext.upper())))

    if not audio_files:
        print(f"‚ùå No audio files found in: {directory_path}")
        return []

    print(f"üìÅ Found {len(audio_files)} audio files in: {directory_path}")

    predictor = LungSoundPredictor(model_path, scaler_path)
    return predictor.predict_multiple_files(audio_files)


# Example usage
if __name__ == "__main__":
    print("ü´Å Lung Sound Classification for Unseen Data")
    print("=" * 50)

    # Example 1: Predict a single file
    print("\nüìñ Example 1: Predict single file")
    result = predict_single_audio('/content/BP50_N,N,P R L ,27,M.wav')

    # Example 2: Predict multiple files
    print("\nüìñ Example 2: Predict all files in directory")
    print("results = predict_audio_directory('path/to/your/audio/directory')")

    # Example 3: Using the class directly
    print("\nüìñ Example 3: Using the class directly")
    print("predictor = LungSoundPredictor()")
    print("predictor.get_model_info()  # Show model details")
    print("result = predictor.predict_single_file('audio_file.wav')")

    print("\n‚úÖ Ready to use! Make sure you have:")
    print("   1. Your trained model file (lung_sound_model_best.keras)")
    print("   2. Your scaler file (scaler.pkl)")
    print("   3. Audio files to classify")

    # Test if files exist
    print("\nüîç Checking for required files...")
    if os.path.exists('lung_sound_model_best.keras'):
        print("   ‚úÖ Model file found")
    else:
        print("   ‚ùå Model file 'lung_sound_model_best.keras' not found")

    if os.path.exists('scaler.pkl'):
        print("   ‚úÖ Scaler file found")
    else:
        print("   ‚ùå Scaler file 'scaler.pkl' not found")


# ans ==> /content/142_1b1_Pl_mc_LittC2SE.wav

#### aoc roc curve

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

# Define class labels
classes = ["Normal", "Asthma", "COPD"]

# Get predictions
preds = model.model.predict(x_test_gru)
classpreds = [np.argmax(t) for t in preds]
y_testclass = [np.argmax(t) for t in y_test_gru]

# Compute confusion matrix
cm = confusion_matrix(y_testclass, classpreds)

# Normalize confusion matrix to display accuracy per class
cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100  # Convert to percentage

# Plot confusion matrix
plt.figure(figsize=(8, 6), dpi=80, facecolor='w', edgecolor='k')
ax = sns.heatmap(cm_percent, cmap='Blues', annot=True, fmt='.2f', xticklabels=classes, yticklabels=classes)

# Labels and title
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix with Class-wise Accuracy (%)')

plt.show()


In [None]:
# Combine counts and percentages into annotation labels
labels = np.array([[f'{int(cm[i,j])}\n{cm_percent[i,j]:.1f}%' for j in range(len(classes))] for i in range(len(classes))])

plt.figure(figsize=(8, 6))
sns.heatmap(cm_percent, cmap='Blues', annot=labels, fmt='', xticklabels=classes, yticklabels=classes)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix (Counts & Accuracy %)')
plt.show()


In [None]:
print(classification_report(y_testclass, classpreds, target_names=classes))

# ü§ñü´Å Model 2 ‚Äì Ensemble-Based Lung Sound Classification

Welcome to **Model 2**, where we apply the power of **ensemble learning** to enhance lung sound classification accuracy. This model intelligently combines predictions from multiple base learners (e.g., GRU, LSTM, CNN) to improve generalization and reliability.

---

## üîç What Does It Detect?
This ensemble model classifies respiratory sounds into:
- ‚úÖ **Normal**
- ‚ö†Ô∏è **Asthma**
- üö® **COPD**

---

## üß† Why Ensemble Learning?
Ensemble techniques help improve robustness by:
- üîÅ Combining multiple deep learning models
- üß™ Reducing overfitting and bias
- üìà Improving performance on **unseen clinical data**

---


In [None]:
# Improved Lung Sound Classification - Focus on Generalization
# Addresses overfitting issues for better unseen data performance

import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, Dense, Dropout, BatchNormalization, GRU, LSTM,
    Bidirectional, GlobalAveragePooling1D, GlobalMaxPooling1D,
    Concatenate, Conv1D, LeakyReLU, SpatialDropout1D,
    MultiHeadAttention, LayerNormalization
)
from tensorflow.keras.regularizers import l2
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import (
    EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
)
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report, confusion_matrix, f1_score
import matplotlib.pyplot as plt
import seaborn as sns

class ImprovedLungSoundClassifier:
    """
    Improved Neural Network for Lung Sound Classification
    Focus: Better generalization on unseen data
    """

    def __init__(self, input_shape=(1, 959), num_classes=3):
        self.input_shape = input_shape
        self.num_classes = num_classes
        self.model = None
        self.history = None
        self.class_weights = None

    def create_simple_model(self):
        """Create a simpler, more generalizable model"""
        inputs = Input(shape=self.input_shape, name='lung_sound_input')
        x = inputs

        # Simple feature extraction with lighter regularization
        x = Conv1D(filters=32, kernel_size=7, padding='same')(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.1)(x)
        x = SpatialDropout1D(0.2)(x)

        # Single RNN layer to reduce complexity
        x = Bidirectional(
            GRU(64, return_sequences=True, dropout=0.3, recurrent_dropout=0.3)
        )(x)
        x = BatchNormalization()(x)

        # Global pooling
        avg_pool = GlobalAveragePooling1D()(x)
        max_pool = GlobalMaxPooling1D()(x)
        x = Concatenate()([avg_pool, max_pool])

        # Simpler dense layers
        x = Dense(64, kernel_regularizer=l2(0.001))(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.1)(x)
        x = Dropout(0.5)(x)

        # Output layer
        outputs = Dense(self.num_classes, activation='softmax', name='output')(x)

        model = Model(inputs=inputs, outputs=outputs, name='SimpleLungClassifier')
        return model

    def create_attention_model(self):
        """Create attention-based model for better feature learning"""
        inputs = Input(shape=self.input_shape, name='lung_sound_input')
        x = inputs

        # Light conv preprocessing
        x = Conv1D(filters=32, kernel_size=5, padding='same')(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.1)(x)
        x = SpatialDropout1D(0.2)(x)

        # RNN processing
        x = Bidirectional(GRU(48, return_sequences=True, dropout=0.2, recurrent_dropout=0.2))(x)
        x = LayerNormalization()(x)

        # Self-attention mechanism
        attention = MultiHeadAttention(
            num_heads=4,
            key_dim=48,
            dropout=0.2
        )(x, x)
        x = LayerNormalization()(x + attention)  # Residual connection

        # Global pooling
        x = GlobalAveragePooling1D()(x)

        # Classification head
        x = Dense(32, kernel_regularizer=l2(0.001))(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.1)(x)
        x = Dropout(0.4)(x)

        outputs = Dense(self.num_classes, activation='softmax', name='output')(x)

        model = Model(inputs=inputs, outputs=outputs, name='AttentionLungClassifier')
        return model

    def create_ensemble_ready_model(self, model_variant='simple'):
        """Create different model variants for ensemble"""
        if model_variant == 'simple':
            return self.create_simple_model()
        elif model_variant == 'attention':
            return self.create_attention_model()
        elif model_variant == 'lstm':
            return self.create_lstm_model()
        else:
            return self.create_simple_model()

    def create_lstm_model(self):
        """LSTM variant for ensemble diversity"""
        inputs = Input(shape=self.input_shape, name='lung_sound_input')
        x = inputs

        # Feature extraction
        x = Conv1D(filters=32, kernel_size=3, padding='same')(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.1)(x)
        x = SpatialDropout1D(0.2)(x)

        # LSTM layers
        x = Bidirectional(LSTM(48, return_sequences=True, dropout=0.3, recurrent_dropout=0.3))(x)
        x = BatchNormalization()(x)

        # Pooling
        x = GlobalAveragePooling1D()(x)

        # Classification
        x = Dense(48, kernel_regularizer=l2(0.001))(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.1)(x)
        x = Dropout(0.4)(x)

        outputs = Dense(self.num_classes, activation='softmax', name='output')(x)

        model = Model(inputs=inputs, outputs=outputs, name='LSTMLungClassifier')
        return model

    def compile_model(self, model, learning_rate=0.0005):
        """Compile with conservative settings for better generalization"""
        optimizer = Adam(
            learning_rate=learning_rate,
            clipnorm=1.0
        )

        model.compile(
            optimizer=optimizer,
            loss='categorical_crossentropy',
            metrics=['accuracy', 'precision', 'recall']
        )

        return model

    def compute_class_weights(self, y_train):
        """Compute balanced class weights"""
        y_indices = np.argmax(y_train, axis=1)
        classes = np.unique(y_indices)
        class_weights = compute_class_weight('balanced', classes=classes, y=y_indices)
        self.class_weights = dict(zip(classes, class_weights))

        print(f"üìä Class weights:")
        class_names = ['Healthy', 'Asthma', 'COPD']
        for i, weight in self.class_weights.items():
            print(f"   ‚Ä¢ {class_names[i]}: {weight:.3f}")

        return self.class_weights

    def create_callbacks(self, model_name='improved_lung_model'):
        """Conservative callbacks for better generalization"""
        callbacks = [
            EarlyStopping(
                monitor='val_loss',
                patience=15,  # Shorter patience to prevent overfitting
                restore_best_weights=True,
                verbose=1,
                min_delta=0.001
            ),
            ReduceLROnPlateau(
                monitor='val_loss',
                factor=0.5,
                patience=5,  # Reduce LR more aggressively
                min_lr=1e-6,
                verbose=1
            ),
            ModelCheckpoint(
                filepath=f'{model_name}_best.keras',
                monitor='val_accuracy',
                save_best_only=True,
                verbose=1,
                mode='max'
            )
        ]

        return callbacks

    def train_with_data_augmentation(self, x_train, y_train, x_val, y_val,
                                   model_type='esamble', epochs=80, batch_size=32):
        """Train with data augmentation for better generalization"""

        print(f"üöÄ Training {model_type} model with data augmentation...")
        print(f"   ‚Ä¢ Training samples: {x_train.shape[0]}")
        print(f"   ‚Ä¢ Validation samples: {x_val.shape[0]}")

        # Data augmentation
        x_train_aug, y_train_aug = self.augment_data(x_train, y_train)
        print(f"   ‚Ä¢ Augmented training samples: {x_train_aug.shape[0]}")

        # Create model
        self.model = self.create_ensemble_ready_model(model_type)
        self.model = self.compile_model(self.model)

        # Compute class weights
        class_weights = self.compute_class_weights(y_train_aug)

        # Callbacks
        callbacks = self.create_callbacks(f'{model_type}_lung_model')

        print(f"\nüèóÔ∏è Model Architecture ({model_type}):")
        print(f"   ‚Ä¢ Total parameters: {self.model.count_params():,}")

        # Train
        self.history = self.model.fit(
            x_train_aug, y_train_aug,
            validation_data=(x_val, y_val),
            epochs=epochs,
            batch_size=batch_size,
            callbacks=callbacks,
            class_weight=class_weights,
            verbose=1,
            shuffle=True
        )

        print("‚úÖ Training completed!")
        return self.history

    def augment_data(self, x_data, y_data, augment_factor=0.5):
        """Simple data augmentation techniques"""
        augmented_x = []
        augmented_y = []

        # Original data
        augmented_x.append(x_data)
        augmented_y.append(y_data)

        n_augment = int(len(x_data) * augment_factor)
        indices = np.random.choice(len(x_data), n_augment, replace=True)

        for idx in indices:
            sample = x_data[idx].copy()
            label = y_data[idx].copy()

            # Random noise addition (5% of signal std)
            noise_level = 0.05 * np.std(sample)
            sample += np.random.normal(0, noise_level, sample.shape)

            # Random scaling (¬±10%)
            scale_factor = np.random.uniform(0.9, 1.1)
            sample *= scale_factor

            augmented_x.append(sample[np.newaxis, :])
            augmented_y.append(label[np.newaxis, :])

        return np.vstack(augmented_x), np.vstack(augmented_y)

    def evaluate_model(self, x_test, y_test):
        """Comprehensive evaluation"""
        if self.model is None:
            print("‚ùå Model not trained yet!")
            return

        # Predictions
        y_pred_proba = self.model.predict(x_test, verbose=0)
        y_pred = np.argmax(y_pred_proba, axis=1)
        y_true = np.argmax(y_test, axis=1)

        # Metrics
        test_loss, test_acc, test_prec, test_rec = self.model.evaluate(x_test, y_test, verbose=0)
        test_f1 = f1_score(y_true, y_pred, average='weighted')

        print(f"\nüìä Test Performance:")
        print(f"   ‚Ä¢ Accuracy: {test_acc:.4f}")
        print(f"   ‚Ä¢ Precision: {test_prec:.4f}")
        print(f"   ‚Ä¢ Recall: {test_rec:.4f}")
        print(f"   ‚Ä¢ F1-Score: {test_f1:.4f}")

        # Per-class metrics
        class_names = ['Healthy', 'Asthma', 'COPD']
        print(f"\nüìã Classification Report:")
        print(classification_report(y_true, y_pred, target_names=class_names))

        # Confusion matrix
        cm = confusion_matrix(y_true, y_pred)
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                   xticklabels=class_names, yticklabels=class_names)
        plt.title('Confusion Matrix')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.tight_layout()
        plt.show()

        return {
            'accuracy': test_acc,
            'precision': test_prec,
            'recall': test_rec,
            'f1_score': test_f1,
            'predictions': y_pred_proba,
            'confusion_matrix': cm
        }

    def plot_training_history(self):
        """Plot training curves to check for overfitting"""
        if self.history is None:
            print("‚ùå No training history!")
            return

        fig, axes = plt.subplots(2, 2, figsize=(12, 8))

        # Accuracy
        axes[0, 0].plot(self.history.history['accuracy'], label='Train', color='blue')
        axes[0, 0].plot(self.history.history['val_accuracy'], label='Val', color='orange')
        axes[0, 0].set_title('Accuracy')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)

        # Loss
        axes[0, 1].plot(self.history.history['loss'], label='Train', color='blue')
        axes[0, 1].plot(self.history.history['val_loss'], label='Val', color='orange')
        axes[0, 1].set_title('Loss')
        axes[0, 1].legend()
        axes[0, 1].grid(True, alpha=0.3)

        # Precision
        axes[1, 0].plot(self.history.history['precision'], label='Train', color='blue')
        axes[1, 0].plot(self.history.history['val_precision'], label='Val', color='orange')
        axes[1, 0].set_title('Precision')
        axes[1, 0].legend()
        axes[1, 0].grid(True, alpha=0.3)

        # Recall
        axes[1, 1].plot(self.history.history['recall'], label='Train', color='blue')
        axes[1, 1].plot(self.history.history['val_recall'], label='Val', color='orange')
        axes[1, 1].set_title('Recall')
        axes[1, 1].legend()
        axes[1, 1].grid(True, alpha=0.3)

        plt.tight_layout()
        plt.show()


class EnsembleLungClassifier:
    """Ensemble approach for robust predictions"""

    def __init__(self, input_shape=(1, 959), num_classes=3):
        self.input_shape = input_shape
        self.num_classes = num_classes
        self.models = []
        self.model_types = ['simple', 'attention', 'lstm']

    def train_ensemble(self, x_train, y_train, x_val, y_val, epochs=60):
        """Train ensemble of diverse models"""
        print("üéØ Training Ensemble of Models...")

        for model_type in self.model_types:
            print(f"\nüîÑ Training {model_type} model...")

            classifier = ImprovedLungSoundClassifier(self.input_shape, self.num_classes)
            history = classifier.train_with_data_augmentation(
                x_train, y_train, x_val, y_val,
                model_type=model_type,
                epochs=epochs,
                batch_size=32
            )

            self.models.append(classifier.model)
            print(f"‚úÖ {model_type} model trained!")

        return self.models

    def predict_ensemble(self, x_test):
        """Make ensemble predictions"""
        if not self.models:
            print("‚ùå No models trained!")
            return None

        predictions = []
        for model in self.models:
            pred = model.predict(x_test, verbose=0)
            predictions.append(pred)

        # Average predictions
        ensemble_pred = np.mean(predictions, axis=0)
        return ensemble_pred

    def evaluate_ensemble(self, x_test, y_test):
        """Evaluate ensemble performance"""
        ensemble_pred = self.predict_ensemble(x_test)
        if ensemble_pred is None:
            return None

        y_pred = np.argmax(ensemble_pred, axis=1)
        y_true = np.argmax(y_test, axis=1)

        # Calculate metrics
        from sklearn.metrics import accuracy_score, precision_score, recall_score

        accuracy = accuracy_score(y_true, y_pred)
        precision = precision_score(y_true, y_pred, average='weighted')
        recall = recall_score(y_true, y_pred, average='weighted')
        f1 = f1_score(y_true, y_pred, average='weighted')

        print(f"\nüéØ Ensemble Performance:")
        print(f"   ‚Ä¢ Accuracy: {accuracy:.4f}")
        print(f"   ‚Ä¢ Precision: {precision:.4f}")
        print(f"   ‚Ä¢ Recall: {recall:.4f}")
        print(f"   ‚Ä¢ F1-Score: {f1:.4f}")

        return {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1_score': f1,
            'predictions': ensemble_pred
        }


# Improved Training Pipeline
def train_improved_lung_classifier(x_train, y_train, x_val, y_val, x_test, y_test,
                                 use_ensemble=True):
    """
    Improved training pipeline focused on generalization

    Key improvements:
    1. Simpler architectures to reduce overfitting
    2. Data augmentation for better generalization
    3. Conservative training settings
    4. Ensemble option for robust predictions
    """

    print("üéØ Improved Lung Sound Classification Pipeline")
    print("Focus: Better generalization on unseen data")
    print("=" * 60)

    if use_ensemble:
        # Train ensemble
        ensemble = EnsembleLungClassifier()
        models = ensemble.train_ensemble(x_train, y_train, x_val, y_val)
        results = ensemble.evaluate_ensemble(x_test, y_test)
        return ensemble, results
    else:
        # Train single improved model
        classifier = ImprovedLungSoundClassifier()

        # Try simple model first
        history = classifier.train_with_data_augmentation(
            x_train, y_train, x_val, y_val,
            model_type='ensemble ',
            epochs=150,
            batch_size=32
        )

        # Evaluate
        results = classifier.evaluate_model(x_test, y_test)

        # Plot training curves
        classifier.plot_training_history()

        return classifier, results


# Usage instructions
def usage_example():
    """How to use the improved classifier"""
    print("\nüìù Usage Example:")
    print("# For single improved model:")
    # classifier, results = train_improved_lung_classifier(
    #   x_train_gru, y_train_gru, x_val_gru, y_val_gru, x_test_gru, y_test_gru)
    print(")")
    print("\n# For ensemble approach (better but slower):")
    classifier, results = train_improved_lung_classifier(
    x_train_gru, y_train_gru, x_val_gru, y_val_gru, x_test_gru, y_test_gru,)
    print("    use_ensemble=True")
    print(")")

    print("\nüí° Key Improvements:")
    print("‚Ä¢ Simpler architecture to prevent overfitting")
    print("‚Ä¢ Data augmentation for better generalization")
    print("‚Ä¢ Conservative training with early stopping")
    print("‚Ä¢ Ensemble option for robust predictions")
    print("‚Ä¢ Better regularization strategies")

if __name__ == "__main__":
    usage_example()

# üß†‚ú® Model 2 ‚Äì BiGRU + Attention-Based Lung Sound Classification

This notebook uses a **pure attention-based deep learning model** to classify lung sound recordings into clinical categories. Attention helps the model **focus on important parts** of the signal ‚Äî just like a doctor listens for subtle patterns.

---

## üîç Target Classes:
- ‚úÖ **Normal**
- üå¨Ô∏è **Asthma**
- üòÆ‚Äçüí® **COPD**

---

## üåü Why Attention Models?
Attention mechanisms enable:
- üéØ Focus on the most relevant acoustic features
- üîÑ Better temporal dynamics over raw sequential models
- üìà Improved accuracy with fewer parameters compared to large CNN stacks

---

In [None]:
# Improved Lung Sound Classification - Focus on Generalization
# Addresses overfitting issues for better unseen data performance

import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, Dense, Dropout, BatchNormalization, GRU, LSTM,
    Bidirectional, GlobalAveragePooling1D, GlobalMaxPooling1D,
    Concatenate, Conv1D, LeakyReLU, SpatialDropout1D,
    MultiHeadAttention, LayerNormalization
)
from tensorflow.keras.regularizers import l2
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import (
    EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
)
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report, confusion_matrix, f1_score
import matplotlib.pyplot as plt
import seaborn as sns

class ImprovedLungSoundClassifier:
    """
    Improved Neural Network for Lung Sound Classification
    Focus: Better generalization on unseen data
    """

    def __init__(self, input_shape=(1, 959), num_classes=3):
        self.input_shape = input_shape
        self.num_classes = num_classes
        self.model = None
        self.history = None
        self.class_weights = None

    def create_simple_model(self):
        """Create a simpler, more generalizable model"""
        inputs = Input(shape=self.input_shape, name='lung_sound_input')
        x = inputs

        # Simple feature extraction with lighter regularization
        x = Conv1D(filters=32, kernel_size=7, padding='same')(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.1)(x)
        x = SpatialDropout1D(0.2)(x)

        # Single RNN layer to reduce complexity
        x = Bidirectional(
            GRU(64, return_sequences=True, dropout=0.3, recurrent_dropout=0.3)
        )(x)
        x = BatchNormalization()(x)

        # Global pooling
        avg_pool = GlobalAveragePooling1D()(x)
        max_pool = GlobalMaxPooling1D()(x)
        x = Concatenate()([avg_pool, max_pool])

        # Simpler dense layers
        x = Dense(64, kernel_regularizer=l2(0.001))(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.1)(x)
        x = Dropout(0.5)(x)

        # Output layer
        outputs = Dense(self.num_classes, activation='softmax', name='output')(x)

        model = Model(inputs=inputs, outputs=outputs, name='SimpleLungClassifier')
        return model

    def create_attention_model(self):
        """Create attention-based model for better feature learning"""
        inputs = Input(shape=self.input_shape, name='lung_sound_input')
        x = inputs

        # Light conv preprocessing
        x = Conv1D(filters=32, kernel_size=5, padding='same')(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.1)(x)
        x = SpatialDropout1D(0.2)(x)

        # RNN processing
        x = Bidirectional(GRU(48, return_sequences=True, dropout=0.2, recurrent_dropout=0.2))(x)
        x = LayerNormalization()(x)

        # Self-attention mechanism
        attention = MultiHeadAttention(
            num_heads=4,
            key_dim=48,
            dropout=0.2
        )(x, x)
        x = LayerNormalization()(x + attention)  # Residual connection

        # Global pooling
        x = GlobalAveragePooling1D()(x)

        # Classification head
        x = Dense(32, kernel_regularizer=l2(0.001))(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.1)(x)
        x = Dropout(0.4)(x)

        outputs = Dense(self.num_classes, activation='softmax', name='output')(x)

        model = Model(inputs=inputs, outputs=outputs, name='AttentionLungClassifier')
        return model

    def create_ensemble_ready_model(self, model_variant='attention'):
        """Create different model variants for ensemble"""
        if model_variant == 'simple':
            return self.create_simple_model()
        elif model_variant == 'attention':
            return self.create_attention_model()
        elif model_variant == 'lstm':
            return self.create_lstm_model()
        else:
            return self.create_simple_model()

    def create_lstm_model(self):
        """LSTM variant for ensemble diversity"""
        inputs = Input(shape=self.input_shape, name='lung_sound_input')
        x = inputs

        # Feature extraction
        x = Conv1D(filters=32, kernel_size=3, padding='same')(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.1)(x)
        x = SpatialDropout1D(0.2)(x)

        # LSTM layers
        x = Bidirectional(LSTM(48, return_sequences=True, dropout=0.3, recurrent_dropout=0.3))(x)
        x = BatchNormalization()(x)

        # Pooling
        x = GlobalAveragePooling1D()(x)

        # Classification
        x = Dense(48, kernel_regularizer=l2(0.001))(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.1)(x)
        x = Dropout(0.4)(x)

        outputs = Dense(self.num_classes, activation='softmax', name='output')(x)

        model = Model(inputs=inputs, outputs=outputs, name='LSTMLungClassifier')
        return model

    def compile_model(self, model, learning_rate=0.0005):
        """Compile with conservative settings for better generalization"""
        optimizer = Adam(
            learning_rate=learning_rate,
            clipnorm=1.0
        )

        model.compile(
            optimizer=optimizer,
            loss='categorical_crossentropy',
            metrics=['accuracy', 'precision', 'recall']
        )

        return model

    def compute_class_weights(self, y_train):
        """Compute balanced class weights"""
        y_indices = np.argmax(y_train, axis=1)
        classes = np.unique(y_indices)
        class_weights = compute_class_weight('balanced', classes=classes, y=y_indices)
        self.class_weights = dict(zip(classes, class_weights))

        print(f"üìä Class weights:")
        class_names = ['Healthy', 'Asthma', 'COPD']
        for i, weight in self.class_weights.items():
            print(f"   ‚Ä¢ {class_names[i]}: {weight:.3f}")

        return self.class_weights

    def create_callbacks(self, model_name='improved_lung_model'):
        """Conservative callbacks for better generalization"""
        callbacks = [
            EarlyStopping(
                monitor='val_loss',
                patience=15,  # Shorter patience to prevent overfitting
                restore_best_weights=True,
                verbose=1,
                min_delta=0.001
            ),
            ReduceLROnPlateau(
                monitor='val_loss',
                factor=0.5,
                patience=5,  # Reduce LR more aggressively
                min_lr=1e-6,
                verbose=1
            ),
            ModelCheckpoint(
                filepath=f'{model_name}_best.keras',
                monitor='val_accuracy',
                save_best_only=True,
                verbose=1,
                mode='max'
            )
        ]

        return callbacks

    def train_with_data_augmentation(self, x_train, y_train, x_val, y_val,
                                   model_type='attention', epochs=80, batch_size=32):
        """Train with data augmentation for better generalization"""

        print(f"üöÄ Training {model_type} model with data augmentation...")
        print(f"   ‚Ä¢ Training samples: {x_train.shape[0]}")
        print(f"   ‚Ä¢ Validation samples: {x_val.shape[0]}")

        # Data augmentation
        x_train_aug, y_train_aug = self.augment_data(x_train, y_train)
        print(f"   ‚Ä¢ Augmented training samples: {x_train_aug.shape[0]}")

        # Create model
        self.model = self.create_ensemble_ready_model(model_type)
        self.model = self.compile_model(self.model)

        # Compute class weights
        class_weights = self.compute_class_weights(y_train_aug)

        # Callbacks
        callbacks = self.create_callbacks(f'{model_type}_lung_model')

        print(f"\nüèóÔ∏è Model Architecture ({model_type}):")
        print(f"   ‚Ä¢ Total parameters: {self.model.count_params():,}")

        # Train
        self.history = self.model.fit(
            x_train_aug, y_train_aug,
            validation_data=(x_val, y_val),
            epochs=epochs,
            batch_size=batch_size,
            callbacks=callbacks,
            class_weight=class_weights,
            verbose=1,
            shuffle=True
        )

        print("‚úÖ Training completed!")
        return self.history

    def augment_data(self, x_data, y_data, augment_factor=0.5):
        """Simple data augmentation techniques"""
        augmented_x = []
        augmented_y = []

        # Original data
        augmented_x.append(x_data)
        augmented_y.append(y_data)

        n_augment = int(len(x_data) * augment_factor)
        indices = np.random.choice(len(x_data), n_augment, replace=True)

        for idx in indices:
            sample = x_data[idx].copy()
            label = y_data[idx].copy()

            # Random noise addition (5% of signal std)
            noise_level = 0.05 * np.std(sample)
            sample += np.random.normal(0, noise_level, sample.shape)

            # Random scaling (¬±10%)
            scale_factor = np.random.uniform(0.9, 1.1)
            sample *= scale_factor

            augmented_x.append(sample[np.newaxis, :])
            augmented_y.append(label[np.newaxis, :])

        return np.vstack(augmented_x), np.vstack(augmented_y)

    def evaluate_model(self, x_test, y_test):
        """Comprehensive evaluation"""
        if self.model is None:
            print("‚ùå Model not trained yet!")
            return

        # Predictions
        y_pred_proba = self.model.predict(x_test, verbose=0)
        y_pred = np.argmax(y_pred_proba, axis=1)
        y_true = np.argmax(y_test, axis=1)

        # Metrics
        test_loss, test_acc, test_prec, test_rec = self.model.evaluate(x_test, y_test, verbose=0)
        test_f1 = f1_score(y_true, y_pred, average='weighted')

        print(f"\nüìä Test Performance:")
        print(f"   ‚Ä¢ Accuracy: {test_acc:.4f}")
        print(f"   ‚Ä¢ Precision: {test_prec:.4f}")
        print(f"   ‚Ä¢ Recall: {test_rec:.4f}")
        print(f"   ‚Ä¢ F1-Score: {test_f1:.4f}")

        # Per-class metrics
        class_names = ['Healthy', 'Asthma', 'COPD']
        print(f"\nüìã Classification Report:")
        print(classification_report(y_true, y_pred, target_names=class_names))

        # Confusion matrix
        cm = confusion_matrix(y_true, y_pred)
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                   xticklabels=class_names, yticklabels=class_names)
        plt.title('Confusion Matrix')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.tight_layout()
        plt.show()

        return {
            'accuracy': test_acc,
            'precision': test_prec,
            'recall': test_rec,
            'f1_score': test_f1,
            'predictions': y_pred_proba,
            'confusion_matrix': cm
        }

    def plot_training_history(self):
        """Plot training curves to check for overfitting"""
        if self.history is None:
            print("‚ùå No training history!")
            return

        fig, axes = plt.subplots(2, 2, figsize=(12, 8))

        # Accuracy
        axes[0, 0].plot(self.history.history['accuracy'], label='Train', color='blue')
        axes[0, 0].plot(self.history.history['val_accuracy'], label='Val', color='orange')
        axes[0, 0].set_title('Accuracy')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)

        # Loss
        axes[0, 1].plot(self.history.history['loss'], label='Train', color='blue')
        axes[0, 1].plot(self.history.history['val_loss'], label='Val', color='orange')
        axes[0, 1].set_title('Loss')
        axes[0, 1].legend()
        axes[0, 1].grid(True, alpha=0.3)

        # Precision
        axes[1, 0].plot(self.history.history['precision'], label='Train', color='blue')
        axes[1, 0].plot(self.history.history['val_precision'], label='Val', color='orange')
        axes[1, 0].set_title('Precision')
        axes[1, 0].legend()
        axes[1, 0].grid(True, alpha=0.3)

        # Recall
        axes[1, 1].plot(self.history.history['recall'], label='Train', color='blue')
        axes[1, 1].plot(self.history.history['val_recall'], label='Val', color='orange')
        axes[1, 1].set_title('Recall')
        axes[1, 1].legend()
        axes[1, 1].grid(True, alpha=0.3)

        plt.tight_layout()
        plt.show()


class EnsembleLungClassifier:
    """Ensemble approach for robust predictions"""

    def __init__(self, input_shape=(1, 959), num_classes=3):
        self.input_shape = input_shape
        self.num_classes = num_classes
        self.models = []
        self.model_types = ['simple', 'attention', 'lstm']

    def train_ensemble(self, x_train, y_train, x_val, y_val, epochs=60):
        """Train ensemble of diverse models"""
        print("üéØ Training Ensemble of Models...")

        for model_type in self.model_types:
            print(f"\nüîÑ Training {model_type} model...")

            classifier = ImprovedLungSoundClassifier(self.input_shape, self.num_classes)
            history = classifier.train_with_data_augmentation(
                x_train, y_train, x_val, y_val,
                model_type=model_type,
                epochs=epochs,
                batch_size=32
            )

            self.models.append(classifier.model)
            print(f"‚úÖ {model_type} model trained!")

        return self.models

    def predict_ensemble(self, x_test):
        """Make ensemble predictions"""
        if not self.models:
            print("‚ùå No models trained!")
            return None

        predictions = []
        for model in self.models:
            pred = model.predict(x_test, verbose=0)
            predictions.append(pred)

        # Average predictions
        ensemble_pred = np.mean(predictions, axis=0)
        return ensemble_pred

    def evaluate_ensemble(self, x_test, y_test):
        """Evaluate ensemble performance"""
        ensemble_pred = self.predict_ensemble(x_test)
        if ensemble_pred is None:
            return None

        y_pred = np.argmax(ensemble_pred, axis=1)
        y_true = np.argmax(y_test, axis=1)

        # Calculate metrics
        from sklearn.metrics import accuracy_score, precision_score, recall_score

        accuracy = accuracy_score(y_true, y_pred)
        precision = precision_score(y_true, y_pred, average='weighted')
        recall = recall_score(y_true, y_pred, average='weighted')
        f1 = f1_score(y_true, y_pred, average='weighted')

        print(f"\nüéØ Ensemble Performance:")
        print(f"   ‚Ä¢ Accuracy: {accuracy:.4f}")
        print(f"   ‚Ä¢ Precision: {precision:.4f}")
        print(f"   ‚Ä¢ Recall: {recall:.4f}")
        print(f"   ‚Ä¢ F1-Score: {f1:.4f}")

        return {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1_score': f1,
            'predictions': ensemble_pred
        }


# Improved Training Pipeline
def train_improved_lung_classifier(x_train, y_train, x_val, y_val, x_test, y_test,
                                 use_ensemble=False):
    """
    Improved training pipeline focused on generalization

    Key improvements:
    1. Simpler architectures to reduce overfitting
    2. Data augmentation for better generalization
    3. Conservative training settings
    4. Ensemble option for robust predictions
    """

    print("üéØ Improved Lung Sound Classification Pipeline")
    print("Focus: Better generalization on unseen data")
    print("=" * 60)

    if use_ensemble:
        # Train ensemble
        ensemble = EnsembleLungClassifier()
        models = ensemble.train_ensemble(x_train, y_train, x_val, y_val)
        results = ensemble.evaluate_ensemble(x_test, y_test)
        return ensemble, results
    else:
        # Train single improved model
        classifier = ImprovedLungSoundClassifier()

        # Try simple model first
        history = classifier.train_with_data_augmentation(
            x_train, y_train, x_val, y_val,
            model_type='attention ',
            epochs=150,
            batch_size=32
        )

        # Evaluate
        results = classifier.evaluate_model(x_test, y_test)

        # Plot training curves
        classifier.plot_training_history()

        return classifier, results


# Usage instructions
def usage_example():
    """How to use the improved classifier"""
    print("\nüìù Usage Example:")
    print("# For single improved model:")
    classifier, results = train_improved_lung_classifier(
      x_train_gru, y_train_gru, x_val_gru, y_val_gru, x_test_gru, y_test_gru)
    print(")")
    print("\n# For ensemble approach (better but slower):")
    # classifier, results = train_improved_lung_classifier(
    # x_train_gru, y_train_gru, x_val_gru, y_val_gru, x_test_gru, y_test_gru,)
    print("    use_ensemble=True")
    print(")")

    print("\nüí° Key Improvements:")
    print("‚Ä¢ Simpler architecture to prevent overfitting")
    print("‚Ä¢ Data augmentation for better generalization")
    print("‚Ä¢ Conservative training with early stopping")
    print("‚Ä¢ Ensemble option for robust predictions")
    print("‚Ä¢ Better regularization strategies")

if __name__ == "__main__":
    usage_example()

# üß†üìâ Model 2 ‚Äì LSTM-Based Lung Sound Classification

This notebook uses a **Long Short-Term Memory (LSTM)** model for classifying lung sound recordings into disease categories. LSTM networks are powerful in handling **sequential and temporal data**, making them ideal for analyzing respiratory sound signals.

---

## üîç Classification Targets:
- ‚úÖ **Normal**
- üå¨Ô∏è **Asthma**
- üòÆ‚Äçüí® **COPD**

---

## üí° Why LSTM?
- üîÅ Remembers patterns over time (important for wheezes/crackles)
- ‚è≥ Excellent for sequential signal modeling
- üß† Lightweight compared to CNN+Attention hybrids
- ‚úÖ Proven for physiological time-series like ECG, audio, etc.

---

In [None]:
# Improved Lung Sound Classification - Focus on Generalization
# Addresses overfitting issues for better unseen data performance

import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, Dense, Dropout, BatchNormalization, GRU, LSTM,
    Bidirectional, GlobalAveragePooling1D, GlobalMaxPooling1D,
    Concatenate, Conv1D, LeakyReLU, SpatialDropout1D,
    MultiHeadAttention, LayerNormalization
)
from tensorflow.keras.regularizers import l2
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import (
    EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
)
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report, confusion_matrix, f1_score
import matplotlib.pyplot as plt
import seaborn as sns

class ImprovedLungSoundClassifier:
    """
    Improved Neural Network for Lung Sound Classification
    Focus: Better generalization on unseen data
    """

    def __init__(self, input_shape=(1, 959), num_classes=3):
        self.input_shape = input_shape
        self.num_classes = num_classes
        self.model = None
        self.history = None
        self.class_weights = None

    def create_simple_model(self):
        """Create a simpler, more generalizable model"""
        inputs = Input(shape=self.input_shape, name='lung_sound_input')
        x = inputs

        # Simple feature extraction with lighter regularization
        x = Conv1D(filters=32, kernel_size=7, padding='same')(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.1)(x)
        x = SpatialDropout1D(0.2)(x)

        # Single RNN layer to reduce complexity
        x = Bidirectional(
            GRU(64, return_sequences=True, dropout=0.3, recurrent_dropout=0.3)
        )(x)
        x = BatchNormalization()(x)

        # Global pooling
        avg_pool = GlobalAveragePooling1D()(x)
        max_pool = GlobalMaxPooling1D()(x)
        x = Concatenate()([avg_pool, max_pool])

        # Simpler dense layers
        x = Dense(64, kernel_regularizer=l2(0.001))(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.1)(x)
        x = Dropout(0.5)(x)

        # Output layer
        outputs = Dense(self.num_classes, activation='softmax', name='output')(x)

        model = Model(inputs=inputs, outputs=outputs, name='SimpleLungClassifier')
        return model

    def create_attention_model(self):
        """Create attention-based model for better feature learning"""
        inputs = Input(shape=self.input_shape, name='lung_sound_input')
        x = inputs

        # Light conv preprocessing
        x = Conv1D(filters=32, kernel_size=5, padding='same')(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.1)(x)
        x = SpatialDropout1D(0.2)(x)

        # RNN processing
        x = Bidirectional(GRU(48, return_sequences=True, dropout=0.2, recurrent_dropout=0.2))(x)
        x = LayerNormalization()(x)

        # Self-attention mechanism
        attention = MultiHeadAttention(
            num_heads=4,
            key_dim=48,
            dropout=0.2
        )(x, x)
        x = LayerNormalization()(x + attention)  # Residual connection

        # Global pooling
        x = GlobalAveragePooling1D()(x)

        # Classification head
        x = Dense(32, kernel_regularizer=l2(0.001))(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.1)(x)
        x = Dropout(0.4)(x)

        outputs = Dense(self.num_classes, activation='softmax', name='output')(x)

        model = Model(inputs=inputs, outputs=outputs, name='AttentionLungClassifier')
        return model

    def create_ensemble_ready_model(self, model_variant='lstm'):
        """Create different model variants for ensemble"""
        if model_variant == 'simple':
            return self.create_simple_model()
        elif model_variant == 'attention':
            return self.create_attention_model()
        elif model_variant == 'lstm':
            return self.create_lstm_model()
        else:
            return self.create_simple_model()

    def create_lstm_model(self):
        """LSTM variant for ensemble diversity"""
        inputs = Input(shape=self.input_shape, name='lung_sound_input')
        x = inputs

        # Feature extraction
        x = Conv1D(filters=32, kernel_size=3, padding='same')(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.1)(x)
        x = SpatialDropout1D(0.2)(x)

        # LSTM layers
        x = Bidirectional(LSTM(48, return_sequences=True, dropout=0.3, recurrent_dropout=0.3))(x)
        x = BatchNormalization()(x)

        # Pooling
        x = GlobalAveragePooling1D()(x)

        # Classification
        x = Dense(48, kernel_regularizer=l2(0.001))(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.1)(x)
        x = Dropout(0.4)(x)

        outputs = Dense(self.num_classes, activation='softmax', name='output')(x)

        model = Model(inputs=inputs, outputs=outputs, name='LSTMLungClassifier')
        return model

    def compile_model(self, model, learning_rate=0.0005):
        """Compile with conservative settings for better generalization"""
        optimizer = Adam(
            learning_rate=learning_rate,
            clipnorm=1.0
        )

        model.compile(
            optimizer=optimizer,
            loss='categorical_crossentropy',
            metrics=['accuracy', 'precision', 'recall']
        )

        return model

    def compute_class_weights(self, y_train):
        """Compute balanced class weights"""
        y_indices = np.argmax(y_train, axis=1)
        classes = np.unique(y_indices)
        class_weights = compute_class_weight('balanced', classes=classes, y=y_indices)
        self.class_weights = dict(zip(classes, class_weights))

        print(f"üìä Class weights:")
        class_names = ['Healthy', 'Asthma', 'COPD']
        for i, weight in self.class_weights.items():
            print(f"   ‚Ä¢ {class_names[i]}: {weight:.3f}")

        return self.class_weights

    def create_callbacks(self, model_name='improved_lung_model'):
        """Conservative callbacks for better generalization"""
        callbacks = [
            EarlyStopping(
                monitor='val_loss',
                patience=15,  # Shorter patience to prevent overfitting
                restore_best_weights=True,
                verbose=1,
                min_delta=0.001
            ),
            ReduceLROnPlateau(
                monitor='val_loss',
                factor=0.5,
                patience=5,  # Reduce LR more aggressively
                min_lr=1e-6,
                verbose=1
            ),
            ModelCheckpoint(
                filepath=f'{model_name}_best.keras',
                monitor='val_accuracy',
                save_best_only=True,
                verbose=1,
                mode='max'
            )
        ]

        return callbacks

    def train_with_data_augmentation(self, x_train, y_train, x_val, y_val,
                                   model_type='lstm', epochs=80, batch_size=32):
        """Train with data augmentation for better generalization"""

        print(f"üöÄ Training {model_type} model with data augmentation...")
        print(f"   ‚Ä¢ Training samples: {x_train.shape[0]}")
        print(f"   ‚Ä¢ Validation samples: {x_val.shape[0]}")

        # Data augmentation
        x_train_aug, y_train_aug = self.augment_data(x_train, y_train)
        print(f"   ‚Ä¢ Augmented training samples: {x_train_aug.shape[0]}")

        # Create model
        self.model = self.create_ensemble_ready_model(model_type)
        self.model = self.compile_model(self.model)

        # Compute class weights
        class_weights = self.compute_class_weights(y_train_aug)

        # Callbacks
        callbacks = self.create_callbacks(f'{model_type}_lung_model')

        print(f"\nüèóÔ∏è Model Architecture ({model_type}):")
        print(f"   ‚Ä¢ Total parameters: {self.model.count_params():,}")

        # Train
        self.history = self.model.fit(
            x_train_aug, y_train_aug,
            validation_data=(x_val, y_val),
            epochs=epochs,
            batch_size=batch_size,
            callbacks=callbacks,
            class_weight=class_weights,
            verbose=1,
            shuffle=True
        )

        print("‚úÖ Training completed!")
        return self.history

    def augment_data(self, x_data, y_data, augment_factor=0.5):
        """Simple data augmentation techniques"""
        augmented_x = []
        augmented_y = []

        # Original data
        augmented_x.append(x_data)
        augmented_y.append(y_data)

        n_augment = int(len(x_data) * augment_factor)
        indices = np.random.choice(len(x_data), n_augment, replace=True)

        for idx in indices:
            sample = x_data[idx].copy()
            label = y_data[idx].copy()

            # Random noise addition (5% of signal std)
            noise_level = 0.05 * np.std(sample)
            sample += np.random.normal(0, noise_level, sample.shape)

            # Random scaling (¬±10%)
            scale_factor = np.random.uniform(0.9, 1.1)
            sample *= scale_factor

            augmented_x.append(sample[np.newaxis, :])
            augmented_y.append(label[np.newaxis, :])

        return np.vstack(augmented_x), np.vstack(augmented_y)

    def evaluate_model(self, x_test, y_test):
        """Comprehensive evaluation"""
        if self.model is None:
            print("‚ùå Model not trained yet!")
            return

        # Predictions
        y_pred_proba = self.model.predict(x_test, verbose=0)
        y_pred = np.argmax(y_pred_proba, axis=1)
        y_true = np.argmax(y_test, axis=1)

        # Metrics
        test_loss, test_acc, test_prec, test_rec = self.model.evaluate(x_test, y_test, verbose=0)
        test_f1 = f1_score(y_true, y_pred, average='weighted')

        print(f"\nüìä Test Performance:")
        print(f"   ‚Ä¢ Accuracy: {test_acc:.4f}")
        print(f"   ‚Ä¢ Precision: {test_prec:.4f}")
        print(f"   ‚Ä¢ Recall: {test_rec:.4f}")
        print(f"   ‚Ä¢ F1-Score: {test_f1:.4f}")

        # Per-class metrics
        class_names = ['Healthy', 'Asthma', 'COPD']
        print(f"\nüìã Classification Report:")
        print(classification_report(y_true, y_pred, target_names=class_names))

        # Confusion matrix
        cm = confusion_matrix(y_true, y_pred)
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                   xticklabels=class_names, yticklabels=class_names)
        plt.title('Confusion Matrix')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.tight_layout()
        plt.show()

        return {
            'accuracy': test_acc,
            'precision': test_prec,
            'recall': test_rec,
            'f1_score': test_f1,
            'predictions': y_pred_proba,
            'confusion_matrix': cm
        }

    def plot_training_history(self):
        """Plot training curves to check for overfitting"""
        if self.history is None:
            print("‚ùå No training history!")
            return

        fig, axes = plt.subplots(2, 2, figsize=(12, 8))

        # Accuracy
        axes[0, 0].plot(self.history.history['accuracy'], label='Train', color='blue')
        axes[0, 0].plot(self.history.history['val_accuracy'], label='Val', color='orange')
        axes[0, 0].set_title('Accuracy')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)

        # Loss
        axes[0, 1].plot(self.history.history['loss'], label='Train', color='blue')
        axes[0, 1].plot(self.history.history['val_loss'], label='Val', color='orange')
        axes[0, 1].set_title('Loss')
        axes[0, 1].legend()
        axes[0, 1].grid(True, alpha=0.3)

        # Precision
        axes[1, 0].plot(self.history.history['precision'], label='Train', color='blue')
        axes[1, 0].plot(self.history.history['val_precision'], label='Val', color='orange')
        axes[1, 0].set_title('Precision')
        axes[1, 0].legend()
        axes[1, 0].grid(True, alpha=0.3)

        # Recall
        axes[1, 1].plot(self.history.history['recall'], label='Train', color='blue')
        axes[1, 1].plot(self.history.history['val_recall'], label='Val', color='orange')
        axes[1, 1].set_title('Recall')
        axes[1, 1].legend()
        axes[1, 1].grid(True, alpha=0.3)

        plt.tight_layout()
        plt.show()


class EnsembleLungClassifier:
    """Ensemble approach for robust predictions"""

    def __init__(self, input_shape=(1, 959), num_classes=3):
        self.input_shape = input_shape
        self.num_classes = num_classes
        self.models = []
        self.model_types = ['simple', 'attention', 'lstm']

    def train_ensemble(self, x_train, y_train, x_val, y_val, epochs=60):
        """Train ensemble of diverse models"""
        print("üéØ Training Ensemble of Models...")

        for model_type in self.model_types:
            print(f"\nüîÑ Training {model_type} model...")

            classifier = ImprovedLungSoundClassifier(self.input_shape, self.num_classes)
            history = classifier.train_with_data_augmentation(
                x_train, y_train, x_val, y_val,
                model_type=model_type,
                epochs=epochs,
                batch_size=32
            )

            self.models.append(classifier.model)
            print(f"‚úÖ {model_type} model trained!")

        return self.models

    def predict_ensemble(self, x_test):
        """Make ensemble predictions"""
        if not self.models:
            print("‚ùå No models trained!")
            return None

        predictions = []
        for model in self.models:
            pred = model.predict(x_test, verbose=0)
            predictions.append(pred)

        # Average predictions
        ensemble_pred = np.mean(predictions, axis=0)
        return ensemble_pred

    def evaluate_ensemble(self, x_test, y_test):
        """Evaluate ensemble performance"""
        ensemble_pred = self.predict_ensemble(x_test)
        if ensemble_pred is None:
            return None

        y_pred = np.argmax(ensemble_pred, axis=1)
        y_true = np.argmax(y_test, axis=1)

        # Calculate metrics
        from sklearn.metrics import accuracy_score, precision_score, recall_score

        accuracy = accuracy_score(y_true, y_pred)
        precision = precision_score(y_true, y_pred, average='weighted')
        recall = recall_score(y_true, y_pred, average='weighted')
        f1 = f1_score(y_true, y_pred, average='weighted')

        print(f"\nüéØ Ensemble Performance:")
        print(f"   ‚Ä¢ Accuracy: {accuracy:.4f}")
        print(f"   ‚Ä¢ Precision: {precision:.4f}")
        print(f"   ‚Ä¢ Recall: {recall:.4f}")
        print(f"   ‚Ä¢ F1-Score: {f1:.4f}")

        return {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1_score': f1,
            'predictions': ensemble_pred
        }


# Improved Training Pipeline
def train_improved_lung_classifier(x_train, y_train, x_val, y_val, x_test, y_test,
                                 use_ensemble=False):
    """
    Improved training pipeline focused on generalization

    Key improvements:
    1. Simpler architectures to reduce overfitting
    2. Data augmentation for better generalization
    3. Conservative training settings
    4. Ensemble option for robust predictions
    """

    print("üéØ Improved Lung Sound Classification Pipeline")
    print("Focus: Better generalization on unseen data")
    print("=" * 60)

    if use_ensemble:
        # Train ensemble
        ensemble = EnsembleLungClassifier()
        models = ensemble.train_ensemble(x_train, y_train, x_val, y_val)
        results = ensemble.evaluate_ensemble(x_test, y_test)
        return ensemble, results
    else:
        # Train single improved model
        classifier = ImprovedLungSoundClassifier()

        # Try simple model first
        history = classifier.train_with_data_augmentation(
            x_train, y_train, x_val, y_val,
            model_type='lstm',
            epochs=150,
            batch_size=32
        )

        # Evaluate
        results = classifier.evaluate_model(x_test, y_test)

        # Plot training curves
        classifier.plot_training_history()

        return classifier, results


# Usage instructions
def usage_example():
    """How to use the improved classifier"""
    print("\nüìù Usage Example:")
    print("# For single improved model:")
    classifier, results = train_improved_lung_classifier(
      x_train_gru, y_train_gru, x_val_gru, y_val_gru, x_test_gru, y_test_gru)
    print(")")
    print("\n# For ensemble approach (better but slower):")
    # classifier, results = train_improved_lung_classifier(
    # x_train_gru, y_train_gru, x_val_gru, y_val_gru, x_test_gru, y_test_gru,)
    print("    use_ensemble=True")
    print(")")

    print("\nüí° Key Improvements:")
    print("‚Ä¢ Simpler architecture to prevent overfitting")
    print("‚Ä¢ Data augmentation for better generalization")
    print("‚Ä¢ Conservative training with early stopping")
    print("‚Ä¢ Ensemble option for robust predictions")
    print("‚Ä¢ Better regularization strategies")

if __name__ == "__main__":
    usage_example()

In [None]:
# prompt: plot classification using plotly for all the three appoach of model2  as we had already trained above

from itertools import cycle
from sklearn.metrics import roc_curve, auc, confusion_matrix, classification_report, accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from sklearn.utils.class_weight import compute_class_weight
import tensorflow as tf
from tensorflow.keras.models import load_model, Model
from tensorflow.keras.layers import (
    Input, Dense, Dropout, BatchNormalization, GRU, LSTM,
    Bidirectional, GlobalAveragePooling1D, GlobalMaxPooling1D,
    Concatenate, Conv1D, LeakyReLU, SpatialDropout1D,
    MultiHeadAttention, LayerNormalization
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l2
from tensorflow.keras.callbacks import (
    EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
)
from tensorflow.keras.utils import register_keras_serializable


# Define the custom metric if not already defined
@register_keras_serializable()
def weighted_categorical_accuracy(y_true, y_pred):
    """Custom weighted categorical accuracy metric"""
    return tf.keras.metrics.categorical_accuracy(y_true, y_pred)

# Re-define the ImprovedLungSoundClassifier class if necessary (based on the last code block)
# Or assume it's already defined and accessible in the environment
class ImprovedLungSoundClassifier:
    """
    Improved Neural Network for Lung Sound Classification
    Focus: Better generalization on unseen data
    """

    def __init__(self, input_shape=(1, 959), num_classes=3):
        self.input_shape = input_shape
        self.num_classes = num_classes
        self.model = None
        self.history = None
        self.class_weights = None

    def create_simple_model(self):
        inputs = Input(shape=self.input_shape, name='lung_sound_input')
        x = inputs
        x = Conv1D(filters=32, kernel_size=7, padding='same')(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.1)(x)
        x = SpatialDropout1D(0.2)(x)
        x = Bidirectional(
            GRU(64, return_sequences=True, dropout=0.3, recurrent_dropout=0.3)
        )(x)
        x = BatchNormalization()(x)
        avg_pool = GlobalAveragePooling1D()(x)
        max_pool = GlobalMaxPooling1D()(x)
        x = Concatenate()([avg_pool, max_pool])
        x = Dense(64, kernel_regularizer=l2(0.001))(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.1)(x)
        x = Dropout(0.5)(x)
        outputs = Dense(self.num_classes, activation='softmax', name='output')(x)
        model = Model(inputs=inputs, outputs=outputs, name='SimpleLungClassifier')
        return model

    def create_attention_model(self):
        inputs = Input(shape=self.input_shape, name='lung_sound_input')
        x = inputs
        x = Conv1D(filters=32, kernel_size=5, padding='same')(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.1)(x)
        x = SpatialDropout1D(0.2)(x)
        x = Bidirectional(GRU(48, return_sequences=True, dropout=0.2, recurrent_dropout=0.2))(x)
        x = LayerNormalization()(x)
        attention = MultiHeadAttention(
            num_heads=4,
            key_dim=48,
            dropout=0.2
        )(x, x)
        x = LayerNormalization()(x + attention)
        x = GlobalAveragePooling1D()(x)
        x = Dense(32, kernel_regularizer=l2(0.001))(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.1)(x)
        x = Dropout(0.4)(x)
        outputs = Dense(self.num_classes, activation='softmax', name='output')(x)
        model = Model(inputs=inputs, outputs=outputs, name='AttentionLungClassifier')
        return model

    def create_lstm_model(self):
        inputs = Input(shape=self.input_shape, name='lung_sound_input')
        x = inputs
        x = Conv1D(filters=32, kernel_size=3, padding='same')(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.1)(x)
        x = SpatialDropout1D(0.2)(x)
        x = Bidirectional(LSTM(48, return_sequences=True, dropout=0.3, recurrent_dropout=0.3))(x)
        x = BatchNormalization()(x)
        x = GlobalAveragePooling1D()(x)
        x = Dense(48, kernel_regularizer=l2(0.001))(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0.1)(x)
        x = Dropout(0.4)(x)
        outputs = Dense(self.num_classes, activation='softmax', name='output')(x)
        model = Model(inputs=inputs, outputs=outputs, name='LSTMLungClassifier')
        return model

    def create_ensemble_ready_model(self, model_variant='simple'):
        if model_variant == 'simple':
            return self.create_simple_model()
        elif model_variant == 'attention':
            return self.create_attention_model()
        elif model_variant == 'lstm':
            return self.create_lstm_model()
        else:
            return self.create_simple_model()

    def compile_model(self, model, learning_rate=0.0005):
        optimizer = Adam(
            learning_rate=learning_rate,
            clipnorm=1.0
        )
        model.compile(
            optimizer=optimizer,
            loss='categorical_crossentropy',
            metrics=['accuracy', 'precision', 'recall']
        )
        return model

    def compute_class_weights(self, y_train):
        y_indices = np.argmax(y_train, axis=1)
        classes = np.unique(y_indices)
        class_weights = compute_class_weight('balanced', classes=classes, y=y_indices)
        self.class_weights = dict(zip(classes, class_weights))
        return self.class_weights

    def create_callbacks(self, model_name='improved_lung_model'):
        callbacks = [
            EarlyStopping(
                monitor='val_loss',
                patience=15,
                restore_best_weights=True,
                verbose=1,
                min_delta=0.001
            ),
            ReduceLROnPlateau(
                monitor='val_loss',
                factor=0.5,
                patience=5,
                min_lr=1e-6,
                verbose=1
            ),
            ModelCheckpoint(
                filepath=f'{model_name}_best.keras',
                monitor='val_accuracy',
                save_best_only=True,
                verbose=1,
                mode='max'
            )
        ]
        return callbacks

    def augment_data(self, x_data, y_data, augment_factor=0.5):
        """Simple data augmentation techniques"""
        augmented_x = []
        augmented_y = []

        # Original data
        augmented_x.append(x_data)
        augmented_y.append(y_data)

        n_augment = int(len(x_data) * augment_factor)
        indices = np.random.choice(len(x_data), n_augment, replace=True)

        for idx in indices:
            sample = x_data[idx].copy()
            label = y_data[idx].copy()

            # Random noise addition (5% of signal std)
            noise_level = 0.05 * np.std(sample)
            sample += np.random.normal(0, noise_level, sample.shape)

            # Random scaling (¬±10%)
            scale_factor = np.random.uniform(0.9, 1.1)
            sample *= scale_factor

            augmented_x.append(sample[np.newaxis, :])
            augmented_y.append(label[np.newaxis, :])

        return np.vstack(augmented_x), np.vstack(augmented_y)

    def train_with_data_augmentation(self, x_train, y_train, x_val, y_val,
                                   model_type='simple', epochs=80, batch_size=32):

        print(f"üöÄ Training {model_type} model with data augmentation...")
        print(f"   ‚Ä¢ Training samples: {x_train.shape[0]}")
        print(f"   ‚Ä¢ Validation samples: {x_val.shape[0]}")

        x_train_aug, y_train_aug = self.augment_data(x_train, y_train)
        print(f"   ‚Ä¢ Augmented training samples: {x_train_aug.shape[0]}")

        self.model = self.create_ensemble_ready_model(model_type)
        self.model = self.compile_model(self.model)

        class_weights = self.compute_class_weights(y_train_aug)

        callbacks = self.create_callbacks(f'{model_type}_lung_model')

        print(f"\nüèóÔ∏è Model Architecture ({model_type}):")
        print(f"   ‚Ä¢ Total parameters: {self.model.count_params():,}")

        self.history = self.model.fit(
            x_train_aug, y_train_aug,
            validation_data=(x_val, y_val),
            epochs=epochs,
            batch_size=batch_size,
            callbacks=callbacks,
            class_weight=class_weights,
            verbose=1,
            shuffle=True
        )

        print("‚úÖ Training completed!")
        return self.history

    def evaluate_model(self, x_test, y_test):
        if self.model is None:
            print("‚ùå Model not trained yet!")
            return

        y_pred_proba = self.model.predict(x_test, verbose=0)
        y_pred = np.argmax(y_pred_proba, axis=1)
        y_true = np.argmax(y_test, axis=1)

        test_loss, test_acc, test_prec, test_rec = self.model.evaluate(x_test, y_test, verbose=0)
        test_f1 = f1_score(y_true, y_pred, average='weighted')

        print(f"\nüìä Test Performance:")
        print(f"   ‚Ä¢ Accuracy: {test_acc:.4f}")
        print(f"   ‚Ä¢ Precision: {test_prec:.4f}")
        print(f"   ‚Ä¢ Recall: {test_rec:.4f}")
        print(f"   ‚Ä¢ F1-Score: {test_f1:.4f}")

        class_names = ['Healthy', 'Asthma', 'COPD']
        print(f"\nüìã Classification Report:")
        print(classification_report(y_true, y_pred, target_names=class_names))

        cm = confusion_matrix(y_true, y_pred)
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                   xticklabels=class_names, yticklabels=class_names)
        plt.title('Confusion Matrix')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.tight_layout()
        plt.show()

        return {
            'accuracy': test_acc,
            'precision': test_prec,
            'recall': test_rec,
            'f1_score': test_f1,
            'predictions': y_pred_proba,
            'confusion_matrix': cm
        }

    def plot_training_history(self):
        if self.history is None:
            print("‚ùå No training history!")
            return

        fig, axes = plt.subplots(2, 2, figsize=(12, 8))

        axes[0, 0].plot(self.history.history['accuracy'], label='Train', color='blue')
        axes[0, 0].plot(self.history.history['val_accuracy'], label='Val', color='orange')
        axes[0, 0].set_title('Accuracy')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)

        axes[0, 1].plot(self.history.history['loss'], label='Train', color='blue')
        axes[0, 1].plot(self.history.history['val_loss'], label='Val', color='orange')
        axes[0, 1].set_title('Loss')
        axes[0, 1].legend()
        axes[0, 1].grid(True, alpha=0.3)

        axes[1, 0].plot(self.history.history['precision'], label='Train', color='blue')
        axes[1, 0].plot(self.history.history['val_precision'], label='Val', color='orange')
        axes[1, 0].set_title('Precision')
        axes[1, 0].legend()
        axes[1, 0].grid(True, alpha=0.3)

        axes[1, 1].plot(self.history.history['recall'], label='Train', color='blue')
        axes[1, 1].plot(self.history.history['val_recall'], label='Val', color='orange')
        axes[1, 1].set_title('Recall')
        axes[1, 1].legend()
        axes[1, 1].grid(True, alpha=0.3)

        plt.tight_layout()
        plt.show()

class EnsembleLungClassifier:
    def __init__(self, input_shape=(1, 959), num_classes=3):
        self.input_shape = input_shape
        self.num_classes = num_classes
        self.models = []
        self.model_types = ['simple', 'attention', 'lstm']

    def train_ensemble(self, x_train, y_train, x_val, y_val, epochs=60):
        print("üéØ Training Ensemble of Models...")

        for model_type in self.model_types:
            print(f"\nüîÑ Training {model_type} model...")

            classifier = ImprovedLungSoundClassifier(self.input_shape, self.num_classes)
            history = classifier.train_with_data_augmentation(
                x_train, y_train, x_val, y_val,
                model_type=model_type,
                epochs=epochs,
                batch_size=32
            )

            self.models.append(classifier.model)
            print(f"‚úÖ {model_type} model trained!")

        return self.models

    def predict_ensemble(self, x_test):
        if not self.models:
            print("‚ùå No models trained!")
            return None

        predictions = []
        for model in self.models:
            pred = model.predict(x_test, verbose=0)
            predictions.append(pred)

        ensemble_pred = np.mean(predictions, axis=0)
        return ensemble_pred

    def evaluate_ensemble(self, x_test, y_test):
        ensemble_pred = self.predict_ensemble(x_test)
        if ensemble_pred is None:
            return None

        y_pred = np.argmax(ensemble_pred, axis=1)
        y_true = np.argmax(y_test, axis=1)

        accuracy = accuracy_score(y_true, y_pred)
        precision = precision_score(y_true, y_pred, average='weighted')
        recall = recall_score(y_true, y_pred, average='weighted')
        f1 = f1_score(y_true, y_pred, average='weighted')

        print(f"\nüéØ Ensemble Performance:")
        print(f"   ‚Ä¢ Accuracy: {accuracy:.4f}")
        print(f"   ‚Ä¢ Precision: {precision:.4f}")
        print(f"   ‚Ä¢ Recall: {recall:.4f}")
        print(f"   ‚Ä¢ F1-Score: {f1:.4f}")

        return {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1_score': f1,
            'predictions': ensemble_pred
        }


# Define the plotting function using Plotly
def plot_classification_plotly(model_results: Dict[str, Dict], class_names: List[str]):
    """
    Plots classification results (Confusion Matrix and Metrics) for multiple models using Plotly.

    Parameters:
    - model_results: A dictionary where keys are model names (e.g., 'Simple', 'Attention', 'LSTM')
                     and values are dictionaries containing 'confusion_matrix' and 'report' (from classification_report).
    - class_names: List of class names.
    """
    print("üìä Generating Plotly Classification Plots...")

    # Create subplots: one row for confusion matrices, one row for metrics bars
    fig = make_subplots(
        rows=2, cols=len(model_results),
        specs=[[{'type': 'heatmap'}] * len(model_results),
               [{'type': 'bar'}] * len(model_results)],
        subplot_titles=[f'Confusion Matrix: {name}' for name in model_results.keys()] +
                       [f'Metrics: {name}' for name in model_results.keys()],
        vertical_spacing=0.1,
        horizontal_spacing=0.05
    )

    row_cm = 1
    row_metrics = 2
    col = 1

    for model_name, results in model_results.items():
        cm = results.get('confusion_matrix')
        report_str = results.get('report') # Get the string output from classification_report

        if cm is not None:
            # Confusion Matrix Plotly
            cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
            cm_text = np.array([[f'{cm[i,j]}\n({cm_percent[i,j]*100:.1f}%)' for j in range(len(class_names))] for i in range(len(class_names))])

            heatmap = go.Heatmap(
                z=cm_percent,
                x=class_names,
                y=class_names,
                colorscale='Blues',
                colorbar=dict(title='%'),
                text=cm_text,
                texttemplate="%{text}",
                hovertemplate='True: %{y}<br>Predicted: %{x}<br>Count: %{text:.0f}<br>Accuracy: %{z:.1%}<extra></extra>'
            )
            fig.add_trace(heatmap, row=row_cm, col=col)

        if report_str:
            # Parse classification report string
            report_data = {}
            lines = report_str.split('\n')
            # Find the line after which class-wise metrics start (usually after headers)
            data_lines = [line for line in lines if line.strip() and not line.startswith(' ')]
            header_end_index = 0
            for i, line in enumerate(lines):
                 if "precision" in line and "recall" in line and "f1-score" in line:
                    header_end_index = i + 1
                    break

            metrics_lines = lines[header_end_index:]
            # Filter out empty lines and support line
            metrics_lines = [line for line in metrics_lines if line.strip() and not line.strip().startswith('support')]

            for line in metrics_lines:
                 parts = line.split()
                 if len(parts) >= 4: # Expect class, precision, recall, f1-score
                     class_label = parts[0]
                     if class_label in class_names:
                         try:
                             precision_val = float(parts[1])
                             recall_val = float(parts[2])
                             f1_val = float(parts[3])
                             report_data[class_label] = {'precision': precision_val, 'recall': recall_val, 'f1-score': f1_val}
                         except ValueError:
                             continue # Skip if parsing fails

            # Metrics Bar Chart Plotly
            metrics = ['precision', 'recall', 'f1-score']
            class_metrics = {metric: [] for metric in metrics}

            for class_name in class_names:
                if class_name in report_data:
                     for metric in metrics:
                         class_metrics[metric].append(report_data[class_name][metric])
                else: # Handle cases where a class might have no predictions
                     for metric in metrics:
                         class_metrics[metric].append(0.0) # Append 0 or NaN if appropriate


            fig.add_trace(go.Bar(
                x=class_names,
                y=class_metrics['precision'],
                name='Precision',
                marker_color='rgba(58,200,225,.5)',
                hovertemplate='Class: %{x}<br>Precision: %{y:.2f}<extra></extra>'
            ), row=row_metrics, col=col)

            fig.add_trace(go.Bar(
                x=class_names,
                y=class_metrics['recall'],
                name='Recall',
                marker_color='rgba(200,58,58,.5)',
                 hovertemplate='Class: %{x}<br>Recall: %{y:.2f}<extra></extra>'
            ), row=row_metrics, col=col)

            fig.add_trace(go.Bar(
                x=class_names,
                y=class_metrics['f1-score'],
                name='F1-Score',
                marker_color='rgba(58,200,58,.5)',
                 hovertemplate='Class: %{x}<br>F1-Score: %{y:.2f}<extra></extra>'
            ), row=row_metrics, col=col)

        # Update layout for the column
        fig.update_yaxes(title_text="True Label", row=row_cm, col=col)
        fig.update_xaxes(title_text="Predicted Label", row=row_cm, col=col)
        fig.update_yaxes(title_text="Score", range=[0, 1], row=row_metrics, col=col)
        fig.update_xaxes(title_text="Class", row=row_metrics, col=col)


        col += 1 # Move to the next column for the next model

    fig.update_layout(
        title_text="Model 2 Classification Performance Comparison (Plotly)",
        height=800,
        showlegend=True # Show one legend for all metric bars
    )

    # Adjust legend positioning to avoid overlap
    fig.update_layout(legend=dict(orientation="h", yanchor="bottom", y=-0.1, xanchor="center", x=0.5))

    fig.show()


# Assume x_test_gru and y_test_gru are available from previous steps
# Assume 'classifier' (the trained ensemble object) is also available

# Re-train models to get individual classifiers and their test results
print("\n retraining individual models for comparison...")

# Simple Model
simple_classifier = ImprovedLungSoundClassifier()
simple_classifier.train_with_data_augmentation(
    x_train_gru, y_train_gru, x_val_gru, y_val_gru,
    model_type='simple', epochs=150, batch_size=32
)
simple_results = simple_classifier.evaluate_model(x_test_gru, y_test_gru)
# Capture the classification report string
y_true_simple = np.argmax(y_test_gru, axis=1)
y_pred_proba_simple = simple_classifier.model.predict(x_test_gru, verbose=0)
y_pred_simple = np.argmax(y_pred_proba_simple, axis=1)
simple_report_str = classification_report(y_true_simple, y_pred_simple, target_names=["Healthy", "Asthma", "COPD"])
simple_results['report'] = simple_report_str


# Attention Model
attention_classifier = ImprovedLungSoundClassifier()
attention_classifier.train_with_data_augmentation(
    x_train_gru, y_train_gru, x_val_gru, y_val_gru,
    model_type='attention', epochs=150, batch_size=32
)
attention_results = attention_classifier.evaluate_model(x_test_gru, y_test_gru)
# Capture the classification report string
y_true_attention = np.argmax(y_test_gru, axis=1)
y_pred_proba_attention = attention_classifier.model.predict(x_test_gru, verbose=0)
y_pred_attention = np.argmax(y_pred_proba_attention, axis=1)
attention_report_str = classification_report(y_true_attention, y_pred_attention, target_names=["Healthy", "Asthma", "COPD"])
attention_results['report'] = attention_report_str


# LSTM Model
lstm_classifier = ImprovedLungSoundClassifier()
lstm_classifier.train_with_data_augmentation(
    x_train_gru, y_train_gru, x_val_gru, y_val_gru,
    model_type='lstm', epochs=150, batch_size=32
)
lstm_results = lstm_classifier.evaluate_model(x_test_gru, y_test_gru)
# Capture the classification report string
y_true_lstm = np.argmax(y_test_gru, axis=1)
y_pred_proba_lstm = lstm_classifier.model.predict(x_test_gru, verbose=0)
y_pred_lstm = np.argmax(y_pred_proba_lstm, axis=1)
lstm_report_str = classification_report(y_true_lstm, y_pred_lstm, target_names=["Healthy", "Asthma", "COPD"])
lstm_results['report'] = lstm_report_str


# Prepare data for Plotly
model2_comparison_results = {
    'Simple Model': simple_results,
    'Attention Model': attention_results,
    'LSTM Model': lstm_results,
}

class_names = ["Healthy", "Asthma", "COPD"]

# Generate Plotly classification plots
plot_classification_plotly(model2_comparison_results, class_names)

