In [None]:
!pip install pytube pydub youtube-dl spleeter

## Algorithm

In [None]:
import IPython
import os
import pandas as pd
import re
import sys
import shutil
import torch
import torchaudio

from dataclasses import dataclass
from io import BytesIO

import plotly.express as px

from pydub import AudioSegment
from pytube import YouTube
import librosa

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
use_spleeter = True

In [None]:
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
labels = bundle.get_labels()
model = bundle.get_model().to(device)
dictionary = {c: i for i, c in enumerate(labels)}

In [None]:
if use_spleeter:
  from spleeter.separator import Separator
  # Initialize the separator
  separator = Separator('spleeter:2stems')

In [None]:
def get_wave(aud):
  aud = aud.set_channels(1)
  aud = aud.get_array_of_samples()
  wave = torch.tensor(aud, dtype = torch.float)
  wave = torch.reshape(wave, (1,wave.shape[0]))

  return wave

In [None]:
def get_wav_sr(file_dir, filename):

    file_path =  os.path.join(file_dir, filename)
    audio_path =  file_path
    filename = filename.split('.mp3')[0]

    if use_spleeter:
      separator.separate_to_file(file_path, "/content/")
      audio_path = f"/content/{filename}/vocals.wav"

    # Load the audio file using pydub
    if use_spleeter:
      audio = AudioSegment.from_file(audio_path, format="wav")
    else:
      audio = AudioSegment.from_file(audio_path, format="mp3")

    waveform = get_wave(audio)
    sr = audio.frame_rate

    # Delete file
    if use_spleeter:
      if os.path.isfile(audio_path):
          os.remove(audio_path)
      else:
          print("{} does not exist.".format(file_path))

      try:
          shutil.rmtree(f"/content/{filename}")
      except OSError as e:
          print("Error: %s - %s." % (e.filename, e.strerror))
        
    # Resample
    if sr != bundle.sample_rate:
        waveform = torchaudio.functional.resample(waveform, sr, bundle.sample_rate)

    return waveform

In [None]:
def clean_lyrics(lyrics):
    lyrics = re.sub(r"\[\t.*\n?\]", "", lyrics, flags=re.MULTILINE)
    lyrics = re.sub(r"’", "", lyrics)
    lyrics = re.sub(r"'", "", lyrics)
    lyrics = re.sub(r"\\xa0", "", lyrics)
    lyrics = re.sub(r"\s\s+" , " ", lyrics)
    lyrics = re.sub(r"[^a-zA-Z|-]|\s", "|", lyrics)
    lyrics = re.sub(r"\|\|+" , "|", lyrics)
    if lyrics[-1] == "|":
      lyrics = lyrics[:-1]
    return lyrics.upper()

In [None]:
def calculate_emission(waveform):
    torch.cuda.empty_cache()    
    
    length = waveform.shape[1]
    chunks = []
    amount_chunks = 10
    chunks_length = length//amount_chunks
    for i in range(amount_chunks):
        with torch.inference_mode():
            emissions, _ = model(waveform[:, i * chunks_length: min(length, (i + 1) * chunks_length)].to(device))
            emissions = torch.log_softmax(emissions, dim=-1)
            chunks.append(emissions)

    return torch.cat(chunks, dim=1)[0].cpu().detach()

In [None]:
def get_tokens(transcript):
    return [dictionary[c] for c in transcript]

In [None]:
def get_trellis(emission, tokens, blank_id=0):
    num_frame = emission.size(0)
    num_tokens = len(tokens)

    # Trellis has extra diemsions for both time axis and tokens.
    # The extra dim for tokens represents <SoS> (start-of-sentence)
    # The extra dim for time axis is for simplification of the code.
    trellis = torch.empty((num_frame + 1, num_tokens + 1))
    trellis[0, 0] = 0
    trellis[1:, 0] = torch.cumsum(emission[:, 0], 0)
    trellis[0, -num_tokens:] = -float("inf")
    trellis[-num_tokens:, 0] = float("inf")

    for t in range(num_frame):
        trellis[t + 1, 1:] = torch.maximum(
            # Score for staying at the same token
            trellis[t, 1:] + emission[t, blank_id],
            # Score for changing to the next token
            trellis[t, :-1] + emission[t, tokens],
        )
    return trellis

In [None]:
@dataclass
class Point:
    token_index: int
    time_index: int
    score: float


