# Same procedure as before

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install pydub youtube-dl spleeter

In [None]:
import IPython
import os
import pandas as pd
import re
import sys
import shutil
import torch
import torchaudio

from dataclasses import dataclass
from io import BytesIO

from pydub import AudioSegment

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
use_spleeter = True

In [None]:
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
labels = bundle.get_labels()
model = bundle.get_model().to(device)
dictionary = {c: i for i, c in enumerate(labels)}

In [None]:
if use_spleeter:
  from spleeter.separator import Separator
  # Initialize the separator
  separator = Separator('spleeter:2stems')

In [None]:
def get_wave(aud):
  aud = aud.set_channels(1)
  aud = aud.get_array_of_samples()
  wave = torch.tensor(aud, dtype = torch.float)
  wave = torch.reshape(wave, (1,wave.shape[0]))

  return wave

In [None]:
def get_wav_sr_from_yt_video_id(video_id):
    # Download the video using youtube-dl
    os.system("youtube-dl --extract-audio --audio-format wav --audio-quality 0 -o '%(id)s.%(ext)s' https://youtu.be/{}".format(video_id))

    file_path = "{}.wav".format(video_id)
    audio_path = file_path

    if use_spleeter:
      separator.separate_to_file("/content/{}.wav".format(video_id), "/content/")
      audio_path = "/content/{}/vocals.wav".format(video_id)

    # Load the audio file using pydub
    audio = AudioSegment.from_file(audio_path, format="wav")

    waveform = get_wave(audio)
    sr = audio.frame_rate

    # Delete file
    if os.path.isfile(file_path):
        os.remove(file_path)
    else:
        print("{} does not exist.".format(file_path))

    try:
        shutil.rmtree("/content/{}".format(video_id))
    except OSError as e:
        print("Error: %s - %s." % (e.filename, e.strerror))
        
    # Resample
    if sr != bundle.sample_rate:
        waveform = torchaudio.functional.resample(waveform, sr, bundle.sample_rate)

    return waveform, sr

In [None]:
def clean_lyrics(lyrics):
    lyrics = re.sub(r"\[.*?\]", "", lyrics, flags=re.MULTILINE)
    lyrics = re.sub(r"’", "'", lyrics)
    lyrics = re.sub(r"[^a-zA-Z'’|-]|\s", "|", lyrics)
    return lyrics.upper()

In [None]:
def calculate_emission(waveform):
    torch.cuda.empty_cache()    
    
    length = waveform.shape[1]
    chunks = []
    amount_chunks = 10
    chunks_length = length//amount_chunks
    for i in range(amount_chunks):
        with torch.inference_mode():
            emissions, _ = model(waveform[:, i * chunks_length: min(length, (i + 1) * chunks_length)].to(device))
            emissions = torch.log_softmax(emissions, dim=-1)
            chunks.append(emissions)

    return torch.cat(chunks, dim=1)[0].cpu().detach()

In [None]:
def get_tokens(transcript):
    return [dictionary[c] for c in transcript]

In [None]:
def get_trellis(emission, tokens, blank_id=0):
    num_frame = emission.size(0)
    num_tokens = len(tokens)

    # Trellis has extra diemsions for both time axis and tokens.
    # The extra dim for tokens represents <SoS> (start-of-sentence)
    # The extra dim for time axis is for simplification of the code.
    trellis = torch.empty((num_frame + 1, num_tokens + 1))
    trellis[0, 0] = 0
    trellis[1:, 0] = torch.cumsum(emission[:, 0], 0)
    trellis[0, -num_tokens:] = -float("inf")
    trellis[-num_tokens:, 0] = float("inf")

    for t in range(num_frame):
        trellis[t + 1, 1:] = torch.maximum(
            # Score for staying at the same token
            trellis[t, 1:] + emission[t, blank_id],
            # Score for changing to the next token
            trellis[t, :-1] + emission[t, tokens],
        )
    return trellis

In [None]:
@dataclass
class Point:
    token_index: int
    time_index: int
    score: float


def backtrack(trellis, emission, tokens, blank_id=0):
    # Note:
    # j and t are indices for trellis, which has extra dimensions
    # for time and tokens at the beginning.
    # When referring to time frame index `T` in trellis,
    # the corresponding index in emission is `T-1`.
    # Similarly, when referring to token index `J` in trellis,
    # the corresponding index in transcript is `J-1`.
    j = trellis.size(1) - 1
    t_start = torch.argmax(trellis[:, j]).item()

    path = []
    for t in range(t_start, 0, -1):
        # 1. Figure out if the current position was stay or change
        # Note (again):
        # `emission[J-1]` is the emission at time frame `J` of trellis dimension.
        # Score for token staying the same from time frame J-1 to T.
        stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
        # Score for token changing from C-1 at T-1 to J at T.
        changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]

        # 2. Store the path with frame-wise probability.
        prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
        # Return token index and time index in non-trellis coordinate.
        path.append(Point(j - 1, t - 1, prob))

        # 3. Update the token
        if changed > stayed:
            j -= 1
            if j == 0:
                break
    else:
        raise ValueError("Failed to align")
    return path[::-1]

