In [None]:
from google.colab import files
from google.colab.patches import cv2_imshow
from moviepy.editor import VideoFileClip
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import numpy as np
import cv2
import torch
import os

# Initialize Wav2Vec2 model and processor
def load_model_and_processor(model_name):
    processor = Wav2Vec2Processor.from_pretrained(model_name)
    model = Wav2Vec2ForCTC.from_pretrained(model_name)
    return model, processor

# Function to extract audio from video
def extract_audio(video_file, audio_file):
    video = VideoFileClip(video_file)
    audio = video.audio
    audio.write_audiofile(audio_file)

# Function to transcribe audio using Wav2Vec2
def transcribe_audio(model, processor, audio_file):
    waveform, sample_rate = torchaudio.load(audio_file)

    # Ensure the waveform is a 2D tensor: [channel, samples]
    if waveform.ndim == 1:
        waveform = waveform.unsqueeze(0)

    # Resample to 16kHz if necessary
    if sample_rate != 16000:
        resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
        waveform = resampler(waveform)

    # Wav2Vec2 expects a single channel (mono) input
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)

    # Process the waveform
    inputs = processor(waveform.squeeze(), return_tensors="pt", sampling_rate=16000)

    with torch.no_grad():
        logits = model(**inputs).logits

    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(predicted_ids)
    return transcription[0]

# Function to find the timestamp for a specific word
def find_word_timestamp(word, transcription):
    words = transcription.split()
    print(words)
    timestamps = []
    for i, w in enumerate(words):
        if word.lower() in w.lower():
            timestamps.append(i)  # Simplified; no exact timestamp here
    print(timestamps)
    return timestamps

# Function to skip to a specific time in the video
def skip_to_time(video_file, timestamp):
    video = cv2.VideoCapture(video_file)
    fps = video.get(cv2.CAP_PROP_FPS)
    frame_number = int(fps * timestamp)

    video.set(cv2.CAP_PROP_POS_FRAMES, frame_number)

    while video.isOpened():
        ret, frame = video.read()
        if not ret:
            break
        # cv2_imshow(frame)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    video.release()
    cv2.destroyAllWindows()

# Main function to search for a word in the video
def search_in_video_transformer(video_file, audio_file, model_name, search_word):
    # Load Wav2Vec2 model and processor
    model, processor = load_model_and_processor(model_name)

    # Extract audio from video
    extract_audio(video_file, audio_file)

    # Transcribe audio
    transcription = transcribe_audio(model, processor, audio_file)
    print(f"Transcription: {transcription}")

    # Find timestamp for the specific word
    timestamps = find_word_timestamp(search_word, transcription)

    if timestamps:
        # For simplicity, assume the timestamp of the first found word
        skip_to_time(video_file, timestamps[0])
    else:
        print("Word not found in the video.")

# Example usage
if __name__ == "__main__":
    # Replace with your own file paths and word to search
    video_file = '/content/videoplayback.mp4'  # Upload this file
    audio_file = 'extracted_audio.wav'
    model_name = 'facebook/wav2vec2-large-960h-lv60-self'  # Pre-trained model for English accents
    search_word = 'tackling'

    search_in_video_transformer(video_file, audio_file, model_name, search_word)