# ASR Demo Notebook

This notebook demonstrates how to use the ASR (Automatic Speech Recognition) model.

**Features:**
- Clone repository and install dependencies
- Download pre-trained model weights
- Run inference on custom data
- Calculate WER/CER metrics

---

## 1. Setup


In [None]:
# Clone the repository (replace with your actual repo URL)
!git clone https://github.com/YOUR_USERNAME/asr-project.git
%cd asr-project


In [None]:
# Install required packages
%pip install -q -r requirements.txt


In [3]:
import os

# Create directory for saved models
os.makedirs("saved", exist_ok=True)

# Download model weights from Google Drive
# Replace YOUR_MODEL_ID with actual Google Drive file ID
MODEL_GDRIVE_ID = "1j9e24ERS2cPsPC1zqabkJjH3E-72LcuW"  # <-- REPLACE THIS

!gdown https://drive.google.com/uc?id={MODEL_GDRIVE_ID} -O saved/model_best.pth
print("Model downloaded successfully!")

zsh:1: no matches found: https://drive.google.com/uc?id=1j9e24ERS2cPsPC1zqabkJjH3E-72LcuW
Model downloaded successfully!


## 2. Run Inference on Custom Data

### 2.1 Mount Google Drive and set data path

Prepare your data in the following format:
```
your_data/
├── audio/
│   ├── file1.wav
    ├── file2.wav
│   └── ...
└── transcriptions/  (optional)
    ├── file1.txt
    ├── file2.txt
    └── ...
```


In [1]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')


ModuleNotFoundError: No module named 'google.colab'

In [7]:
# Set path to your custom data
# Example: /content/drive/MyDrive/HSE_ASR_project/custom_data
# CUSTOM_DATA_PATH = "/content/drive/MyDrive/YOUR_DATA_FOLDER"  # <-- CHANGE THIS
CUSTOM_DATA_PATH = 'data/custom_data/'

print(f"Data path: {CUSTOM_DATA_PATH}")
if os.path.exists(os.path.join(CUSTOM_DATA_PATH, 'audio')):
    audio_files = os.listdir(os.path.join(CUSTOM_DATA_PATH, 'audio'))[:5]
    print(f"Audio files found: {audio_files}...")
else:
    print("ERROR: audio/ folder not found!")


Data path: data/custom_data/
Audio files found: ['audio_4.wav', 'audio_5.wav', 'audio_7.wav', 'audio_6.wav', 'audio_2.wav']...


In [10]:
# Run inference on custom data
!python inference.py -cn=asr_inference_custom \
    "datasets.inference.data_dir={CUSTOM_DATA_PATH}" \
    inferencer.from_pretrained=saved/model_best.pth \
    inferencer.save_path=custom_results