def backtrack(trellis, emission, tokens, blank_id=0):
    # Note:
    # j and t are indices for trellis, which has extra dimensions
    # for time and tokens at the beginning.
    # When referring to time frame index `T` in trellis,
    # the corresponding index in emission is `T-1`.
    # Similarly, when referring to token index `J` in trellis,
    # the corresponding index in transcript is `J-1`.
    j = trellis.size(1) - 1
    t_start = torch.argmax(trellis[:, j]).item()

    path = []
    for t in range(t_start, 0, -1):
        # 1. Figure out if the current position was stay or change
        # Note (again):
        # `emission[J-1]` is the emission at time frame `J` of trellis dimension.
        # Score for token staying the same from time frame J-1 to T.
        stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
        # Score for token changing from C-1 at T-1 to J at T.
        changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]

        # 2. Store the path with frame-wise probability.
        prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
        # Return token index and time index in non-trellis coordinate.
        path.append(Point(j - 1, t - 1, prob))

        # 3. Update the token
        if changed > stayed:
            j -= 1
            if j == 0:
                break
    else:
        raise ValueError("Failed to align")
    return path[::-1]

In [None]:

# Merge the labels
@dataclass
class Segment:
    label: str
    start: int
    end: int
    score: float

    def __repr__(self):
        return f"{self.label}\t({self.score:4.2f}): [{self.start}, {self.end})"

    @property
    def length(self):
        return self.end - self.start

    def __hash__(self):
        return hash((self.label, self.start, self.end, self.score))

    def __eq__(self, other):
        if not isinstance(other, Segment):
            return False
        return (self.label, self.start, self.end, self.score) == (other.label, other.start, other.end, other.score)



def merge_repeats(path, transcript):
    i1, i2 = 0, 0
    segments = []
    while i1 < len(path):
        while i2 < len(path) and path[i1].token_index == path[i2].token_index:
            i2 += 1
        score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
        segments.append(
            Segment(
                transcript[path[i1].token_index],
                path[i1].time_index,
                path[i2 - 1].time_index + 1,
                score,
            )
        )
        i1 = i2
    return segments

In [None]:
# Merge words
def merge_words(segments, ratio, sr, separator="|"):
    words = []
    i1, i2 = 0, 0
    while i1 < len(segments):
        if i2 >= len(segments) or segments[i2].label == separator:
            if i1 != i2:
                segs = segments[i1:i2]
                word = "".join([seg.label for seg in segs])
                score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)

                x0 = int(ratio * segments[i1].start)
                x1 = int(ratio * segments[i2 - 1].end)
                start = x0 / sr
                end = x1 / sr

                words.append(Segment(word, start, end, score))
            i1 = i2 + 1
            i2 = i1
        else:
            i2 += 1
    return words

In [None]:
def execute(audio, transcript):
    transcript = clean_lyrics(transcript)
    emission = calculate_emission(audio)
    tokens = get_tokens(transcript)
    trellis = get_trellis(emission, tokens)
    path = backtrack(trellis, emission, tokens)
    segments = merge_repeats(path, transcript)

    ratio = audio.size(1) / (trellis.size(0) - 1)

    word_segments = merge_words(segments, ratio=ratio, sr = bundle.sample_rate)
    return emission, tokens, trellis, path, segments, word_segments

In [None]:
def get_timestamps(word_segments):
    word_beginnings = []
    word_endings = []
    for i in range(len(word_segments)):
      word_beginnings.append(word_segments[i]['start'])
      if i < len(word_segments) - 1:
        word_endings.append(word_segments[i+1]['start'])
      else: 
        #word_endings.append(word_segments[i]['start'] + 1.0)
        word_endings.append(word_segments[i]['end'])
    return word_beginnings, word_endings

# Load Testdataset

In [None]:
import pandas as pd

In [None]:
#### TODO upload the dataset on your drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
def get_dataset(dir, sr=16000):
  dir_waveforms = os.path.join(dir, 'mp3')
  dir_y = os.path.join(dir, 'annotations/words')
  dir_transcripts = os.path.join(dir, 'lyrics_raw')
  dir_words =  os.path.join(dir, 'lyrics')

  X = []
  y_wb = []
  y_we = []
  transcripts = []
  word_lengths = []

  files = os.listdir(dir_waveforms)
  files = sorted(files)
  for filename in files:
    # waveform
    print(f"Process {filename}")
    waveform = get_wav_sr(dir_waveforms, filename)
    X.append(waveform)


  files = os.listdir(dir_transcripts)
  files = sorted(files)
  for filename in files:
    # lyrics
    path_transcript = os.path.join(dir_transcripts, filename)
    with open(path_transcript) as f:
      transcript = f.read()
    transcripts.append(transcript)
    print(repr(transcript))

  files = os.listdir(dir_words)
  files = sorted(files)
  for filename in files:
    # words
    path_words = os.path.join(dir_words, filename)
    with open(path_words) as f:
      words = f.readlines()
    word_l = list(map(lambda x: len(x), words))
    word_lengths.append(word_l)

  assert len(word_lengths) == len(transcripts)


  files = os.listdir(dir_y)
  files = sorted(files)
  #i = 0
  for filename in files:
    # label (list of timestamps for every word)
    path_y = os.path.join(dir_y, filename)
    df = pd.read_csv(path_y)
    timestamps = df['word_start'].to_list()     
    y_wb.append(timestamps)
    timestamps = timestamps[1:] + [timestamps[-1] + 0.5]
    y_we.append(timestamps)
    #we = []
    #for j in range(len(timestamps)):
      
    #  if  j < len(timestamps) - 1:
    #    we.append(min(timestamps[j+1], timestamps[j] + word_lengths[i][j] * 0.1))

    #    assert we[j] > timestamps[j]
    #  else:
    #    we.append(timestamps[j] + word_lengths[i][j] * 0.1)
    #y_we.append(we)
    #i += 1


  return X, y_wb, y_we, transcripts


