In [None]:
# 1. First checking CUDA availability and PyTorch version
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Current PyTorch version: {torch.__version__}")
print(f"Current CUDA version: {torch.version.cuda if torch.cuda.is_available() else 'None'}")


In [None]:
# 2. Install Dependencies

!pip install fastapi uvicorn librosa python-multipart ffmpeg-python aiofiles soundfile pyngrok nest_asyncio
!apt-get install -y sox ffmpeg libportaudio2

#!pip install fastapi==0.109.0 uvicorn==0.27.0 torch==2.1.2 torchaudio==2.1.2 librosa==0.10.1 numpy==1.26.3 python-multipart==0.0.6 ffmpeg-python python-multipart aiofiles soundfile pyngrok

# ffmpeg used under the hood for speed advantage:
#!apt-get update && apt-get install -y ffmpeg # -y libsndfile1
# main audio processing packages:
#!pip install torch torchaudio librosa numpy
# api setup and storage:
#!pip install flask flask-cors pyngrok firebase-admin

In [None]:
!apt-get update && apt-get install -y sox libsox-dev
!pip install torchaudio --no-cache-dir

In [None]:
# 3. setup directories

!mkdir -p /content/sample-prepper/uploads
!mkdir -p /content/sample-prepper/output

In [None]:
# 4. Imports
from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
import uvicorn
import asyncio
import nest_asyncio
from pyngrok import ngrok
from concurrent.futures import ThreadPoolExecutor
from google.colab import userdata
import logging
import json
import os
import io
import logging
import numpy as np
import math
import time
from pathlib import Path
import IPython.display as ipd
from IPython.core.magic import register_cell_magic
import mimetypes
import matplotlib.pyplot as plt
import torch
import torchaudio
import torchaudio.transforms as T
import torchaudio.functional as F
import librosa


# For skipping cells
@register_cell_magic
def skip(line, cell):
    return

# Check if running on GPU
if torch.cuda.is_available():
  print(f"GPU available: {torch.cuda.get_device_name(torch.cuda.current_device())}")

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

In [None]:
# Logging setup

def setup_minimal_logger():
    """Configure minimal logging for core processing steps"""
    logger = logging.getLogger('audio')
    logger.setLevel(logging.INFO)

    # Clear any existing handlers
    logger.handlers = []

    # Simple console handler
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    formatter = logging.Formatter('%(message)s')
    console.setFormatter(formatter)
    logger.addHandler(console)

    return logger

# Initialize logger
log = setup_minimal_logger()

In [None]:
# Cell 3: Firebase Setup
fire_cred = userdata.get('FB_CRED')

# Only initialize if no apps exist
if not firebase_admin._apps:
    # Write the credentials to a temporary file
    with tempfile.NamedTemporaryFile(delete=False, suffix='.json') as temp_file:
        temp_file.write(fire_cred.encode('utf-8'))
        temp_file_path = temp_file.name

    # Initialize Firebase with the temporary credential file
    cred = credentials.Certificate(temp_file_path)
    firebase_admin.initialize_app(cred, {
        'storageBucket': 'sample-prep-dbd20.firebasestorage.app'
    })

    os.unlink(temp_file_path)

bucket = storage.bucket()

In [None]:
# main.py

# Create fixed directories
BASE_DIR = Path('/content/sample-prepper')
UPLOAD_DIR = BASE_DIR / 'uploads'
OUTPUT_DIR = BASE_DIR

# Create directories
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)


# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log.info(f"Using device: {device}")

def get_mime_type(path: str) -> tuple:
    """Get the mime type of the audio file.

    Args:
        path (str): Path to the audio file

    Returns:
        tuple: Mime type and encoding
    """
    mime_type = mimetypes.guess_type(path)
    log.info(f"{path}: Mime type: {mime_type}")
    return mime_type

def get_output_format(input_format: str, requested_format: str = None) -> str:
    """Determine the output format based on input and requested format.

    Args:
        input_format (str): Original file format (e.g., 'mp3', 'webm')
        requested_format (str): Optional requested output format

    Returns:
        str: Output format to use
    """
    valid_formats = {'wav', 'mp3', 'webm'}

    if requested_format and requested_format.lower() in valid_formats:
        return requested_format.lower()

    # Default to input format, or wav if input format not supported
    return input_format.lower() if input_format.lower() in valid_formats else 'wav'


