<a href="https://colab.research.google.com/github/Sazim2019331087/voice_model/blob/main/CNN_RNN_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# --- IMPORTANT WARNING ---
print("="*80)
print("WARNING: Training a deep learning model from scratch (without pre-training)")
print("         with only 147 audio files is highly challenging and prone to overfitting.")
print("         The model will likely have limited generalization to new voices not in your dataset.")
print("         For a robust, good-performing model, you typically need thousands of hours of audio data.")
print("         This code is provided for educational purposes to demonstrate the architecture.")
print("         Ensure you have a GPU runtime enabled in Colab, as CPU training will be extremely slow.")
print("="*80)

In [None]:
print("--- Installing required libraries ---")
!pip install --upgrade pip
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 # Install PyTorch with CUDA support
!pip install pandas scikit-learn joblib tqdm # tqdm for progress bars
!pip install ffmpeg-python # Python wrapper for ffmpeg
!apt-get update && apt-get install -y ffmpeg # Install ffmpeg on Colab for audio processing

In [None]:
# ==============================================================================
# --- Step 1: Setup and Install Libraries ---
# ==============================================================================
# You need to run this cell before any other code.

print("--- Installing required libraries ---")
!pip install --upgrade pip
!pip install pandas scikit-learn joblib tqdm # pandas, scikit-learn, joblib, tqdm
!pip install ffmpeg-python # Python wrapper for ffmpeg
!apt-get update && apt-get install -y ffmpeg # Install ffmpeg on Colab for audio processing

# --- CRITICAL: Re-install PyTorch and TorchAudio to ensure CUDA version compatibility ---
# This line will install PyTorch and a compatible torchaudio/torchvision
# It's important to specify the PyTorch version first, then let torchaudio/torchvision follow.
# This command will usually select the correct torchaudio version automatically.
!pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118
# We specifically target torch==2.0.1 and torchaudio==2.0.2 for cu118,
# which are known to be compatible based on common PyTorch versions.

# After this, restart the runtime as before.

print("Installation complete. Please RESTART YOUR COLAB RUNTIME (Runtime -> Restart session) and then run all cells from the beginning.")

In [None]:
# Step 2: Mount Google Drive and Load Data
# ==============================================================================
# This step connects your Colab notebook to your Google Drive to access the data.

from google.colab import drive
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import torchaudio
from torchaudio.transforms import MFCC # For MFCC feature extraction
import numpy as np
from tqdm.notebook import tqdm # For progress bars
import joblib # For saving and loading the speaker mapping and model

# Mount Google Drive
print("\n--- Mounting Google Drive ---")
drive.mount('/content/drive')

In [None]:
# --- Define the paths to your data ---
# IMPORTANT: Make sure these paths match your Google Drive structure exactly.
# Example: If your main folder is 'project' in My Drive and audio files are in 'voices' subfolder.
PROJECT_ROOT_DIR = '/content/drive/MyDrive/project'
CSV_PATH = os.path.join(PROJECT_ROOT_DIR, 'training.csv')
AUDIO_FOLDER_PATH = os.path.join(PROJECT_ROOT_DIR, 'voices')

# Check if the directories and files exist
if not os.path.exists(PROJECT_ROOT_DIR):
    raise FileNotFoundError(f"Error: Project folder '{PROJECT_ROOT_DIR}' not found. Please check the path and your Google Drive structure.")
elif not os.path.exists(CSV_PATH):
    raise FileNotFoundError(f"Error: CSV file '{CSV_PATH}' not found. Please ensure it's in the correct location.")
elif not os.path.exists(AUDIO_FOLDER_PATH):
    raise FileNotFoundError(f"Error: Audio folder '{AUDIO_FOLDER_PATH}' not found. Please check the path and upload your audio files.")
else:
    print(f"Successfully located project folder at: {PROJECT_ROOT_DIR}")

# Load the CSV file into a pandas DataFrame
print("\n--- Loading data from CSV ---")
df = pd.read_csv(CSV_PATH)

# Construct the full path to each audio file
df['audio_path'] = df['audio_file'].apply(lambda x: os.path.join(AUDIO_FOLDER_PATH, x))

# --- IMPORTANT: Filter out missing/unreadable audio files ---
# This step is crucial to prevent errors during audio loading.
print("\n--- Verifying audio file paths and formats... ---")
verified_data_for_df = [] # To store valid rows for the new DataFrame

