In [None]:
!pip install yt-dlp transformers huggingface_hub librosa soundfile noisereduce pydub deepmultilingualpunctuation sentence_transformers scikit-learn

import yt_dlp
from deepmultilingualpunctuation import PunctuationModel
import soundfile as sf
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import noisereduce as nr
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import librosa
import json
import torch
import string
from typing import Dict, List
from transformers import pipeline
from huggingface_hub import login
import re
import gc
import time
import os
from huggingface_hub import InferenceClient
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
from pydub import AudioSegment

#### UTILS ###########
def split_at_center_comma(input_string):
    """
    Takes input a sentence and splits it into 2 near the center-most comma
    """
    comma_indices = [i for i, char in enumerate(input_string) if char == ',']

    if not comma_indices:
        return input_string, ""

    center_index = len(comma_indices) // 2
    center_comma_index = comma_indices[center_index]

    first_part = input_string[:center_comma_index]
    second_part = input_string[center_comma_index + 1:]  # Skip the comma

    return first_part, second_part

def split_into_sentences(text):
    """
    Takes a string and returns a list of all sentences in it
    """
    # Split the text at full stops,?,!
    sentences = re.split(r'[.?!]', text)
    # Remove any leading/trailing whitespace and filter out empty sentences
    sentences = [sentence.strip() for sentence in sentences if sentence.strip()]
    return sentences

def format_data_chunks(data: List[Dict[str,List[float]]]):
    for i in range(len(data)):
        try:
            data[i]['text'] = data[i]['text'].translate(str.maketrans('', '', string.punctuation))  # Remove punctuations

            if data[i]['text'] == '000':

                data[i-1]['text'] = data[i-1]['text'] + data[i]['text']

                del data[i]

        except:
            pass

def check_if_chunk_is_long_enough(chunk):
        """
        Checks if a particular chunk is bigger than 15s
        """
        if chunk['end_time'] - chunk['start_time'] > 15:
            return False
        return True

def download_audio(url):
    '''
    Downloads the audio from YT
    '''
    ydl_opts = {
        'format': 'bestaudio/best',
        'postprocessors': [{
            'key': 'FFmpegExtractAudio',
            'preferredcodec': 'wav',
            'preferredquality': '192',
        }],
        'outtmpl': 'audio_extracted.%(ext)s',
    }
    try:
        with yt_dlp.YoutubeDL(ydl_opts) as ydl:
            ydl.download([url])
        print("Audio downloaded successfully.")
    except Exception as e:
        print(f"An error occurred: {e}")

def clean_chunks(chunks):
    """
    Removes filler words from the text of the chunks and merges chunks with less than 3 words while ensuring they don't exceed 15s
    """
    filler_words = {'ah', 'uh', 'uhm','',' ah', 'ah,','ah ','uh ','uhm ','uhm,','uh,',' uh',' uhm',' uhm,',' uhm',' uh',' uh ',' uhm',' uhm ',' uhm'}
    processed = []
    i = 0
    while i < len(chunks):
        chunk = chunks[i]
        # Split text into words and remove filler words
        words = [word for word in chunk['text'].split() if word.lower() not in filler_words]
        text = ' '.join(words)
        # If no words left after removal, skip this chunk
        if not text:
            i += 1
            continue
        # Update the chunk's text
        chunk['text'] = text
        # Calculate duration
        duration = chunk['end_time'] - chunk['start_time']
        # Check if text has less than 3 words and duration < 15 seconds
        if len(words) < 3 and duration < 15:
            # Try to merge with the previous chunk if possible
            if len(processed) > 0:
                prev_chunk = processed[-1]
                combined_duration = prev_chunk['end_time'] - prev_chunk['start_time'] + duration
                if combined_duration < 15:
                    # Merge with previous chunk
                    prev_chunk['end_time'] = chunk['end_time']
                    prev_chunk['text'] += ' ' + chunk['text']
                    i += 1
                    continue
            # Try to merge with the next chunk if possible
            if i + 1 < len(chunks):
                next_chunk = chunks[i + 1]
                combined_duration = next_chunk['end_time'] - chunk['start_time']
                if combined_duration < 15:
                    # Merge with next chunk
                    chunk['end_time'] = next_chunk['end_time']
                    chunk['text'] += ' ' + next_chunk['text']
                    # Skip the next chunk
                    i += 2
                    processed.append(chunk)
                    continue
            # If merging is not possible, keep the chunk as is
            processed.append(chunk)
            i += 1
            continue
        # If chunk has 3 or more words or duration >= 15, keep it as is
        processed.append(chunk)
        i += 1
    return processed

