In [4]:
from transformers import (
    SeamlessM4TModel,
    AutoProcessor
)
import torch
import os
from pydub import AudioSegment
import numpy as np
import io
import soundfile as sf
import librosa
import torchaudio
import gradio as gr
import IPython.display as ipd

from lang_list_pkg import (
    TEXT_SOURCE_LANGUAGE_NAMES,
    S2ST_TARGET_LANGUAGE_NAMES,
    T2TT_TARGET_LANGUAGE_NAMES,
    S2TT_TARGET_LANGUAGE_NAMES,
    LANGUAGE_CODE_TO_NAME,
    LANG_TO_SPKR_ID
)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# URL save model and related layers
save_url = "./seamless_m4t_model"
model_name= "facebook/hf-seamless-m4t-medium"

In [7]:
# Load the model and related layers from huggingface
try:
  processor = AutoProcessor.from_pretrained(model_name)
  model = SeamlessM4TModel.from_pretrained(model_name)
except OSError as e:
  print(f"Error when load model and related layer: {e}")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
# Load the model and related layers from local filesystem if model available in filesystem
try:
  processor = AutoProcessor.from_pretrained(save_url)
  model = SeamlessM4TModel.from_pretrained(save_url)
  print(f"The model and related layers were loaded from {save_url}")
except OSError as e:
  print(f"Error when load model and related layer from local filesystem: {e}")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


The model and related layers were loaded from ./seamless_m4t_model


In [8]:
# Save the model and related layers to a local file
model.save_pretrained(save_url)
processor.save_pretrained(save_url)

print(f"The model and related layers were saved from {save_url}")

Mô hình và các lớp liên quan đã được lưu vào ./seamless_m4t_model


Test model

In [51]:
# Load the audio from local filesystem
input_file = "data/audio/taunt.wav"
arr, org_sr = torchaudio.load(input_file)

AUDIO_SAMPLE_RATE = 16000.0
# Adjusts the input audio sample frequency to match the model's audio sampling rate
new_arr = torchaudio.functional.resample(arr, orig_freq=org_sr, new_freq=AUDIO_SAMPLE_RATE)

# Maximum length of input audio, in seconds
MAX_INPUT_AUDIO_LENGTH = 60  # in seconds
max_length = int(MAX_INPUT_AUDIO_LENGTH * AUDIO_SAMPLE_RATE)
if new_arr.shape[1] > max_length:
    new_arr = new_arr[:, :max_length]
    gr.Warning(f"Input audio is too long. Only the first {MAX_INPUT_AUDIO_LENGTH} seconds is used.")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Run gpu if device have
# Convert and format audio data before feeding it into the model for processing
input_data = processor(audios = new_arr, sampling_rate=AUDIO_SAMPLE_RATE, return_tensors="pt").to(device)

In [27]:
tokens_ids = model.generate(
  **input_data, 
  generate_speech=False, 
  tgt_lang="cmn", 
  num_beams=5, 
  do_sample=True
)[0].cpu().squeeze().detach().tolist()

text_out = processor.decode(tokens_ids, skip_special_tokens=True)

print("Translate: ", text_out)

Translate:  我们美国人民为了形成一个更完美的联盟建立正义确保家庭平静提供共同的防御


In [24]:
output = model.generate(
  **input_data, 
  return_intermediate_token_ids=True, 
  tgt_lang="cmn", 
  num_beams=5, 
  do_sample=True, 
  spkr_id=10
)
        
waveform = output.waveform.cpu().squeeze().detach().numpy()
tokens_ids = output.sequences.cpu().squeeze().detach().tolist()

text_out = processor.decode(tokens_ids, skip_special_tokens=True)
# Audio playback
ipd.Audio(waveform, rate=AUDIO_SAMPLE_RATE)
# print("Audio", (AUDIO_SAMPLE_RATE, waveform))

In [25]:
output = model.generate(
  **input_data, 
  return_intermediate_token_ids=True, 
  tgt_lang="vie", 
  num_beams=5, 
  do_sample=True, 
  spkr_id=10
)
        
waveform = output.waveform.cpu().squeeze().detach().numpy()
tokens_ids = output.sequences.cpu().squeeze().detach().tolist()

text_out = processor.decode(tokens_ids, skip_special_tokens=True)
# Audio playback
ipd.Audio(waveform, rate=AUDIO_SAMPLE_RATE)

Functions of the seamlessM4T model

In [8]:
TASK_NAMES = [
    "S2ST (Speech to Speech translation)",
    "S2TT (Speech to Text translation)",
    "T2ST (Text to Speech translation)",
    "T2TT (Text to Text translation)",
    "ASR (Automatic Speech Recognition)",
] # Tasks that the model can perform