def normalize(waveform: torch.Tensor) -> torch.Tensor:
    """Peak normalize audio to range [-1, 1].

    Args:
        waveform (torch.Tensor): Input audio waveform

    Returns:
        torch.Tensor: Normalized waveform
    """
    input_peak = torch.abs(waveform).max()
    log.info(f"[NORMALIZE] Input peak amplitude: {input_peak}")

    if input_peak > 0:
        normalized_waveform = waveform / input_peak
        output_peak = torch.abs(normalized_waveform).max()
        log.info(f"[NORMALIZE] Output peak amplitude: {output_peak}")
        return normalized_waveform
    return waveform

def get_pitch_factor(original_pitch: float, target_pitch: float) -> float:
    """Calculate the factor needed to transpose from original pitch to target pitch.

    Args:
        original_pitch (float): Original pitch frequency in Hz
        target_pitch (float): Target pitch frequency in Hz

    Returns:
        float: Pitch adjustment factor

    Raises:
        ValueError: If original pitch is not positive
    """
    if original_pitch <= 0:
        raise ValueError("Original pitch must be positive")
    return target_pitch / original_pitch

def transpose_torch(waveform: torch.Tensor, sample_rate: int, factor: float) -> torch.Tensor:
    """Transpose audio by resampling.

    Args:
        waveform (torch.Tensor): Input audio waveform
        sample_rate (int): Original sample rate
        factor (float): Pitch adjustment factor

    Returns:
        torch.Tensor: Transposed waveform

    Raises:
        TypeError: If inputs are of wrong type
        ValueError: If factor is not positive
    """
    if not isinstance(factor, (int, float)):
        raise TypeError("Factor must be a number")
    if factor <= 0:
        raise ValueError("Factor must be positive")
    if not isinstance(sample_rate, int):
        raise TypeError("Sample rate must be an integer")
    if not torch.is_tensor(waveform):
        raise TypeError("Waveform must be a torch.Tensor")

    resample_rate = int(sample_rate / factor)
    resampler = T.Resample(
        orig_freq=sample_rate,
        new_freq=resample_rate,
        dtype=waveform.dtype
    )

    return resampler(waveform)

def trim_silence(waveform: torch.Tensor, threshold_db: float = -50.0,
                min_length_ms: float = 50, sr: int = 44100) -> torch.Tensor:
    """Trim silence from start and end of audio using numpy for fast processing.

    Args:
        waveform (torch.Tensor): Input audio waveform
        threshold_db (float): Threshold in decibels below which audio is considered silence
        min_length_ms (float): Minimum length of audio segment in milliseconds
        sr (int): Sample rate of the audio

    Returns:
        torch.Tensor: Trimmed waveform
    """
    try:
        # Convert to numpy and ensure it's flat
        audio_np = waveform.cpu().numpy()
        if audio_np.ndim == 2:
            audio_np = audio_np.squeeze(0)  # Remove channel dimension

        # Calculate frame length
        frame_length = int(min_length_ms * sr / 1000)

        # Quick returns for short audio
        if len(audio_np) < frame_length or len(audio_np) // frame_length == 0:
            return waveform

        # Reshape audio into frames (faster than processing sample by sample)
        frames = audio_np[:len(audio_np) - (len(audio_np) % frame_length)]
        frames = frames.reshape(-1, frame_length)

        # Vectorized RMS energy calculation
        rms = np.sqrt(np.mean(np.square(frames), axis=1))
        db = 20 * np.log10(rms + 1e-8)

        # Find start and end points above threshold
        mask = db > threshold_db
        nonzero = np.nonzero(mask)[0]

        if len(nonzero) == 0:
            return waveform

        # Calculate trim points
        start = nonzero[0] * frame_length
        end = min((nonzero[-1] + 1) * frame_length, len(audio_np))

        # Convert back to torch tensor efficiently
        trimmed = torch.from_numpy(audio_np[start:end]).to(waveform.device)
        return trimmed.unsqueeze(0) if waveform.dim() == 2 else trimmed

    except Exception as e:
        log.warning(f"Silence trimming failed: {str(e)}")
        return waveform


