In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import librosa
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
from tqdm.notebook import tqdm
import glob
import re

# Define dataset path
DATASET_PATH = '/content/ICBHI_dataset'
DIAGNOSIS_CSV = '/content/patient_diagnosis.csv'

# Set parameters for audio processing
SAMPLE_RATE = 22050
DURATION = 5  # seconds
N_MELS = 128
HOP_LENGTH = 512

# Create a class mapping for diagnosis
class_mapping = {
    'Healthy': 0,
    'COPD': 1,
    'Asthma': 1,
    'Bronchiectasis': 1,
    'Bronchiolitis': 1,
    'LRTI': 2,
    'Pneumonia': 2,
    'URTI': 3
}

def load_diagnosis_data():
    """Load patient diagnosis data from CSV file"""
    try:
        diagnosis_df = pd.read_csv(DIAGNOSIS_CSV, header=None, names=['patient_id', 'diagnosis'])
        print(f"Loaded diagnosis data for {len(diagnosis_df)} patients")

        print("\nDiagnosis distribution in CSV:")
        diagnosis_counts = diagnosis_df['diagnosis'].value_counts()
        print(diagnosis_counts)

        diagnosis_dict = dict(zip(diagnosis_df['patient_id'].astype(str), diagnosis_df['diagnosis']))

        return diagnosis_dict
    except Exception as e:
        print(f"Error loading diagnosis data: {e}")
        return None

def find_wav_files():
    """Find all WAV files in the dataset directory"""
    wav_files = glob.glob(os.path.join(DATASET_PATH, '**', '*.wav'), recursive=True)
    print(f"Found {len(wav_files)} WAV files")
    return wav_files

def extract_patient_id_from_filename(filename):
    """Extract patient ID from WAV filename"""
    # Common patterns for ICBHI dataset filenames
    patterns = [
        # Try to extract a numeric ID that matches those in the CSV
        r'(\d{3})_',  # Format: 101_something.wav
        r'p(\d{3})',  # Format: p101_something.wav
        r'patient(\d{3})',  # Format: patient101_something.wav
        r'_(\d{3})\.wav'  # Format: something_101.wav
    ]

    for pattern in patterns:
        match = re.search(pattern, os.path.basename(filename))
        if match:
            return match.group(1)

    # If no pattern matches, try to get the first 3 digits from the filename
    digits = re.findall(r'\d{3}', os.path.basename(filename))
    if digits:
        return digits[0]

    return None

def build_dataset(wav_files, diagnosis_dict):
    """Build dataset from WAV files and diagnosis dictionary"""
    annotations = []
    missing_diagnosis = []

    print("Matching WAV files with diagnosis data...")
    for wav_file in tqdm(wav_files):
        # Extract patient ID from the filename
        patient_id = extract_patient_id_from_filename(wav_file)

        if patient_id and patient_id in diagnosis_dict:
            diagnosis = diagnosis_dict[patient_id]

            # Check if the diagnosis is in our class mapping
            if diagnosis in class_mapping:
                annotations.append({
                    'filename': wav_file,
                    'patient_id': patient_id,
                    'diagnosis': diagnosis,
                    'diagnosis_class': class_mapping[diagnosis]
                })
            else:
                print(f"Warning: Unrecognized diagnosis '{diagnosis}' for patient {patient_id}")
        else:
            missing_diagnosis.append(wav_file)

    if missing_diagnosis:
        print(f"\n{len(missing_diagnosis)} WAV files could not be matched with a diagnosis.")
        print("First few unmatched files:")
        for file in missing_diagnosis[:5]:
            print(f"  - {os.path.basename(file)}")

    if not annotations:
        print("\nNo files could be matched with diagnosis information.")
        return None

    result_df = pd.DataFrame(annotations)

    print(f"\nSuccessfully matched {len(result_df)} WAV files with diagnosis information")
    print("\nDiagnosis distribution in dataset:")
    diagnosis_counts = result_df['diagnosis'].value_counts()
    print(diagnosis_counts)

    return result_df