AUDIO_SAMPLE_RATE = 16000.0 # Sampling rate of the input audio (Compatible with a fixed sampling rate parameter trained using the SeamlessM4T model)

MAX_INPUT_AUDIO_LENGTH = 60  # Maximum length of input audio, in seconds

DEFAULT_TARGET_LANGUAGE = "English" # Default output target language: English

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Run gpu if device have

In [10]:
def predict(
  task_name: str,
  audio_source: str,
  input_audio_mic: str | None,
  input_audio_file: str | None,
  input_text: str | None,
  source_language: str | None,
  target_language: str | None
) -> tuple[tuple[int, np.ndarray] | None, str]:
  task_name = task_name.split()[0]
  source_language_code = LANGUAGE_CODE_TO_NAME[source_language] if source_language else None
  target_language_code = LANGUAGE_CODE_TO_NAME[target_language]
  
  # Input cases are audio
  if task_name in ["S2ST", "S2TT", "ASR"]:
    # Check case is microphone audio and cases are different audio
    if audio_source == "microphone":
      input_data = input_audio_mic
    else: 
      input_data = input_audio_file
    
    # Load the input audio with a tensor containing the audio data and the original sampling frequency of the audio input
    # arr: Tensor([nums_channels, n_samples]), org_sr: int
    arr, org_sr = torchaudio.load(input_data)
    
    # Adjusts the input audio sample frequency to match the model's audio sampling rate
    new_arr = torchaudio.functional.resample(arr, orig_freq=org_sr, new_freq=AUDIO_SAMPLE_RATE)
    
    max_length = int(MAX_INPUT_AUDIO_LENGTH * AUDIO_SAMPLE_RATE)
    # Check if the number of samples 
    if new_arr.shape[1] > max_length:
      new_arr = new_arr[:, :max_length]
      gr.warning(f"Input audio is too long. Only the first {MAX_INPUT_AUDIO_LENGTH} seconds is used.")
    
    # Convert and format audio data before feeding it into the model for processing
    input_data = processor(
      audios=new_arr,
      sampling_rate=AUDIO_SAMPLE_RATE,
      return_tensors="pt"
    ).to(device)
  else: 
    input_data = processor(
      text=input_text,
      src_lang=source_language_code,
      return_tensors="pt"
    ).to(device)
    
  # Generate tokens from the model based on pre-preprared input data according to each task instance
  if task_name in ["S2TT", "T2TT"]:
    tokens_ids = model.generate(
      **input_data,
      generate_speech=False,
      tgt_lang=target_language_code,
      num_beams=5,
      do_sample=True
    )[0].cpu().squeeze().detach().tolist()
  else: 
    output = model.generate(
      **input_data,
      return_intermediate_token_ids=True,
      tgt_lang=target_language_code,
      num_beams=5,
      do_sample=True,
      spkr_id=LANG_TO_SPKR_ID[target_language_code][0]
    )
    
    waveform = output.waveform.cpu().squeeze().detach().numpy()
    tokens_ids = output.sequences.cpu().squeeze().detach().tolist()
  
  # Decode the tokens IDs into texts
  text_out = processor.decode(tokens_ids, skip_special_tokens=True)
  
  if task_name in ["S2ST", "T2ST"]:
    return (AUDIO_SAMPLE_RATE, waveform), text_out
  else:
    return None, text_out

In [7]:
# The function converts the audio sound to the required audio
def process_s2st(
  input_audio_file: str,
  target_language: str
) -> tuple[tuple[int, np.ndarray] | None, str]:
  return predict(
    task_name="S2ST",
    audio_source="file",
    input_audio_mic=None,
    input_audio_file=input_audio_file,
    input_text=None,
    source_language=None,
    target_language=target_language
  )
  
# The function converts the audio sound to the required text
def process_s2tt(
  input_audio_file: str,
  target_language: str
) -> tuple[tuple[int, np.ndarray] | None, str]:
  return predict(
    task_name="S2TT",
    audio_source="file",
    input_audio_mic=None,
    input_audio_file=input_audio_file,
    input_text=None,
    source_language=None,
    target_language=target_language
  )
  
# The function converts the text to the required audio
def process_t2st(
  input_text: str,
  source_language: str,
  target_language: str
) -> tuple[tuple[int, np.ndarray] | None, str]:
  return predict(
    task_name="T2ST",
    audio_source="",
    input_audio_mic=None,
    input_audio_file=None,
    input_text=input_text,
    source_language=source_language,
    target_language=target_language
  )
  
# The function converts the text to the required text
def process_t2tt(
  input_text: str,
  source_language: str,
  target_language: str
) -> tuple[tuple[int, np.ndarray] | None, str]:
  return predict(
    task_name="T2TT",
    audio_source="",
    input_audio_mic=None,
    input_audio_file=None,
    input_text=input_text,
    source_language=source_language,
    target_language=target_language
  )
  