In [None]:
# Merge the labels
@dataclass
class Segment:
    label: str
    start: int
    end: int
    score: float

    def __repr__(self):
        return f"{self.label}\t({self.score:4.2f}): [{self.start}, {self.end})"

    @property
    def length(self):
        return self.end - self.start

    def __hash__(self):
        return hash((self.label, self.start, self.end, self.score))

    def __eq__(self, other):
        if not isinstance(other, Segment):
            return False
        return (self.label, self.start, self.end, self.score) == (other.label, other.start, other.end, other.score)



def merge_repeats(path, transcript):
    i1, i2 = 0, 0
    segments = []
    while i1 < len(path):
        while i2 < len(path) and path[i1].token_index == path[i2].token_index:
            i2 += 1
        score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
        segments.append(
            Segment(
                transcript[path[i1].token_index],
                path[i1].time_index,
                path[i2 - 1].time_index + 1,
                score,
            )
        )
        i1 = i2
    return segments

In [None]:
# Merge words
def merge_words(segments, ratio, sr, separator="|"):
    words = []
    i1, i2 = 0, 0
    while i1 < len(segments):
        if i2 >= len(segments) or segments[i2].label == separator:
            if i1 != i2:
                segs = segments[i1:i2]
                word = "".join([seg.label for seg in segs])
                score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)

                x0 = int(ratio * segments[i1].start)
                x1 = int(ratio * segments[i2 - 1].end)
                start = x0 / sr
                end = x1 / sr

                words.append(Segment(word, start, end, score))
            i1 = i2 + 1
            i2 = i1
        else:
            i2 += 1
    return words

In [None]:
def execute(audio, transcript):
    transcript = clean_lyrics(transcript)
    emission = calculate_emission(audio)
    tokens = get_tokens(transcript)
    trellis = get_trellis(emission, tokens)
    path = backtrack(trellis, emission, tokens)
    segments = merge_repeats(path, transcript)

    ratio = audio.size(1) / (trellis.size(0) - 1)

    word_segments = merge_words(segments, ratio=ratio, sr = bundle.sample_rate)
    return emission, tokens, trellis, path, segments, word_segments

In [None]:
def execute_with_id(video_id, transcript):
    waveform, sr = get_wav_sr_from_yt_video_id(video_id)
    return execute(waveform, transcript)

# Evaluation

In [None]:
import matplotlib.pyplot as plt
import math

In [None]:
csv_dir = "/content/drive/MyDrive/ASR-Praktikum/4/eval"
use_spleeter = True

In [None]:
dict_hist = {}

In [None]:
def IoU(truth, pred):
    return intersect(truth, pred) / union(truth, pred)

def intersect(truth, pred):
    start_truth, end_truth = truth
    start_pred, end_pred = pred

    return max(0, min(end_truth, end_pred) - max(start_truth, start_pred))

def union(truth, pred):
    start_truth, end_truth = truth
    start_pred, end_pred = pred
    return max(end_truth, end_pred) - min(start_truth, start_pred)

In [None]:
def round_to_next_0_1(n):
    return math.ceil(n * 10) / 10

In [None]:
THRESHOLD = 0.2

In [None]:
import pandas as pd

def calculate_IoU(csv_dir):
    accum_iou_avg = 0
    dict_id2iou = {}
    df = pd.read_csv(os.path.join(csv_dir, 'eval_list.csv'))
    df = df.reset_index() 

    for _, row in df.iterrows():
        accum_iou = 0
        small_words = 0
        id = row['ID']
        df_song = pd.read_csv(os.path.join(csv_dir, id + '.csv'))
        df_song = df_song.reset_index()
        _, _, _, _, _, word_segments = execute_with_id(id, row['Lyrics'])
        for i, row_song in df_song.iterrows():
            word = word_segments[i]
            if word.end - word.start <= THRESHOLD:
              small_words += 1
            else:
              truth = (float(row_song['start']), float(row_song['end']))
              pred = (word.start, word.end)
              iou = IoU(truth, pred)
              accum_iou += iou

              key = round_to_next_0_1(float(row_song['end']) - float(row_song['start']))
              value = iou
              dict_hist.setdefault(key, []).append(value)

        avg = accum_iou / (len(df_song.index) - small_words)
        dict_id2iou[id] = avg
        accum_iou_avg += avg
    avg_iou = accum_iou_avg / len(df.index)
    return avg_iou, dict_id2iou

In [None]:
avg_iou, dict_id2iou = calculate_IoU(csv_dir)

In [None]:
avg_iou

In [None]:
dict_id2iou

## Plot each segment

In [None]:
mean_dict = {key: sum(value) / len(value) for key, value in dict_hist.items()}

#remove outlier for plotting
mean_dict = {k: v for k, v in mean_dict.items() if k >= 0 and k < 3}

keys = list(mean_dict.keys())
values = list(mean_dict.values())

sorted_keys = sorted(mean_dict.keys())
sorted_values = [mean_dict[key] for key in sorted_keys]

fig, ax = plt.subplots()

for key, value in zip(sorted_keys, sorted_values):
    ax.bar(key, value, width=0.05)

for i, v in enumerate(sorted_values):
    plt.annotate(len(dict_hist[sorted_keys[i]]), (sorted_keys[i], v), xytext=(0, 10), textcoords='offset points', ha='center', va='bottom')


ax.set_xlabel('Length of Segment [s]')
ax.set_ylabel('IoU')
ax.set_title('')

plt.show()