def extract_features(file_path):
    """Extract Mel spectrogram features from an audio file"""
    try:
        y, sr = librosa.load(file_path, sr=SAMPLE_RATE)

        # If audio is too long, take a segment
        if len(y) > SAMPLE_RATE * DURATION:
            y = y[:int(SAMPLE_RATE * DURATION)]
        else:
            # Pad with zeros if audio is too short
            pad_length = int(SAMPLE_RATE * DURATION) - len(y)
            if pad_length > 0:
                y = np.pad(y, (0, pad_length), mode='constant')

        # Generate Mel spectrogram
        mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=N_MELS, hop_length=HOP_LENGTH)
        mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)

        # Normalize
        mel_spec_db = (mel_spec_db - mel_spec_db.min()) / (mel_spec_db.max() - mel_spec_db.min())

        return mel_spec_db
    except Exception as e:
        print(f"Error processing {file_path}: {e}")
        return None

def prepare_dataset(annotations_df):
    """Prepare dataset for training by extracting features and labels"""
    if annotations_df is None or len(annotations_df) == 0:
        raise ValueError("No valid annotations found")

    X = []
    y = []
    filenames = []
    patient_ids = []

    for idx, row in tqdm(annotations_df.iterrows(), total=len(annotations_df), desc="Processing audio"):
        file_path = row['filename']
        if os.path.exists(file_path):
            features = extract_features(file_path)
            if features is not None:
                X.append(features)
                y.append(row['diagnosis_class'])
                filenames.append(file_path)
                patient_ids.append(row['patient_id'])

    if not X or not y:
        raise ValueError("No valid features extracted")

    # Convert to numpy arrays
    X = np.array(X)
    y = np.array(y)

    # Reshape for CNN input (samples, height, width, channels)
    X = X.reshape(X.shape[0], X.shape[1], X.shape[2], 1)

    # Convert labels to categorical
    num_classes = len(np.unique(y))
    y = to_categorical(y, num_classes=num_classes)

    return X, y, filenames, patient_ids

def ensure_patient_stratification(annotations_df):
    """
    Ensure that files from the same patient don't appear in both training and test sets
    by creating patient-based splits with handling for rare diagnoses
    """
    # Group by patient_id and diagnosis
    patient_groups = annotations_df.groupby(['patient_id', 'diagnosis']).size().reset_index(name='count')

    # Get diagnosis counts per patient
    diagnosis_counts = patient_groups.groupby('diagnosis').size()

    # Identify rare diagnoses (those with only 1 patient)
    rare_diagnoses = diagnosis_counts[diagnosis_counts < 2].index.tolist()
    print(f"Rare diagnoses with only one patient: {rare_diagnoses}")

    # For patients with rare diagnoses, we'll handle them separately
    common_patients = patient_groups[~patient_groups['diagnosis'].isin(rare_diagnoses)]['patient_id'].unique()
    rare_patients = patient_groups[patient_groups['diagnosis'].isin(rare_diagnoses)]['patient_id'].unique()

    # If we have patients with common diagnoses, do a stratified split on those
    if len(common_patients) > 0:
        common_patient_diagnoses = patient_groups[~patient_groups['diagnosis'].isin(rare_diagnoses)]
        patient_train, patient_test = train_test_split(
            common_patients,
            test_size=0.2,
            random_state=42,
            stratify=common_patient_diagnoses.groupby('patient_id')['diagnosis'].first().values
        )

        # Add rare diagnosis patients to training set only
        patient_train = np.concatenate([patient_train, rare_patients])
    else:
        # If there are no common diagnoses (unlikely), just do a regular split
        patient_train, patient_test = train_test_split(
            patient_groups['patient_id'].unique(),
            test_size=0.2,
            random_state=42
        )

    # Create masks for the original dataframe
    train_mask = annotations_df['patient_id'].isin(patient_train)
    test_mask = annotations_df['patient_id'].isin(patient_test)

    train_df = annotations_df[train_mask]
    test_df = annotations_df[test_mask]

    print(f"Training set: {len(train_df)} files from {len(patient_train)} patients")
    print(f"Test set: {len(test_df)} files from {len(patient_test)} patients")

    # Verify class distribution
    print("\nDiagnosis distribution in training set:")
    print(train_df['diagnosis'].value_counts())
    print("\nDiagnosis distribution in test set:")
    print(test_df['diagnosis'].value_counts())

    return train_df, test_df