# The function automatic speech recognition
def process_asr(
  input_audio_file: str,
  target_language: str
) -> tuple[tuple[int, np.ndarray] | None, str]:
  return predict(
    task_name="ASR",
    audio_source="file",
    input_audio_mic=None,
    input_audio_file=input_audio_file,
    input_text=None,
    source_language=None,
    target_language=target_language
  )

In [9]:
audio, text_out = process_s2st("data/audio/taunt.wav", "Vietnamese")
print("Audio: ", audio)
print("Text: ", text_out)

Audio:  (16000.0, array([-1.5829524e-05, -1.1388641e-04, -2.8213311e-05, ...,
       -7.0676891e-05,  9.5056232e-05, -7.0123926e-05], dtype=float32))
Text:  Bây giờ đi đi, nếu không tôi sẽ chế giễu bạn lần thứ hai.


In [11]:
ipd.Audio(audio[1], rate=audio[0])

Build Machine Learning Web Apps

In [2]:
def update_audio_ui(audio_source: str) -> tuple[dict, dict]:
  mic = audio_source == "microphone"
  return (
    gr.update(visible=mic, value=None),  # input_audio_mic
    gr.update(visible=not mic, value=None),  # input_audio_file
  )
  
def update_input_ui(task_name: str) -> tuple[dict, dict, dict, dict]:
  task_name = task_name.split()[0]
  if task_name == "S2ST":
    return (
      gr.update(visible=True),  # audio_box
      gr.update(visible=False),  # input_text
      gr.update(visible=False),  # source_language
      gr.update(
        visible=True, choices=S2ST_TARGET_LANGUAGE_NAMES, value=DEFAULT_TARGET_LANGUAGE
      ),  # target_language
    )
  elif task_name == "S2TT":
    return (
      gr.update(visible=True),  # audio_box
      gr.update(visible=False),  # input_text
      gr.update(visible=False),  # source_language
      gr.update(
        visible=True, choices=S2TT_TARGET_LANGUAGE_NAMES, value=DEFAULT_TARGET_LANGUAGE
      ),  # target_language
    )
  elif task_name == "T2ST":
    return (
      gr.update(visible=False),  # audio_box
      gr.update(visible=True),  # input_text
      gr.update(visible=True),  # source_language
      gr.update(
        visible=True, choices=S2ST_TARGET_LANGUAGE_NAMES, value=DEFAULT_TARGET_LANGUAGE
      ),  # target_language
    )
  elif task_name == "T2TT":
    return (
      gr.update(visible=False),  # audio_box
      gr.update(visible=True),  # input_text
      gr.update(visible=True),  # source_language
      gr.update(
        visible=True, choices=T2TT_TARGET_LANGUAGE_NAMES, value=DEFAULT_TARGET_LANGUAGE
      ),  # target_language
    )
  elif task_name == "ASR":
    return (
      gr.update(visible=True),  # audio_box
      gr.update(visible=False),  # input_text
      gr.update(visible=False),  # source_language
      gr.update(
        visible=True, choices=S2TT_TARGET_LANGUAGE_NAMES, value=DEFAULT_TARGET_LANGUAGE
      ),  # target_language
    )
  else:
    raise ValueError(f"Unknown task: {task_name}")

def update_output_ui(task_name: str) -> tuple[dict, dict]:
  task_name = task_name.split()[0]
  if task_name in ["S2ST", "T2ST"]:
    return (
      gr.update(visible=True, value=None), # Output audio
      gr.update(value=None) # Output text
    )
  elif task_name in ["S2TT", "T2TT", "ASR"]:
    return (
      gr.update(visible=False, value=None), # Output audio
      gr.update(value=None) # Output text
    )
  else:
    raise ValueError(f"Unknown task: {task_name}")
  
def update_example_ui(task_name: str) -> tuple[dict, dict, dict, dict, dict]:
  task_name = task_name.split()[0]
  return (
    gr.update(visible=task_name == "S2ST"), 
    gr.update(visible=task_name == "S2TT"),
    gr.update(visible=task_name == "T2ST"),
    gr.update(visible=task_name == "T2TT"),
    gr.update(visible=task_name == "ASR")
  )

In [5]:
DESCRIPTION = """# SeamlessM4T
[SeamlessM4T](https://github.com/facebookresearch/seamless_communication) is designed to provide high-quality
translation, allowing people from different linguistic communities to communicate effortlessly through speech and text.
This unified model enables multiple tasks like Speech-to-Speech (S2ST), Speech-to-Text (S2TT), Text-to-Speech (T2ST)
translation and more, without relying on multiple separate models.
"""

CACHE_EXAMPLES = os.getenv("CACHE_EXAMPLES") == "1"

