# Download YouTube Video

In [None]:
!pip install pytube pydub youtube-dl

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytube
  Downloading pytube-12.1.2-py3-none-any.whl (57 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.0/57.0 KB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pydub
  Downloading pydub-0.25.1-py2.py3-none-any.whl (32 kB)
Collecting youtube-dl
  Downloading youtube_dl-2021.12.17-py2.py3-none-any.whl (1.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m26.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: youtube-dl, pydub, pytube
Successfully installed pydub-0.25.1 pytube-12.1.2 youtube-dl-2021.12.17


In [None]:
import IPython
import os
import pandas as pd
import re
import torch
import torchaudio

from dataclasses import dataclass
from io import BytesIO

from pydub import AudioSegment
from pytube import YouTube

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
labels = bundle.get_labels()
model = bundle.get_model().to(device)
dictionary = {c: i for i, c in enumerate(labels)}

Downloading: "https://download.pytorch.org/torchaudio/models/wav2vec2_fairseq_base_ls960_asr_ls960.pth" to /root/.cache/torch/hub/checkpoints/wav2vec2_fairseq_base_ls960_asr_ls960.pth


  0%|          | 0.00/360M [00:00<?, ?B/s]

In [None]:
def get_wave(aud):
  aud = aud.set_channels(1)
  aud = aud.get_array_of_samples()
  wave = torch.tensor(aud, dtype = torch.float)
  wave = torch.reshape(wave, (1,wave.shape[0]))

  return wave

In [None]:
def get_wav_sr_from_yt_video_id(video_id):
    # Download the video using youtube-dl
    os.system("youtube-dl --extract-audio --audio-format wav --audio-quality 0 -o '%(id)s.%(ext)s' https://youtu.be/{}".format(video_id))

    file_path = "{}.wav".format(video_id)

    # Load the audio file using pydub
    audio = AudioSegment.from_file(file_path, format="wav")

    waveform = get_wave(audio)
    sr = audio.frame_rate

    # Delete file
    if os.path.isfile(file_path):
        os.remove(file_path)
    else:
        print("{} does not exist.".format(file_path))

    # Resample
    if sr != bundle.sample_rate:
        waveform = torchaudio.functional.resample(waveform, sr, bundle.sample_rate)

    return waveform, sr

In [None]:
def clean_lyrics(lyrics):
    lyrics = re.sub(r"\[.*?\]", "", lyrics, flags=re.MULTILINE)
    lyrics = re.sub(r"’", "'", lyrics)
    lyrics = re.sub(r"[^a-zA-Z'’|-]|\s", "|", lyrics)
    return lyrics.upper()

In [None]:
def calculate_emission(waveform):
    torch.cuda.empty_cache()    
    
    length = waveform.shape[1]
    chunks = []
    amount_chunks = 10
    chunks_length = length//amount_chunks
    for i in range(amount_chunks):
        with torch.inference_mode():
            emissions, _ = model(waveform[:, i * chunks_length: min(length, (i + 1) * chunks_length)].to(device))
            emissions = torch.log_softmax(emissions, dim=-1)
            chunks.append(emissions)

    return torch.cat(chunks, dim=1)[0].cpu().detach()

In [None]:
def get_tokens(transcript):
    return [dictionary[c] for c in transcript]

In [None]:
def get_trellis(emission, tokens, blank_id=0):
    num_frame = emission.size(0)
    num_tokens = len(tokens)

    # Trellis has extra diemsions for both time axis and tokens.
    # The extra dim for tokens represents <SoS> (start-of-sentence)
    # The extra dim for time axis is for simplification of the code.
    trellis = torch.empty((num_frame + 1, num_tokens + 1))
    trellis[0, 0] = 0
    trellis[1:, 0] = torch.cumsum(emission[:, 0], 0)
    trellis[0, -num_tokens:] = -float("inf")
    trellis[-num_tokens:, 0] = float("inf")

    for t in range(num_frame):
        trellis[t + 1, 1:] = torch.maximum(
            # Score for staying at the same token
            trellis[t, 1:] + emission[t, blank_id],
            # Score for changing to the next token
            trellis[t, :-1] + emission[t, tokens],
        )
    return trellis

## Find the most likely path (backtracking)

Once the trellis is generated, we will traverse it following the
elements with high probability.

We will start from the last label index with the time step of highest
probability, then, we traverse back in time, picking stay
($c_j \rightarrow c_j$) or transition
($c_j \rightarrow c_{j+1}$), based on the post-transition
probability $k_{t, j} p(t+1, c_{j+1})$ or
$k_{t, j+1} p(t+1, repeat)$.

Transition is done once the label reaches the beginning.

The trellis matrix is used for path-finding, but for the final
probability of each segment, we take the frame-wise probability from
emission matrix.




In [None]:
@dataclass
class Point:
    token_index: int
    time_index: int
    score: float


def backtrack(trellis, emission, tokens, blank_id=0):
    # Note:
    # j and t are indices for trellis, which has extra dimensions
    # for time and tokens at the beginning.
    # When referring to time frame index `T` in trellis,
    # the corresponding index in emission is `T-1`.
    # Similarly, when referring to token index `J` in trellis,
    # the corresponding index in transcript is `J-1`.
    j = trellis.size(1) - 1
    t_start = torch.argmax(trellis[:, j]).item()

    path = []
    for t in range(t_start, 0, -1):
        # 1. Figure out if the current position was stay or change
        # Note (again):
        # `emission[J-1]` is the emission at time frame `J` of trellis dimension.
        # Score for token staying the same from time frame J-1 to T.
        stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
        # Score for token changing from C-1 at T-1 to J at T.
        changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]

        # 2. Store the path with frame-wise probability.
        prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
        # Return token index and time index in non-trellis coordinate.
        path.append(Point(j - 1, t - 1, prob))

        # 3. Update the token
        if changed > stayed:
            j -= 1
            if j == 0:
                break
    else:
        raise ValueError("Failed to align")
    return path[::-1]

In [None]:
# Merge the labels
@dataclass
class Segment:
    label: str
    start: int
    end: int
    score: float

    def __repr__(self):
        return f"{self.label}\t({self.score:4.2f}): [{self.start}, {self.end})"

    @property
    def length(self):
        return self.end - self.start

    def __hash__(self):
        return hash((self.label, self.start, self.end, self.score))

    def __eq__(self, other):
        if not isinstance(other, Segment):
            return False
        return (self.label, self.start, self.end, self.score) == (other.label, other.start, other.end, other.score)



def merge_repeats(path, transcript):
    i1, i2 = 0, 0
    segments = []
    while i1 < len(path):
        while i2 < len(path) and path[i1].token_index == path[i2].token_index:
            i2 += 1
        score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
        segments.append(
            Segment(
                transcript[path[i1].token_index],
                path[i1].time_index,
                path[i2 - 1].time_index + 1,
                score,
            )
        )
        i1 = i2
    return segments

### Visualization



Looks good. Now let’s merge the words. The Wav2Vec2 model uses ``'|'``
as the word boundary, so we merge the segments before each occurance of
``'|'``.

Then, finally, we segment the original audio into segmented audio and
listen to them to see if the segmentation is correct.




In [None]:
# Merge words
def merge_words(segments, ratio, sr, separator="|"):
    words = []
    i1, i2 = 0, 0
    while i1 < len(segments):
        if i2 >= len(segments) or segments[i2].label == separator:
            if i1 != i2:
                segs = segments[i1:i2]
                word = "".join([seg.label for seg in segs])
                score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)

                x0 = int(ratio * segments[i1].start)
                x1 = int(ratio * segments[i2 - 1].end)
                start = x0 / sr
                end = x1 / sr

                words.append(Segment(word, start, end, score))
            i1 = i2 + 1
            i2 = i1
        else:
            i2 += 1
    return words

