In [39]:
# !pip install ffmpeg-python datasets evaluate jiwer gradio

In [40]:
from transformers import pipeline
import torch

device = device = 'cuda' if torch.cuda.is_available() else 'cpu'

model_id = 'openai/whisper-small'
model = pipeline('automatic-speech-recognition', model=model_id, device=device)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [41]:
model_conf = {
    'max_new_tokens': 258,
    'generate_kwargs': {'task': 'transcribe'},
    'chunk_length_s': 8,
    'batch_size': 64,
    'return_timestamps': True
}

In [42]:
import time

def time_took(f):
  def timed_f(*args, **kwargs):
    start = time.time()
    output = f(*args, **kwargs)
    end = time.time()
    print(f'{f.__name__} took: {end-start:.3f} secs')
    print('----------')
    return output
  return timed_f

In [43]:
@time_took
def predict(audio, config=model_conf):
  predictions = model(audio.copy(), **config)
  return predictions

In [44]:
import ffmpeg
import uuid

@time_took
def extract_audio_from_video(video_file):
  start = time.time()
  input_stream = ffmpeg.input(video_file)
  audio = input_stream.audio
  temp_audio_file = f'temp_{uuid.uuid4()}.mp3'
  ffmpeg.output(audio, temp_audio_file).run()
  return temp_audio_file

In [45]:
import librosa

@time_took
def convert_audio_to_array(audio_path):
  audio_arr, sr = librosa.load(audio_path, sr=16_000)
  return audio_arr, sr

In [46]:
def format_time_in_iso8601(secs):
  hrs = int(secs // 3600)
  mins = int((secs % 3600) // 60)
  secs = secs % 60
  return f"{hrs:02d}:{mins:02d}:{secs:06.3f}"

@time_took
def text_to_vtt(prediction_chunks, output_file):
  lang = 'en'
  with open(output_file, 'w', encoding='utf-8') as f:
    f.write(f'WEBVTT\nKind: captions\nLanguage: {lang}\n')
    for caption in prediction_chunks:
      start_time = format_time_in_iso8601(caption['timestamp'][0])
      end_time = format_time_in_iso8601(caption['timestamp'][1])
      text = caption['text']
      f.write(f'{start_time} --> {end_time}\n')
      f.write(f'{text}\n\n')

In [47]:
import os

@time_took
def remove_temp_audio_file(audio_file):
  if os.path.exists(audio_file):
    os.remove(audio_file)

In [51]:
@time_took
def create_video_caption(video_file, output_file):
  audio_path = extract_audio_from_video(video_file)
  audio_arr, sr = convert_audio_to_array(audio_path)
  remove_temp_audio_file(audio_path)
  predictions = predict(audio_arr)
  text_to_vtt(predictions['chunks'], output_file)

In [52]:
video_file = 'Halves and fourths _ Geometry _ Early Math _ Khan Academy-0lSTXtwPuOU.mp4'
output_file = 'caption_en.vtt'

In [53]:
create_video_caption(video_file, output_file)

extract_audio_from_video took: 2.704 secs
----------
convert_audio_to_array took: 0.467 secs
----------
remove_temp_audio_file took: 0.001 secs
----------
predict took: 6.304 secs
----------
text_to_vtt took: 0.001 secs
----------
create_video_caption took: 9.479 secs
----------


In [75]:
import gradio as gr

def genereate_caption(video_file):
  caption_file = 'caption_en_test.vtt'
  create_video_caption(video_file, caption_file)
  return video_file, caption_file

demo = gr.Interface(genereate_caption, gr.Video(),
                                     gr.Video(
    height=300,
    width=600,
    ), submit_btn='Create Caption', allow_flagging='never')

if __name__ == '__main__':
    demo.launch(debug=False)

Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Running on public URL: https://5431ae4fb5706fa3fb.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


In [76]:
# import gradio as gr

# def play_video(video_file, caption_file):
#   def play():
#     return (video_file, caption_file)

#   demo = gr.Interface(play, None, gr.Video(height=300, width=600), submit_btn='Play',
#                       allow_flagging='never')

#   if __name__ == '__main__':
#     demo.launch()

In [77]:
# play_video(video_file, output_file)