In [1]:
audio_path = "/Users/fwtrv/Downloads/Jump Up, Super Star! NDC Festival Edition.wav"
lyrics_path = "/Users/fwtrv/Desktop/ECE-GY 7123/Final/Jump Up, Super Star! NDC Festival Edition.txt"
save_track_vocal_path = "track_vocal.wav"
# lang="zh-CN"
lang = "en-US"


In [2]:
# Library
import torch
import demucs
import librosa
import soundfile
import transformers

In [3]:
# Select device
if torch.cuda.is_available():
    device = "cuda"     # CUDA
elif torch.backends.mps.is_available(): 
    device = torch.device('mps')    # Apple Sillicon
else:
    device = torch.device("cpu")    # CPU

print("Device Selected:", device)

Device Selected: mps


# Audio File Preprocessing - Extract Vocals from Audio File

In [4]:
# Extract Vocals from Audio File
from demucs.pretrained import get_model
from demucs.apply import apply_model
from demucs.separate import load_track
ORIGINAL_SR = 44100
TARGET_SR = 16000

# Choose Demucs Model for Vocals Extraction
demucs_model = get_model(name="htdemucs", repo=None)
demucs_model.to(device)
demucs_model.eval()
vocals_source_idx = demucs_model.sources.index("vocals")
sample_rate = demucs_model.samplerate

# Load Aduio Track
audio_track = load_track(audio_path, 2, sample_rate)

# Extract Vocal
ref = audio_track.mean(0)
audio_track_nor = (audio_track - ref.mean()) / ref.std() # Normalization
with torch.no_grad():
    sources = apply_model(demucs_model, audio_track_nor[None], device=device, shifts=1, split=True, overlap=0.25, progress=False)
track_vocal = sources[0][vocals_source_idx].cpu().numpy()[0, ...]

# Post-processing
track_vocal = librosa.resample(track_vocal, orig_sr=ORIGINAL_SR, target_sr=TARGET_SR)

# Write to Output
soundfile.write(save_track_vocal_path, track_vocal, TARGET_SR)

# Lyrics Preprocessing

In [5]:
# Get plain lyrics from file
with open(lyrics_path, 'r') as file:
    lyrics_plain = file.read()

lyrics_processed = ""
if lang == "en-US":
    lyrics_processed = lyrics_plain.upper()
    lyrics_processed = lyrics_processed.replace(' ', '|')
    lyrics_processed = lyrics_processed.replace('\n', '|')
    lyrics_processed = lyrics_processed.replace('_', '\'')
    lyrics_processed = lyrics_processed.replace('’', '\'')

print(lyrics_processed)

HERE|WE|GO,|OFF|THE|RAILS|DON'T|YOU|KNOW|IT'S|TIME|TO|RAISE|OUR|SAILS|IT'S|FREEDOM|LIKE|YOU|NEVER|KNEW||DON'T|NEED|BAGS,|OR|A|PASS|SAY|THE|WORD|I'LL|BE|THERE|IN|A|FLASH|YOU|COULD|SAY|MY|HAT|IS|OFF|TO|YOU||OH|WE|CAN|ZOOM|ALL|THE|WAY|TO|THE|MOON|FROM|THIS|GREAT|WIDE|WACKY|WORLD|JUMP|WITH|ME,|GRAB|COINS|WITH|ME|OH|YEAH!||IT'S|TIME|TO|JUMP|UP|IN|THE|AIR|(JUMP|UP|IN|THE|AIR)|JUMP|UP,|DON'T|BE|SCARED|(JUMP|UP,|DON'T|BE|SCARED)|JUMP|UP|AND|YOUR|CARES|WILL|SOAR|AWAY|AND|IF|THE|DARK|CLOUDS|START|TO|SWIRL|(DARK|CLOUDS|START|TO|SWIRL)|DON'T|FEAR,|DON'T|SHED|A|TEAR,|'CAUSE|I'LL|BE|YOUR|1UP|GIRL||SO|LET'S|ALL|JUMP|UP|SUPER|HIGH|(JUMP|UP|SUPER|HIGH)|HIGH|UP|IN|THE|SKY|(HIGH|UP|IN|THE|SKY)|THERE'S|NO|POWER-UP|LIKE|DANCING|YOU|KNOW|THAT|YOU'RE|MY|SUPERSTAR|(YOU'RE|MY|SUPERSTAR)|NO|ONE|ELSE|CAN|TAKE|ME|THIS|FAR|I'M|FLIPPING|THE|SWITCH|GET|READY|FOR|THIS|OH,|LET'S|DO|THE|ODYSSEY||ODYSSEY,|YA|SEE!|ODYSSEY,|YA|SEE!|ODYSSEY,|YA|SEE!|ODYSSEY,|YA|SEE!|ODYSSEY,|YA|SEE!|ODYSSEY,|YA|SEE!|ODYSSEY,|YA|SEE!|ODYSSE

