In [None]:
# Install Dependencies

# ffmpeg used under the hood for speed advantage:
!apt-get update && apt-get install -y ffmpeg # -y libsndfile1
# main audio processing packages:
!pip install torchaudio librosa soundfile pydub
# api setup and storage:
!pip install flask flask-cors pyngrok firebase-admin

In [None]:
# Cell 2: Imports
from flask import Flask, request, jsonify, send_file
from flask_cors import CORS
from pyngrok import ngrok
import firebase_admin
from firebase_admin import credentials, storage

from google.colab import userdata
import tempfile
import json
import os
import io
import logging
import numpy as np
import math
import time
from pathlib import Path
import IPython.display as idp
from IPython.core.magic import register_cell_magic
import mimetypes
import matplotlib.pyplot as plt

import soundfile as sf
import pydub
import torchaudio
import torch
import torchaudio.transforms as T
import torchaudio.functional as F
import torch.nn.functional as FNN
import librosa

# For skipping cells
@register_cell_magic
def skip(line, cell):
    return

# Check if running on GPU
if torch.cuda.is_available():
  print(f"GPU available: {torch.cuda.get_device_name(torch.cuda.current_device())}")

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

In [None]:
# Cell 3: Firebase Setup
fire_cred = userdata.get('FB_CRED')

# Only initialize if no apps exist
if not firebase_admin._apps:
    # Write the credentials to a temporary file
    with tempfile.NamedTemporaryFile(delete=False, suffix='.json') as temp_file:
        temp_file.write(fire_cred.encode('utf-8'))
        temp_file_path = temp_file.name

    # Initialize Firebase with the temporary credential file
    cred = credentials.Certificate(temp_file_path)
    firebase_admin.initialize_app(cred, {
        'storageBucket': 'sample-prep-dbd20.firebasestorage.app'
    })

    os.unlink(temp_file_path)

bucket = storage.bucket()

In [None]:
# Utility functions

def get_pitch_factor(original_pitch, target_pitch):
    """
    Calculate the factor needed to transpose from original pitch to target pitch.
    Returns a positive factor where:
    - factor > 1 means transpose up
    - factor < 1 means transpose down
    """
    if original_pitch <= 0:
        raise ValueError("Original pitch must be positive")
    return target_pitch / original_pitch

#Get the mime type of the audio
def get_mime_type(path):
    mime_type = mimetypes.guess_type(path)
    print(f"{path}: Mime type: {mime_type}")
    return mime_type

In [None]:
# Process 1: Normalize
def normalize(waveform):
    """Peak normalize audio to range [-1, 1]"""
    input_peak = torch.abs(waveform).max()
    print(f"[NORMALIZE] Input peak amplitude: {input_peak}")

    if input_peak > 0:
        normalized_waveform = waveform / input_peak
        output_peak = torch.abs(normalized_waveform).max()
        print(f"[NORMALIZE] Output peak amplitude: {output_peak}")
        return normalized_waveform

    return waveform

In [None]:
# Process 2: Trim Silence

def trim_silence(waveform, threshold_db=-50.0, min_length_ms=50):
    """Trim silence from start and end of audio using RMS energy threshold"""
    # Convert to mono if stereo
    if waveform.size(0) > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)

    # Calculate RMS energy in decibels using Torchs NN module
    frame_length = int(min_length_ms * waveform.size(1) / 1000)
    rms = torch.sqrt(FNN.avg_pool1d(waveform ** 2, kernel_size=frame_length, stride=1))
    db = 20 * torch.log10(rms + 1e-8)

    # Find start and end points above threshold
    mask = (db > threshold_db)[0]
    nonzero = torch.nonzero(mask)

    if len(nonzero) == 0:
        return waveform

    start, end = nonzero[0], nonzero[-1]
    return waveform[:, start:end+1]


In [None]:
# Process 3: Pitch Detection: currently using librosa PYIN algorithm