def reduce_to_less_than_15(chunks: List, timestamps):
    """
    Takes input a list of chunks and tris to reduce them to less than 15s by breaking the text at a comma and also adjusting the timestamps accordingly based on the word-level timestamps
    """
    reduced_chunks = []

    for i in range(len(chunks)):
        begin_time = chunks[i]['start_time']
        end_time = chunks[i]['end_time']

        words = []
        split_sentence1, split_sentence2 = split_at_center_comma(chunks[i]['text'])
        end_time1 = 0
        start_time2 = 0

        # Gather all words that are within the time range
        for i in range(len(timestamps)):
            if timestamps[i]['timestamp'][0] >= begin_time and timestamps[i]['timestamp'][1] <= end_time:
                words.append(timestamps[i])

        last_wrd_sentence1 = split_sentence1.split()[-1]

        for i in range(len(words)):
            if words[i]['text'].strip().lower() == last_wrd_sentence1.strip().lower():
                end_time1 = words[i]['timestamp'][0] + 0.01
                start_time2 = words[i+1]['timestamp'][0] - 0.01
                break


        chunk1 = {"text": split_sentence1,"start_time": begin_time,"end_time": end_time1}
        chunk2 = {"text": split_sentence2,"start_time": start_time2,"end_time": end_time}


        reduced_chunks.append(chunk1)
        reduced_chunks.append(chunk2)

    return reduced_chunks


def prepare_for_output(chunks):
  for i in range(len(chunks)):
    chunks[i]['chunk_id'] = i
    chunks[i]['chunk_length'] = chunks[i]['end_time'] - chunks[i]['start_time']

  return chunks


##########UTILS END############

def create_proper_chunks(list_of_sentences: List[str],timestamps: Dict[str, List[float]]):
    chunks = []
    word_level_timestamps = list(timestamps)  # List of dictionaries

    def split_sentence(sentence: str):
        words = sentence.replace("-", " ").split()
        sentence_words = []
        for word in words:
            # 2. Remove punctuation
            word = word.translate(str.maketrans('', '', string.punctuation))
            if word:  # Add only if the word is not empty after punctuation removal
                sentence_words.append(word)

        return sentence_words


    # Create a flat list of all words from list_of_sentences
    all_sentence_words = []
    for sentence in list_of_sentences:
        # 1. Split at hyphens first
        words = sentence.replace("-", " ").split()

        for word in words:
            # 2. Remove punctuation
            word = word.translate(str.maketrans('', '', string.punctuation))
            if word:  # Add only if the word is not empty after punctuation removal
                all_sentence_words.append(word)

    # Check if the number of words matches
    if len(all_sentence_words) != len(word_level_timestamps):
        raise ValueError(f"The total number of words in sentences does not match the number of word timestamps {len(all_sentence_words)} != {len(word_level_timestamps)}")

    word_counter = 0
    for i in range(len(list_of_sentences)):
        sentence = list_of_sentences[i]
        start_time = 0
        end_time = 0
        begin_word_counter = word_counter
        end_word_counter = word_counter
        words = split_sentence(sentence)

        for j in range(len(words)):

            word = words[j]
            if word.lower().strip() == word_level_timestamps[word_counter]['text'].strip().lower():
                if j == 0:
                    start_time = word_level_timestamps[word_counter]['timestamp'][0] - 0.25
                if j == len(words) - 1:
                    end_time = word_level_timestamps[word_counter]['timestamp'][1] + 0.25
                    end_word_counter = word_counter

            if word.lower().strip() != word_level_timestamps[word_counter]['text'].lower().strip():
                print(word)
                print(word_level_timestamps[i]['text'])

            word_counter += 1

        if end_time - start_time > 15:
            # split sentence in half at a comma
            if ',' in sentence:
                split_sentence1, split_sentence2 = split_at_center_comma(sentence)

                words1 = split_sentence(split_sentence1)
                words2 = split_sentence(split_sentence2)

                start_time1 = end_time1 = 0
                start_time2 = end_time2 = 0

                for j in range(len(words1)):
                    word = words1[j]
                    if word.lower().strip() == word_level_timestamps[begin_word_counter]['text'].strip().lower():
                        if j == 0:
                            start_time1 = word_level_timestamps[begin_word_counter]['timestamp'][0] - 0.25
                        if j == len(words1) - 1:
                            end_time1 = word_level_timestamps[begin_word_counter]['timestamp'][1] + 0.25

                    if word.lower().strip() != word_level_timestamps[begin_word_counter]['text'].lower().strip():
                        print(word)
                        print(word_level_timestamps[i]['text'])

                    begin_word_counter += 1

                for j in range(len(words2)):
                    word = words2[j]
                    if word.lower().strip() == word_level_timestamps[begin_word_counter]['text'].strip().lower():
                        if j == 0:
                            start_time2 = word_level_timestamps[begin_word_counter]['timestamp'][0] - 0.01
                        if j == len(words2) - 1:
                            end_time2 = word_level_timestamps[begin_word_counter]['timestamp'][1] + 0.01

                    if word.lower().strip() != word_level_timestamps[begin_word_counter]['text'].lower().strip():
                        print(word)
                        print(word_level_timestamps[i]['text'])

                    begin_word_counter += 1


                dictionary1 = {'text': split_sentence1}
                dictionary1['start_time'] = start_time1
                dictionary1['end_time'] = end_time1

                dictionary2 = {'text': split_sentence2}
                dictionary2['start_time'] = start_time2
                dictionary2['end_time'] = end_time2

                chunks.append(dictionary1)
                chunks.append(dictionary2)

        else:
            dictionary = {'text': sentence}
            dictionary['start_time'] = start_time
            dictionary['end_time'] = end_time
            chunks.append(dictionary)

    return chunks