In [9]:
with gr.Blocks(css="style.css") as demo:
  gr.Markdown(DESCRIPTION)
  gr.DuplicateButton(
    value="Duplicate Space for private use",
    elem_id="duplicate-button",
    visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1"
  )
  with gr.Group():
    task_name = gr.Dropdown(
      label="Task",
      choices=TASK_NAMES,
      value=TASK_NAMES[0],
    )
    with gr.Row():
      source_language = gr.Dropdown(
        label="Source language",
        choices=TEXT_SOURCE_LANGUAGE_NAMES,
        value="English",
        visible=False
      )
      target_language = gr.Dropdown(
        label="Target language",
        choices=S2ST_TARGET_LANGUAGE_NAMES,
        value=DEFAULT_TARGET_LANGUAGE
      )
    with gr.Row() as audio_box:
      audio_source = gr.Radio(
        label="Audio source",
        choices=["file", "microphone"],
        value="file"
      )
      input_audio_mic = gr.Audio(
        label="Input speech",
        type="filepath",
        source="microphone",
        visible=False
      )
      input_audio_file = gr.Audio(
        label="Input speech",
        type="filepath",
        source="upload",
        visible=True
      )
    input_text = gr.Textbox(label="Input text", visible=False)
    btn = gr.Button("Translate")
    with gr.Column():
      output_audio = gr.Audio(
        label="Translated speech",
        autoplay=False,
        streaming=False,
        type="numpy"
      )
      output_text = gr.Textbox(label="Translated text")
  with gr.Row(visible=True) as s2st_example_row:
    s2st_example = gr.Examples(
      examples=[
        ["assets/sample_input.mp3", "English"],
        ["assets/sample_input.mp3", "English"],
      ],
      inputs=[input_audio_file, target_language],
      outputs=[output_audio, output_text],
      fn=process_s2st,
      cache_examples=CACHE_EXAMPLES,
    )
  with gr.Row(visible=False) as s2tt_example_row:
    s2tt_examples = gr.Examples(
      examples=[
        ["assets/sample_input.mp3", "English"],
        ["assets/sample_input.mp3", "English"]
      ],
      inputs=[input_audio_file, target_language],
      outputs=[output_audio, output_text],
      fn=process_s2tt,
      cache_examples=CACHE_EXAMPLES,
    )
  with gr.Row(visible=False) as t2st_example_row:
    t2st_examples = gr.Examples(
      examples=[
        ["My favorite animal is the elephant.", "English", "Vietnamese"],
        ["My favorite animal is the elephant.", "English", "Mandarin Chinese"],
      ],
      inputs=[input_text, source_language, target_language],
      outputs=[output_audio, output_text],
      fn=process_t2st,
      cache_examples=CACHE_EXAMPLES,
    )
  with gr.Row(visible=False) as t2tt_example_row:
    t2tt_examples = gr.Examples(
      examples=[
        ["My favorite animal is the elephant.", "English", "Vietnamese"],
        ["My favorite animal is the elephant.", "English", "Mandarin Chinese"],
      ],
      inputs=[input_text, source_language, target_language],
      outputs=[output_audio, output_text],
      fn=process_t2tt,
      cache_examples=CACHE_EXAMPLES,
    )
  with gr.Row(visible=False) as asr_example_row:
    asr_examples = gr.Examples(
      examples=[
        ["assets/sample_input.mp3", "English"],
        ["assets/sample_input_2.mp3", "Vietnamese"],
      ],
      inputs=[input_audio_file, target_language],
      outputs=[output_audio, output_text],
      fn=process_asr,
      cache_examples=CACHE_EXAMPLES,
    )
  audio_source.change(
    fn=update_audio_ui,
    inputs=audio_source,
    outputs=[
      input_audio_mic,
      input_audio_file
    ],
    queue=False,
    api_name=False
  )
  task_name.change(
    fn=update_input_ui,
    inputs=task_name,
    outputs=[
      audio_box,
      input_text,
      source_language,
      target_language
    ],
    queue=False,
    api_name=False
  ).then(
    fn=update_output_ui,
    inputs=task_name,
    outputs=[output_audio, output_text],
    queue=False,
    api_name=False,
  ).then(
    fn=update_example_ui,
    inputs=task_name,
    outputs=[
      s2st_example_row,
      s2tt_example_row,
      t2st_example_row,
      t2tt_example_row,
      asr_example_row,
    ],
    queue=False,
    api_name=False,
  )
  btn.click(
    fn=predict,
    inputs=[
      task_name,
      audio_source,
      input_audio_mic,
      input_audio_file,
      input_text,
      source_language,
      target_language,
    ],
    outputs=[output_audio, output_text],
    api_name="run",
  )
demo.queue(max_size=50).launch()

TypeError: Audio.__init__() got an unexpected keyword argument 'source'