# Create a temporary MFCC transform to test audio loading and feature shape
test_mfcc_transform = MFCC(
    sample_rate=16000,
    n_mfcc=40, # Assuming 40 MFCCs
    melkwargs={
        'n_fft': 400,
        'hop_length': 160,
        'n_mels': 128
    }
)

problematic_files = []

# Iterate through each row to verify audio files
for idx, row in tqdm(df.iterrows(), total=len(df), desc="Verifying audio files"):
    audio_file_path = row['audio_path']

    if not os.path.exists(audio_file_path):
        problematic_files.append((row['audio_file'], row['email'], "File Not Found"))
        continue

    try:
        waveform, sample_rate = torchaudio.load(audio_file_path, frame_offset=0, num_frames=16000 * 2) # Load first 2 seconds
        if sample_rate != 16000:
            resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
            waveform = resampler(waveform)
        if waveform.shape[0] > 1: # Convert to mono
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        # Ensure waveform is 2D (channels, samples)
        elif waveform.ndim == 1:
            waveform = waveform.unsqueeze(0) # Add channel dimension if it was (samples,)

        # Test MFCC transformation
        temp_mfcc_features = test_mfcc_transform(waveform)

        # Squeeze the channel dimension for the test if it's there
        if temp_mfcc_features.ndim == 3 and temp_mfcc_features.shape[0] == 1:
            temp_mfcc_features = temp_mfcc_features.squeeze(0)

        # Verify it's 2D (num_mfcc, num_frames) after squeezing
        if temp_mfcc_features.ndim != 2:
            raise ValueError(f"MFCC features for {audio_file_path} unexpected shape after squeeze: {temp_mfcc_features.shape}")

        # If successfully processed, add the original row to our verified list
        verified_data_for_df.append(row.to_dict())

    except Exception as e:
        problematic_files.append((row['audio_file'], row['email'], f"Format Error: {e}"))

if problematic_files:
    print(f"\n--- {len(problematic_files)} Problematic audio files found and skipped ---")
    problematic_df = pd.DataFrame(problematic_files, columns=['audio_ffile', 'email', 'Reason'])
    print(problematic_df.to_markdown(index=False, numalign="left", stralign="left"))
    print("\nTip: Use `ffmpeg -i input.wav -ar 16000 -ac 1 -c:a pcm_s16le output_converted.wav` to convert problematic files.")
else:
    print("\nAll audio files verified successfully!")

# Create a new DataFrame with only verified files
if not verified_data_for_df:
    raise ValueError("No valid audio files found after verification. Please check your data.")

existing_files_df = pd.DataFrame(verified_data_for_df)
existing_files_df['speaker_id'] = existing_files_df['email'].astype('category').cat.codes
speaker_mapping = dict(enumerate(existing_files_df['email'].astype('category').cat.categories))
num_speakers = len(speaker_mapping)

print(f"\n--- Speaker Mapping (Total Unique Speakers: {num_speakers}) ---")
print(speaker_mapping)

# --- CRITICAL FIX for ValueError in train_test_split ---
# Split data into training and testing sets
# We use test_size=num_speakers to ensure there's at least one sample per speaker in the test set.
# This is required for stratification when the number of speakers is high and samples per speaker are low.
if len(existing_files_df) < num_speakers * 2: # Check if there are enough samples for a reasonable split
    print("\nWARNING: Dataset has very few samples per speaker. Stratification might be difficult.")
    print(f"Total samples: {len(existing_files_df)}, Unique speakers: {num_speakers}")
    # Fallback to a non-stratified split if stratification is impossible
    # Or adjust test_size to be a fraction that might pass, but not ideal for evaluation
    train_df, test_df = train_test_split(
        existing_files_df,
        test_size=max(1, min(int(0.2 * len(existing_files_df)), num_speakers)), # At least 1, max 20% or num_speakers
        random_state=42,
        stratify=None # Disable stratification as it's problematic with few samples per class
    )
    print("Proceeding with NON-STRATIFIED split due to limited samples per speaker.")
else:
    train_df, test_df = train_test_split(
        existing_files_df,
        test_size=num_speakers, # Set test_size to the absolute number of speakers for stratification
        random_state=42,
        stratify=existing_files_df['speaker_id'] # This now ensures each speaker is in the test set
    )
    print(f"Using STRATIFIED split with {len(test_df)} samples in test set.")