# needed for subprocess ?
import subprocess

def get_main_pitch(audio_path, min_note='C1', max_note='C7'):
    """
    Get the main pitch of an audio file with robust format handling.

    Parameters:
    -----------
    audio_path : str or Path
        Path to the audio file
    min_note : str
        Lowest note to consider (in scientific pitch notation)
    max_note : str
        Highest note to consider (in scientific pitch notation)

    Returns:
    --------
    tuple
        (frequency in Hz, note name, confidence score)
    """
    audio_path = Path(audio_path)

    # Get file extension and mime type
    ext = audio_path.suffix.lower()

    try:
        # Create a temporary directory for processing
        with tempfile.TemporaryDirectory() as temp_dir:
            temp_dir = Path(temp_dir)
            temp_wav = temp_dir / 'temp.wav'

            # Convert to WAV if needed
            if ext == '.webm':
                logger.info(f"Converting WEBM to WAV: {audio_path}")
                subprocess.run([
                    'ffmpeg', '-i', str(audio_path),
                    '-acodec', 'pcm_s16le',
                    '-ar', '44100',
                    str(temp_wav)
                ], check=True, capture_output=True)
                processing_path = temp_wav
            else:
                processing_path = audio_path

            # Load audio with explicit parameters
            y, sr = librosa.load(
                processing_path,
                sr=None,  # Keep original sample rate
                mono=True  # Convert to mono if needed
            )

            logger.info(f"Loaded audio: sample rate={sr}, duration={len(y)/sr:.2f}s")

            # Normalize audio if needed
            if np.abs(y).max() > 1.0:
                y = y / np.abs(y).max()

            # Calculate pitch using PYIN algorithm
            f0, voiced_flag, voiced_probs = librosa.pyin(
                y,
                fmin=librosa.note_to_hz(min_note),
                fmax=librosa.note_to_hz(max_note),
                sr=sr,
                frame_length=2048,  # Increase for better low frequency resolution
                win_length=1024,
                hop_length=512
            )

            # Filter out unvoiced and low probability segments
            mask = voiced_flag & (voiced_probs > 0.6)
            f0_valid = f0[mask]

            if len(f0_valid) == 0:
                logger.warning("No valid pitch detected")
                return None, None, 0.0

            # Get the median frequency
            median_f0 = float(np.median(f0_valid))

            # Convert median frequency to note
            closest_note = librosa.hz_to_note(median_f0)
            note_freq = librosa.note_to_hz(closest_note)
            note = {'closest_note': closest_note, 'freq': note_freq}

            # Calculate confidence score
            #voiced_probs = voiced_probs[mask]
            confidence = float(np.mean(voiced_probs[voiced_flag]))

            # Log results
            print_pitch_details(audio_path, median_f0, note,  confidence)

            return median_f0, note, confidence

    except subprocess.CalledProcessError as e:
        logger.error(f"FFmpeg conversion failed: {e.stderr.decode()}")
        raise RuntimeError("Audio conversion failed") from e
    except Exception as e:
        logger.error(f"Pitch detection failed: {str(e)}")
        raise

# Utility for logging and debugging pitch detection
def print_pitch_details(audio_path, freq, note, conf):
    """
    Test the pitch detection on an audio file and print detailed results.
    """
    try:
        if freq is not None:
            print(f"\nPitch detection results for {audio_path}:")
            print(f"Main frequency: {freq:.4f} Hz")
            print(f"Closest musical note is: {note['closest_note']}, with frequency ({note['freq']:.4f} Hz)")
            print(f"Confidence: {conf:.3f}")
        else:
            print(f"\nNo valid pitch detected in {audio_path}")
    except Exception as e:
        print(f"\nError processing {audio_path}: {str(e)}")



In [None]:
# Process 4: Transposition

