In [1]:
import torch
from demucs.pretrained import get_model
from demucs.apply import apply_model
import torchaudio
from pathlib import Path
import whisper
import difflib
import json
import re

In [None]:
# =========================================
# FILE_PATH = Path("/home/michal/DEV/Karaoke/rower.mp3")
FILE_PATH = Path("/home/michal/DEV/Karaoke/youre_the_one_that_i_want.mp3")
# =========================================

CACHE_DIR = Path("cache")
VOCALS_FILE_PATH = CACHE_DIR / FILE_PATH.stem / f"{FILE_PATH.stem}.vocals.mp3"
INSTRUMENTAL_FILE_PATH = (
    CACHE_DIR / FILE_PATH.stem / f"{FILE_PATH.stem}.instrumental.mp3"
)
LYRICS_FILE_PATH = CACHE_DIR / FILE_PATH.stem / f"{FILE_PATH.stem}.lyrics.txt"
ALIGNED_LYRICS_FILE_PATH = (
    CACHE_DIR / FILE_PATH.stem / f"{FILE_PATH.stem}.aligned_lyrics.json"
)

CACHE_DIR.mkdir(parents=True, exist_ok=True)


print(f"{CACHE_DIR=}")
print(f"{FILE_PATH=}")
print(f"{VOCALS_FILE_PATH=}")
print(f"{INSTRUMENTAL_FILE_PATH=}")
print(f"{LYRICS_FILE_PATH=}")
print(f"{ALIGNED_LYRICS_FILE_PATH=}")

In [3]:
def save_track(path: Path, audio: torch.Tensor, sr: int = 44100):
    print(f"Saving {path}")
    path.parent.mkdir(parents=True, exist_ok=True)
    torchaudio.save(path, audio, sample_rate=sr)

In [4]:
def process_audio(file: Path):
    model = get_model("htdemucs")
    model.eval()

    if torch.cuda.is_available():
        model.cuda()

    audio, sr = torchaudio.load(file)

    # Ensure audio is stereo
    if audio.shape[0] == 1:
        audio = audio.cat([audio, audio], dim=0)  # Convert mono to stereo
    elif audio.shape[0] > 2:
        audio = audio[:2]  # Take first two channels if more than stereo
    audio = audio.unsqueeze(0)

    if torch.cuda.is_available():
        audio = audio.cuda()

    with torch.no_grad():
        sources = apply_model(model, audio, split=True, progress=True)

    # Get the index of vocals
    vocals_idx = model.sources.index("vocals")
    vocals = sources[0, vocals_idx]
    # Sum all other sources for instrumental
    instrumental = torch.zeros_like(vocals)
    for i, source in enumerate(model.sources):
        if i != vocals_idx:  # Skip vocals
            instrumental += sources[0, i]

    vocals = torch.tanh(vocals * 2) / 2  # Increase vocal presence
    vocals = vocals + (vocals - vocals.roll(1, -1)) * 0.5  # Enhance clarity
    vocals = vocals + (vocals - vocals.roll(2, -1)) * 0.2  # Enhance harmonicity

    if torch.cuda.is_available():
        vocals = vocals.cpu()
        instrumental = instrumental.cpu()

    print(
        f"Saving vocals shape: {vocals.shape}, instrumental shape: {instrumental.shape}"
    )
    save_track(VOCALS_FILE_PATH, vocals, sr)
    save_track(INSTRUMENTAL_FILE_PATH, instrumental, sr)

In [5]:
def get_lyrics() -> str:
    if LYRICS_FILE_PATH.exists():
        with open(LYRICS_FILE_PATH, "r") as f:
            lyrics = f.read()
            lyrics = re.sub(r"[\{\[\(].*?[\}\]\)]", "", lyrics)
            return lyrics
    else:
        return ""


