# Simple Real-Time Mridangam Inference

A streamlined real-time inference pipeline for mridangam stroke detection.

## Features:
- Real-time audio capture
- Simple onset detection
- CNN model inference using modelOutput.pth
- Streaming results with yield
- Console output of detected strokes

In [11]:
# CUDA Setup and Package Installation

# First, check if CUDA is available
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU name: {torch.cuda.get_device_name()}")
else:
    print("⚠️  CUDA not available. Will use CPU.")
    print("To enable CUDA, install PyTorch with CUDA support:")
    print("")
    print("For CUDA 11.8:")
    print("!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118")
    print("")
    print("For CUDA 12.1:")
    print("!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121")
    print("")
    print("Then restart the kernel and run this cell again.")

# Install other required packages if needed
# Uncomment the lines below if you need to install packages:
# !pip install librosa pyaudio numpy matplotlib

CUDA available: False
⚠️  CUDA not available. Will use CPU.
To enable CUDA, install PyTorch with CUDA support:

For CUDA 11.8:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

For CUDA 12.1:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

Then restart the kernel and run this cell again.


In [12]:
import torch
import torch.nn as nn
import librosa
import numpy as np
import pyaudio
import threading
import queue
import time
from typing import Optional, Tuple
import warnings

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

print("✅ Libraries imported successfully!")
print(f"🔧 PyTorch version: {torch.__version__}")
print(f"🎵 Librosa version: {librosa.__version__}")
print(f"🎤 CUDA available: {torch.cuda.is_available()}")

✅ Libraries imported successfully!
🔧 PyTorch version: 2.7.1+cpu
🎵 Librosa version: 0.11.0
🎤 CUDA available: False