print(f"\n--- Dataset Split for Training and Testing ---")
print(f"Training samples: {len(train_df)}")
print(f"Testing samples: {len(test_df)}")

In [None]:
# Step 3: Create a Custom PyTorch Dataset with MFCCs (FINAL CORRECTED VERSION)
# ==============================================================================

class SpeakerDatasetMFCC(Dataset):
    def __init__(self, dataframe, target_sr=16000, num_mfcc=40, n_fft=400, hop_length=160):
        self.dataframe = dataframe
        self.target_sr = target_sr
        self.num_mfcc = num_mfcc
        self.max_len_sec = 30 # Fixed duration for training
        self.max_len_samples = self.max_len_sec * self.target_sr

        self.mfcc_transform = MFCC(
            sample_rate=target_sr, n_mfcc=num_mfcc, melkwargs={'n_fft': n_fft, 'hop_length': hop_length, 'n_mels': 128}
        )

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        audio_path = row['audio_path']
        label = row['speaker_id']

        try:
            waveform, sample_rate = torchaudio.load(audio_path)

            if sample_rate != self.target_sr:
                resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.target_sr)
                waveform = resampler(waveform)

            if waveform.shape[0] > 1: # Convert to mono if stereo
                waveform = torch.mean(waveform, dim=0, keepdim=True)
            # Ensure waveform is 2D (channels, samples) expected by MFCC transform
            elif waveform.ndim == 1:
                waveform = waveform.unsqueeze(0) # Add channel dimension if it's just (samples,)

            # Pad or truncate to a fixed length (max_len_samples)
            if waveform.shape[1] > self.max_len_samples:
                waveform = waveform[:, :self.max_len_samples]
            elif waveform.shape[1] < self.max_len_samples:
                padding = self.max_len_samples - waveform.shape[1]
                waveform = torch.nn.functional.pad(waveform, (0, padding))

            mfcc_features = self.mfcc_transform(waveform)

            # --- CRITICAL FIX: Squeeze the channel dimension (dim=0) here ---
            # mfcc_features original shape is (1, num_mfcc, num_frames) for mono audio
            # We want (num_mfcc, num_frames) for the Conv1d input after batching
            if mfcc_features.ndim == 3 and mfcc_features.shape[0] == 1:
                mfcc_features = mfcc_features.squeeze(0)

            return mfcc_features, torch.tensor(label, dtype=torch.long)

        except Exception as e:
            print(f"Error processing {audio_path}: {e}. Skipping this sample.")
            return None, None

# Custom collate_fn to handle None values (from failed audio loads) and ensure consistent stacking
def collate_fn(batch):
    batch = [item for item in batch if item[0] is not None]
    if not batch:
        return None, None

    mfccs, labels = zip(*batch)

    mfccs_stacked = torch.stack(mfccs)
    labels_stacked = torch.stack(labels)

    return mfccs_stacked, labels_stacked

# Create datasets and data loaders
train_dataset = SpeakerDatasetMFCC(train_df)
test_dataset = SpeakerDatasetMFCC(test_df)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)

print(f"\n--- DataLoader created. Total training batches: {len(train_loader)} ---")

In [None]:
# Step 4: Define a Combined CNN-RNN Model from Scratch
# ==============================================================================
# This model uses CNN layers to extract local features from MFCCs and RNN (GRU)
# layers to capture temporal dependencies, followed by a classification head.

