In [None]:
import os
import numpy as np
import torch
import torchaudio
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
import random
import time
import warnings
import math
import traceback
from einops import rearrange
from transformers import Wav2Vec2Model, WavLMModel, HubertModel
import torch.nn.functional as F

warnings.filterwarnings('ignore')

# =============================================
# 1. SETUP AND CONFIGURATION
# =============================================

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

set_seed()

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Create directories
os.makedirs('checkpoints', exist_ok=True)

In [None]:
# =============================================
# 2. DATA LOADING AND EXPLORATION
# =============================================

# Load data

emotion_mapping = {
        "marah": 0, "jijik": 1, "takut": 2,
        "bahagia": 3, "netral": 4, "sedih": 5
    }

def load_data():
    train_df = pd.read_csv("train.csv")
    test_df = pd.read_csv("test.csv")
    
    print(f"Training data shape: {train_df.shape}")
    print(f"Test data shape: {test_df.shape}")
    
   
    
    # Split into train/val
    train_data, val_data = train_test_split(
        train_df, test_size=0.15, random_state=42, stratify=train_df['label']
    )
    
    print(f"Training data: {len(train_data)} samples")
    print(f"Validation data: {len(val_data)} samples")
    print(f"Test data: {len(test_df)} samples")
    
    return train_df, test_df, train_data, val_data, emotion_mapping

# Analyze data distributions
def analyze_data_distribution(train_data, val_data, test_df):
    # Count the emotion distributions
    train_dist = train_data['label'].value_counts().sort_index()
    val_dist = val_data['label'].value_counts().sort_index()
    
    # Check if test data has labels
    if 'label' in test_df.columns:
        test_dist = test_df['label'].value_counts().sort_index()
        has_test_labels = True
    else:
        has_test_labels = False
    
    # Create visualizations
    plt.figure(figsize=(14, 8))
    
    # Training distribution
    ax1 = plt.subplot(1, 3, 1)
    sns.barplot(x=train_dist.index, y=train_dist.values, palette='viridis')
    plt.title('Training Emotion Distribution')
    plt.xlabel('Emotion')
    plt.ylabel('Count')
    plt.xticks(rotation=45)
    
    # Add count labels on the bars
    for i, count in enumerate(train_dist.values):
        plt.text(i, count + 5, str(count), ha='center', fontweight='bold')

    # Validation distribution
    ax2 = plt.subplot(1, 3, 2)
    sns.barplot(x=val_dist.index, y=val_dist.values, palette='viridis')
    plt.title('Validation Emotion Distribution')
    plt.xlabel('Emotion')
    plt.ylabel('Count')
    plt.xticks(rotation=45)
    
    # Add count labels
    for i, count in enumerate(val_dist.values):
        plt.text(i, count + 1, str(count), ha='center', fontweight='bold')
    
    # Test distribution if available
    if has_test_labels:
        ax3 = plt.subplot(1, 3, 3)
        sns.barplot(x=test_dist.index, y=test_dist.values, palette='viridis')
        plt.title('Test Emotion Distribution')
        plt.xlabel('Emotion')
        plt.ylabel('Count')
        plt.xticks(rotation=45)
        
        for i, count in enumerate(test_dist.values):
            plt.text(i, count + 1, str(count), ha='center', fontweight='bold')
    
    plt.tight_layout()
    plt.show()
    
    # Print statistics
    print("Emotion Distribution:")
    print("\nTraining Dataset:")
    train_percent = (train_dist / train_dist.sum() * 100).round(2)
    for emotion, count in zip(train_dist.index, train_dist.values):
        print(f"{emotion}: {count} samples ({train_percent[emotion]}%)")
    
    print("\nValidation Dataset:")
    val_percent = (val_dist / val_dist.sum() * 100).round(2)
    for emotion, count in zip(val_dist.index, val_dist.values):
        print(f"{emotion}: {count} samples ({val_percent[emotion]}%)")
    
    # Create pie charts
    plt.figure(figsize=(18, 6))
    
    plt.subplot(1, 3, 1)
    plt.pie(train_dist.values, labels=train_dist.index, autopct='%1.1f%%', startangle=90, shadow=True)
    plt.title('Training Data Emotion Distribution')
    
    plt.subplot(1, 3, 2)
    plt.pie(val_dist.values, labels=val_dist.index, autopct='%1.1f%%', startangle=90, shadow=True)
    plt.title('Validation Data Emotion Distribution')
    
    if has_test_labels:
        plt.subplot(1, 3, 3)
        plt.pie(test_dist.values, labels=test_dist.index, autopct='%1.1f%%', startangle=90, shadow=True)
        plt.title('Test Data Emotion Distribution')
    
    plt.tight_layout()
    plt.show()