def transpose(waveform: torch.Tensor,
             sample_rate: int,
             factor: float) -> torch.Tensor:
    """
    Resamples audio it by a given factor.

    Args:
        waveform (torch.Tensor): Input audio waveform tensor
        sample_rate (int): Original sample rate of the audio
        factor (float): Factor to transpose by:
        (e.g., 0.5 for one octave down, 2.0 for one octave up)

    Returns:
        torch.Tensor: Transposed waveform
        Should be played back at original sample rate for transposing effect

    Raises:
        ValueError: If factor is <= 0
        TypeError: If inputs are of incorrect type
    """
    # Input validation
    if not isinstance(factor, (int, float)):
        raise TypeError("Factor must be a number")
    if factor <= 0:
        raise ValueError("Factor must be positive")
    if not isinstance(sample_rate, int):
        raise TypeError("Sample rate must be an integer")
    if not torch.is_tensor(waveform):
        raise TypeError("Waveform must be a torch.Tensor")

    # Calculate new sample rate
    resample_rate = int(sample_rate * factor)

    # Create resampler
    resampler = T.Resample(
        orig_freq=sample_rate,
        new_freq=resample_rate,
        dtype=waveform.dtype
    )

    # Perform resampling
    resampled_waveform = resampler(waveform)

    return resampled_waveform


In [None]:
# Audio Processing Pipeline entry point

def process_audio_file(file_path, options, sample_rate=None):
    """
    Process audio file according to specified options.

    Args:
        file_path (Path): Path to input audio file
        options (dict): Processing options including:
            - normalize (bool): Whether to normalize audio
            - trim (bool): Whether to trim silence
            - tune (bool): Whether to apply pitch tuning
            - target_pitch (float): Target pitch frequency if tuning
        sample_rate (int, optional): Target sample rate for processing

    Returns:
        tuple: (processed_waveform, sample_rate)
    """
    # Load audio
    waveform, sr = torchaudio.load(file_path)
    waveform = waveform.to(device)

    # Resample if needed
    if sample_rate and sample_rate != sr:
        resampler = torchaudio.transforms.Resample(sr, sample_rate)
        waveform = resampler(waveform)
        sr = sample_rate

    # Apply processing based on options
    if options['normalize']:
        waveform = normalize(waveform)

    if options['trim']:
        waveform = trim_silence(waveform)

    if options['tune']:
        detected_pitch, note, confidence = get_main_pitch(file_path)
        C4_FREQ = 261.6255 # default
        target = options.get('target_pitch', C4_FREQ)  # C4 frequency
        print(target)
        factor = get_pitch_factor(detected_pitch, target)
        print(factor)
        waveform = transpose(waveform, sr, factor=factor)

    return waveform, sr

In [None]:
# API Setup

app = Flask(__name__)
CORS(app)

# Create fixed directories in Colab's content directory
BASE_DIR = Path('/content')  # Colab's base content directory
UPLOAD_DIR = BASE_DIR / 'uploads'
OUTPUT_DIR = BASE_DIR / 'outputs'

# Create directories
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"Upload directory created at: {UPLOAD_DIR}")
print(f"Output directory created at: {OUTPUT_DIR}")


# Setup logger
logger = logging.getLogger(__name__) # ('werkzeug')
logger.setLevel(logging.ERROR)
logger.info(f"Upload directory: {UPLOAD_DIR}")