#### Preprocessing audio
def preprocess_audio(input_file, output_file='audio_processed.wav', target_sr=16000):
    data, sr = sf.read(input_file)
    # Convert to mono if stereo
    if data.ndim > 1:
        data = data.mean(axis=1)

    # Resample if necessary
    if sr != target_sr:
        data = librosa.resample(data, orig_sr=sr, target_sr=target_sr)
    # Noise reduction
    reduced_noise = nr.reduce_noise(y=data, sr=target_sr)
    # Save the processed audio
    sf.write(output_file, reduced_noise, target_sr)


#### Transcribing audio using Whisper
def transcribe_audio_with_timestamps(audio_file, model_name='openai/whisper-large-v3-turbo',chunk_length=30):
    """
    Transcribe an audio file in chunks and return transcription with timestamps.
    """

    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

    model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_name, torch_dtype=torch_dtype, low_cpu_mem_usage=False, use_safetensors=True
    )
    model.to(device)

    processor = AutoProcessor.from_pretrained(model_name)

    pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    chunk_length_s=chunk_length,
    feature_extractor=processor.feature_extractor,
    torch_dtype=torch_dtype,
    device=device,
    )


    result = pipe(audio_file,return_timestamps='word')
    model.to('cpu')  # Move model to CPU
    del model        # Delete the model
    torch.cuda.empty_cache()

    return result


def punctuate(text: str):
    """
    Punctuate the given text using a finetuned bert model.
    """
    model = PunctuationModel()
    return model.restore_punctuation(text)