In [None]:
# =============================================
# 3. DATASET CLASSES
# =============================================


class RawAudioEmotionDataset(Dataset):
    def __init__(self, df, base_path, use_cache=True):
        self.df = df
        self.base_path = base_path
        self.class2idx = {"marah":0, "jijik":1, "takut":2, "bahagia":3, "netral":4, "sedih":5}
        self.use_cache = use_cache
        self.cache = {}
        
        # Audio preprocessing settings
        self.target_sr = 16000  # Wav2Vec2 expects 16kHz audio
        self.max_length = 160000  # 10 seconds max at 16kHz
        
    def __len__(self):
        return len(self.df)
        
    def __getitem__(self, idx):
        # Get audio path and label
        row = self.df.iloc[idx]
        audio_id = row['id']
        audio_path = f"{self.base_path}/{audio_id}"
        
        # Use cache if available
        if self.use_cache and audio_path in self.cache:
            return self.cache[audio_path]
            
        try:
            # Load audio
            waveform, sample_rate = torchaudio.load(audio_path)
            
            # Convert stereo to mono
            if waveform.shape[0] > 1:
                waveform = waveform.mean(dim=0, keepdim=True)
            
            # Resample to target sample rate
            if sample_rate != self.target_sr:
                waveform = torchaudio.functional.resample(waveform, sample_rate, self.target_sr)
            
            # Normalize waveform
            waveform = (waveform - waveform.mean()) / (waveform.std() + 1e-8)
            
            # Pad or truncate to max_length
            if waveform.shape[1] > self.max_length:
                waveform = waveform[:, :self.max_length]
            else:
                padding_length = self.max_length - waveform.shape[1]
                waveform = F.pad(waveform, (0, padding_length), "constant", 0)
                
        except Exception as e:
            print(f"Error processing {audio_path}: {e}")
            waveform = torch.zeros(1, self.max_length)
        
        # Get label (or None for test data)
        if 'label' in row:
            label = self.class2idx[row['label']]
            result = (waveform, torch.tensor(label))
        else:
            result = (waveform, audio_id)
            
        # Store in cache
        if self.use_cache:
            self.cache[audio_path] = result
        
        return result


class FastRawAudioEmotionDataset(RawAudioEmotionDataset):
    def __init__(self, df, base_path, use_cache=True):
        super().__init__(df, base_path, use_cache)
        self.max_length = 48000  # 3 seconds at 16kHz instead of 10s
    
    def __getitem__(self, idx):
        # Get audio path and label
        row = self.df.iloc[idx]
        audio_id = row['id']
        audio_path = f"{self.base_path}/{audio_id}"
        
        # Use cache if available
        if self.use_cache and audio_path in self.cache:
            return self.cache[audio_path]
        
        try:
            # Load audio
            waveform, sample_rate = torchaudio.load(audio_path)
            
            # Convert stereo to mono
            if waveform.shape[0] > 1:
                waveform = waveform.mean(dim=0, keepdim=True)
            
            # Resample if needed
            if sample_rate != self.target_sr:
                waveform = torchaudio.functional.resample(waveform, sample_rate, self.target_sr)
            
            # Normalize
            waveform = (waveform - waveform.mean()) / (waveform.std() + 1e-8)
            
            # Take center 3 seconds for consistency
            if waveform.shape[1] > self.max_length:
                start = (waveform.shape[1] - self.max_length) // 2
                waveform = waveform[:, start:start + self.max_length]
            else:
                padding_length = self.max_length - waveform.shape[1]
                waveform = F.pad(waveform, (0, padding_length), "constant", 0)
                
        except Exception as e:
            print(f"Error processing {audio_path}: {e}")
            waveform = torch.zeros(1, self.max_length)
        
        # Get label or ID
        if 'label' in row:
            label = self.class2idx[row['label']]
            result = (waveform, torch.tensor(label))
        else:
            result = (waveform, audio_id)
            
        # Store in cache
        if self.use_cache:
            self.cache[audio_path] = result
        
        return result
    
    

In [None]:
# =============================================
# 4. MODEL ARCHITECTURES
# =============================================