def transcribe_audio(model_size: str = "large") -> dict:
    model = whisper.load_model(model_size)
    print("Transcribing audio...")
    audio_path = str(VOCALS_FILE_PATH)
    audio = whisper.load_audio(audio_path)
    result = model.transcribe(
        audio,
        # temperature=0.2,
        word_timestamps=True,
        # condition_on_previous_text=True,
        # initial_prompt=get_lyrics(),
        hallucination_silence_threshold=0.1,
        # verbose=True,
    )
    return result


# transcribe_audio()

In [6]:
from dataclasses import dataclass
from typing import TypeVar
# from rich import print


@dataclass
class Word:
    text: str
    start: float
    end: float


LyricsSlice = TypeVar("LyricsSlice")


@dataclass
class LyricsSlice:
    words: list[Word]
    raw_text: str | None = None

    def __str__(self):
        return " ".join(word.text.strip() for word in self.words)

    def __add__(self, other): ...

    def append(self, slice: LyricsSlice, rebase: bool = True) -> None:
        if rebase:
            slice.rebase(self.end)
        self.words += slice.words

    def stretch(self, value: float, backward: bool = False):
        duration = self.words[-1].end - self.words[0].start
        new_duration = duration + value
        scale = new_duration / duration

        for word in self.words:
            if not backward:
                word.start = (
                    self.words[0].start + (word.start - self.words[0].start) * scale
                )
                word.end = (
                    self.words[0].start + (word.end - self.words[0].start) * scale
                )
            else:
                word.start = (
                    self.words[-1].end - (self.words[-1].end - word.start) * scale
                )
                word.end = self.words[-1].end - (self.words[-1].end - word.end) * scale

    def shift(self, value: float):
        for word in self.words:
            word.start += value
            word.end += value

    def rebase(self, base: float):
        diff = base - self.start
        for word in self.words:
            word.start += diff
            word.end += diff

    @property
    def start(self):
        return self.words[0].start

    @property
    def end(self):
        return self.words[-1].end

In [None]:
l = LyricsSlice(
    [
        Word("Hello", 2.0, 3.0),
        Word("happy", 3.0, 4.5),
        Word("world", 4.6, 5.8),
    ],
)
l2 = LyricsSlice(
    [
        Word("some", 1.0, 2.0),
        Word("other", 2.0, 3.5),
        Word("text", 3.6, 4.8),
    ]
)
l.shift(-1.0)
l.stretch(2.5, backward=True)
l.rebase(0)
l.append(l2)
print(f"{l=}")
print(l)