class SpeakerCNN_RNN(nn.Module):
    def __init__(self, num_speakers, num_mfcc=40, hidden_dim=128, rnn_layers=2, dropout_rate=0.3):
        super(SpeakerCNN_RNN, self).__init__()

        # CNN layers for feature extraction from MFCCs
        # Input: (batch_size, num_mfcc, sequence_length_frames)
        self.conv_layers = nn.Sequential(
            nn.Conv1d(num_mfcc, 64, kernel_size=5, padding=2), # Output: (B, 64, L)
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2), # Output: (B, 64, L/2)

            nn.Conv1d(64, 128, kernel_size=5, padding=2), # Output: (B, 128, L/2)
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2), # Output: (B, 128, L/4)

            nn.Conv1d(128, 256, kernel_size=5, padding=2), # Output: (B, 256, L/4)
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2) # Output: (B, 256, L/8)
        )

        # The output of CNN layers will be (batch_size, features, new_sequence_length)
        # We need to calculate the actual sequence length after pooling to correctly initialize RNN
        # Let's assume input sequence length for MFCCs of 30s @ 16kHz, hop_length=160
        # Frame length = (16000 * 30) = 480000 samples
        # Number of frames = (480000 - 400) / 160 + 1 = ~3000 frames
        # After 3 MaxPool1d(kernel_size=2), sequence length becomes 3000 / 2 / 2 / 2 = 375 frames

        rnn_input_size = 256 # Number of features from CNN output

        # RNN (GRU) layers for temporal modeling
        # Input to RNN: (batch_size, sequence_length_frames, features)
        self.rnn = nn.GRU(
            input_size=rnn_input_size,
            hidden_size=hidden_dim,
            num_layers=rnn_layers,
            bidirectional=True, # Bidirectional GRU for better context over time
            batch_first=True # Input and output tensors are provided as (batch, seq, feature)
        )

        # Global Average Pooling after RNN to get a fixed-size embedding for classification
        # We take the mean across the sequence dimension (dim=1)
        self.global_pool = nn.AdaptiveAvgPool1d(1)

        # Final fully connected layer for classification
        # hidden_dim * 2 because of bidirectional GRU
        self.fc_layer = nn.Linear(hidden_dim * 2, num_speakers)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        # x shape: (batch_size, num_mfcc, sequence_length_frames)

        # CNN layers
        x = self.conv_layers(x) # Output shape: (batch_size, 256, reduced_sequence_length_frames)

        # Permute for RNN input: (batch_size, sequence_length_frames, features)
        x = x.permute(0, 2, 1)

        # RNN layers
        rnn_out, _ = self.rnn(x) # rnn_out shape: (batch_size, sequence_length_frames, hidden_dim * 2)

        # Apply Global Average Pooling across the sequence length dimension (dim=1)
        # Squeeze the resulting 1-dimensional output
        # Input for global_pool needs to be (batch_size, features, sequence_length)
        # So we permute rnn_out back
        pooled_output = self.global_pool(rnn_out.permute(0, 2, 1)).squeeze(-1) # Output: (batch_size, hidden_dim * 2)

        # Dropout for regularization
        x = self.dropout(pooled_output)

        # Final fully connected layer for classification
        x = self.fc_layer(x)
        return x

# Determine the device (GPU if available, else CPU)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"\n--- Initializing model on device: {device} ---")

# Initialize the model with the correct number of speakers and MFCC features
model = SpeakerCNN_RNN(num_speakers=num_speakers, num_mfcc=train_dataset.num_mfcc).to(device)

# Define Loss function and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4) # Start with a small learning rate


In [None]:
# Step 5: Train the Model
# ==============================================================================