def build_model(input_shape, num_classes):
    """Build a CNN model for audio classification"""
    model = models.Sequential([
        layers.Conv2D(32, (3, 3), activation='relu', input_shape=input_shape),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(128, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Dropout(0.3),
        layers.Flatten(),
        layers.Dense(128, activation='relu'),
        layers.Dropout(0.5),
        layers.Dense(num_classes, activation='softmax')
    ])

    model.compile(
        optimizer='adam',
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )

    return model

def plot_training_history(history):
    """Plot training and validation metrics"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    # Plot accuracy
    ax1.plot(history.history['accuracy'], label='Training Accuracy')
    ax1.plot(history.history['val_accuracy'], label='Validation Accuracy')
    ax1.set_title('Model Accuracy')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Accuracy')
    ax1.legend()

    # Plot loss
    ax2.plot(history.history['loss'], label='Training Loss')
    ax2.plot(history.history['val_loss'], label='Validation Loss')
    ax2.set_title('Model Loss')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Loss')
    ax2.legend()

    plt.tight_layout()
    plt.show()

def plot_confusion_matrix(y_true, y_pred, class_names):
    """Plot confusion matrix"""
    cm = confusion_matrix(np.argmax(y_true, axis=1), np.argmax(y_pred, axis=1))
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.show()

def visualize_samples(X, y, filenames, patient_ids, class_names, n_samples=3):
    """Visualize some sample spectrograms from the dataset"""
    plt.figure(figsize=(15, n_samples * 3))

    # Get indices for each class
    class_indices = {}
    y_classes = np.argmax(y, axis=1)

    for i, class_idx in enumerate(range(len(class_names))):
        indices = np.where(y_classes == class_idx)[0]
        if len(indices) > 0:
            class_indices[class_idx] = indices

    # Plot spectrograms for each class
    plot_idx = 1
    for class_idx, indices in class_indices.items():
        # Pick n_samples random samples from this class
        sample_indices = np.random.choice(indices, size=min(n_samples, len(indices)), replace=False)

        for idx in sample_indices:
            plt.subplot(len(class_indices), n_samples, plot_idx)
            plt.imshow(X[idx, :, :, 0], aspect='auto', origin='lower', cmap='viridis')
            plt.title(f"Class: {class_names[class_idx]}\nPatient: {patient_ids[idx]}")
            plt.tight_layout()
            plot_idx += 1

    plt.tight_layout()
    plt.show()

# Main execution
def main():
    # Load diagnosis data from CSV
    diagnosis_dict = load_diagnosis_data()
    if not diagnosis_dict:
        print("Could not load diagnosis data. Please check the CSV file path.")
        return

    # Find all WAV files
    wav_files = find_wav_files()
    if not wav_files:
        print("No WAV files found in the dataset directory.")
        return

    # Build dataset by matching WAV files with diagnosis information
    annotations_df = build_dataset(wav_files, diagnosis_dict)
    if annotations_df is None or len(annotations_df) == 0:
        print("Could not match WAV files with diagnosis information.")
        return

    # Split data by patient to avoid data leakage
    train_df, test_df = ensure_patient_stratification(annotations_df)

    # Prepare training data
    print("\nPreparing training dataset...")
    X_train, y_train, train_filenames, train_patient_ids = prepare_dataset(train_df)

    # Prepare test data
    print("\nPreparing test dataset...")
    X_test, y_test, test_filenames, test_patient_ids = prepare_dataset(test_df)

    print(f"Training set shape: X: {X_train.shape}, y: {y_train.shape}")
    print(f"Test set shape: X: {X_test.shape}, y: {y_test.shape}")

    # Get class names for visualization
    class_names = [k for k, v in sorted(class_mapping.items(), key=lambda item: item[1])]
    class_names = list(dict.fromkeys(class_names))  # Remove duplicates while preserving order

    # Visualize some samples
    print("\nVisualizing sample spectrograms from training set:")
    visualize_samples(X_train, y_train, train_filenames, train_patient_ids, class_names)

    # Build and train model
    input_shape = X_train.shape[1:]
    num_classes = y_train.shape[1]
    print(f"\nBuilding model for {num_classes} classes...")
    model = build_model(input_shape, num_classes)
    model.summary()

    print("\nTraining model...")
    early_stopping = tf.keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=10,
        restore_best_weights=True
    )

    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=50,
        batch_size=32,
        callbacks=[early_stopping]
    )

    # Evaluate model
    print("\nEvaluating model...")
    test_loss, test_acc = model.evaluate(X_test, y_test)
    print(f"Test accuracy: {test_acc:.4f}")
    y_pred = model.predict(X_test)

    plot_training_history(history)
    plot_confusion_matrix(y_test, y_pred, class_names)

    y_true_classes = np.argmax(y_test, axis=1)
    y_pred_classes = np.argmax(y_pred, axis=1)

    # Save model
    model.save('/content/icbhi_diagnosis_model.h5')
    print("\nModel saved to /content/icbhi_diagnosis_model.h5")

if __name__ == "__main__":
    main()

In [None]:
%%writefile app.py
import streamlit as st
import numpy as np
import librosa
import tensorflow as tf
import matplotlib.pyplot as plt
import os
import tempfile

# Define constants for audio processing
SAMPLE_RATE = 22050
DURATION = 5  # seconds
N_MELS = 128
HOP_LENGTH = 512

# Define class mapping
class_mapping = {
    0: "Healthy",
    1: "COPD/Asthma/Bronchiectasis/Bronchiolitis",
    2: "LRTI/Pneumonia",
    3: "URTI"
}

# Function to extract features from an audio file
def extract_features(file_path):
    try:
        y, sr = librosa.load(file_path, sr=SAMPLE_RATE)

        if len(y) > SAMPLE_RATE * DURATION:
            y = y[:int(SAMPLE_RATE * DURATION)]
        else:
            pad_length = int(SAMPLE_RATE * DURATION) - len(y)
            if pad_length > 0:
                y = np.pad(y, (0, pad_length), mode='constant')

        mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=N_MELS, hop_length=HOP_LENGTH)
        mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)

        mel_spec_db = (mel_spec_db - mel_spec_db.min()) / (mel_spec_db.max() - mel_spec_db.min())

        return mel_spec_db
    except Exception as e:
        st.error(f"Error processing audio: {e}")
        return None

# Function to plot spectrogram
def plot_spectrogram(spec, title="Mel Spectrogram"):
    fig, ax = plt.subplots(figsize=(10, 4))
    img = librosa.display.specshow(spec, x_axis='time', y_axis='mel', sr=SAMPLE_RATE, hop_length=HOP_LENGTH, ax=ax)
    ax.set_title(title)
    fig.colorbar(img, ax=ax, format='%+2.0f dB')
    return fig

@st.cache_resource
def load_model():
    model = tf.keras.models.load_model('icbhi_diagnosis_model.h5')
    return model

def main():
    st.title("Detecting Respiratory Diseases")
    st.write("""
    ## Upload a respiratory sound recording to classify the condition
    This application uses a deep learning model trained on the ICBHI dataset to predict respiratory conditions.
    """)

    # Load model
    try:
        model = load_model()
        st.success("Model loaded successfully!")
    except Exception as e:
        st.error(f"Failed to load model: {e}")
        return

    # File uploader
    uploaded_file = st.file_uploader("Upload a WAV file", type=['wav'])

    if uploaded_file is not None:
        temp_dir = tempfile.TemporaryDirectory()
        temp_file_path = os.path.join(temp_dir.name, "temp_audio.wav")

        with open(temp_file_path, "wb") as f:
            f.write(uploaded_file.getbuffer())

        st.audio(uploaded_file, format='audio/wav')

        with st.spinner('Processing audio...'):
            features = extract_features(temp_file_path)

            if features is not None:
                st.subheader("Mel Spectrogram")
                fig = plot_spectrogram(features)
                st.pyplot(fig)

                features = features.reshape(1, features.shape[0], features.shape[1], 1)

                with st.spinner('Classifying...'):
                    prediction = model.predict(features)
                    predicted_class = np.argmax(prediction[0])
                    prediction_proba = prediction[0][predicted_class] * 100

                st.subheader("Prediction Results")
                st.write(f"**Predicted Condition:** {class_mapping[predicted_class]}")
                st.write(f"**Confidence:** {prediction_proba:.2f}%")

                st.subheader("Class Probabilities")
                for class_id, class_name in class_mapping.items():
                    st.write(f"{class_name}: {prediction[0][class_id]*100:.2f}%")

                temp_dir.cleanup()

if __name__ == "__main__":
    main()

In [None]:
# Set up ngrok
from pyngrok import ngrok

# Set your auth token - REPLACE WITH YOUR ACTUAL TOKEN
ngrok.set_auth_token("2v8XsTmzMl5e639nypEY1cVJyZr_3nfL3ZpPLiPp5x2RKhjuc")

# Terminate any existing tunnels
ngrok.kill()

# Run the Streamlit app
!streamlit run app.py &>/dev/null &

# Create a tunnel to the Streamlit port
public_url = ngrok.connect(8501)
print(f"Public URL: {public_url}")