In [2]:
# imports
import os
import subprocess
from enum import Enum
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor

import ffmpeg
import numpy as np
import pandas as pd
from tqdm import tqdm

In [11]:
DATASET_ORIGIN      = Path("./datasets/CBU0521DD_stories")
DATASET_NORMALIZED  = Path("./datasets/normalized")
DATASET_DENOISED    = Path("./datasets/denoised")
DATASET_NORM_DENOISED = Path("./datasets/denoised_norm")
DATASET_SRT         = Path("./datasets/transcription")

DATASET_VOCAL_EMBED = Path("./datasets/vocal_embedded")
DATASET_TEXT_EMBED  = Path("./datasets/text_embedded")
DATASET_ATTR_CSV    = Path("./datasets/dataset_attr.csv")

dataset_attr_df = pd.read_csv(Path('CBU0521DD_stories_attributes.csv'))
dataset_attr_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 100 entries, 0 to 99
Data columns (total 3 columns):
 #   Column      Non-Null Count  Dtype 
---  ------      --------------  ----- 
 0   filename    100 non-null    object
 1   Language    100 non-null    object
 2   Story_type  100 non-null    object
dtypes: object(3)
memory usage: 2.5+ KB


In [7]:
dataset_attr_df.head()

Unnamed: 0,filename,Language,Story_type
0,00001.wav,Chinese,True Story
1,00002.wav,Chinese,True Story
2,00003.wav,Chinese,True Story
3,00004.wav,Chinese,True Story
4,00005.wav,Chinese,True Story


In [12]:
for i in range(len(dataset_attr_df)):
    dataset_attr_df.loc[i, "label"] = dataset_attr_df.loc[i, "Story_type"] == "True Story"
dataset_attr_df.to_csv(DATASET_ATTR_CSV, index=False)

In [3]:
def loudness_norm(audios: list, output: Path, out_fmt: str = "wav"):
    output.mkdir(parents=True, exist_ok=True)
    pbar = tqdm(total=len(audios))
    def norm(audio: Path):
        out_path = output.joinpath(audio.name)
        try:
            (
                ffmpeg
                    .input(str(audio), f=audio.suffix.lstrip('.'))
                    .filter("loudnorm")
                    .output(str(out_path), **{'ac': 1, 'ar': 48000, 'af': 'silence_threshold=-50dB:cutoff=0.01:lowpass=3000:highpass=200'}, f=out_fmt)
                    .overwrite_output()
                    .run()
            )
        except Exception as e:
            print(f"Error: {e} on audio {str(audio)}")
        pbar.update()

    with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
        results = executor.map(norm, audios)
        results = [result for result in results if result is not None]
        return results


In [4]:
############################################################
#            ____  ___    _   __________________           #
#           / __ \/   |  / | / / ____/ ____/ __ \          #
#          / / / / /| | /  |/ / / __/ __/ / /_/ /          #
#         / /_/ / ___ |/ /|  / /_/ / /___/ _, _/           #
#        /_____/_/  |_/_/ |_/\____/_____/_/ |_|            #
#               Time Consuming block(~ 2m)                 #
############################################################
original_audios = list(dataset_attr_df["filename"])
original_audios = [DATASET_ORIGIN.joinpath(audio) for audio in original_audios]
loudness_norm(original_audios, DATASET_NORMALIZED)

100%|██████████| 100/100 [01:43<00:00,  1.04s/it]


[None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None]

In [6]:
############################################################
#                       Vocal Denoise                      #
#            ____  ___    _   __________________           #
#           / __ \/   |  / | / / ____/ ____/ __ \          #
#          / / / / /| | /  |/ / / __/ __/ / /_/ /          #
#         / /_/ / ___ |/ /|  / /_/ / /___/ _, _/           #
#        /_____/_/  |_/_/ |_/\____/_____/_/ |_|            #
#               Time Consuming block(~ 30m)                #
############################################################
MSST_ROOT_PATH = Path("./tools/Music-Source-Separation-Training")
msst_infer_paths = {
    "python":       MSST_ROOT_PATH.joinpath("workenv/python.exe"),
    ".py":          MSST_ROOT_PATH.joinpath("inference.py"),
    "config":       MSST_ROOT_PATH.joinpath("configs/model_bs_roformer_ep_317_sdr_12.9755.yaml"),
    "input_dir":    DATASET_NORMALIZED,
    "out_dir":      DATASET_DENOISED,
    "model_ckpt":   MSST_ROOT_PATH.joinpath("pretrain/model_bs_roformer_ep_317_sdr_12.9755.ckpt"),
}