def train_model(model, train_loader, test_loader, criterion, optimizer, num_epochs=100):
    model.train()
    print("\n--- Starting Training ---")
    best_accuracy = 0.0

    for epoch in range(num_epochs):
        running_loss = 0.0
        correct_train_predictions = 0
        total_train_samples = 0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} (Train)")
        for i, (inputs, labels) in enumerate(pbar):
            if inputs is None: # Skip batches with no valid samples
                pbar.set_postfix_str("Skipping empty batch")
                continue

            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad() # Zero the gradients
            outputs = model(inputs) # Forward pass
            loss = criterion(outputs, labels) # Calculate loss
            loss.backward() # Backward pass
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_train_samples += labels.size(0)
            correct_train_predictions += (predicted == labels).sum().item()

            pbar.set_postfix({'loss': running_loss / (i+1), 'train_acc': 100 * correct_train_predictions / total_train_samples})

        epoch_train_loss = running_loss / len(train_loader)
        epoch_train_accuracy = 100 * correct_train_predictions / total_train_samples

        # Evaluate on the test set after each epoch
        model.eval() # Set model to evaluation mode
        correct_test_predictions = 0
        total_test_samples = 0
        test_loss = 0.0

        with torch.no_grad(): # Disable gradient calculations for evaluation
            test_pbar = tqdm(test_loader, desc=f"Epoch {epoch+1}/{num_epochs} (Test)")
            for inputs, labels in test_pbar:
                if inputs is None: continue
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(inputs)
                loss = criterion(outputs, labels)
                test_loss += loss.item()

                _, predicted = torch.max(outputs.data, 1)
                total_test_samples += labels.size(0)
                correct_test_predictions += (predicted == labels).sum().item()
                test_pbar.set_postfix({'test_loss': test_loss / (test_pbar.n + 1), 'test_acc': 100 * correct_test_predictions / total_test_samples})

        epoch_test_loss = test_loss / len(test_loader)
        epoch_test_accuracy = 100 * correct_test_predictions / total_test_samples

        print(f"Epoch {epoch+1} Summary: Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_accuracy:.2f}%, "
              f"Test Loss: {epoch_test_loss:.4f}, Test Acc: {epoch_test_accuracy:.2f}%")

        # Save the model if it's the best one so far (based on test accuracy)
        if epoch_test_accuracy > best_accuracy:
            best_accuracy = epoch_test_accuracy
            SAVE_DIR = os.path.join(PROJECT_ROOT_DIR, 'saved_models_scratch_mfcc_rnn')
            os.makedirs(SAVE_DIR, exist_ok=True)
            model_save_path = os.path.join(SAVE_DIR, 'speaker_cnn_rnn_best.pth')
            torch.save(model.state_dict(), model_save_path)
            print(f"New best model saved with Test Accuracy: {best_accuracy:.2f}%")

        model.train() # Set model back to training mode for the next epoch

train_model(model, train_loader, test_loader, criterion, optimizer, num_epochs=100) # Training for 100 epochs


In [None]:
# Step 6: Save the Trained Model and Speaker Mapping (Final Save)
# ==============================================================================
# Even if a "best" model was saved, we save the final state as well.

SAVE_DIR = os.path.join(PROJECT_ROOT_DIR, 'saved_models_scratch_mfcc_rnn')
os.makedirs(SAVE_DIR, exist_ok=True) # Ensure directory exists
model_final_save_path = os.path.join(SAVE_DIR, 'speaker_cnn_rnn_final.pth')
mapping_save_path = os.path.join(SAVE_DIR, 'speaker_mapping.joblib')

joblib.dump(speaker_mapping, mapping_save_path)
torch.save(model.state_dict(), model_final_save_path)

print(f"\n--- Final Trained Model Saved to: {model_final_save_path} ---")
print(f"Speaker Mapping Saved to: {mapping_save_path}")


In [None]:
# ==============================================================================
# Step 7: Inference (Detect a Person from a New Audio File - Interactive Upload)
# --- FIX: 'SpeakerCNN_RNN' object has no attribute 'device' ---
# ==============================================================================

from google.colab import files
import soundfile as sf
import numpy as np
import os # Ensure os is imported for path operations

# --- The predict_speaker_from_audio function with the fix ---
def predict_speaker_from_audio(model, audio_file_path, speaker_mapping,
                               target_sr=16000, num_mfcc=40, n_fft=400, hop_length=160):
    model.eval() # Set model to evaluation mode
    # Ensure mfcc_transform is initialized within the function or passed as an argument
    mfcc_transform = MFCC(sample_rate=target_sr, n_mfcc=num_mfcc, melkwargs={'n_fft': n_fft, 'hop_length': hop_length})
    max_len_samples = 30 * target_sr # Ensure consistent audio length for inference

    try:
        if not os.path.exists(audio_file_path):
            raise FileNotFoundError(f"Audio file not found: {audio_file_path}")

        waveform, sample_rate = torchaudio.load(audio_file_path)
        if sample_rate != target_sr:
            resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)
            waveform = resampler(waveform)
        if waveform.shape[0] > 1: # Convert to mono
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        elif waveform.ndim == 1: # Add channel dimension if it was (samples,)
            waveform = waveform.unsqueeze(0)

        # Pad/truncate to max_len_samples
        if waveform.shape[1] > max_len_samples:
            waveform = waveform[:, :max_len_samples]
        elif waveform.shape[1] < max_len_samples:
            padding = max_len_samples - waveform.shape[1]
            waveform = torch.nn.functional.pad(waveform, (0, padding))

        # Convert to MFCCs
        mfcc_features = mfcc_transform(waveform)

        # CRITICAL FIX: Squeeze the channel dimension (dim=0) for inference
        if mfcc_features.ndim == 3 and mfcc_features.shape[0] == 1:
            mfcc_features = mfcc_features.squeeze(0)

        # --- FIX APPLIED HERE: Get device from model parameters ---
        input_tensor = mfcc_features.unsqueeze(0).to(next(model.parameters()).device)

        with torch.no_grad(): # No need to calculate gradients for inference
            outputs = model(input_tensor)
            probabilities = torch.softmax(outputs, dim=1)
            confidence, predicted_id_tensor = torch.max(probabilities, 1)

            predicted_id = predicted_id_tensor.item()
            predicted_confidence = confidence.item()

        predicted_email = speaker_mapping[predicted_id]

        return predicted_email, predicted_confidence

    except Exception as e:
        print(f"Error during inference for {audio_file_path}: {e}")
        return None, None