def get_main_pitch(audio_data, sr, min_note='C1', max_note='C7'):
    """Get the main pitch from audio data"""
    try:
        # Ensure input is numpy array
        if torch.is_tensor(audio_data):
            audio_data = audio_data.numpy()

        # Calculate pitch using PYIN algorithm
        f0, voiced_flag, voiced_probs = librosa.pyin(
            audio_data,
            fmin=librosa.note_to_hz(min_note),
            fmax=librosa.note_to_hz(max_note),
            sr=sr,
            frame_length=2048,
            win_length=1024,
            hop_length=512
        )

        # Filter out unvoiced and low probability segments
        mask = voiced_flag & (voiced_probs > 0.6)
        f0_valid = f0[mask]

        if len(f0_valid) == 0:
            log.warning("No valid pitch detected")
            return None, None, 0.0

        # Get the median frequency
        median_f0 = float(np.median(f0_valid))

        # Convert to note
        closest_note = librosa.hz_to_note(median_f0)
        note_freq = librosa.note_to_hz(closest_note)
        note = {'closest_note': closest_note, 'freq': note_freq}

        # Calculate confidence
        confidence = float(np.mean(voiced_probs[voiced_flag]))

        return median_f0, note, confidence

    except Exception as e:
        log.error(f"Pitch detection failed: {str(e)}")
        raise
## ________________ MUCH FASTER BUT ONLY CONSIDERING a Specified time segment of the audio  ________________

def detect_pitch_optimized(audio_data, sr, min_note='C1', max_note='C7', analysis_ratio=0.2):
    """Get the main pitch from audio data

    Args:
        audio_data: Input audio array
        sr: Sample rate
        min_note: Minimum note to detect
        max_note: Maximum note to detect
        analysis_ratio: Ratio of total audio length to analyze (0.0 to 1.0)
    """

    try:
        # Ensure input is numpy array
        if torch.is_tensor(audio_data):
            audio_data = audio_data.numpy()

        # Downsample for pitch detection if sample rate is high
        if sr > 22050:
            target_sr = 22050
            audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=target_sr)
            sr = target_sr

        # Calculate segment length based on ratio
        total_length = len(audio_data)
        segment_len = int(total_length * analysis_ratio)  # Simply take ratio of total length

        # Calculate segment boundaries
        start_pos = total_length // 8 # Start from 1/8th of audio
        end_pos = start_pos + segment_len

        # Take segment from middle of audio
        audio_segment = audio_data[start_pos:end_pos]

        log.info(f"Analyzing {analysis_ratio*100:.1f}% of audio "
                   f"({len(audio_segment)/sr:.3f}s out of {total_length/sr:.3f}s total) "
                   f"from position {start_pos/sr:.3f}s to {end_pos/sr:.3f}s")


        # Calculate pitch using PYIN algorithm with optimized parameters
        f0, voiced_flag, voiced_probs = librosa.pyin(
            audio_segment,
            fmin=librosa.note_to_hz(min_note),
            fmax=librosa.note_to_hz(max_note),
            sr=sr,
            frame_length=1024,  # Reduced from 2048
            win_length=512,     # Reduced from 1024
            hop_length=256      # Reduced from 512
        )

        low_prob_ratio = 0.3  # Set a low probability threshold 0.0 to 1.0

        # Filter out unvoiced and low probability segments
        mask = voiced_flag & (voiced_probs > low_prob_ratio)
        f0_valid = f0[mask]

        if len(f0_valid) == 0:
            log.warning("No valid pitch detected")
            return None, None, 0.0

        # Get the median frequency
        median_f0 = float(np.median(f0_valid))

        # Convert to note
        closest_note = librosa.hz_to_note(median_f0)
        note_freq = librosa.note_to_hz(closest_note)
        note = {'closest_note': closest_note, 'freq': note_freq}

        # Calculate confidence
        confidence = float(np.mean(voiced_probs[voiced_flag]))

        return median_f0, note, confidence

    except Exception as e:
        log.error(f"Pitch detection failed: {str(e)}")
        raise