class Wav2Vec2ForSEROptimized(nn.Module):
    def __init__(self, num_classes=6, pretrained_model="ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition", num_layers_to_unfreeze=3):
        super().__init__()
        
        from transformers import AutoModelForAudioClassification
        
        # Load the specialized SER model
        self.model = AutoModelForAudioClassification.from_pretrained(pretrained_model)
        
        # Get the base Wav2Vec2 model
        self.wav2vec = self.model.wav2vec2
        
        # Freeze all parameters first
        for param in self.wav2vec.parameters():
            param.requires_grad = False
            
        # Unfreeze the last N transformer layers for fine-tuning
        total_layers = len(self.wav2vec.encoder.layers)
        print(f"Model has {total_layers} transformer layers")
        print(f"Unfreezing the last {num_layers_to_unfreeze} transformer layers")
        
        for i in range(1, min(num_layers_to_unfreeze + 1, total_layers + 1)):
            layer_idx = total_layers - i
            print(f"Unfreezing layer {layer_idx}")
            for param in self.wav2vec.encoder.layers[layer_idx].parameters():
                param.requires_grad = True
        
        # Feature dimension from the model
        hidden_size = self.wav2vec.config.hidden_size
        
        # Custom attention pooling
        self.attention = nn.Sequential(
            nn.Linear(hidden_size, 128),
            nn.Tanh(),
            nn.Linear(128, 1)
        )
        
        # Replace the classifier with our own (adapted to our 6 emotion classes)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, 256),
            nn.LayerNorm(256),
            nn.GELU(),
            nn.Dropout(0.4),  # Increased dropout to 0.4 for better regularization
            nn.Linear(256, 128),
            nn.LayerNorm(128),
            nn.GELU(),
            nn.Dropout(0.4),  # Increased dropout to 0.4
            nn.Linear(128, num_classes)
        )
        
    def forward(self, x):
        x = x.squeeze(1)  # [batch_size, audio_length]
        
        # Extract features using the pretrained model
        outputs = self.wav2vec(x, output_hidden_states=True)
        hidden_states = outputs.last_hidden_state
        
        # Apply attention pooling
        attention_weights = F.softmax(self.attention(hidden_states), dim=1)
        attention_output = torch.sum(hidden_states * attention_weights, dim=1)
        
        # Classification with our custom head
        logits = self.classifier(attention_output)
        
        return logits

In [None]:
# =============================================
# 5. TRAINING UTILITIES
# =============================================


def train_mixed_precision(model, train_loader, val_loader, criterion, optimizer, 
                         num_epochs=3, device=device, scheduler=None):
    best_val_acc = 0.0
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
    
    # Enable mixed precision if available
    scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None
    
    print("Starting accelerated training...")
    
    for epoch in range(num_epochs):
        print(f"\nStarting Epoch {epoch+1}/{num_epochs}...")
        
        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        batch_count = 0
        
        for inputs, labels in tqdm(train_loader, desc=f"Training epoch {epoch+1}"):
            batch_count += 1
            
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Mixed precision forward pass
            if scaler:
                with torch.cuda.amp.autocast():
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                
                # Mixed precision backward pass
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                # Regular forward/backward pass
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
            
            # Track metrics
            train_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            train_correct += (preds == labels).sum().item()
            train_total += labels.size(0)
            
            # Print updates
            if batch_count % 50 == 0:
                print(f"  Batch {batch_count}/{len(train_loader)}: loss={loss.item():.4f}")
                
        # Calculate training metrics
        epoch_train_loss = train_loss / max(1, batch_count)
        epoch_train_acc = train_correct / max(1, train_total)
        
        print(f"\nEvaluating epoch {epoch+1}...")
        
        # Validation (simplified)
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        val_batch_count = 0
        
        with torch.no_grad():
            for inputs, labels in tqdm(val_loader, desc="Validation"):
                val_batch_count += 1
                
                inputs, labels = inputs.to(device), labels.to(device)
                
                # Mixed precision inference
                if scaler:
                    with torch.cuda.amp.autocast():
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)
                else:
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                
                # Track metrics
                val_loss += loss.item()
                _, preds = torch.max(outputs, 1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)
        
        # Calculate validation metrics
        epoch_val_loss = val_loss / max(1, val_batch_count)
        epoch_val_acc = val_correct / max(1, val_total)
        
        # Update history
        history['train_loss'].append(epoch_train_loss)
        history['val_loss'].append(epoch_val_loss)
        history['train_acc'].append(epoch_train_acc)
        history['val_acc'].append(epoch_val_acc)
        
        # Update learning rate
        if scheduler is not None:
            scheduler.step()
        
        # Print epoch summary
        print(f"Epoch {epoch+1}/{num_epochs} Summary:")
        print(f"  Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.4f}")
        print(f"  Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.4f}")
        
        # Save best model
        if epoch_val_acc > best_val_acc:
            best_val_acc = epoch_val_acc
            torch.save(model.state_dict(), 'checkpoints/best_wav2vec.pth')
            print(f"  ✓ New best model saved with accuracy: {best_val_acc:.4f}")
    
    # Load best model
    try:
        model.load_state_dict(torch.load('checkpoints/best_wav2vec.pth'))
        print("Loaded best model")
    except Exception as e:
        print(f"Error loading best model: {e}")
    
    return model, history