DeepSpeech2(
  (conv_layers): ModuleList(
    (0): MaskConv2d(
      (conv): Conv2d(1, 32, kernel_size=(41, 11), stride=(2, 2), padding=(20, 5))
    )
    (1): MaskConv2d(
      (conv): Conv2d(32, 32, kernel_size=(21, 11), stride=(2, 1), padding=(10, 5))
    )
  )
  (conv_bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv_activation): Hardtanh(min_val=0, max_val=20, inplace=True)
  (rnn_layers): ModuleList(
    (0): BatchRNN(
      (rnn): GRU(640, 256, batch_first=True, bidirectional=True)
    )
    (1-2): 2 x BatchRNN(
      (bn): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (rnn): GRU(512, 256, batch_first=True, bidirectional=True)
    )
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (fc): Linear(in_features=512, out_features=29, bias=True)
)
All parameters: 4,012,797
Trainable parameters: 4,012,797
Loading model weights from: saved/model_best.pth ...
inference: 100%|████████████████████████████████

## 3. View Predictions and Calculate Metrics


In [11]:
from pathlib import Path

# Show predictions
predictions_dir = Path("data/saved/custom_results/inference")

if predictions_dir.exists():
    pred_files = list(predictions_dir.glob("*.txt"))[:5]
    
    print("Sample predictions:\n")
    for pred_file in pred_files:
        with open(pred_file, 'r') as f:
            prediction = f.read()
        print(f"[{pred_file.stem}]")
        print(f"  {prediction}\n")
else:
    raise Exception("No predictions found. Please run inference first.")


Sample predictions:

[audio_7]
  nraseoer gol

[audio_6]
  norier qapcer qipool

[audio_4]
  krir'ar a ar aalrk

[audio_5]
  k f ar ahogw ol

[audio_1]
  nral



In [12]:
# Calculate metrics (if ground truth transcriptions are available)
GROUND_TRUTH_PATH = f"{CUSTOM_DATA_PATH}/transcriptions"
PREDICTIONS_PATH = "data/saved/custom_results/inference"

!python calc_metrics.py \
    --predictions "{PREDICTIONS_PATH}" \
    --ground_truth "{GROUND_TRUTH_PATH}" \
    --verbose


Loading predictions from: data/saved/custom_results/inference
Loading ground truth from: data/custom_data/transcriptions

Found 10 predictions
Found 10 ground truth files

[audio_0]
  GT:   жизнь других
  Pred: ga fcer gopck
  WER: 150.00%, CER: 108.33%

[audio_1]
  GT:   элтон джонс
  Pred: nral
  WER: 100.00%, CER: 100.00%

[audio_2]
  GT:   побег из шоушенко
  Pred: nherhaer f f ak gol
  WER: 166.67%, CER: 105.88%

[audio_3]
  GT:   дайан вормик
  Pred: nhraokqporan h ael
  WER: 150.00%, CER: 150.00%

[audio_4]
  GT:   мухаммед али
  Pred: krir'ar a ar aalrk
  WER: 200.00%, CER: 141.67%

[audio_5]
  GT:   дневник памяти
  Pred: k f ar ahogw ol
  WER: 250.00%, CER: 107.14%

[audio_6]
  GT:   рэйф файнс
  Pred: norier qapcer qipool
  WER: 150.00%, CER: 190.00%

[audio_7]
  GT:   нефть
  Pred: nraseoer gol
  WER: 200.00%, CER: 240.00%

[audio_8]
  GT:   золотая лихорадка
  Pred: nier ar acanh aer
  WER: 200.00%, CER: 94.12%

[audio_9]
  GT:   служебный роман
  Pred: moil'pololqller qxl

## 4. Interactive Demo - Transcribe a Single Audio File


In [None]:
import torch
import torchaudio
from IPython.display import Audio, display
import sys
sys.path.insert(0, '.')

from src.model import DeepSpeech2
from src.text_encoder import TextEncoder
from src.transforms import MelSpectrogram

# Load model
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Initialize components
text_encoder = TextEncoder()
mel_spec = MelSpectrogram()

# Load model
model = DeepSpeech2(
    n_feats=80,
    n_tokens=len(text_encoder),
).to(device)

checkpoint = torch.load("saved/model_best.pth", map_location=device)
model.load_state_dict(checkpoint["state_dict"])
model.eval()

print("Model loaded successfully!")


Using device: cpu


UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL omegaconf.dictconfig.DictConfig was not an allowed global by default. Please use `torch.serialization.add_safe_globals([omegaconf.dictconfig.DictConfig])` or the `torch.serialization.safe_globals([omegaconf.dictconfig.DictConfig])` context manager to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.

In [None]:
def transcribe(audio_path: str, use_beam_search: bool = False, beam_size: int = 10):
    """Transcribe an audio file."""
    # Load audio
    audio, sr = torchaudio.load(audio_path)
    audio = audio.squeeze(0)
    
    # Resample if needed
    if sr != 16000:
        resampler = torchaudio.transforms.Resample(sr, 16000)
        audio = resampler(audio)
    
    # Compute spectrogram
    spectrogram = mel_spec(audio)
    spectrogram = spectrogram.unsqueeze(0).to(device)
    spec_length = torch.tensor([spectrogram.shape[-1]]).to(device)
    
    # Run model
    with torch.no_grad():
        output = model(spectrogram=spectrogram, spectrogram_length=spec_length)
    
    log_probs = output["log_probs"]
    log_probs_length = output["log_probs_length"]
    
    # Decode
    if use_beam_search:
        return text_encoder.ctc_beam_search(log_probs, log_probs_length, beam_size=beam_size)[0]
    return text_encoder.ctc_decode(log_probs, log_probs_length)[0]


In [None]:
# Upload an audio file and transcribe
from google.colab import files

print("Upload an audio file (wav, flac, or mp3):")
uploaded = files.upload()

for filename in uploaded.keys():
    print(f"\nProcessing: {filename}")
    display(Audio(filename))
    
    print("\nTranscription (Greedy):")
    print(f"  {transcribe(filename, use_beam_search=False)}")
    
    print("\nTranscription (Beam Search):")
    print(f"  {transcribe(filename, use_beam_search=True, beam_size=10)}")