# Load Model

In [6]:
# Load Model
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, logging, Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor

# Select Model
model_id = 'facebook/wav2vec2-large-960h-lv60-self'
if lang == "zh-CN":
    model_id = 'jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn'

# Load Model
Wav2Vec2_model = Wav2Vec2ForCTC.from_pretrained(model_id)
Wav2Vec2_processor = Wav2Vec2Processor.from_pretrained(model_id)

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-960h-lv60-self and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# Helper Function

In [7]:
def phoneme_recognizer(vocals, processor, model):
    vals = processor(vocals, return_tensors="pt", padding="longest", sampling_rate=16000).input_values
    duration_sec = vals.shape[1] / TARGET_SR
    with torch.no_grad():
        logits = model(vals).logits
    return logits.cpu().detach(), duration_sec

# Get Prediction

In [8]:
duration_sec = 0

# Segmentation
# Each column track_segments[:, i] contains a contiguous slice of the input track_vocal[i * hop_length : i * hop_length + frame_length]
track_segments = librosa.util.frame(track_vocal, frame_length=int(TARGET_SR * 15), hop_length=int(TARGET_SR * 15), axis=0)
print("Segementation from " + str(track_vocal.shape) + " to " + str(track_segments.shape))

# Get Raw Prediction
# logits: the vector of raw (non-normalized) predictions
duration_sec = 0
logits, duration_temp = phoneme_recognizer(track_segments[0], Wav2Vec2_processor, Wav2Vec2_model)
duration_sec += duration_temp
for seg in track_segments[1:]:
    logits_seg, duration_temp = phoneme_recognizer(seg, Wav2Vec2_processor, Wav2Vec2_model)
    duration_sec += duration_temp
    # Concatenates the logits
    logits = torch.cat((logits, logits_seg), dim=1)

# Normalize Results
emission = torch.log_softmax(logits, dim=-1)[0].cpu().detach()

# Get Prediction Results
pred = torch.argmax(logits, dim=-1)
transcription = Wav2Vec2_processor.batch_decode(pred)
print(transcription)