def predict(model, test_loader, device, emotion_mapping):
    model.eval()
    all_preds = []
    all_ids = []
    
    with torch.no_grad():
        for inputs, ids in tqdm(test_loader, desc="Predicting"):
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_ids.extend(ids)
    
    # Map predictions back to emotion labels
    reverse_mapping = {v: k for k, v in emotion_mapping.items()}
    pred_emotions = [reverse_mapping[pred] for pred in all_preds]
    
    # Create submission dataframe
    submission_df = pd.DataFrame({
        'id': all_ids,
        'label': pred_emotions
    })
    
    return submission_df


def plot_training_history(history, title='', save_path=None):
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train')
    plt.plot(history['val_loss'], label='Validation')
    plt.title(f'{title} Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(history['train_acc'], label='Train')
    plt.plot(history['val_acc'], label='Validation')
    plt.title(f'{title} Accuracy')
    plt.legend()
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path)
        print(f"Training history plot saved to {save_path}")
    
    plt.show()


In [None]:
# =============================================
# 6. TRAINING PIPELINES
# =============================================


def train_wav2vec_optimized(train_data, val_data, test_df, emotion_mapping):
    print("Creating optimized raw audio datasets...")
    raw_train_dataset = FastRawAudioEmotionDataset(
        train_data, 
        base_path="train/", 
        use_cache=True
    )
    
    raw_val_dataset = FastRawAudioEmotionDataset(
        val_data, 
        base_path="train/", 
        use_cache=True
    )
    
    raw_test_dataset = FastRawAudioEmotionDataset(
        test_df, 
        base_path="test/", 
        use_cache=True
    )
    
    # Create dataloaders
    raw_train_loader = DataLoader(
        raw_train_dataset, 
        batch_size=16,
        shuffle=True,
        num_workers=0,
        pin_memory=True
    )
    
    raw_val_loader = DataLoader(
        raw_val_dataset, 
        batch_size=16,
        shuffle=False,
        num_workers=0,
        pin_memory=True
    )
    
    raw_test_loader = DataLoader(
        raw_test_dataset, 
        batch_size=16,
        shuffle=False,
        num_workers=0,
        pin_memory=True
    )
    
    # Initialize model with 3 unfrozen layers
    print("Initializing optimized Wav2Vec2 model with multiple unfrozen layers...")
    model = Wav2Vec2ForSEROptimized(
        num_classes=len(emotion_mapping),
        pretrained_model="ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition",  
        num_layers_to_unfreeze=3  
    )
    model = model.to(device)
    
    # Print model summary
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Model has {total_params:,} parameters ({trainable_params:,} trainable)")
    
    # Loss function with label smoothing 
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    
    # Optimizer with slightly lower learning rate for more unfrozen layers
    optimizer = optim.AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=2.5e-5,  
        weight_decay=0.02
    )
    
    # Better scheduler for longer training
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer, 
        T_max=7,  
        eta_min=1e-6
    )
    
    # Train with mixed precision for 7 epochs
    print(f"\nStarting enhanced Wav2Vec2 training for 7 epochs...")
    model, history = train_mixed_precision(
        model=model,
        train_loader=raw_train_loader,
        val_loader=raw_val_loader,
        criterion=criterion,
        optimizer=optimizer,
        num_epochs=7,
        device=device,
        scheduler=scheduler
    )
    
    # Generate predictions
    print("\nGenerating predictions with optimized Wav2Vec2 model...")
    submission_df = predict(model, raw_test_loader, device, emotion_mapping)
    submission_df.to_csv('wav2vec_submission.csv', index=False)
    print("Wav2Vec2 predictions saved to wav2vec_submission.csv")
    
    # Plot training history
    plot_training_history(history, 'Wav2Vec2 Model', 'wav2vec_history.png')
    
    return model, history

