# Distributed Multi-Node, Multi-GPU Audio Transcription in ML Container Runtime

In [None]:
# Import python packages
import streamlit as st
import pandas as pd
import torch
# We can also use Snowpark for our analyses!
from typing import Dict
from pathlib import Path
import numpy as np
import shutil
from snowflake.snowpark.context import get_active_session
from snowflake.ml.ray.datasource import SFStageBinaryFileDataSource
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from snowflake.ml.runtime_cluster import scale_cluster, get_nodes
from snowflake.ml.ray.datasink import SnowflakeTableDatasink
import ray
import subprocess
import logging
session = get_active_session()

### Start with one node

In [None]:
ray.init(ignore_reinit_error=True)
num_nodes = len([node for node in ray.nodes() if node["Alive"]==True])
print(num_nodes)

### Scale up to 5 nodes

In [None]:
# Asynchronous scaling - function returns immediately after request is accepted
scale_cluster(5, is_async=True)

### Control ray logging

In [None]:
def configure_ray_logger() -> None:
    #Configure Ray logging
    ray_logger = logging.getLogger("ray")
    ray_logger.setLevel(logging.CRITICAL)

    data_logger = logging.getLogger("ray.data")
    data_logger.setLevel(logging.CRITICAL)

    #Configure root logger
    logger = logging.getLogger()
    logger.setLevel(logging.CRITICAL)

    #Configure Ray's data context
    context = ray.data.DataContext.get_current()
    context.execution_options.verbose_progress = False
    context.enable_operator_progress_bars = False

configure_ray_logger()

In [None]:
! ffmpeg -version

In [None]:
print(int(ray.cluster_resources()['GPU']))

### See audio files in snowflake stage

In [None]:
ls @AUDIO_FILES_STAGE

In [None]:
audio_source = SFStageBinaryFileDataSource(
    stage_location = "@AUDIO_FILES_STAGE/",
    database = session.get_current_database(),
    schema = session.get_current_schema(),
    file_pattern = "*.flac"
)

# Load audio files into a ray dataset
audio_dataset = ray.data.read_datasource(audio_source)

In [None]:
audio_dataset.show(1)

In [None]:
audio_dataset.count()

### Get whisper model

In [None]:
model_id = "openai/whisper-large-v3"
batch_size = 30
is_cuda_available = torch.cuda.is_available()
device = torch.device("cuda" if is_cuda_available else "cpu")
torch_dtype = torch.float16 if is_cuda_available else torch.float32
print(device)
print(torch_dtype)

### Distributed inferencing

In [None]:
import pandas as pd
import tempfile
import os

class TranscribeAudioUpdated:
    def __init__(self):
        # initialize model here so that model can be put into correct GPU/node
        model = AutoModelForSpeechSeq2Seq.from_pretrained(
            model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
        )
        model.to(device)
        processor = AutoProcessor.from_pretrained(model_id)
        self.pipe = pipeline(
            "automatic-speech-recognition",
            model=model,
            tokenizer=processor.tokenizer,
            feature_extractor=processor.feature_extractor,
            max_new_tokens=128,
            chunk_length_s=30,
            batch_size=batch_size,
            return_timestamps=True,
            torch_dtype=torch_dtype,
            device=device,
            generate_kwargs={"language": "english"}
        )

    def __call__(self, batch: pd.DataFrame) -> pd.DataFrame:
        temp_files = []
        try:
            # Write each binary to a temporary file.
            for binary_content in batch["file_binary"]:
                # Use an appropriate suffix (e.g., .wav or .flac) based on your audio format.
                tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".flac")
                tmp_file.write(binary_content)
                tmp_file.close()
                temp_files.append(tmp_file.name)
            
            # Use the temporary file paths for inference.
            predictions = self.pipe(temp_files)
            assert len(predictions) == len(batch)
            outputs = [str(generated_audio["text"]).strip() for generated_audio in predictions]
            batch['outputs'] = outputs
            batch.drop(columns=['file_binary'], inplace=True)
        finally:
            # Clean up temporary files.
            for file_path in temp_files:
                try:
                    os.remove(file_path)
                except OSError:
                    pass
        return batch

In [None]:
transcribed_ds = audio_dataset.map_batches(TranscribeAudioUpdated,
        batch_size=batch_size,
        batch_format='pandas',
        concurrency=5,
        num_gpus=1,
)

In [None]:
drop table if exists WHISPER_DEMO_OUTPUT

In [None]:
datasink = SnowflakeTableDatasink(
    table_name="WHISPER_DEMO_OUTPUT",
    database=session.get_current_database(),
    schema=session.get_current_schema(),
    auto_create_table=True
)

In [None]:
transcribed_ds.write_datasink(datasink)

In [None]:
session.table("WHISPER_DEMO_OUTPUT").show()