### Visualization



In [None]:
# A trick to embed the resulting audio to the generated file.
# `IPython.display.Audio` has to be the last call in a cell,
# and there should be only one call par cell.
def display_segment(waveform, i):
    ratio = waveform.size(1) / (trellis.size(0) - 1)
    word = word_segments[i]
    x0 = int(ratio * word.start)
    x1 = int(ratio * word.end)
    print(f"{word.label} ({word.score:.2f}): {x0 / bundle.sample_rate:.3f} - {x1 / bundle.sample_rate:.3f} sec")
    segment = waveform[:, x0:x1]
    return IPython.display.Audio(segment.numpy(), rate=bundle.sample_rate)

In [None]:
def execute(audio, transcript):
    transcript = clean_lyrics(transcript)
    emission = calculate_emission(audio)
    tokens = get_tokens(transcript)
    trellis = get_trellis(emission, tokens)
    path = backtrack(trellis, emission, tokens)
    segments = merge_repeats(path, transcript)

    ratio = audio.size(1) / (trellis.size(0) - 1)

    word_segments = merge_words(segments, ratio=ratio, sr = bundle.sample_rate)
    return emission, tokens, trellis, path, segments, word_segments

In [None]:
def execute_with_id(video_id, transcript):
    waveform, sr = get_wav_sr_from_yt_video_id(video_id)
    return execute(waveform, transcript)

# Website

In [None]:
!pip install jupyter-dash
!pip install pytube
!pip install dash-player

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting jupyter-dash
  Downloading jupyter_dash-0.4.2-py3-none-any.whl (23 kB)
Collecting nest-asyncio
  Downloading nest_asyncio-1.5.6-py3-none-any.whl (5.2 kB)