# Cell 7: Process Endpoint
@app.route('/process', methods=['POST'])
def process_audio():
    if 'file' not in request.files:
        return jsonify({'error': 'No file provided'}), 400

    # Get and validate options
    options = {
        'normalize': True,
        'trim': True,
        'tune': True,
        'saveToFirebase': False,
        'returnType': 'blob'
    }

    if 'options' in request.form:
        try:
            user_options = json.loads(request.form['options'])
            options.update(user_options)
        except json.JSONDecodeError:
            return jsonify({'error': 'Invalid options format'}), 400

    file = request.files['file']

    # Absolute paths with sanitized filenames
    safe_filename = os.path.basename(file.filename)
    input_path = UPLOAD_DIR / safe_filename
    output_path = OUTPUT_DIR / f"processed_{safe_filename}"

    # Save uploaded file
    file.save(str(input_path))

    try:

        # Process audio through pipeline
        waveform, sample_rate = process_audio_file(str(input_path), options)

        # Save processed audio
        output_path = UPLOAD_DIR / f"processed_{file.filename}"
        outpath_str = str(output_path)
        torchaudio.save(outpath_str, waveform, sample_rate, format="wav") # wav for testing

        # Handle Firebase storage if requested
        if options['saveToFirebase']:
            blob = bucket.blob(f'processed_{file.filename}')
            blob.upload_from_filename(outpath_str)

            if options['returnType'] == 'url':
                download_url = blob.generate_signed_url(expiration=300)
                return jsonify({'download_url': download_url})

        # Return audio file directly if no Firebase or URL not requested
        if options['returnType'] == 'blob' or not options['saveToFirebase']:
            return send_file(output_path)

        return send_file(output_path)

    except Exception as e:
        logger.error(f"Processing failed: {str(e)}")
        return jsonify({'error': str(e)}), 500
    finally:
        #temp_path.unlink(missing_ok=True)
        Path(input_path).unlink(missing_ok=True)
        Path(output_path).unlink(missing_ok=True)



In [None]:
# Cell 8: Run the Server

ngrok.set_auth_token(userdata.get('NGROK_SECRET'))
public_url = ngrok.connect(5000, domain="singular-roughy-humane.ngrok-free.app")
print(f' * Public URL: {public_url}')

# Setup logger
logger = logging.getLogger(__name__) # ('werkzeug')
logger.setLevel(logging.ERROR)

# Start server
app.run(port=5000)

In [None]:
%%skip

# Test cell: Normalization

# Test case 1: High amplitude signal
print("Test 1: High amplitude signal")
high_amp = torch.tensor([[2.5, -3.0, 1.5]])
normalized_high = normalize(high_amp)
print()

# Test case 2: Normal amplitude signal
print("Test 2: Normal amplitude signal")
normal_amp = torch.tensor([[0.5, -0.7, 0.3]])
normalized_normal = normalize(normal_amp)
print()

# Test case 3: Zero signal
print("Test 3: Zero signal")
zero_amp = torch.zeros((1, 3))
normalized_zero = normalize(zero_amp)
print()

# Test case 4: Predominantly negative signal
print("Test 4: Predominantly negative signal")
neg_amp = torch.tensor([[-2.5, 0.5, -3.0]])
normalized_neg = normalize(neg_amp)
print()

# Test case 5: Multi-channel audio
print("Test 5: Stereo signal")
stereo_amp = torch.tensor([[2.0, -3.0, 1.0],
                          [1.5, -2.0, 0.5]])
normalized_stereo = normalize(stereo_amp)
print()

# Test case 6: Different precision
print("Test 6: Float64 signal")
float64_amp = torch.tensor([[2.5, -3.0, 1.5]], dtype=torch.float64)
normalized_float64 = normalize(float64_amp)
print()

# Display results
print("Final Results:")
for name, orig, norm in [
    ("High amplitude", high_amp, normalized_high),
    ("Normal amplitude", normal_amp, normalized_normal),
    ("Zero amplitude", zero_amp, normalized_zero),
    ("Negative amplitude", neg_amp, normalized_neg),
    ("Stereo", stereo_amp, normalized_stereo),
    ("Float64", float64_amp, normalized_float64)
]:
    print(f"\n{name}:")
    print(f"Original: {orig}")
    print(f"Normalized: {norm}")
    if norm is not None:
        print(f"Max abs value: {torch.abs(norm).max()}")

In [None]:
%%skip

# Test Cell: Silence Trimming