Segementation from (4152960,) to (17, 240000)
["BARAPARATA PAPARAT HERE WE GO OF THE RAILS YOU KNOW IT'S TIME TO RAISE OURSELVS IT'S FREEDOM LIKE YOU NEVER KNEW E AD O THE PAST SAY THE WORD I'LL BE ANA FLASH YOU COULD SAY MY HATES OYOU O WE CONSUME ON THE WAY TO THE MOONTHIS GREAT WIDE WAY WORLD JUMP WITH ME GROW HORNS WITH ME  IT'S TIME TO JUMP UP IN THE A JUMPUP DON'T BE STRAT TUMP UP AND YOUR CHAIRS WILL SOR AWAY AND IN THE DAROFTE O O IT I'LL BE O ONE AGIR SO LET OM JUMP UP SUPERA HIGH UP IN THE SKY THE SNOW OEOP MY ANY YOU KNOW THAT YOU'RE MY SUPERSTI NO WHENIN POPING THE PIGE GET READY FOR THELES O TE ODES I MUST SAY YE S I SAY HAS I SAY EA SA I SAY HEA SA I SAY HESAY  SAY HS SSS S SPIN THE WHEEL TAKE A CHANCE EVERY TURN HE STARTS A NEW ROMANCE A NEW WORLD'S CALLING OUT TO YOU TAKE A TURN OF THE PATH FIND A NEW ADDITION TO THE CAST YOU KNOW THAT ANY CAPTAIN NEEDS A CREW TAKEIN STAMO DIFFERENT POINTS OF VIEW JUMP WITH ME GRAND COINS WITH ME A COME ON JUMP U IN THE AREP WITHOUT A C

# Alignment

In [9]:
from dataclasses import dataclass
from typing import List

SEPARATOR = '|'


@dataclass
class Segment:
    label: str
    start: int
    end: int

@dataclass
class Point:
    token_index: int
    time_index: int
    prob: float

def seconds_to_lrc(seconds, is_word = True):
    minutes = int(seconds // 60)
    seconds = seconds % 60
    hundredths = int((seconds % 1) * 100)
    seconds = int(seconds)
    formated = f"{minutes:02d}:{seconds:02d}.{hundredths:02d}"
    return f"<{formated}>" if is_word else f"[{formated}]"

@dataclass
class Word:
    label: str
    start: float
    end: float

    def __repr__(self):
        return f"{seconds_to_lrc(self.start)} {self.label}"

class LrcFormatter():
    @staticmethod
    def words2lrc(words: List[Word], original_lyrics: str, lang="en-US"):
        lrc = ""
        counter = 0
        word_end = None
        for line in original_lyrics.splitlines():
            if line == '': continue
            if word_end:
                lrc += f"\n{seconds_to_lrc(word_end, False)}"
            else:
                lrc += "[00:00.00]"
            if lang == "en-US":
                splitted_words = line.split(' ')
            elif lang == "zh-CN":
                splitted_words = line

            for original_word in splitted_words:
                if original_word == '': continue
                word = words[counter]
                word.label = original_word
                lrc += f" {word}"
                word_end = word.end
                counter+=1
        return lrc

class Aligner():
    @staticmethod
    def align(emission, tokens, blank_id=0):
        trellis = Aligner.get_trellis(emission, tokens, blank_id=blank_id)
        path = Aligner.backtrack(trellis, emission, tokens, blank_id=blank_id)
        return path

    @staticmethod
    def get_trellis(emission, tokens, blank_id=0):
        num_frame = emission.size(0)
        num_tokens = len(tokens)
        trellis = torch.empty((num_frame + 1, num_tokens + 1))
        trellis[0, 0] = 0
        trellis[1:, 0] = torch.cumsum(emission[:, blank_id], 0)
        trellis[0, -num_tokens:] = -float("inf")
        trellis[-num_tokens:, 0] = float("inf")

        for t in range(num_frame):
            trellis[t + 1, 1:] = torch.maximum(
                # Score for staying at the same token
                trellis[t, 1:] + emission[t, blank_id],
                # Score for changing to the next token
                trellis[t, :-1] + emission[t, tokens],
            )
        return trellis

    @staticmethod
    def backtrack(trellis, emission, tokens, blank_id=0):
        j = trellis.size(1) - 1
        t_start = torch.argmax(trellis[:, j]).item()

        path = []
        for t in range(t_start, 0, -1):
            stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
            changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
            prob = emission[t - 1, tokens[j - 1]
                            if changed > stayed else 0].exp().item()
            path.append(Point(j - 1, t - 1, prob))
            if changed > stayed:
                j -= 1
                if j == 0:
                    break
        else:
            raise Exception("Failed")
        return path[::-1]

    def get_words_from_path(text, path, frame_duration):
        # Skip repeating char
        i1, i2 = 0, 0
        segments = []
        while i1 < len(path):
            while i2 < len(path) and path[i1].token_index == path[i2].token_index:
                i2 += 1
            segments.append(
                Segment(
                    text[path[i1].token_index],
                    path[i1].time_index,
                    path[i2 - 1].time_index + 1
                )
            )
            i1 = i2
        if lang == "en-US":
            return Aligner.merge_en(segments, frame_duration)
        elif lang == "zh-CN":
            return [Word(s.label, s.start * frame_duration, s.end * frame_duration) for s in segments]
        

    def merge_en(segments, frame_duration, separator=SEPARATOR):
        # Merge chars to word
        words = []
        i1, i2 = 0, 0
        while i1 < len(segments):
            if i2 >= len(segments) or segments[i2].label == separator:
                if i1 != i2:
                    segs = segments[i1:i2]
                    word = "".join([seg.label for seg in segs])
                    words.append(Word(
                        word, segments[i1].start * frame_duration, segments[i2 - 1].end * frame_duration))
                i1 = i2 + 1
                i2 = i1
            else:
                i2 += 1
        return words

In [10]:
# Prepares text labels for the CTC
lyrics_tokens = Wav2Vec2_processor.tokenizer(lyrics_processed).input_ids

path = Aligner.align(emission, lyrics_tokens)

words = Aligner.get_words_from_path(
            text=lyrics_processed, path=path, frame_duration=Wav2Vec2_model.config.inputs_to_logits_ratio / TARGET_SR)

# lrc
lrc = LrcFormatter.words2lrc(words, lyrics_plain)
# with open(f'results.lrc', 'w+') as fp:
#     fp.write(lrc)

In [11]:
words[0].label

'Here'