In [None]:
#infer_emotion function

def load_model(model_path):
    model = Wav2Vec2ForSEROptimized(num_classes=6)
    
    # Load the saved weights
    model.load_state_dict(torch.load(model_path))
    model.eval()
    
    # Move model to appropriate device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    print(f"Loaded Wav2Vec2 model from {model_path}")
    return model, device

# Simplify infer_emotion function:
def infer_emotion(model, audio_paths, base_dir=""):
    reverse_mapping = {0: "marah", 1: "jijik", 2: "takut", 3: "bahagia", 4: "netral", 5: "sedih"}
    device = next(model.parameters()).device
    
    # Create a DataFrame with just the filenames
    file_ids = [os.path.basename(path) for path in audio_paths]
    
    # Create dataset for inference
    example_dataset = FastRawAudioEmotionDataset(
        pd.DataFrame({'id': file_ids}),
        base_path=base_dir
    )
    
    example_loader = DataLoader(
        example_dataset, 
        batch_size=1,
        shuffle=False,
        num_workers=0
    )
    
    results = []
    probabilities = []
    model.eval()
    with torch.no_grad():
        for inputs, _ in example_loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            probs = torch.nn.functional.softmax(outputs, dim=1)
            max_prob, preds = torch.max(probs, 1)
            emotion = reverse_mapping[preds.item()]
            results.append(emotion)
            probabilities.append(max_prob.item())
    
    return results, probabilities


In [None]:
# =============================================
# 7. MAIN EXECUTION
# =============================================

