# AUDIO/VISUAL SPEECH RECOGNITION

**Note** This tutorial requires `mediapipe` or `retinaface` detector. Please refer to [preparation](../preparation#setup) for installation.

**Note** To run this tutorial, please make sure you are in tutorials folder.

In [1]:
import sys
sys.path.insert(0, "../")

In [2]:
import os
import torch
import torchaudio
import torchvision

## 1. Build an inference pipeline

The InferencePipeline carries out the following three steps:

1. Load audio or video data
2. Run pre-processing functions
3. Run inference

In [3]:
import os
from lightning import ModelModule
from datamodule.transforms import AudioTransform, VideoTransform
from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda" # the device to load the model onto

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
import argparse
parser = argparse.ArgumentParser()
args, _ = parser.parse_known_args(args=[])

In [5]:
class InferencePipeline(torch.nn.Module):
    def __init__(self, args, ckpt_path,lm_path, detector="retinaface"):
        super(InferencePipeline, self).__init__()
        self.model = AutoModelForCausalLM.from_pretrained(
            lm_path,
            torch_dtype="auto",
            device_map="auto"
        )
        self.tokenizer = AutoTokenizer.from_pretrained(lm_path)
        self.modality = args.modality
        if self.modality == "audio":
            self.audio_transform = AudioTransform(subset="test")
        elif self.modality == "video":
            if detector == "mediapipe":
                from preparation.detectors.mediapipe.detector import LandmarksDetector
                from preparation.detectors.mediapipe.video_process import VideoProcess
                self.landmarks_detector = LandmarksDetector()
                self.video_process = VideoProcess(convert_gray=False)
            elif detector == "retinaface":
                from preparation.detectors.retinaface.detector import LandmarksDetector
                from preparation.detectors.retinaface.video_process import VideoProcess
                self.landmarks_detector = LandmarksDetector(device="cuda:0")
                self.video_process = VideoProcess(convert_gray=False)
            self.video_transform = VideoTransform(subset="test")

        ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
        self.modelmodule = ModelModule(args)
        self.modelmodule.model.load_state_dict(ckpt)
        self.modelmodule.eval()

    def load_video(self, data_filename):
        return torchvision.io.read_video(data_filename, pts_unit="sec")[0].numpy()

    def forward(self, data_filename):
        data_filename = os.path.abspath(data_filename)
        assert os.path.isfile(data_filename), f"data_filename: {data_filename} does not exist."


        if self.modality == "video":
            video = self.load_video(data_filename)
            landmarks = self.landmarks_detector(video)
            video = self.video_process(video, landmarks)
            video = torch.tensor(video)
            video = video.permute((0, 3, 1, 2))
            video = self.video_transform(video)
            with torch.no_grad():
                transcript = self.modelmodule(video)
            prompt = f"通过语境将语音识别的拼音识别成通顺的中文，并加上标点：{transcript.lower()}"
            messages = [
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": prompt}
            ]
            text = self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )
            model_inputs = self.tokenizer([text], return_tensors="pt").to(device)
            
            generated_ids = self.model.generate(
                model_inputs.input_ids,
                max_new_tokens=512
            )
            generated_ids = [
                output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
            ]
            
            response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

        return response

    def load_audio(self, data_filename):
        waveform, sample_rate = torchaudio.load(data_filename, normalize=True)
        return waveform, sample_rate

    def load_video(self, data_filename):
        return torchvision.io.read_video(data_filename, pts_unit="sec")[0].numpy()

    def audio_process(self, waveform, sample_rate, target_sample_rate=16000):
        if sample_rate != target_sample_rate:
            waveform = torchaudio.functional.resample(
                waveform, sample_rate, target_sample_rate
            )
        waveform = torch.mean(waveform, dim=0, keepdim=True)
        return waveform

In [6]:
data_filename = "/home/agent_h/data/datasets/autoavsr_bili_splitted_sampled/lrs3/lrs3_video_seg16s/test/202308071746_bilibili_xiaochanwan_008_000085/202308071746_bilibili_xiaochanwan_008_000085.mp4"
setattr(args, 'modality', 'video')
model_path = "/home/agent_h/data/ckpts/0204_0217_autoavsr_cnv2/epoch=65.pth"
lm_path = "/home/agent_h/data/llms/Qwen2.5-72B-Instruct-AWQ"
pipeline = InferencePipeline(args, model_path,lm_path, detector="retinaface")

We suggest you to set `torch_dtype=torch.float16` for better efficiency with AWQ.
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████| 11/11 [00:07<00:00,  1.50it/s]


In [8]:
transcript = pipeline(data_filename)
print(transcript)

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


尊重自己，比大而有爱或得人的声好。那今天这期视频到这就结束了，大家看完不要忘记一键三连感谢，这样的话不要忘记关注。大家拜拜。