In [None]:
sr = bundle.sample_rate
# X = list of waveform (pytorch tensor)
# y = list of list of timestamps of word beginnings
#### TODO define the path to the testset
X, y_wb, y_we, transcripts = get_dataset("/content/drive/MyDrive/hsr/testset", sr=sr)

In [None]:
print(repr(transcripts[0]))

In [None]:
print(clean_lyrics(transcripts[0]))

In [None]:
print(transcripts[0])

In [None]:
IPython.display.Audio(X[1], rate=sr)

# Evaluation

In [None]:
import IPython
import os
import numpy as np

In [None]:
def evaluate_average_absolute_error(y_pred, y_true, tolerance=0.3):
  assert len(y_true) == len(y_pred)
  deviations = np.abs(np.array(y_true) - np.array(y_pred))
  return np.mean(deviations)

In [None]:
def evaluate_accuracy(y_pred, y_true, tolerance=0.3):
  #assert len(y_true) == len(y_pred)
  deviations = np.abs(np.array(y_true) - np.array(y_pred))
  return np.mean(deviations < tolerance)

In [None]:
def evaluate_iou(y_pred_wb, y_pred_we, y_true_wb, y_true_we):
    ious = []

    for i in range(len(y_pred_wb)):
      if y_pred_wb[i] < y_true_wb[i]:
        a1 = y_pred_wb[i]
        a2 = y_pred_we[i]
        b1 = y_true_wb[i]
        b2 = y_true_we[i]
      else:
        a1 = y_true_wb[i]
        a2 = y_true_we[i]
        b1 = y_pred_wb[i]
        b2 = y_pred_we[i] 


      assert (a1 < a2) & (b1 < b2)

      # intersection
      if a2 < b1:
        intersection = 0
      else:
        lower = max(a1, b1)
        upper = min(a2, b2)
        intersection = upper - lower

      # union
      if a2 < b1:
        union = a2 - a1 + b2 - b1
      else:
        lower = min(a1, b1)
        upper = max(a2, b2)
        union = upper - lower

      iou = intersection / union
      ious.append(iou)

      assert iou >= 0.0

    return np.mean(ious)

In [None]:
def evaluate(waveforms, y_true_wb, y_true_we, transcripts, sr=16000):
  #assert len(y_true) == len(waveforms)
  aaes = []
  accs = []
  ious = []

  for i in range(len(waveforms)):
    print(f"Process audio {i}")
    _, _, trellis, _, _, word_segments = execute(waveforms[i], transcripts[i])
    dict_words = pd.DataFrame([vars(f) for f in word_segments]).to_dict('records')

    y_pred_wb,  y_pred_we = get_timestamps(dict_words)


    if len(y_true_wb[i]) != len(y_pred_wb):
      print('Skipping because of formating mismatch')
      continue
    

    # Average Absoulute Error
    aae = evaluate_average_absolute_error(y_pred_wb, y_true_wb[i])
    aaes.append(aae)

    # Accuracy
    acc = evaluate_accuracy(y_pred_wb, y_true_wb[i])
    accs.append(acc)

    # Intersection over union
    iou = evaluate_iou(y_pred_wb, y_pred_we, y_true_wb[i], y_true_we[i])
    ious.append(iou)


    print(f"AAE of audio {i}: {aae}")
    print(f"Accuracy of audio {i}: {acc}")
    print(f"IOU of audio {i}: {iou}")

  print(f"AAE: {np.mean(aaes)}")
  print(f"Accuracy: {np.mean(accs)}")
  print(f"IOU: {np.mean(ious)}")

  return aaes, accs, ious

In [None]:
aaes, accs, ious = evaluate(X, y_wb, y_we, transcripts)

In [None]:
fig = px.scatter(x=aaes, y=np.arange(0, len(aaes)))
fig.show()

In [None]:
fig = px.scatter(x=accs, y=np.arange(0, len(accs)))
fig.show()

In [None]:
fig = px.scatter(x=ious, y=np.arange(0, len(ious)))
fig.show()