msst_infer_paths["out_dir"].mkdir(parents=True, exist_ok=True)

subprocess.run([
    msst_infer_paths["python"], "-u", msst_infer_paths[".py"],
    "--config_path",        str(msst_infer_paths["config"]),
    "--store_dir",          str(msst_infer_paths["out_dir"]),
    "--input_folder",       str(msst_infer_paths["input_dir"]),
    "--start_check_point",  str(msst_infer_paths["model_ckpt"]),
    "--model_type",         "bs_roformer",
    "--device_ids",         "0",
])

for audio in msst_infer_paths["out_dir"].iterdir():
    new_name = audio.name.rstrip("_vocals.wav") + ".wav"
    audio.rename(msst_infer_paths["out_dir"].joinpath(new_name))

"ok"

CompletedProcess(args=[WindowsPath('tools/Music-Source-Separation-Training/workenv/python.exe'), '-u', WindowsPath('tools/Music-Source-Separation-Training/inference.py'), '--config_path', 'tools\\Music-Source-Separation-Training\\configs\\model_bs_roformer_ep_317_sdr_12.9755.yaml', '--store_dir', 'datasets\\tmp', '--input_folder', 'datasets\\test', '--start_check_point', 'tools\\Music-Source-Separation-Training\\pretrain\\model_bs_roformer_ep_317_sdr_12.9755.ckpt', '--model_type', 'bs_roformer', '--device_ids', '0'], returncode=0)

In [7]:
############################################################
#            ____  ___    _   __________________           #
#           / __ \/   |  / | / / ____/ ____/ __ \          #
#          / / / / /| | /  |/ / / __/ __/ / /_/ /          #
#         / /_/ / ___ |/ /|  / /_/ / /___/ _, _/           #
#        /_____/_/  |_/_/ |_/\____/_____/_/ |_|            #
#               Time Consuming block(~ 2m)                 #
############################################################
denoised_audios = list(msst_infer_paths["out_dir"].iterdir())
loudness_norm(denoised_audios, DATASET_NORM_DENOISED)

100%|██████████| 100/100 [01:57<00:00,  1.18s/it]


[None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None]

In [4]:
# wrapping to https://github.com/openai/whisper
from datetime import timedelta
class Timing:
    _hours      :int = None
    _minutes    :int = None
    _seconds    :int = None
    _mseconds   :int = None
    _origin     :float = None

    def __init__(self, seconds: float):
        self._origin = seconds
        delta = timedelta(seconds=seconds)
        total_seconds = int(delta.total_seconds())
        hours, remainder = divmod(total_seconds, 3600)
        minutes, seconds_int = divmod(remainder, 60)

        self._hours = hours
        self._minutes = minutes
        self._seconds = seconds_int
        self._mseconds = int((seconds - total_seconds) * 1000)


    def to_string(self) -> str:
        return f"{self._hours:02}:{self._minutes:02}:{self._seconds:02},{self._mseconds:03}"

    def __str__(self):
        return self.to_string()

    def __eq__(self, time):
        if not isinstance(time, Timing): return False
        return (self._origin - time._origin) < 0.0001

    def __lt__(self, time):
        if not isinstance(time, Timing): return None
        return (self._origin - time._origin) < 0

    def __gt__(self, time):
        if not isinstance(time, Timing): return None
        return (self._origin - time._origin) > 0

    def hours(self)-> int:
        return self._hours

    def minutes(self) -> int:
        return self._minutes

    def seconds(self) -> int:
        return self._seconds

    def milliseconds(self) -> int:
        return self._mseconds