Collecting dash
  Downloading dash-2.8.0-py3-none-any.whl (9.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.9/9.9 MB[0m [31m78.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting ansi2html
  Downloading ansi2html-1.8.0-py3-none-any.whl (16 kB)
Collecting retrying
  Downloading retrying-1.3.4-py3-none-any.whl (11 kB)
Collecting dash-html-components==2.0.0
  Downloading dash_html_components-2.0.0-py3-none-any.whl (4.1 kB)
Collecting dash-core-components==2.0.0
  Downloading dash_core_components-2.0.0-py3-none-any.whl (3.8 kB)
Collecting dash-table==5.0.0
  Downloading dash_table-5.0.0-py3-none-any.whl (3.9 kB)
Collecting jedi>=0.10
  Downloading jedi-0.18.2-py2.py3-none-any.whl (1.6 MB)
[2K     [90m━━━━

In [None]:
import re

In [None]:
def extract_video_id(link):
    regExp = re.compile(r'^.*(youtu.be\/|v\/|u\/\w\/|embed\/|watch\?v=|&v=)([^#&?]*).*')
    match_id = regExp.match(link)
    if match_id:
        video_id = match_id.group(2)
        if len(video_id) == 11:
            return video_id 
    return None

In [None]:
from jupyter_dash import JupyterDash
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import Input, Output, State
import dash_player

The dash_core_components package is deprecated. Please replace
`import dash_core_components as dcc` with `from dash import dcc`
  import dash_core_components as dcc
The dash_html_components package is deprecated. Please replace
`import dash_html_components as html` with `from dash import html`
  import dash_html_components as html


In [None]:
external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']
app = JupyterDash(__name__, external_stylesheets=external_stylesheets)

In [None]:
app.layout = html.Div([
    html.H1("AI Karaoke"),
    # Yt link
    html.Div([
        dcc.Input(id="input_yt", placeholder="Input Youtube Link", style={'width': '600px', 'margin-right': '5px'}),
        html.Button('Submit', id='btn_submit', n_clicks=0),
        html.Button('Reset', id='btn_reset', n_clicks=0),
        html.Div(id="initial_message", children="Enter a YouTube link and transcript and press submit to load video", style={'display': 'block'}),
        html.Div(id="invalid_link_div", children="Invalid YouTube link", style={'color': 'red', 'display': 'none'}),
        html.Div(id="no_transcript_div", children="Please enter a transcript", style={'color': 'red', 'display': 'none'}),
    ], style={'margin-bottom': '30px'}),
     
    # video and transcript
    html.Div([
        dash_player.DashPlayer(id="player", url="", controls=True, width="70%", height="80%", style={'display': 'inline-block', 'margin-right': '10px'}),
        # TODO change input as you need
        # dcc.Textarea(id="input_transcript", placeholder="Input transcript", style={'width': '29%', 'height': '80%', 'display': 'inline-block'})
        dcc.Textarea(id='input_transcript', value="", placeholder="Input transcript", style={'width': '29%', 'height': '80%', 'display': 'inline-block'}),
        html.Div(id='output_transcript', children=[
            html.Div(id='pre_transcript', style={'display': 'inline'}), 
            html.Span(id='pre_highlight_word', style={'display': 'inline', 'background-color': '#ffffb3'}), 
            html.Span(id='highlight_word', style={'display': 'inline', 'background-color': 'orange'}), 
            html.Span(id='post_highlight_word', style={'display': 'inline', 'background-color': '#ffff99'}), 
            html.Div(id='post_transcript', style={'display': 'inline'})
            ], style={'width': '29%', 'height': '80%', 'display': 'none', 'overflow': 'auto'})
    ], style={'height': '100vh', 'width': '100%', 'display': 'flex'}),

    # current video timestamp, only for debugging
    html.Div(id="div_current_time", style={"margin": "10px 0px"}),
    dcc.Interval(id='interval', interval=2, n_intervals=0),
    dcc.Store(id='clientside-store-data')
])

In [None]:
@app.callback(
    [Output('player', 'url'),
    Output('initial_message', 'style'),
    Output('invalid_link_div', 'style'),
    Output('no_transcript_div', 'style'),
    Output('input_transcript', 'style'),
    Output('output_transcript', 'style'),
    Output('clientside-store-data', 'data')],
    Input('btn_submit', 'n_clicks'),
    [State('input_yt', 'value'),
    State('input_transcript', 'value')]
)
def embed_video(n_clicks, link, transcript):
    initial_message_style = {'display': 'block'}
    invalid_style = {'color': 'red', 'display': 'none'}
    no_transcript = {'color': 'red', 'display': 'none'}

    input_style = {'width': '29%', 'height': '80%', 'display': 'inline-block'}
    output_style = {'width': '29%', 'height': '80%', 'display': 'none'}

    dict_words = {}
    url = ""
    if n_clicks > 0:
        # TODO maybe add loading bar
        video_id = extract_video_id(link)
        if not video_id or not transcript:
            initial_message_style = {'display': 'none'}
            if not video_id:
                invalid_style = {'color': 'red', 'display': 'block'}
            if not transcript:
                no_transcript = {'color': 'red', 'display': 'block'}
        else:
            url = link
            input_style, output_style = output_style, input_style
            waveform, sr = get_wav_sr_from_yt_video_id(video_id)
            _, _, trellis, _, _, word_segments = execute(waveform, transcript)
            dict_words = pd.DataFrame([vars(f) for f in word_segments]).to_dict('records')
    return url, initial_message_style, invalid_style, no_transcript, input_style, output_style, dict_words

app.clientside_callback(
    """
    function highlightWords(n_intervals, current_time, input, data) {
        let pre_transcript = "";
        let pre_highlight_word = "";
        let highlight_word = "";
        let post_highlight_word = "";
        let post_transcript = "";
        var listLength = data.length;
        for (var i = 0; i < listLength; i++) {
            let word = data[i];
            if (current_time != null) {
                start = word['start'];
                end = word['end'];
                if (current_time <= start - 1){
                    post_transcript += ' ' + word['label'];
                }
                if (current_time >= start - 1 && current_time <= start){
                    post_highlight_word += ' ' + word['label'];
                }
                if (current_time >= start && current_time <= end){
                    highlight_word += word['label'];
                }
                if (current_time >= end && current_time <= end + 1){
                    pre_highlight_word += word['label'] + ' ';
                }
                if (current_time >= end +1){
                    pre_transcript += word['label'] + ' ';
                }
            }
        }
        return [pre_transcript, pre_highlight_word, highlight_word, post_highlight_word, post_transcript];
    }
    """,
    [Output('pre_transcript', 'children'),
    Output('pre_highlight_word', 'children'),
    Output('highlight_word', 'children'),
    Output('post_highlight_word', 'children'),
    Output('post_transcript', 'children')],
    [Input('interval', 'n_intervals')],
    [State('player', 'currentTime'),
     State('input_transcript', 'value'),
     State('clientside-store-data', 'data')]
)

# @app.callback(
#     [Output('input_transcript', 'value'),
#     Output('output_transcript', 'value'),
#     Output('input_transcript', 'style'),
#     Output('output_transcript', 'style'),
#     Output('clientside-store-data', 'data'),
#     Output('btn_submit', 'n_clicks')],
#     Input('btn_reset', 'n_clicks')
# )
# def reset(n_clicks):
#     input_style = {'width': '29%', 'height': '80%', 'display': 'inline-block'}
#     output_style = {'width': '29%', 'height': '80%', 'display': 'none'}

#     dict_words = {}
#     input_transcript = ""
#     output_transcript = ""
#     return input_transcript, output_transcript, input_style, output_style, dict_words, 0

In [None]:
# click link to open website in new tab
if __name__ == '__main__':
    app.run_server(mode='inline')
    # app.run_server(debug=True)

<IPython.core.display.Javascript object>

I have this thing where I get older but just never wiser
Midnights become my afternoons
When my depression works the graveyard shift
All of the people I've ghosted stand there in the room
I should not be left to my own devices
They come with prices and vices
I end up in crisis (tale as old as time)
I wake up screaming from dreaming
One day I'll watch as you're leaving
'Cause you got tired of my scheming
(For the last time)
It's me, hi, I'm the problem, it's me
At tea time, everybody agrees
I'll stare directly at the sun but never in the mirror
It must be exhausting always rooting for the anti-hero
Sometimes I feel like everybody is a sexy baby
And I'm a monster on the hill
Too big to hang out, slowly lurching toward your favorite city
Pierced through the heart, but never killed
Did you hear my covert narcissism I disguise as altruism
Like some kind of congressman? (Tale as old as time)
I wake up screaming from dreaming
One day I'll watch as you're leaving
And life will lose all its meaning
(For the last time)
It's me, hi, I'm the problem, it's me (I'm the problem, it's me)
At tea time, everybody agrees
I'll stare directly at the sun but never in the mirror
It must be exhausting always rooting for the anti-hero
I have this dream my daughter in-law kills me for the money
She thinks I left them in the will
The family gathers 'round and reads it and then someone screams out
"She's laughing up at us from hell"
It's me, hi, I'm the problem, it's me
It's me, hi, I'm the problem, it's me
It's me, hi, everybody agrees, everybody agrees
It's me, hi (hi), I'm the problem, it's me (I'm the problem, it's me)
At tea (tea) time (time), everybody agrees (everybody agrees)
I'll stare directly at the sun but never in the mirror
It must be exhausting always rooting for the anti-hero