In [1]:
# -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
# Cell 1 - Core imports and logging setup
from __future__ import annotations

import logging
from pathlib import Path
from typing import Any, Dict, Optional, Union

import torch
import torchaudio
import yaml

from src import commons
from src.datasets.base_dataset import APPLY_NORMALIZATION, apply_preprocessing
from src.models import models


LOGGER = logging.getLogger("inference_notebook")
if not LOGGER.handlers:
    logging.basicConfig(level=logging.INFO)





In [2]:

# -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
# Cell 2 - Paths, device configuration, and deterministic setup
DEFAULT_CONFIG_PATH = Path("/Users/ahmedgamal/Downloads/deepfake-whisper-features-1/configs/training/whisper_specrnet.yaml")
DEFAULT_WEIGHTS_PATH = Path("/Users/ahmedgamal/Downloads/whisper_specrnet/weights.pth")
# DEFAULT_WEIGHTS_PATH = Path("trained_models/whisper_specrnet/best_model.pth")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


def _resolve_path(path: Union[str, Path, None]) -> Optional[Path]:
    """Safely convert incoming path-like objects into `Path` instances."""

    if path is None:
        return None

    path_obj = Path(path)
    if str(path_obj).strip() == "":
        return None
    return path_obj


In [3]:


# -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
# Cell 3 - Configuration loading utilities
def load_inference_config(config_path: Union[str, Path]) -> Dict[str, Any]:
    """Load YAML configuration that defines the inference model."""

    config_path = Path(config_path)
    if not config_path.exists():
        raise FileNotFoundError(
            f"Configuration file not found at '{config_path.resolve()}'"
        )

    LOGGER.info("Loading inference configuration from %s", config_path)
    with config_path.open("r", encoding="utf-8") as file:
        config: Dict[str, Any] = yaml.safe_load(file)

    seed = config.get("data", {}).get("seed", 42)
    commons.set_seed(seed)
    LOGGER.info("Random seed fixed at %s", seed)

    return config

In [4]:



# -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
# Cell 4 - Model creation and weight restoration
def build_model(
    model_config: Dict[str, Any],
    weights_path: Optional[Union[str, Path]] = None,
    device: str = DEVICE,
) -> torch.nn.Module:
    """Instantiate the neural network and optionally restore trained weights."""

    model_name = model_config["name"]
    model_parameters = model_config.get("parameters", {})

    LOGGER.info("Creating model '%s' on %s", model_name, device)
    model = models.get_model(model_name=model_name, config=model_parameters, device=device)
    model = model.to(device)

    weights_file = _resolve_path(weights_path)
    if weights_file and weights_file.exists():
        LOGGER.info("Loading model weights from %s", weights_file)
        state_dict = torch.load(weights_file, map_location=device)
        model.load_state_dict(state_dict)
    else:
        LOGGER.warning(
            "Weights file not found. Inference will use randomly initialized weights."
        )

    model.eval()
    return model

INFERENCE_CONFIG = load_inference_config(DEFAULT_CONFIG_PATH)
INFERENCE_MODEL = build_model(
    model_config=INFERENCE_CONFIG["model"],
    weights_path=INFERENCE_CONFIG.get("checkpoint", {}).get("path")
    or DEFAULT_WEIGHTS_PATH,
    device=DEVICE,
)


INFO:inference_notebook:Loading inference configuration from /Users/ahmedgamal/Downloads/deepfake-whisper-features-1/configs/training/whisper_specrnet.yaml
INFO:inference_notebook:Random seed fixed at 42
INFO:inference_notebook:Creating model 'whisper_specrnet' on cpu
  checkpoint = torch.load(WHISPER_MODEL_WEIGHTS_PATH)
  from .autonotebook import tqdm as notebook_tqdm
INFO:inference_notebook:Loading model weights from /Users/ahmedgamal/Downloads/whisper_specrnet/weights.pth
  state_dict = torch.load(weights_file, map_location=device)


In [None]:
# -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
# Cell 5 - Audio loading and preprocessing helpers
def load_waveform(audio_path: Union[str, Path]) -> torch.Tensor:
    """Load a waveform from disk and apply dataset-aligned preprocessing."""

    audio_path = Path(audio_path)
    if not audio_path.exists():
        raise FileNotFoundError(f"Audio file not found at '{audio_path.resolve()}'")

    waveform, sample_rate = torchaudio.load(
        str(audio_path), normalize=APPLY_NORMALIZATION
    )
    waveform, _ = apply_preprocessing(waveform, sample_rate)

    return waveform.float()


def prepare_batch(waveform: torch.Tensor, device: str = DEVICE) -> torch.Tensor:
    """Prepare a batch tensor compatible with the model's expected input."""

    if waveform.dim() == 1:
        waveform = waveform.unsqueeze(0)
    return waveform.to(device)


In [None]:

# -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
# Cell 6 - Inference helpers and probability conversion
def predict_probability(
    model: torch.nn.Module,
    batch_waveform: torch.Tensor,
) -> float:
    """Run the forward pass and return the bonafide (real) probability."""

    model.eval()
    with torch.no_grad():
        logits = model(batch_waveform).squeeze(1)
        probability = torch.sigmoid(logits)

    return float(probability.detach().cpu().item())


def probability_to_label(probability: float, threshold: float = 0.5) -> str:
    """Convert bonafide probability into a human-readable label."""

    return "real" if probability >= threshold else "fake"



In [None]:




# -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
# Cell 7 - Public API to classify external audio files
def classify_audio_file(
    audio_path: Union[str, Path],
    *,
    model: torch.nn.Module = INFERENCE_MODEL,
    device: str = DEVICE,
    decision_threshold: float = 0.5,
) -> str:
    """Classify an input `.wav` file as real (bonafide) or fake (spoof)."""

    waveform = load_waveform(audio_path)
    batch = prepare_batch(waveform, device=device)
    score = predict_probability(model, batch)
    label = probability_to_label(score, threshold=decision_threshold)

    LOGGER.info(
        "File '%s' classified as %s with probability %.3f", audio_path, label, score
    )
    return label


__all__ = [
    "classify_audio_file",
    "load_inference_config",
    "build_model",
    "load_waveform",
    "predict_probability",
    "probability_to_label",
    "prepare_batch",
]