import whisper
from pydub import AudioSegment

class VttModelType(Enum):
    TINY = 0
    BASE = 1
    SMALL = 2
    MEDIUM = 3
    LARGE = 4
    TURBO = 5

    @staticmethod
    def names():
        return "tiny", "base", "small", "medium", "large", "turbo"

    def __init__(self, value):
        if value >= len(self.names()) or self.value < 0:
            raise ValueError("Invalid model size")
        self.val = value

    @staticmethod
    def from_string(s: str):
        try:
            idx = VttModelType.names().index(s)
        except ValueError:
            raise ValueError("Invalid model size")

        return VttModelType(idx)

    def to_string(self):
        return self.names()[self.val]

    def __str__(self):
        return self.to_string()

class VTT:
    def __init__(self, model_size: VttModelType, init_model: bool = True):
        self.model_size = model_size
        if init_model: self.model_init()

    def model_init(self, in_memory: bool = True):
        self.model = whisper.load_model(self.model_size.to_string(), in_memory=in_memory)

    def transcribe(self, audio_path: Path, lang: str = "zh", verbose=True) -> map:
        verbose = False if verbose is True else None
        result = self.model.transcribe(str(audio_path), verbose=verbose, language=lang)
        return map(lambda r: {"start": r["start"], "end": r["end"], "text": r["text"]}, result["segments"])

    @staticmethod
    def get_audio_time(audio_path: Path) -> Timing:
        audio = AudioSegment.from_file(audio_path, format=audio_path.suffix[1:])
        return Timing(float(len(audio)) / 1000)
    @staticmethod
    def write_srt(segments: map, audio_len: Timing, output_path: Path):
        with output_path.open("w", encoding="utf-8") as f:
            for i, segment in enumerate(segments):
                start = Timing(segment["start"])
                end = Timing(segment["end"])
                if end > audio_len: break
                f.write(f"{i+1}\n{start.to_string()} --> {end.to_string()}\n{segment['text']}\n\n")

    def transcribe_to_srt(self, audio: Path, srt: Path, lang: str = "Mandarin", verbose=True):
        if not audio.exists(): raise ValueError("Audio file does not exist")
        if not srt.exists(): srt.touch()
        # print("[INFO] Parsing audio...")
        audio_len = VTT.get_audio_time(audio)
        # print("[INFO] Start transcribing...")
        segments = self.transcribe(audio, lang, verbose)
        # print("[INFO] Transcribe done, Writing SRT file...")
        VTT.write_srt(segments, audio_len, srt)
        # print("[INFO] SRT file written!")

def srt_read_to_str_list(srt: Path) -> list[str]:
    with open(srt, "r", encoding="utf-8") as f:
        res = []
        line_code = 1
        state = 0       # 0: line_code, 1: time_duration, 2: text
        while True:
            line = f.readline()
            if not line: break
            line = line.rstrip()
            match state:
                case 0:
                    if line.isdigit():
                        if line_code != int(line): continue
                        line_code += 1
                        state = 1
                case 1:
                    if state == 1:
                        state = 2
                case 2:
                    if state == 2:
                        res.append(line)
                        state = 0
                case x:
                    raise ValueError(f"Invalid state {x}")
        return res


In [5]:
dataset_attr_df

Unnamed: 0,filename,Language,Story_type
0,00001.wav,Chinese,True Story
1,00002.wav,Chinese,True Story
2,00003.wav,Chinese,True Story
3,00004.wav,Chinese,True Story
4,00005.wav,Chinese,True Story
...,...,...,...
95,00096.wav,English,True Story
96,00097.wav,Chinese,Deceptive Story
97,00098.wav,Chinese,True Story
98,00099.wav,Chinese,Deceptive Story