def generate_test_audio_with_middle_silence(sample_rate=16000, duration_seconds=3):
    """Generate test audio with silences at start, middle, and end"""
    total_samples = int(sample_rate * duration_seconds)
    silence_samples = int(sample_rate * 0.5)  # 0.5 seconds of silence

    # Create the waveform
    waveform = torch.zeros(1, total_samples)

    # Generate two sine waves for the first and third thirds
    t1 = torch.linspace(0, duration_seconds/3, total_samples//3 - silence_samples)
    t2 = torch.linspace(0, duration_seconds/3, total_samples//3 - silence_samples)
    frequency = 440  # A4 note
    signal1 = 0.5 * torch.sin(2 * math.pi * frequency * t1)
    signal2 = 0.5 * torch.sin(2 * math.pi * frequency * t2)

    # Insert the signals with silence at start, middle, and end
    mid_point = total_samples // 2
    waveform[0, silence_samples:total_samples//3] = signal1
    waveform[0, -total_samples//3:-silence_samples] = signal2

    return waveform, sample_rate

def plot_waveforms_with_db(waveform, processed_waveform, db, threshold_db, sample_rate):
    """Plot original, processed waveforms, and the dB levels"""
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(15, 12))

    # Plot original
    time_orig = torch.arange(waveform.size(1)) / sample_rate
    ax1.plot(time_orig, waveform[0])
    ax1.set_title('Original Waveform')
    ax1.set_ylabel('Amplitude')

    # Plot processed
    time_proc = torch.arange(processed_waveform.size(1)) / sample_rate
    ax2.plot(time_proc, processed_waveform[0])
    ax2.set_title('Processed Waveform')
    ax2.set_ylabel('Amplitude')

    # Plot dB levels
    time_db = torch.arange(len(db)) / (len(db) / duration_seconds)
    ax3.plot(time_db, db)
    ax3.axhline(y=threshold_db, color='r', linestyle='--', label=f'Threshold ({threshold_db} dB)')
    ax3.set_title('RMS Energy (dB)')
    ax3.set_xlabel('Time (s)')
    ax3.set_ylabel('dB')
    ax3.legend()

    plt.tight_layout()
    plt.show()

# Generate test audio with middle silence
duration_seconds = 3
waveform, sample_rate = generate_test_audio_with_middle_silence(duration_seconds=duration_seconds)

# Process the audio and capture intermediate values
threshold_db = -50.0
min_length_ms = 50

# Calculate RMS energy in decibels (copying from trim_silence function)
if waveform.size(0) > 1:
    waveform = torch.mean(waveform, dim=0, keepdim=True)
frame_length = int(min_length_ms * waveform.size(1) / 1000)
rms = torch.sqrt(F.avg_pool1d(waveform ** 2, kernel_size=frame_length, stride=1))
db = 20 * torch.log10(rms + 1e-8)[0]

# Process audio
start_time = time.time()
processed_waveform = trim_silence(waveform, threshold_db, min_length_ms)
processing_time = time.time() - start_time

# Plot with dB levels
plot_waveforms_with_db(waveform, processed_waveform, db, threshold_db, sample_rate)

# Print analysis
original_duration = waveform.size(1) / sample_rate
processed_duration = processed_waveform.size(1) / sample_rate
print(f"\nDuration Analysis:")
print(f"Original audio duration: {original_duration:.2f} seconds")
print(f"Processed audio duration: {processed_duration:.2f} seconds")
print(f"\nProcessing time: {processing_time*1000:.2f} ms")

# Play the audio
print("\nPlaying original audio:")
ipd.display(ipd.Audio(waveform.numpy(), rate=sample_rate))
print("\nPlaying processed audio:")
ipd.display(ipd.Audio(processed_waveform.numpy(), rate=sample_rate))

In [None]:
%%skip

# Test Cell: Pitch Detection


def generate_sine_wave(frequency, duration=1.0, sample_rate=44100, amplitude=0.5):
    """Generate a sine wave for testing.

    Args:
        frequency (float): The frequency of the sine wave in Hz
        duration (float): Duration of the signal in seconds
        sample_rate (int): Sample rate in Hz
        amplitude (float): Amplitude of the signal

    Returns:
        tuple: (samples, sample_rate)
    """
    t = np.linspace(0, duration, int(sample_rate * duration))
    samples = amplitude * np.sin(2 * np.pi * frequency * t)
    return samples, sample_rate  # Return both the samples and sample rate



def test_pitch_detection():
    # Create temporary directory for test files
    temp_dir = Path(tempfile.mkdtemp())

    # Test frequencies (in Hz)
    test_frequencies = [
        261.63,  # C4
        293.66,  # D4
        329.63,  # E4
        349.23,  # F4
        392.00,  # G4
        440.00,  # A4
        493.88,  # B4
        523.25   # C5
    ]

    note_names = ['C4', 'D4', 'E4', 'F4', 'G4', 'A4', 'B4', 'C5']
    results = []
    execution_times = []

    print("Starting pitch detection test...")
    print("-" * 50)

    try:
        # Test synthetic tones
        for freq, note in zip(test_frequencies, note_names):
            # Generate and save test tone
            samples, sample_rate = generate_sine_wave(freq)
            temp_file = temp_dir / f"{note}_{freq}Hz.wav"
            sf.write(temp_file, samples, sample_rate)

            # Time the pitch detection
            start_time = time.time()
            detected_freq = librosa_get_main_pitch(str(temp_file))
            end_time = time.time()
            execution_times.append(end_time - start_time)

            if detected_freq is not None:
                # Calculate error
                error_cents = 1200 * np.log2(detected_freq / freq)
                error_hz = detected_freq - freq

                results.append({
                    'note': note,
                    'expected_freq': freq,
                    'detected_freq': detected_freq,
                    'error_hz': error_hz,
                    'error_cents': error_cents
                })

                print(f"Note: {note}")
                print(f"Expected: {freq:.1f} Hz")
                print(f"Detected: {detected_freq:.1f} Hz")
                print(f"Error: {error_hz:.1f} Hz ({error_cents:.1f} cents)")
                print(f"Execution time: {(end_time - start_time):.3f} seconds")
                print("-" * 50)
            else:
                print(f"Warning: No pitch detected for {note}")
                print("-" * 50)

        # Print timing statistics
        mean_time = np.mean(execution_times)
        std_time = np.std(execution_times)
        print(f"\nTiming Statistics:")
        print(f"Mean execution time: {mean_time:.3f} seconds")
        print(f"Std dev of execution time: {std_time:.3f} seconds")


    finally:
        # Cleanup temporary files
        for file in temp_dir.glob("*.wav"):
            try:
                file.unlink()
            except Exception as e:
                print(f"Warning: Could not delete {file}: {e}")
        try:
            temp_dir.rmdir()
        except Exception as e:
            print(f"Warning: Could not delete temp directory: {e}")

# Run the test
if __name__ == "__main__":
    test_pitch_detection()

## Extra test draft for pitch detection - ignore ##

base = 'sPno_'
note_list = ['C3', 'D3', 'E3', 'F3', 'G3', 'A3', 'B3', 'C4', 'D4', 'E4', 'F4', 'G4', 'A4', 'B4'] # [pno + 'C3', pno + 'D3', pno + 'E3', pno + 'F3', pno + 'G3', pno + 'A3', pno + 'B3', pno + 'C4', pno + 'D4', pno + 'E4', pno + 'F4', pno + 'G4', pno + 'A4', pno + 'B4']
base_path = f'/content/{base}'
extensions = ['webm']

for note in note_list:
  print(f"Sample: {base}{note}")
  for ext in extensions:
    sample_path = f'{base_path}{note}.{ext}'
    mimetype = logMimeType(sample_path)

    get_main_pitch(sample_path)