In [13]:
# CNN Model Architecture (same as training)
class MridangamCNN(nn.Module):
    def __init__(self, n_mels = 128, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(

            # input shape: (1, 128, 128)
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3), padding=1, stride=1),
            # output shape: (32, 128, 128)

            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(2, 2), stride=2, padding=0),
            nn.Dropout(p=0.3),
            # output shape: (32, 64, 64)

            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3), padding=1, stride=1),
            # output shape: (64, 64, 64)

            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(2, 2), stride=2, padding=0),
            nn.Dropout(p=0.3),
            # output shape: (64, 32, 32)

            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), padding=1, stride=1),
            # output shape: (64, 32, 32)

            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(2, 2), stride=2, padding=0),
            nn.Dropout(p=0.4),
            # output shape: (64, 16, 16)
            nn.AdaptiveAvgPool2d((1, None)),  # Then reduce spatial

        )

        self.classifier = nn.Sequential(
            nn.Conv1d(in_channels=64, out_channels=64, kernel_size=3, padding=1),  # output shape: (64, 1, 16)
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool1d(1),  # output shape: (64, 1, 1)
            nn.Dropout(p=0.4),
            nn.Flatten(),  # output shape: (64)
            nn.Linear(64, num_classes),  # output shape: (num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.squeeze(2)
        x = self.classifier(x)
        return x


In [26]:
# Load the trained model
def load_model(model_path: str, device: torch.device) -> MridangamCNN:
    """Load the trained model from modelOutput.pth"""
    try:
        # Create model instance
        model = MridangamCNN(n_mels=128, num_classes=10)
        
        # Load weights
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.to(device)
        model.eval()
        
        print(f"✅ Model loaded successfully from {model_path}")
        print(f"🔧 Using device: {device}")
        
        return model
        
    except Exception as e:
        print(f"❌ Error loading model: {e}")
        raise

# Setup device and load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_path = r"C:\Users\aniru\Desktop\MridangamTranscription\inference\model_weights\mridangam_model_final.pth"  # Adjust path as needed

# Load model
model = load_model(model_path, device)

# Default class names (update these based on your training data)
class_names = ['bheem','cha', 'dheem', 'dhin','num', 'ta', 'tha', 'tham', 'thi', 'thom']
print(f"📊 Classes: {', '.join(class_names)}")

✅ Model loaded successfully from C:\Users\aniru\Desktop\MridangamTranscription\inference\model_weights\mridangam_model_final.pth
🔧 Using device: cpu
📊 Classes: bheem, cha, dheem, dhin, num, ta, tha, tham, thi, thom


In [27]:
# Audio processing functions
def detect_onset(audio_buffer: np.ndarray, sr: float) -> Optional[float]:
    """Enhanced onset detection for clear mridangam strokes only"""
    if len(audio_buffer) < sr * 0.15:  # Need at least 150ms of audio
        return None
        
    try:
        # Use spectral flux for onset detection with stricter parameters
        onset_envelope = librosa.onset.onset_strength(
            y=audio_buffer, 
            sr=sr,
            aggregate=np.median,  # Use median for more stable detection
            fmax=8000,  # Focus on relevant frequency range for mridangam
            n_mels=128
        )
        
        # Detect onsets with stricter thresholds
        onsets = librosa.onset.onset_detect(
            onset_envelope=onset_envelope,
            sr=sr, 
            units='time',
            pre_max=5,   # Increased for clearer peaks
            post_max=5, 
            pre_avg=5, 
            post_avg=5,
            delta=0.3,   # Higher threshold for clearer onsets
            wait=10      # Longer wait between onsets
        )
        
        if len(onsets) > 0:
            # Get the most recent onset
            latest_onset = onsets[-1]
            
            # Validate onset strength - only return if it's a strong onset
            onset_idx = int(latest_onset * sr / 512)  # Convert to frame index
            if onset_idx < len(onset_envelope):
                onset_strength = onset_envelope[onset_idx]
                
                # Only return onset if it's significantly stronger than average
                avg_strength = np.mean(onset_envelope)
                if onset_strength > avg_strength * 2.0:  # Must be 2x stronger than average
                    return latest_onset
            
    except Exception as e:
        print(f"Onset detection error: {e}")
        
    return None


def extract_mel_spectrogram(audio: np.ndarray, sr: float, n_mels=128, target_length=128) -> np.ndarray:
    """Extract mel spectrogram from audio segment"""
    try:
        mel_spec = librosa.feature.melspectrogram(
            y=audio,
            sr=sr,
            n_mels=n_mels,
            hop_length=512,
            win_length=1024,
            n_fft=2048
        )
        
        # Convert to dB
        mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
        
        # Pad or truncate to target length
        if mel_spec_db.shape[1] < target_length:
            pad_width = target_length - mel_spec_db.shape[1]
            mel_spec_db = np.pad(mel_spec_db, ((0, 0), (0, pad_width)), mode='constant')
        else:
            mel_spec_db = mel_spec_db[:, :target_length]
            
        return mel_spec_db
        
    except Exception as e:
        print(f"Mel spectrogram extraction error: {e}")
        return None


def validate_audio_segment(audio_segment: np.ndarray, sr: float) -> bool:
    """Validate that audio segment contains sufficient energy for prediction"""
    if len(audio_segment) < sr * 0.05:  # At least 50ms
        return False
        
    # Check RMS energy
    rms = np.sqrt(np.mean(audio_segment**2))
    if rms < 0.01:  # Too quiet
        return False
        
    # Check for dynamic range
    peak = np.max(np.abs(audio_segment))
    if peak < 0.05:  # Too low amplitude
        return False
        
    return True


def predict_stroke(audio_segment: np.ndarray, sample_rate: int) -> Tuple[str, float]:
    """Predict stroke from audio segment with validation"""
    try:
        # First validate the audio segment
        if not validate_audio_segment(audio_segment, sample_rate):
            return "Weak_Signal", 0.0
        
        # Extract mel spectrogram
        mel_spec = extract_mel_spectrogram(audio_segment, sample_rate)
        if mel_spec is None:
            return "Unknown", 0.0
        
        # Prepare input tensor
        input_tensor = torch.FloatTensor(mel_spec).unsqueeze(0).unsqueeze(0)  # (1, 1, n_mels, time)
        input_tensor = input_tensor.to(device)
        
        # Model inference
        with torch.no_grad():
            outputs = model(input_tensor)
            probabilities = torch.softmax(outputs, dim=1)
            predicted_class = torch.argmax(probabilities, dim=1).item()
            confidence = probabilities[0, predicted_class].item()
        
        predicted_stroke = class_names[predicted_class]
        return predicted_stroke, confidence
        
    except Exception as e:
        print(f"Prediction error: {e}")
        return "Error", 0.0

print("✅ Enhanced audio processing functions defined!")

✅ Enhanced audio processing functions defined!


In [28]:
# Simple audio listener class
class SimpleAudioListener:
    """Simple real-time audio capture and processing"""
    
    def __init__(self, sample_rate=22050, chunk_size=1024, channels=1):
        self.sample_rate = sample_rate
        self.chunk_size = chunk_size
        self.channels = channels
        self.audio_queue = queue.Queue()
        self.is_recording = False
        
        # Initialize PyAudio
        self.p = pyaudio.PyAudio()
        
    def start_recording(self):
        """Start audio recording in a separate thread"""
        self.is_recording = True
        
        stream = self.p.open(
            format=pyaudio.paInt16,
            channels=self.channels,
            rate=self.sample_rate,
            input=True,
            frames_per_buffer=self.chunk_size
        )
        
        def record_audio():
            while self.is_recording:
                try:
                    data = stream.read(self.chunk_size, exception_on_overflow=False)
                    audio_chunk = np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0
                    self.audio_queue.put(audio_chunk)
                except Exception as e:
                    print(f"Audio recording error: {e}")
                    break
        
        self.recording_thread = threading.Thread(target=record_audio)
        self.recording_thread.daemon = True
        self.recording_thread.start()
        
        print(f"🎤 Started recording at {self.sample_rate}Hz...")
        
    def stop_recording(self):
        """Stop audio recording"""
        self.is_recording = False
        if hasattr(self, 'recording_thread'):
            self.recording_thread.join()
        self.p.terminate()
        print("🛑 Stopped recording")
        
    def get_audio_chunk(self, timeout=0.1):
        """Get audio chunk from queue"""
        try:
            return self.audio_queue.get(timeout=timeout)
        except queue.Empty:
            return None

print("✅ Audio listener class defined!")

✅ Audio listener class defined!


In [29]:
# Main real-time detection function
def run_realtime_detection(duration_seconds=30):
    """Run real-time stroke detection with clear onset validation"""
    
    # Audio settings
    sample_rate = 22050
    buffer_duration = 3.0  # Increased buffer for better onset detection
    buffer_size = int(sample_rate * buffer_duration)
    audio_buffer = np.zeros(buffer_size)
    
    # Initialize audio listener
    audio_listener = SimpleAudioListener(sample_rate=sample_rate)
    
    # Detection parameters - stricter settings
    last_detection_time = 0
    detection_cooldown = 0.3  # Longer cooldown to avoid rapid false detections
    confidence_threshold = 0.6  # Higher confidence threshold
    
    print(f"🎤 Starting real-time detection for {duration_seconds} seconds...")
    print("🎵 Listening for CLEAR mridangam strokes only...")
    print("💡 Tip: Play mridangam strokes clearly and distinctly")
    print("=" * 60)
    
    audio_listener.start_recording()
    start_time = time.time()
    detection_count = 0
    
    try:
        while (time.time() - start_time) < duration_seconds:
            # Get new audio chunk
            chunk = audio_listener.get_audio_chunk(timeout=0.05)
            if chunk is not None:
                # Update circular buffer
                audio_buffer = np.roll(audio_buffer, -len(chunk))
                audio_buffer[-len(chunk):] = chunk
            
            # Check for onset in recent audio
            current_time = time.time()
            if current_time - last_detection_time > detection_cooldown:
                
                # Look for onset in recent audio (last 1 second for better analysis)
                recent_audio = audio_buffer[-int(1.0 * sample_rate):]
                onset_time = detect_onset(recent_audio, sample_rate)
                
                if onset_time is not None:
                    # Extract segment around onset with better timing
                    pre_onset = int(0.05 * sample_rate)   # 50ms before
                    post_onset = int(0.2 * sample_rate)   # 200ms after (increased)
                    
                    onset_sample = int(onset_time * sample_rate)
                    start_idx = max(0, len(recent_audio) - onset_sample - pre_onset)
                    end_idx = min(len(recent_audio), len(recent_audio) - onset_sample + post_onset)
                    
                    stroke_segment = recent_audio[start_idx:end_idx]
                    
                    if len(stroke_segment) > 0:
                        # Predict stroke
                        stroke_name, confidence = predict_stroke(stroke_segment, sample_rate)
                        
                        # Only report high-confidence predictions from clear onsets
                        if confidence > confidence_threshold and stroke_name not in ['Weak_Signal', 'Unknown', 'Error']:
                            timestamp = current_time
                            last_detection_time = current_time
                            detection_count += 1
                            
                            # Format output
                            time_str = time.strftime("%H:%M:%S", time.localtime(timestamp))
                            elapsed = timestamp - start_time
                            
                            # Color coding based on confidence
                            confidence_bar = "🟢" if confidence > 0.8 else "🟡" if confidence > 0.7 else "🔴"
                            
                            print(f"{time_str} | +{elapsed:5.1f}s | {confidence_bar} {stroke_name:>8} | {confidence:.1%} | #{detection_count}")
                        elif stroke_name == 'Weak_Signal':
                            # Optionally show weak signals for debugging
                            # print(f"  💨 Weak signal detected (not classified)")
                            pass
            
            time.sleep(0.01)  # Small delay to prevent busy waiting
    
    except KeyboardInterrupt:
        print("\n🛑 Detection stopped by user")
    finally:
        audio_listener.stop_recording()
        print(f"✅ Detection completed - Found {detection_count} clear strokes")

print("✅ Enhanced real-time detection function ready!")

✅ Enhanced real-time detection function ready!


In [32]:
# Run the real-time detection
# Adjust duration as needed (in seconds)
run_realtime_detection(duration_seconds=30)

🎤 Starting real-time detection for 30 seconds...
🎵 Listening for CLEAR mridangam strokes only...
💡 Tip: Play mridangam strokes clearly and distinctly
🎤 Started recording at 22050Hz...
11:13:47 | +  2.6s | 🟢       ta | 99.4% | #1
11:13:48 | +  3.8s | 🟢    dheem | 100.0% | #2
11:13:49 | +  5.1s | 🟢    dheem | 100.0% | #3
11:13:51 | +  6.4s | 🟢    dheem | 100.0% | #4
11:13:52 | +  7.7s | 🟢       ta | 98.1% | #5
11:13:53 | +  9.0s | 🟢    dheem | 100.0% | #6
11:13:55 | + 10.3s | 🟢    dheem | 98.7% | #7
11:13:56 | + 11.9s | 🟢       ta | 98.7% | #8
11:13:56 | + 12.2s | 🟢       ta | 93.5% | #9
11:13:57 | + 12.9s | 🟢      tha | 95.9% | #10
11:13:58 | + 13.3s | 🟢    dheem | 100.0% | #11
11:13:58 | + 13.6s | 🟢       ta | 99.7% | #12
11:13:58 | + 14.3s | 🟡     thom | 73.5% | #13
11:13:59 | + 14.6s | 🔴       ta | 62.6% | #14
11:13:59 | + 15.0s | 🟢       ta | 93.2% | #15
11:14:00 | + 16.0s | 🟢    dheem | 81.4% | #16
11:14:01 | + 16.9s | 🟢    dheem | 100.0% | #17
11:14:01 | + 17.2s | 🟢     thom | 100

In [None]:


# Enhanced Audio File Inference
def analyze_audio_file(audio_file_path, show_details=True, confidence_threshold=0.6):
    """
    Analyze an audio file for mridangam strokes with enhanced onset detection.
    Only predicts when clear onsets are detected.
    """
    try:
        # Load audio file
        audio, sr = librosa.load(audio_file_path, sr=22050)
        print(f"📁 Loaded: {audio_file_path}")
        print(f"🎵 Duration: {len(audio)/sr:.2f} seconds")
        print(f"🔊 Sample rate: {sr} Hz")
        print("=" * 60)
        
        # Enhanced onset detection (same as real-time)
        onset_envelope = librosa.onset.onset_strength(
            y=audio, 
            sr=sr,
            aggregate=np.median,  # Use median for more stable detection
            fmax=8000,  # Focus on relevant frequency range for mridangam
            n_mels=128
        )
        
        # Detect onsets with stricter thresholds
        onsets = librosa.onset.onset_detect(
            onset_envelope=onset_envelope,
            sr=sr, 
            units='time',
            pre_max=5,   # Increased for clearer peaks
            post_max=5, 
            pre_avg=5, 
            post_avg=5,
            delta=0.3,   # Higher threshold for clearer onsets
            wait=10      # Longer wait between onsets
        )
        
        print(f"🎯 Found {len(onsets)} potential onsets")
        
        valid_predictions = []
        
        # Process each onset with validation
        for i, onset_time in enumerate(onsets):
            # Validate onset strength
            onset_idx = int(onset_time * sr / 512)  # Convert to frame index
            if onset_idx < len(onset_envelope):
                onset_strength = onset_envelope[onset_idx]
                avg_strength = np.mean(onset_envelope)
                
                # Only process if it's a strong onset
                if onset_strength > avg_strength * 2.0:
                    # Extract segment around onset
                    pre_onset = int(0.05 * sr)   # 50ms before
                    post_onset = int(0.2 * sr)   # 200ms after
                    
                    onset_sample = int(onset_time * sr)
                    start_idx = max(0, onset_sample - pre_onset)
                    end_idx = min(len(audio), onset_sample + post_onset)
                    
                    segment = audio[start_idx:end_idx]
                    
                    if len(segment) > 0:
                        stroke_name, confidence = predict_stroke(segment, sr)
                        
                        # Only include high-confidence predictions from validated onsets
                        if (confidence > confidence_threshold and 
                            stroke_name not in ['Weak_Signal', 'Unknown', 'Error']):
                            
                            valid_predictions.append({
                                'time': onset_time,
                                'stroke': stroke_name,
                                'confidence': confidence,
                                'onset_strength': onset_strength / avg_strength
                            })
                            
                            if show_details:
                                # Color coding based on confidence
                                confidence_bar = "🟢" if confidence > 0.8 else "🟡" if confidence > 0.7 else "🔴"
                                strength_indicator = "⚡" if onset_strength > avg_strength * 3.0 else "🔥"
                                
                                print(f"  {i+1:2d}. {onset_time:6.2f}s | {confidence_bar} {stroke_name:>8} | {confidence:.1%} | {strength_indicator} {onset_strength/avg_strength:.1f}x")
        
        # Summary
        print("\n" + "=" * 60)
        print("📊 ANALYSIS SUMMARY")
        print("=" * 60)
        print(f"Total onsets detected: {len(onsets)}")
        print(f"Valid stroke predictions: {len(valid_predictions)}")
        print(f"Confidence threshold: {confidence_threshold:.1%}")
        
        if valid_predictions:
            # Count strokes by type
            stroke_counts = {}
            total_confidence = 0
            
            for pred in valid_predictions:
                stroke = pred['stroke']
                if stroke in stroke_counts:
                    stroke_counts[stroke] += 1
                else:
                    stroke_counts[stroke] = 1
                total_confidence += pred['confidence']
            
            avg_confidence = total_confidence / len(valid_predictions)
            
            print(f"Average confidence: {avg_confidence:.1%}")
            print("\nDetected strokes:")
            for stroke, count in sorted(stroke_counts.items()):
                print(f"  {stroke}: {count} times")
            
            # Timeline of strokes
            print("\nStroke timeline:")
            timeline = ' → '.join([f"{p['stroke']}({p['time']:.1f}s)" for p in valid_predictions[:10]])
            if len(valid_predictions) > 10:
                timeline += f" ... (+{len(valid_predictions)-10} more)"
            print(f"  {timeline}")
        else:
            print("❌ No clear mridangam strokes detected above threshold")
            print("💡 Try:")
            print("   - Lowering confidence_threshold (e.g., 0.4)")
            print("   - Using audio with clearer mridangam strokes")
            print("   - Checking if the audio contains mridangam sounds")
        
        return valid_predictions
        
    except Exception as e:
        print(f"❌ Error analyzing file: {e}")
        return []

print("✅ Enhanced audio file analysis function ready!")

In [None]:
# Analyze your audio file
# Replace the path below with your audio file path

# Example usage:
audio_file_path = r"C:\path\to\your\audio\file.wav"  # Change this path

# Uncomment and modify the path to analyze your audio file:
# results = analyze_audio_file(audio_file_path, show_details=True, confidence_threshold=0.6)

# You can also adjust the confidence threshold:
# results = analyze_audio_file(audio_file_path, show_details=True, confidence_threshold=0.4)  # Lower threshold
# results = analyze_audio_file(audio_file_path, show_details=True, confidence_threshold=0.8)  # Higher threshold

print("💡 Instructions:")
print("1. Update the 'audio_file_path' variable with your audio file path")
print("2. Uncomment one of the 'analyze_audio_file' calls")
print("3. Run this cell to analyze your audio file")
print("")
print("Supported formats: .wav, .mp3, .flac, .m4a, .ogg")
print("Recommended: .wav files with clear mridangam recordings")