def merge_chunks(data, embedding_model="sentence-transformers/all-MiniLM-L6-v2", similarity_threshold=0.2, max_duration=15):
    """
    Merges semantically similar and continuous chunks of text based on embeddings
    and duration constraints.

    Args:
        data (list of dict): A list of dictionaries, where each dictionary has
                             'start_time', 'end_time', and 'text' keys.
        embedding_model (str, optional): The name of the sentence transformer model
                                         to use. Defaults to "all-MiniLM-L6-v2".
        similarity_threshold (float, optional): The minimum cosine similarity score
                                               for merging chunks. Defaults to 0.8.
        max_duration (float, optional): The maximum allowed duration for a merged chunk
                                       in seconds. Defaults to 15.

    Returns:
        list of dict: A list of merged dictionaries with updated 'start_time', 'end_time',
                      and 'text'.
    """

    # 1. Embeddings
    model = SentenceTransformer(embedding_model)
    texts = [d["text"] for d in data]
    embeddings = model.encode(texts)

    merged_chunks = []
    i = 0
    while i < len(data):
        current_chunk = data[i]
        current_embedding = embeddings[i]
        j = i + 1
        while j < len(data):
            # Check duration constraint
            if data[j]["end_time"] - current_chunk["start_time"] > max_duration:
                break

            # Check semantic similarity
            similarity = cosine_similarity(
                current_embedding.reshape(1, -1), embeddings[j].reshape(1, -1)
            )[0][0]

            if similarity >= similarity_threshold:
                # Merge
                current_chunk["end_time"] = data[j]["end_time"]
                current_chunk["text"] += " " + data[j]["text"]
                current_embedding = model.encode([current_chunk["text"]])[
                    0
                ]  # Update embedding
                j += 1
            else:
                break
        merged_chunks.append(current_chunk)
        i = j

    return merged_chunks

####################




if __name__ == "__main__":

    # Download the video
    video_url = 'https://www.youtube.com/watch?v=zdmEzqpAG70'
    download_audio(video_url)
    ####

    ## Preprocesss the audio
    preprocess_audio('audio_extracted.wav')
    #####

    ## Returns a transcription with word level timestamps
    transcription = transcribe_audio_with_timestamps('audio_processed.wav',)   # Transcription: Dict[[str,Dict[str,List[float]]]]
    ###

    text = transcription["text"]   ### Entire transcribed text as a single string
    punctuated_text = punctuate(text)  ## Punctuate the text
    transcription["text"] = punctuated_text  ## Update the text with punctuated text

    list_of_sentences = split_into_sentences(transcription["text"])    ## Split the punctuated string into a list of sentences

    format_data_chunks(transcription["chunks"])   ### Clean the text provided by Whisper (remove all the punctuations)

    ### transcription["chunks"] is a list of dictionaries. Each dictionary has keys 'text' (a single word from the transcription) and 'timestamps' (List[start,end])
    chunks = create_proper_chunks(list_of_sentences,transcription["chunks"])  # This function takes input the word level timestamps and all the punctuated sentences and returns a list of dictionaries
    # chunks is a list of dictionaries. Each dictionary has "text" (a single sentence),'start_time' (float) and 'end_time' (float)

    ## Collect all the chunks that are bigger than 15s
    filtered_chunks = []
    for i in range(len(chunks)):
        if not check_if_chunk_is_long_enough(chunks[i]):
            filtered_chunks.append(chunks[i])

    ## Reduce the big chunks into smaller chunks by dividing at a comma
    reduced_chunks = reduce_to_less_than_15(filtered_chunks,transcription["chunks"])

    ## Check again if all the chunks are less than 15s
    for i in range(len(reduced_chunks)):
      if check_if_chunk_is_long_enough(reduced_chunks[i]):
        chunks.append(reduced_chunks[i])    # Append all the chunks that are less than 15s

      else:     # If suppose there exists chunks greater than 15s, break them down again once more.
        further_reduced = reduce_to_less_than_15([reduced_chunks[i]],transcription["chunks"])
        for j in range(len(further_reduced)):
            if check_if_chunk_is_long_enough(further_reduced[j]):
                chunks.append(reduced_chunks[i])

            else:
                pass

    # Remove all the chunks bigger than 15s. Earlier we have just filtered the bigger chunks but we did not remove them
    filtered_chunks = []
    for i in range(len(chunks)):
        if check_if_chunk_is_long_enough(chunks[i]):
            filtered_chunks.append(chunks[i])

    processed_chunks = clean_chunks(filtered_chunks)   # Do some cleaning on the text. Join small chunks (with less than 3 words) together while ensuring they don't exceed 15s
    # print(processed_chunks)

    merged_data = merge_chunks(processed_chunks)    # Merges semantically similar and continuous chunks of text based on embeddings while ensuring they don't exceed 15s.

    output = prepare_for_output(merged_data)
    print(output)
