# AUDIO/VISUAL SPEECH RECOGNITION

**Note** This tutorial requires `mediapipe` or `retinaface` detector. Please refer to [preparation](../preparation#setup) for installation.

**Note** To run this tutorial, please make sure you are in tutorials folder.

In [1]:
import sys
sys.path.insert(0, "../")

In [2]:
import os
import torch
import torchaudio
import torchvision

## 1. Build an inference pipeline

The InferencePipeline carries out the following three steps:

1. Load audio or video data
2. Run pre-processing functions
3. Run inference

In [3]:
import os
from lightning import ModelModule
from datamodule.transforms import AudioTransform, VideoTransform

In [4]:
import argparse
parser = argparse.ArgumentParser()
args, _ = parser.parse_known_args(args=[])

In [5]:
class InferencePipeline(torch.nn.Module):
    def __init__(self, args, ckpt_path, detector="retinaface"):
        super(InferencePipeline, self).__init__()
        self.modality = args.modality
        if self.modality == "audio":
            self.audio_transform = AudioTransform(subset="test")
        elif self.modality == "video":
            if detector == "mediapipe":
                from preparation.detectors.mediapipe.detector import LandmarksDetector
                from preparation.detectors.mediapipe.video_process import VideoProcess
                self.landmarks_detector = LandmarksDetector()
                self.video_process = VideoProcess(convert_gray=False)
            elif detector == "retinaface":
                from preparation.detectors.retinaface.detector import LandmarksDetector
                from preparation.detectors.retinaface.video_process import VideoProcess
                self.landmarks_detector = LandmarksDetector(device="cuda:0")
                self.video_process = VideoProcess(convert_gray=False)
            self.video_transform = VideoTransform(subset="test")

        ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
        self.modelmodule = ModelModule(args)
        self.modelmodule.model.load_state_dict(ckpt)
        self.modelmodule.eval()

    def load_video(self, data_filename):
        return torchvision.io.read_video(data_filename, pts_unit="sec")[0].numpy()

    def forward(self, data_filename):
        data_filename = os.path.abspath(data_filename)
        assert os.path.isfile(data_filename), f"data_filename: {data_filename} does not exist."

        if self.modality == "audio":
            audio, sample_rate = self.load_audio(data_filename)
            audio = self.audio_process(audio, sample_rate)
            audio = audio.transpose(1, 0)
            audio = self.audio_transform(audio)
            with torch.no_grad():
                transcript = self.modelmodule(audio)

        if self.modality == "video":
            video = self.load_video(data_filename)
            landmarks = self.landmarks_detector(video)
            video = self.video_process(video, landmarks)
            video = torch.tensor(video)
            video = video.permute((0, 3, 1, 2))
            video = self.video_transform(video)
            with torch.no_grad():
                transcript = self.modelmodule(video)

        return transcript

    def load_audio(self, data_filename):
        waveform, sample_rate = torchaudio.load(data_filename, normalize=True)
        return waveform, sample_rate

    def load_video(self, data_filename):
        return torchvision.io.read_video(data_filename, pts_unit="sec")[0].numpy()

    def audio_process(self, waveform, sample_rate, target_sample_rate=16000):
        if sample_rate != target_sample_rate:
            waveform = torchaudio.functional.resample(
                waveform, sample_rate, target_sample_rate
            )
        waveform = torch.mean(waveform, dim=0, keepdim=True)
        return waveform

## 2. Get a video from the samples folder

In [6]:
data_filename = "/home/asish/LAALM/samples/video/bbaf2n.mpg"

## 3. VSR inference

### 3.1 Download a pre-trained model

In [7]:
model_path = "/home/asish/LAALM/auto_avsr/pretrained_models/vsr_trlrs2lrs3vox2avsp_base.pth"

### 3.2 Initialize VSR pipeline

In [8]:
setattr(args, 'modality', 'video')
pipeline = InferencePipeline(args, model_path, detector="mediapipe")

I0000 00:00:1766841012.294169 1428160 gl_context_egl.cc:85] Successfully initialized EGL. Major : 1 Minor: 5
I0000 00:00:1766841012.297280 1428335 gl_context.cc:344] GL version: 3.2 (OpenGL ES 3.2 Mesa 25.0.7-0ubuntu0.24.04.2), renderer: Mesa Intel(R) Graphics (RPL-S)
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
I0000 00:00:1766841012.301094 1428160 gl_context_egl.cc:85] Successfully initialized EGL. Major : 1 Minor: 5
I0000 00:00:1766841012.303216 1428345 gl_context.cc:344] GL version: 3.2 (OpenGL ES 3.2 Mesa 25.0.7-0ubuntu0.24.04.2), renderer: Mesa Intel(R) Graphics (RPL-S)


### 3.3 Run inference

In [9]:
transcript = pipeline(data_filename)
print(transcript)



BIMBO F2 NOW


## 4. ASR inference

### 4.1 Download a pre-trained model

In [None]:
!wget http://www.doc.ic.ac.uk/~pm4115/autoAVSR/asr_trlrs3_base.pth -O ./asr_trlrs3_base.pth
model_path = "./asr_trlrs3_base.pth"

### 4.2 Initialize ASR pipeline

In [None]:
setattr(args, 'modality', 'audio')
pipeline = InferencePipeline(args, model_path)

### 4.3 Run inference

In [None]:
transcript = pipeline("input.mp4")
print(transcript)