# Main processing function
def process_audio_file(file_path, options):
    """Process audio file with given options."""
    start_time = time.time()
    results = []

    try:
        # Load and convert to mono if needed
        waveform, sr = torchaudio.load(file_path)
        waveform = waveform.to(device)
        if waveform.size(0) > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)

        # Process based on options
        if options['normalize']:
            waveform = waveform / torch.max(torch.abs(waveform))

        if options['trim']:
            waveform = trim_silence(waveform)

        if options['tune']:
            audio_np = waveform[0].cpu().numpy()
            detected_pitch, note, confidence = detect_pitch_optimized(audio_np, sr)

            if detected_pitch and confidence > 0.6:
                target = options.get('target_pitch', 261.6255)  # C4
                factor = get_pitch_factor(detected_pitch, target)

                if 0.5 <= factor <= 2.0:
                    log.info(f"{note['closest_note']} → {librosa.hz_to_note(target)}")
                    effects = [["speed", str(factor)], ["rate", str(sr)]]
                    processed_waveform, processed_sr = torchaudio.sox_effects.apply_effects_tensor(
                        waveform.cpu(), sr, effects)
                    results.append(("sox", processed_waveform, processed_sr))
                else:
                    log.info("Pitch adjustment too large")
            else:
                log.info("No clear pitch detected")

        if not results:
            results.append(("original", waveform, sr))

        log.info(f"Done in {time.time() - start_time:.1f}s")
        return results

    except Exception as e:
        log.info(f"Error: {str(e)}")
        raise

    finally:
        torch.cuda.empty_cache()


# Initialize FastAPI app

app = FastAPI()

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"], # add my client when deploying
    allow_credentials=True,
    allow_methods=["*"], # restrict to needed methods when deploying
    allow_headers=["*"],
    expose_headers=["*"],
    max_age=86400,  # Cache preflight requests for 24 hours
)

@app.get("/")
async def root():
    return {"message": "Sample Prepper API running"}


@app.get("/process")
async def health_check():
    return {"status": "ok"}

# Create thread pool for CPU-intensive tasks
thread_pool = ThreadPoolExecutor(max_workers=4)

async def run_in_thread(func, *args):
    """Run CPU-intensive tasks in thread pool"""
    return await asyncio.get_event_loop().run_in_executor(thread_pool, func, *args)



@app.post("/process")
async def process_audio_endpoint(
    file: UploadFile = File(...),
    options: str = Form(default=None)
):
    """Process audio file endpoint"""
    try:
        process_options = {
            'normalize': False,
            'trim': False,
            'tune': False,
            'returnType': 'blob',
            'outputFormat': 'wav'
        }

        if options:
            process_options.update(json.loads(options))

        # Process file
        timestamp = int(time.time() * 1000)
        input_path = UPLOAD_DIR / f"{timestamp}_{file.filename}"
        output_path = None

        with open(input_path, "wb") as f:
            f.write(await file.read())

        results = await run_in_thread(process_audio_file, str(input_path), process_options)

        if not results:
            raise HTTPException(status_code=500, detail="Processing failed")

        # Save result
        output_path = OUTPUT_DIR / f"processed_{timestamp}_{file.filename}"
        waveform, sr = results[0][1], results[0][2]

        torchaudio.save(
            str(output_path),
            waveform.cpu(),
            sr,
            encoding="PCM_S",
            bits_per_sample=16
        )

        return FileResponse(
            path=output_path,
            filename=f"processed_{timestamp}_{file.filename}",
            media_type='audio/wav'
        )

    except Exception as e:
        log.info(f"Error: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

    finally:
        # Cleanup in background task
        async def cleanup():
            await asyncio.sleep(1)  # Wait for file transfer to complete
            if input_path.exists():
                input_path.unlink()
            if output_path and output_path.exists():
                output_path.unlink()

        background_tasks = BackgroundTasks()
        background_tasks.add_task(cleanup)



In [None]:
from pyngrok import ngrok

# Start ngrok
ngrok.set_auth_token(userdata.get('NGROK_SECRET'))
ngrok_tunnel = ngrok.connect(5001, domain="singular-roughy-humane.ngrok-free.app")
print('Public URL:', ngrok_tunnel.public_url)

# Import and run uvicorn directly
import nest_asyncio
import uvicorn

nest_asyncio.apply()
uvicorn.run(app, port=5001, host='0.0.0.0', workers=1)