In [8]:
class Lyrics:
    def __init__(self, transcript: dict) -> None:
        self.raw_lyrics: list | None = None
        self.slices: list[LyricsSlice] = []
        self.words = []
        for i, seg in enumerate(transcript["segments"]):
            for word in transcript["segments"][i]["words"]:
                self.words.append(
                    Word(word["word"].strip(" ,.!?()[]"), word["start"], word["end"])
                ) if word["probability"] > 0.3 else None

    def __repr__(self) -> str:
        return " ".join([word.text for word in self.words])

    def __str__(self) -> str:
        return self.__repr__()

    def get_time_slice(self, start: float, end: float) -> LyricsSlice:
        result = [
            word for word in self.words if word.start >= start and word.end <= end
        ]
        return LyricsSlice(result)

    def get_words_slice(self, start: int, end: int) -> LyricsSlice:
        result = self.words[start:end]
        return LyricsSlice(result)

    def get_text_slice(
        self,
        text: str,
        after: float = None,
        similarity_threshold: float = 0.5,
    ) -> LyricsSlice:
        search_words = [w.strip(" ,.!?()[]").lower() for w in text.split()]
        if not search_words:
            return LyricsSlice([])

        search_space = self.words
        if after is not None:
            search_space = [w for w in self.words if w.start >= after]
            if not search_space:
                return LyricsSlice([])

        # Find the best starting position in self.words that matches our search words
        # best_match_start = 0
        best_match_score = 0
        best_matched_words = []

        # Try each possible starting position in self.words
        for start_idx in range(len(search_space) - len(search_words) + 1):
            current_words = search_space[start_idx : start_idx + len(search_words) + 1]
            current_score = 0
            matched = []

            # Compare each word pair
            for search_word, word in zip(search_words, current_words):
                # Use difflib to compute similarity
                similarity = difflib.SequenceMatcher(
                    None, search_word, word.text.lower()
                ).ratio()

                if similarity >= similarity_threshold:
                    current_score += similarity
                    matched.append(word)
                else:
                    matched.append(None)

            # Normalize score by number of words
            avg_score = current_score / len(search_words)

            if avg_score > best_match_score:
                best_match_score = avg_score
                # best_match_start = start_idx
                best_matched_words = matched

        # Filter out None values and return the slice
        result = [w for w in best_matched_words if w is not None]
        return LyricsSlice(result, text)

    def get_word_at_time(self, time: float) -> Word:
        for word in self.words:
            if word.start <= time <= word.end:
                return word
        return None

    def align_lyrics(self, lyrics: str, similarity_threshold: float = 0.5):
        self.raw_lyrics = lyrics.splitlines()
        self.slices = []
        last_slice_end = 0.0
        for line in self.raw_lyrics:
            slice = self.get_text_slice(
                line, after=last_slice_end, similarity_threshold=similarity_threshold
            )
            # last_slice_end = slice.end if slice.words else 0.0
            if slice.words:
                last_slice_end = slice.end
            else:
                # last_slice_end = 0.0
                slice = self.get_text_slice(
                    line, similarity_threshold=similarity_threshold
                )
            self.slices.append(slice)
        return self.slices

    def check_time_overlap(self, slice1, slice2) -> bool:
        if slice1.end > slice2.start and slice1.start < slice2.end:
            return True
        else:
            return False

    def check_time_continuation(self, slice1, slice2) -> bool:
        if slice1.end <= slice2.start:
            return True
        else:
            return False

    def check_alignment(self):
        for i in range(len(self.slices) - 1):
            if self.check_time_overlap(self.slices[i], self.slices[i + 1]):
                print("Overlap detected:", self.slices[i], self.slices[i + 1])
                return False
            if not self.check_time_continuation(self.slices[i], self.slices[i + 1]):
                print("Continuation problem:", self.slices[i], self.slices[i + 1])
                return False
        return True

    def fix_alignment(self, max_try: int = 100):
        try_num = 0
        if not self.check_alignment():
            print("Fixing alignment...")
            while not self.check_alignment() and try_num < max_try:
                try_num += 1
                for i in range(len(self.slices) - 1):
                    current_slice = self.slices[i]
                    next_slice = self.slices[i + 1]

                    # if self.check_time_overlap(current_slice, next_slice):
                    #     overlap = current_slice.end - next_slice.start
                    #     next_slice.shift(overlap)

                    if not self.check_time_continuation(current_slice, next_slice):
                        gap = next_slice.start - current_slice.end
                        print(f"{gap=}")
                        next_slice.shift(gap)

                # self.slices[0].rebase(0)  # Ensure the first slice starts at 0
            print("Alignment fixed!")

In [None]:
process_audio(FILE_PATH)
a = transcribe_audio()

In [None]:
from rich import print

lyrics = Lyrics(a)
print(lyrics, end="\n\n")

# slice = lyrics.get_text_slice("Are you sure Yes I'm sure down deep inside")
# print(slice)
# slice.rebase(10)
# print(slice)
# slice = lyrics.get_text_slice("Are you sure Yes I'm sure down deep inside")
# print(slice)

lyrics.align_lyrics(get_lyrics(), similarity_threshold=0.7)
lyrics.fix_alignment()
# print(lyrics.slices)
# l = get_lyrics().splitlines()
# for line in l:
#     print(line, lyrics.get_text_slice(line, similarity_threshold=0.3), end="\n\n")