# --- Reload the trained model and speaker mapping for inference ---
# These variables should already be defined from previous cells (Steps 1-6)
# If you run this cell independently, ensure these are initialized:
# PROJECT_ROOT_DIR = '/content/drive/MyDrive/project'
# num_speakers = # Get this from your speaker_mapping or previous output
# num_mfcc_features = 40 # Needs to match what you trained with (from SpeakerDatasetMFCC init)
# device = 'cuda' if torch.cuda.is_available() else 'cpu'

# loaded_model = SpeakerCNN_RNN(num_speakers=num_speakers, num_mfcc=num_mfcc_features).to(device)
# best_model_path = os.path.join(PROJECT_ROOT_DIR, 'saved_models_scratch_mfcc_rnn', 'speaker_cnn_rnn_best.pth')
# model_final_save_path = os.path.join(PROJECT_ROOT_DIR, 'saved_models_scratch_mfcc_rnn', 'speaker_cnn_rnn_final.pth')

# if os.path.exists(best_model_path):
#     loaded_model.load_state_dict(torch.load(best_model_path, map_location=device))
#     print(f"Loaded best model from: {best_model_path}")
# else:
#     loaded_model.load_state_dict(torch.load(model_final_save_path, map_location=device))
#     print(f"Loaded final model from: {model_final_save_path} (Best model not found)")

# loaded_speaker_mapping = joblib.load(os.path.join(PROJECT_ROOT_DIR, 'saved_models_scratch_mfcc_rnn', 'speaker_mapping.joblib'))

# Ensure existing_files_df is accessible for getting the name
# If running this cell independently, you'd need to load it:
# existing_files_df = pd.read_csv(os.path.join(PROJECT_ROOT_DIR, 'main_data.csv'))
# existing_files_df['email'] = existing_files_df['email'].astype('category') # Ensure type matches the one used during training


# --- Interactive File Upload for Prediction ---
print("\n--- Upload an audio file from your PC for speaker detection ---")
uploaded_files = files.upload() # This will open a file dialog

if uploaded_files:
    uploaded_file_name = list(uploaded_files.keys())[0]
    uploaded_file_path = os.path.join('/content/', uploaded_file_name)

    print(f"\nUploaded file: {uploaded_file_name}")
    print(f"File saved to: {uploaded_file_path}")

    print(f"\n--- Performing Inference on the uploaded audio file ---")

    # Predict the speaker using the uploaded file
    # Pass the num_mfcc from the trained dataset/model configuration
    # Assuming num_mfcc was 40 during training, and train_dataset is accessible or its num_mfcc known
    detected_email, confidence = predict_speaker_from_audio(
        loaded_model, uploaded_file_path, loaded_speaker_mapping,
        num_mfcc=train_dataset.num_mfcc # Use the same num_mfcc as used for training
    )

    if detected_email:
        # Find the corresponding name from the original DataFrame
        detected_person_df = existing_files_df[existing_files_df['email'] == detected_email]
        detected_name = detected_person_df['name'].iloc[0] if not detected_person_df.empty else "Unknown"

        print("\n--- Detection Result ---")
        print(f"The detected person is: {detected_name}")
        print(f"Corresponding Email ID: {detected_email}")
        print(f"Confidence: {confidence:.4f}")

    else:
        print("Detection failed for the uploaded file. Please check the audio file and its format.")
        print("Ensure it's a clear recording of one of the trained speakers.")

else:
    print("No file was uploaded.")