In [14]:
############################################################
#            ____  ___    _   __________________           #
#           / __ \/   |  / | / / ____/ ____/ __ \          #
#          / / / / /| | /  |/ / / __/ __/ / /_/ /          #
#         / /_/ / ___ |/ /|  / /_/ / /___/ _, _/           #
#        /_____/_/  |_/_/ |_/\____/_____/_/ |_|            #
#               Time Consuming block(~ 4h)                 #
############################################################
# Voice Transcription
vtt = VTT(VttModelType.LARGE)
DATASET_SRT.mkdir(parents=True, exist_ok=True)
for i in tqdm(range(len(dataset_attr_df))):
    filename = dataset_attr_df.iloc[i]["filename"]
    audio_path = DATASET_NORM_DENOISED.joinpath(filename)
    srt_path = DATASET_SRT.joinpath(".".join(filename.split(".")[:-1]) + ".srt")
    audio_lang = "zh" if dataset_attr_df.iloc[i]["Language"] == "Chinese" else "en"
    vtt.transcribe_to_srt(audio_path, srt_path, audio_lang)

del vtt

  checkpoint = torch.load(fp, map_location=device)
  0%|          | 0/100 [00:00<?, ?it/s]
  0%|          | 0/20505 [00:00<?, ?frames/s][A
 14%|█▍        | 2826/20505 [00:25<02:38, 111.37frames/s][A
 14%|█▍        | 2826/20505 [00:39<02:38, 111.37frames/s][A
 28%|██▊       | 5642/20505 [00:50<02:13, 111.47frames/s][A
 28%|██▊       | 5642/20505 [01:09<02:13, 111.47frames/s][A
 42%|████▏     | 8642/20505 [01:14<01:41, 116.95frames/s][A
 42%|████▏     | 8642/20505 [01:29<01:41, 116.95frames/s][A
 57%|█████▋    | 11642/20505 [01:37<01:11, 123.48frames/s][A
 57%|█████▋    | 11642/20505 [01:49<01:11, 123.48frames/s][A
 71%|███████▏  | 14612/20505 [02:01<00:47, 123.05frames/s][A
 71%|███████▏  | 14612/20505 [02:19<00:47, 123.05frames/s][A
 86%|████████▌ | 17612/20505 [02:25<00:23, 123.55frames/s][A
 86%|████████▌ | 17612/20505 [02:39<00:23, 123.55frames/s][A
100%|██████████| 20505/20505 [02:44<00:00, 124.41frames/s][A
 51%|█████     | 51/100 [02:45<02:38,  3.24s/it]
  0%|      

In [6]:
############################################################
#            ____  ___    _   __________________           #
#           / __ \/   |  / | / / ____/ ____/ __ \          #
#          / / / / /| | /  |/ / / __/ __/ / /_/ /          #
#         / /_/ / ___ |/ /|  / /_/ / /___/ _, _/           #
#        /_____/_/  |_/_/ |_/\____/_____/_/ |_|            #
#               Time Consuming block(~ 20m)                #
############################################################
# Vocal Tokenize & Embedding
WAV_TKNZR_ROOT_PATH = Path("./tools/WavTokenizer")
# Original WavTokenizer does not have a satisfying CLI inference interface
# Personal realization as below

wav_tknzr_infer_paths = {
    "python":       WAV_TKNZR_ROOT_PATH.joinpath("workenv/python.exe"),
    ".py":          WAV_TKNZR_ROOT_PATH.joinpath("inference.py"),
    "config":       WAV_TKNZR_ROOT_PATH.joinpath("configs/wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"),
    "input_dir":    DATASET_NORM_DENOISED,
    "out_dir":      DATASET_VOCAL_EMBED,
    "model_ckpt":   WAV_TKNZR_ROOT_PATH.joinpath("pretrain/wavtokenizer_large_speech_320_24k.ckpt"),
}

wav_tknzr_infer_paths["out_dir"].mkdir(parents=True, exist_ok=True)

subprocess.run([
    wav_tknzr_infer_paths["python"], wav_tknzr_infer_paths[".py"],
    "--config_path",        str(wav_tknzr_infer_paths["config"]),
    "--store_dir",          str(wav_tknzr_infer_paths["out_dir"]),
    "--input_folder",       str(wav_tknzr_infer_paths["input_dir"]),
    "--start_check_point",  str(wav_tknzr_infer_paths["model_ckpt"]),
])


CompletedProcess(args=[WindowsPath('tools/WavTokenizer/workenv/python.exe'), WindowsPath('tools/WavTokenizer/inference.py'), '--config_path', 'tools\\WavTokenizer\\configs\\wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml', '--store_dir', 'datasets\\vocal_embedded', '--input_folder', 'datasets\\denoised_norm', '--start_check_point', 'tools\\WavTokenizer\\pretrain\\wavtokenizer_large_speech_320_24k.ckpt'], returncode=0)

In [24]:
all_vocal_tokens = np.asarray([])
for token in tqdm(DATASET_VOCAL_EMBED.iterdir()):
    all_vocal_tokens = np.concatenate((all_vocal_tokens, np.load(token)[0][0]), axis=0)
print(all_vocal_tokens.shape)

token_dict = {}
for token in all_vocal_tokens:
    if token not in token_dict:
        token_dict[token] = 1
        continue
    token_dict[token] += 1

token_dict = sorted(token_dict.items(), key=lambda x: x[1], reverse=True)
print(len(token_dict))

100it [00:00, 855.88it/s]


(1046664,)
1824


In [18]:
import re
import torch
import torch.utils.data as Data
from transformers.modeling_utils import SpecificPreTrainedModelType
from transformers import BertTokenizer, BertModel, PreTrainedTokenizerBase

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [40]:
from torch import Tensor
import torch.nn.functional as F

def is_punctuation(s: str):
    pattern = r'[!"#$%&\'()*+,-./:;<=>?@[\\\]^_`{|}~！“”‘’（）【】《》〈〉；：，。？、]$'
    return bool(re.search(pattern, s))

# def get_text_embedding_bert(tknzr: PreTrainedTokenizerBase, model: SpecificPreTrainedModelType, sentences: list[str], lang: str = "zh"):
#     text = "[CLS] "
#     punctuation_dict = { "zh": "，", "en": "," }
#     for sentence in sentences:
#         sentence = sentence.strip()
#         sentence = sentence.replace("[CLS]", "")
#         sentence = sentence.replace("[SEP]", "")
#         text += sentence
#         if not is_punctuation(sentence):
#             text += punctuation_dict[lang]
#         text += "[SEP]"
#
#     inputs = tknzr(text, return_tensors="pt").to(DEVICE)
#     with torch.no_grad():
#         outputs = model(**inputs)
#     return outputs

def get_text_embedding_qwen(tknzr: PreTrainedTokenizerBase, model: SpecificPreTrainedModelType, sentences: list[str], lang: str = "zh"):
    def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
        left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
        if left_padding:
            return last_hidden_states[:, -1]
        else:
            sequence_lengths = attention_mask.sum(dim=1) - 1
            batch_size = last_hidden_states.shape[0]
            return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]


    text = ""
    punctuation_dict = { "zh": "，", "en": "," }
    for sentence in sentences:
        sentence = sentence.strip()
        text += sentence
        if not is_punctuation(sentence):
            text += punctuation_dict[lang]
    inputs = tknzr(text, max_length=2048, padding=True, truncation=True, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        outputs = model(**inputs)
        _pool = last_token_pool(outputs.last_hidden_state, inputs['attention_mask'])
        _pool = F.normalize(_pool, p=2, dim=1)
        _embedding = F.normalize(outputs.last_hidden_state, p=2, dim=1)
        # scores = (embeddings[:2] @ embeddings[2:].T) * 100
    return _embedding, _pool, inputs["attention_mask"]

# loader = Data.DataLoader(MyDataSet(enc_inputs, dec_inputs, dec_outputs), 2, True)
# class TextDataset(Data.dataset):
#     def __init__(self, srt_path: Path, transform=None):
#         super(TextDataset, self).__init__()
#         self.srt_paths = [p for p in srt_path.iterdir() if p.suffix == ".srt"]
#         self.transform = transform
#
#     def __len__(self):
#         return len(self.srt_paths)
#
#     def __getitem__(self, idx):
#         srt_path = self.srt_paths[idx]

In [44]:
############################################################
#            ____  ___    _   __________________           #
#           / __ \/   |  / | / / ____/ ____/ __ \          #
#          / / / / /| | /  |/ / / __/ __/ / /_/ /          #
#         / /_/ / ___ |/ /|  / /_/ / /___/ _, _/           #
#        /_____/_/  |_/_/ |_/\____/_____/_/ |_|            #
#               Time Consuming block(~ 4m)                 #
############################################################
# tokenizer_zh    = BertTokenizer.from_pretrained('google-bert/bert-base-chinese')
# tokenizer_en    = BertTokenizer.from_pretrained('google-bert/bert-base-uncased')
# bert_model_zh   = BertModel.from_pretrained('google-bert/bert-base-chinese').to(DEVICE)
# bert_model_en   = BertModel.from_pretrained('google-bert/bert-base-uncased').to(DEVICE)

# Load model directly
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("Alibaba-NLP/gte-Qwen2-1.5B-instruct", trust_remote_code=True)
model = AutoModel.from_pretrained("Alibaba-NLP/gte-Qwen2-1.5B-instruct", trust_remote_code=True).to(DEVICE)

print("Qwen Model Loaded")
DATASET_TEXT_EMBED.mkdir(parents=True, exist_ok=True)

for srt in tqdm(list(DATASET_SRT.iterdir())):
    lang = "zh" if dataset_attr_df.loc[dataset_attr_df["filename"] == srt.stem + ".wav", "Language"].item() == "Chinese" else "en"
    sentences = srt_read_to_str_list(srt)
    embedding, pool, _ = get_text_embedding_qwen(tokenizer, model, sentences, "zh")
    save_path = DATASET_TEXT_EMBED.joinpath(srt.stem + ".npz")
    np.savez(save_path, embedding=embedding.cpu().numpy(), pool=pool.cpu().numpy())


Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  5.09it/s]


Qwen Model Loaded


100%|██████████| 100/100 [03:00<00:00,  1.81s/it]


In [13]:
DATASET_MIXED_EMBED = Path("./datasets/embedded")
DATASET_MIXED_EMBED.mkdir(parents=True, exist_ok=True)

max_len = 0
for filename in tqdm(dataset_attr_df["filename"]):
    filename_stem = ".".join(filename.split(".")[:-1])
    vocal_embedded_path = DATASET_VOCAL_EMBED.joinpath(filename_stem + ".npy")
    text_embedded_path = DATASET_TEXT_EMBED.joinpath(filename_stem + ".npz")
    vocal_embed = np.load(vocal_embedded_path)
    text_embed = np.load(text_embedded_path)["embedding"]
    text_embed = np.transpose(text_embed, (0, 2, 1))

    dim_vocal   = vocal_embed.shape[1]
    dim_text    =  text_embed.shape[1]
    seq_len_vocal   = vocal_embed.shape[2]
    seq_len_text    =  text_embed.shape[2]

    total_dim = dim_vocal + dim_text
    seq_len = seq_len_vocal + seq_len_text

    mixed_embed = np.zeros((vocal_embed.shape[0], total_dim, seq_len))
    mixed_embed[:, :dim_vocal, :seq_len_vocal] = vocal_embed
    mixed_embed[:, dim_vocal:, seq_len_vocal:] = text_embed

    # vocal_embed.shape (bs, dim, sequence_len)
    # text_embed.shape  (bs, sequence_len, dim)

    np.save(DATASET_MIXED_EMBED.joinpath(filename_stem + ".npy"), mixed_embed)
    max_len = max(max_len, mixed_embed.shape[2])

100%|██████████| 100/100 [00:20<00:00,  4.94it/s]


In [14]:
max_len

16300

(9157, 1, 2048)




KeyboardInterrupt: 