def main():
    """Run the SER pipeline with Wav2Vec2 model."""
    # Load and explore data
    print("Loading data...")
    train_df, test_df, train_data, val_data, emotion_mapping = load_data()
    
    # Analyze data distribution
    print("Analyzing data distribution...")
    analyze_data_distribution(train_data, val_data, test_df)
    
    # Train Wav2Vec2 model
    print("\n" + "="*50)
    print("TRAINING OPTIMIZED WAV2VEC2 MODEL")
    print("="*50)
    model, history = train_wav2vec_optimized(train_data, val_data, test_df, emotion_mapping)
    result_file = 'wav2vec_optimized_submission.csv'
    history_plot = 'wav2vec_optimized_history.png'
    model_name = 'Optimized Wav2Vec2'
    
    # Get best accuracy
    best_acc = max(history['val_acc'])
    print(f"\nBest {model_name} Accuracy: {best_acc:.4f}")
    
    # Plot training history
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train')
    plt.plot(history['val_loss'], label='Validation')
    plt.title(f'{model_name} Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(history['train_acc'], label='Train')
    plt.plot(history['val_acc'], label='Validation')
    plt.title(f'{model_name} Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig(history_plot)
    plt.show()
    
    print(f"Speech Emotion Recognition with {model_name} completed successfully!")
    print(f"Results saved to {result_file}")


In [None]:

if __name__ == "__main__":
    main()

In [None]:
# Example usage with better error handling
if __name__ == "__main__":
    # Path to saved model
    model_path = 'checkpoints/best_wav2vec.pth'
    
    try:
        # Import glob for file pattern matching
        from glob import glob
        
        # Load the model
        model, device = load_model(model_path)
        
        # ===== LOAD 5 TRAINING FILES =====
        train_files = []
        try:
            train_df = pd.read_csv('train.csv')
            if 'label' in train_df.columns:
                # Try to get diverse samples (one from each emotion)
                selected_files = []
                unique_emotions = train_df['label'].unique()
                for emotion in unique_emotions[:5]:  # Get up to 5 different emotions
                    sample = train_df[train_df['label'] == emotion].sample(1)
                    selected_files.append(f"{sample['id'].values[0]}")
                
                # If we didn't get 5 files, add more random ones
                if len(selected_files) < 5:
                    more_samples = train_df.sample(5 - len(selected_files))
                    for _, row in more_samples.iterrows():
                        selected_files.append(f"{row['id']}")
                
                # Make sure we have exactly 5 files
                train_files = selected_files[:5]
            else:
                # Just get random samples if no labels
                train_files = [f"{id}" for id in train_df.sample(5)['id'].values]
                
        except Exception as e:
            print(f"Error loading train files from CSV: {e}")
            try:
                train_files = glob('train/*.wav')[:5]
            except Exception:
                print("Could not find training files automatically.")
        
        # ===== LOAD 5 TEST FILES =====
        test_files = []
        try:
            test_df = pd.read_csv('test.csv')
            # Get 5 random samples
            test_files = [f"{id}" for id in test_df.sample(5)['id'].values]
        except Exception as e:
            print(f"Error loading test files from CSV: {e}")
            try:
                test_files = glob('test/*.wav')[:5]
            except Exception:
                print("Could not find test files automatically.")
        
        print(f"\nSelected Training Files:")
        for file in train_files:
            print(f"- {file}")
            
        print(f"\nSelected Test Files:")
        for file in test_files:
            print(f"- {file}")
        
        # Verify files exist and prepend directory
        valid_train_files = []
        for file in train_files:
            path = os.path.join("train", file)
            if os.path.exists(path):
                valid_train_files.append(file)
            else:
                print(f"Warning: Training file not found: {path}")
                
        valid_test_files = []
        for file in test_files:
            path = os.path.join("test", file)
            if os.path.exists(path):
                valid_test_files.append(file)
            else:
                print(f"Warning: Test file not found: {path}")
        
        if not valid_train_files and not valid_test_files:
            raise FileNotFoundError("No valid audio files found to process")
            
        # ===== INFERENCE ON TRAIN FILES =====
        train_emotions = []
        train_probs = []
        if valid_train_files:
            print("\nProcessing training files...")
            train_emotions, train_probs = infer_emotion(model, valid_train_files, "train")
            
            print("\nTraining File Inference Results:")
            for audio_file, emotion, prob in zip(valid_train_files, train_emotions, train_probs):
                print(f"File: {audio_file} -> Emotion: {emotion} (Confidence: {prob:.2f})")
        
        # ===== INFERENCE ON TEST FILES =====
        test_emotions = []
        test_probs = []
        if valid_test_files:
            print("\nProcessing test files...")
            test_emotions, test_probs = infer_emotion(model, valid_test_files, "test")
            
            print("\nTest File Inference Results:")
            for audio_file, emotion, prob in zip(valid_test_files, test_emotions, test_probs):
                print(f"File: {audio_file} -> Emotion: {emotion} (Confidence: {prob:.2f})")
        
        # ===== VISUALIZE WAVEFORMS =====
        # Create a figure with enough subplots for all files
        plt.figure(figsize=(15, 10))
        
        # Plot training files first
        for i, audio_file in enumerate(valid_train_files):
            try:
                audio_path = os.path.join("train", audio_file)
                waveform, sample_rate = torchaudio.load(audio_path)
                plt.subplot(len(valid_train_files) + len(valid_test_files), 1, i+1)
                plt.plot(waveform[0].numpy())
                plt.title(f"TRAIN: {audio_file} - {train_emotions[i]} (Conf: {train_probs[i]:.2f})")
                plt.ylim([-1, 1])  # Standardize y-axis
            except Exception as e:
                plt.subplot(len(valid_train_files) + len(valid_test_files), 1, i+1)
                plt.text(0.5, 0.5, f"Error loading audio: {str(e)}", 
                        horizontalalignment='center', verticalalignment='center')
                plt.axis('off')
        
        # Then plot test files
        offset = len(valid_train_files)
        for i, audio_file in enumerate(valid_test_files):
            try:
                audio_path = os.path.join("test", audio_file)
                waveform, sample_rate = torchaudio.load(audio_path)
                plt.subplot(len(valid_train_files) + len(valid_test_files), 1, offset+i+1)
                plt.plot(waveform[0].numpy())
                plt.title(f"TEST: {audio_file} - {test_emotions[i]} (Conf: {test_probs[i]:.2f})")
                plt.ylim([-1, 1])  # Standardize y-axis
            except Exception as e:
                plt.subplot(len(valid_train_files) + len(valid_test_files), 1, offset+i+1)
                plt.text(0.5, 0.5, f"Error loading audio: {str(e)}", 
                        horizontalalignment='center', verticalalignment='center')
                plt.axis('off')
        
        plt.tight_layout()
        plt.show()
        
    except Exception as e:
        print(f"Error during inference: {e}")
        import traceback
        traceback.print_exc()