<a href="https://colab.research.google.com/github/Huang-Yongzhi/musiclm-pytorch/blob/main/musiclm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

如果重新运行，需要重启，不然trainer会报错，认为有多个实例

# 一、加载数据
## 1.安装必要的包

In [1]:
!pip install you-get
!pip install yt-dlp
!sudo apt-get install ffmpeg

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
ffmpeg is already the newest version (7:4.4.2-0ubuntu0.22.04.1).
0 upgraded, 0 newly installed, 0 to remove and 19 not upgraded.


## 2.加载数据集

**数据集内容**

调用的.csv文件内容如下
```
ytid,start_s,end_s,audioset_positive_labels,aspect_list,caption,author_id,is_balanced_subset,is_audioset_eval
-0Gj8-vB1q4,30,40,"/m/0140xf,/m/02cjck,/m/04rlf","['low quality', 'sustained strings melody', 'soft female vocal', 'mellow piano melody', 'sad', 'soulful', 'ballad']","The low quality recording features a ballad song that contains sustained strings, mellow piano melody and soft female vocal singing over it. It sounds sad and soulful, like something you would hear at Sunday services.",4,False,True
...
```

In [2]:
import pandas as pd

#  GitHub 上的 .csv 文件的 Raw 链接
url = "https://raw.githubusercontent.com/Huang-Yongzhi/musiclm-pytorch/c91b9f96775751aefa6f507fb304e7fd12182bf8/Data/musiccaps-public.csv"

df = pd.read_csv(url)

# 显示数据以验证加载正确
print(df.head())

# 指定文件保存路径
file_path = '/content/musiccaps-public.csv'

# 将 DataFrame 保存为 CSV 文件
df.to_csv(file_path, index=False)  # 设置index=False以防止添加额外的索引列


# 检查文件是否已保存
!ls /content


          ytid  start_s  end_s  \
0  -0Gj8-vB1q4       30     40   
1  -0SdAVK79lg       30     40   
2  -0vPFx-wRRI       30     40   
3  -0xzrMun0Rs       30     40   
4  -1LrH01Ei1w       30     40   

                            audioset_positive_labels  \
0                       /m/0140xf,/m/02cjck,/m/04rlf   
1  /m/0155w,/m/01lyv,/m/0342h,/m/042v_gx,/m/04rlf...   
2                                /m/025_jnm,/m/04rlf   
3                                 /m/01g90h,/m/04rlf   
4                                /m/02p0sh1,/m/04rlf   

                                         aspect_list  \
0  ['low quality', 'sustained strings melody', 's...   
1  ['guitar song', 'piano backing', 'simple percu...   
2  ['amateur recording', 'finger snipping', 'male...   
3  ['backing track', 'jazzy', 'digital drums', 'p...   
4  ['rubab instrument', 'repetitive melody on dif...   

                                             caption  author_id  \
0  The low quality recording features a ballad so...  

**解释**：
数据集是一个包含音频信息和描述的元数据文件，格式类似于 CSV。每行包含一个 YouTube 音频的标识符（ytid），音频的开始和结束时间（start_s 和 end_s），音频标签（audioset_positive_labels）和其他相关信息。

使用如 youtube-dl 这类工具来下载视频，然后使用音频处理库（例如 librosa 或 pydub）来裁剪音频。以下是一个大致的步骤指南：

## 使用 yt-dlp
下载 YouTube 音频
首先，您需要安装 yt-dlp。在 Colab 中，可以使用以下命令安装：

In [3]:
# !pip install yt-dlp

测试yt-dlp

In [4]:
# !yt-dlp -f140 -x --audio-format mp3 https://www.youtube.com/watch?v=-0vPFx-wRRI

In [5]:
import subprocess
import os
import glob

ytid = "-0xzrMun0Rs"  # 示例 YouTube 视频ID
start_s = 10  # 裁剪开始时间（秒）
end_s = 20    # 裁剪结束时间（秒）
audio_output_dir = '/content/downloaded_audios'  # 音频输出目录

# 确保输出目录存在
os.makedirs(audio_output_dir, exist_ok=True)

video_url = f'https://www.youtube.com/watch?v={ytid}'
temp_audio_path_pattern  = os.path.join(audio_output_dir, f'{ytid}_temp.*')
output_audio_path = os.path.join(audio_output_dir, f'{ytid}.wav')


try:
    # 下载音频
    subprocess.run(['yt-dlp', '-x', '--audio-format', 'wav', '-o', temp_audio_path_pattern, video_url], check=True)

    # 查找下载的音频文件
    downloaded_files = glob.glob(temp_audio_path_pattern)
    if not downloaded_files:
        raise Exception("Downloaded audio file not found.")
    downloaded_audio_path = downloaded_files[0]  # 取得实际下载的文件路径

    # 使用 ffmpeg 裁剪音频
    subprocess.run(['ffmpeg', '-i', downloaded_audio_path, '-ss', str(start_s), '-to', str(end_s), '-c', 'copy', output_audio_path], check=True)
    os.remove(downloaded_audio_path)  # 删除临时文件
    print(f"Audio downloaded and trimmed: {output_audio_path}")
except subprocess.CalledProcessError as e:
    print(f"Error: {e}")


Error: Command '['ffmpeg', '-i', '/content/downloaded_audios/-0xzrMun0Rs_temp.*.wav', '-ss', '10', '-to', '20', '-c', 'copy', '/content/downloaded_audios/-0xzrMun0Rs.wav']' returned non-zero exit status 1.


在 yt-dlp 命令中使用了 --postprocessor-args 选项，但没有指定具体哪个后处理器（post-processor）应该使用这些参数。
我们需要参数应用于 ffmpeg 音频裁剪处理器。

In [6]:
import subprocess
import os
import pandas as pd
from concurrent.futures import ThreadPoolExecutor

def download_audio(ytid, audio_output_dir):
    video_url = f'https://www.youtube.com/watch?v={ytid}'
    temp_audio_path_pattern  = os.path.join(audio_output_dir, f'{ytid}_temp.*')
    output_audio_path = os.path.join(audio_output_dir, f'{ytid}.wav')

    try:
        # 下载音频
        subprocess.run(['yt-dlp', '-x', '--audio-format', 'wav', '-o', temp_audio_path_pattern, video_url], check=True)
        # 查找下载的音频文件
        downloaded_files = glob.glob(temp_audio_path_pattern)
        if not downloaded_files:
            raise Exception("Downloaded audio file not found.")
        return downloaded_files[0], output_audio_path  # 返回实际下载的文件路径
    except subprocess.CalledProcessError as e:
        print(f"Error downloading audio for {ytid}: {e}")
        return None, None

def trim_audio(temp_audio_path, output_audio_path, start_s, end_s):
    if not os.path.exists(temp_audio_path):
        raise FileNotFoundError(f"File not found: {temp_audio_path}")

    try:
        # 使用 ffmpeg 裁剪音频
        subprocess.run(['ffmpeg', '-i', temp_audio_path, '-ss', str(start_s), '-to', str(end_s), '-c', 'copy', output_audio_path], check=True)
        os.remove(temp_audio_path)  # 删除临时文件
        print(f"Trimmed audio to {output_audio_path}.")
    except subprocess.CalledProcessError as e:
        print(f"Error trimming audio: {e}\nOutput: {e.stdout.decode()}\nError: {e.stderr.decode()}")

def download_and_trim_audio(ytid, start_s, end_s, audio_output_dir):
    temp_audio_path, output_audio_path = download_audio(ytid, audio_output_dir)
    if temp_audio_path and output_audio_path:
        trim_audio(temp_audio_path, output_audio_path, start_s, end_s)

# 加载CSV文件
csv_file = '/content/musiccaps-public.csv'  # colab
# csv_file = '/kaggle/input/musiccaps/musiccaps-public.csv'  # kaggle

df = pd.read_csv(csv_file)

audio_output_dir = './downloaded_audios'  # 音频输出目录

# 确保输出目录存在
os.makedirs(audio_output_dir, exist_ok=True)

# 使用线程池
with ThreadPoolExecutor(max_workers=10) as executor:
    for index, row in df.iterrows():
        if index >= 20:
            break
        executor.submit(download_and_trim_audio, row['ytid'], row['start_s'], row['end_s'], audio_output_dir)


## 2. 使用Youtube-dl会报错,改用you-get
### 但是you-get不能给出固定文件名，不适合多线程

In [7]:
# !pip install you-get

测试一下You-Get

In [8]:
# !you-get -i 'https://www.youtube.com/watch?v=jNQXAC9IVRw'

检查可用格式：运行 you-get 命令带 -i 选项（用于信息查看模式），查看该视频支持的所有可用格式。这样可以帮助您了解是否有特定的音频格式可供下载。执行命令如下：

In [9]:
# !you-get -i "https://www.youtube.com/watch?v=-0Gj8-vB1q4"


查看命令行是否正确

In [10]:
# !you-get --no-caption -o "./downloaded_videos" --itag=160 "https://www.youtube.com/watch?v=-0Gj8-vB1q4"


##  提取音频
安装**ffmpeg** 或其他类似工具来从下载的视频文件中提取音频。

In [11]:
# !sudo apt-get install ffmpeg


## 使用无损编码器
编解码器是 aac，这是 .m4a 文件的典型编解码器。

要生成 .wav 文件，应该使用无损的编解码器（例如，pcm_s16le）

In [12]:
# import subprocess
# import librosa
# import soundfile as sf
# import os
# import pandas as pd
# import glob # 用于文件路径名的模式匹配
# from datetime import datetime
# from concurrent.futures import ThreadPoolExecutor


# def get_latest_file_in_dir(directory):
#     """ 获取指定目录中最新的文件 """
#     list_of_files = glob.glob(os.path.join(directory, '*'))
#     if not list_of_files:  # 如果目录为空
#         return None
#     latest_file = max(list_of_files, key=os.path.getmtime)
#     return latest_file



# def download_lowest_resolution_video(ytid, video_output_dir):
#     video_url = f'https://www.youtube.com/watch?v={ytid}'
#     try:
#       # 使用 you-get 下载分辨率最低的视频
#         subprocess.run(['you-get', '--no-caption', '-o', video_output_dir, '--itag=160', video_url], check=True)
#         print (f"Download {ytid} video.")
#     except subprocess.CalledProcessError as e:
#         print(f"Error downloading video {ytid}: {e}")
#         return None
#     # 查找下载的视频文件
#     return get_latest_file_in_dir(video_output_dir)



# def extract_audio_from_video(video_path, output_audio_path):
#     try:
#         result = subprocess.run(['ffmpeg', '-i', video_path, '-vn', '-acodec', 'pcm_s16le', output_audio_path], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
#         file_name = video_path.split('/')[-1]
#         print (f"Extract {file_name} video.")
#     except subprocess.CalledProcessError as e:
#         print(f"Error extracting audio from video {video_path}: {e}\nOutput: {e.stdout.decode()}\nError: {e.stderr.decode()}")



# def download_and_extract(ytid, video_output_dir, audio_output_dir):
#     downloaded_video = download_lowest_resolution_video(ytid, video_output_dir) # 下载视频
#     if downloaded_video:
#         audio_path = os.path.join(audio_output_dir, f'{ytid}.wav')
#         extract_audio_from_video(downloaded_video, audio_path)


# # 加载CSV文件
# csv_file = '/kaggle/input/musiccaps/musiccaps-public.csv' # musiccaps-public.csv
# df = pd.read_csv(csv_file)

# video_output_dir = './downloaded_videos' # .downloaded_videos
# audio_output_dir = './downloaded_audios' # ./downloaded_audios

# # 确保输出目录存在
# os.makedirs(video_output_dir, exist_ok=True)
# os.makedirs(audio_output_dir, exist_ok=True)

# # 初始化 test_n
# # test_n = 0

# # 遍历CSV文件，下载视频并提取音频
# # 使用线程池
# with ThreadPoolExecutor(max_workers=5) as executor:
#     for index, row in df.iterrows():
#         if index >= 20:
#             break
#         executor.submit(download_and_extract, row['ytid'], video_output_dir, audio_output_dir)
# #         if test_n >= 20:
# #             break
# #         test_n += 1

# 二、下载训练模型的文件内容

hubert_base_ls960.pt 文件是一个预训练的模型权重文件，用于 **HuBERT （Hidden Unit BERT）模型**。HuBERT 是由Facebook AI 研究团队开发的一种**自监督学习的语音识别模型**。它是基于 BERT 架构的，专门针对语音处理任务进行了优化。

In [13]:
import requests

def download_file(url, filename):
    response = requests.get(url)
    response.raise_for_status()  # 检查请求是否成功

    with open(filename, 'wb') as f:
        f.write(response.content)

# 设置文件的URL和你想要保存的文件名
file_url = "https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt"
file_name = "hubert_base_ls960.pt"

# 下载文件
download_file(file_url, file_name)

# 设置文件的URL和你想要保存的文件名
file_url = "https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960_L9_km500.bin"
file_name = "hubert_base_ls960_L9_km500.bin"

# 下载文件
download_file(file_url, file_name)



SemanticTransformerTrainer（这可能是一个音频处理或自然语言处理相关的训练器）

In [14]:
# 不是我们要的音频链接
# import requests

# url = "https://github.com/hsfzxjy/models.storage/releases/download/HRNet-OCR/hrnet_cs_8090_torch11.pth"
# response = requests.get(url)
# response.raise_for_status()

# file_name = url.split('/')[-1]

# with open(file_name, 'wb') as f:
#   f.write(response.content)






# 前面都是数据准备，现在才是模型相关的。安装MusicLM必要的包

In [15]:
# 重启也要运行
!pip install musiclm-pytorch
!pip install --upgrade tensorflow tensorflow-io
!pip install audiolm_pytorch



# Usage
`MuLaN` first needs to be trained

In [16]:
# 重启也要运行

import array
import torchaudio
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy
import soundfile
from scipy.io.wavfile import read
import torch
from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer
import os
import pathlib
import numpy
import pandas
from musiclm_pytorch import MuLaNEmbedQuantizer


# 验证环境
如果没有CUDA环境，就要将后面的.cuda()去掉

In [17]:
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
print(device)

cuda


# 加载三个transformer
`**MusicLM**` 用了三个transformer，`MuLan(Audio)`, `w2v-BERT`, `SoundStream`
![image.png](attachment:b35cb9db-0628-477f-af00-cf4b6437741d.png)

**声学模块SoundStream**：端到端神经音频编解码器，能提供更高质量的音频，并扩展至编码不同的声音类型

**语义模块w2v-BERT**：使用该模型的掩码语言建模(MLM)模块的中间层。在预训练和冻结模型之后，从第7层提取embedding，并使用学习到的k-means质心对embedding进行量化。该模块主要起到提取语义词元的作用

**音频文本对MuLan**：采用双塔并行编码器架构，使用对比损失进行训练，在音乐音频和文本之间形成共享嵌入空间。即`音频信号和文本语义都具有基于离散标记的同质表示`


In [18]:
# 重启也要运行
import torch
from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer

audio_transformer = AudioSpectrogramTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64,
    spec_n_fft = 128,
    spec_win_length = 24,
    spec_aug_stretch_factor = 0.8
)

text_transformer = TextTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64
)

mulan = MuLaN(
    audio_transformer = audio_transformer,
    text_transformer = text_transformer
)

# get a ton of <sound, text> pairs and train

wavs = torch.randn(2, 1024)
texts = torch.randint(0, 20000, (2, 256))

loss = mulan(wavs, texts)
loss.backward()

# after much training, you can embed sounds and text into a joint embedding space
# for conditioning the audio LM

embeds = mulan.get_audio_latents(wavs)  # during training

embeds = mulan.get_text_latents(texts)  # during inference

spectrogram yielded shape of (65, 86), but had to be cropped to (64, 80) to be patchified for transformer


To obtain the conditioning embeddings for the three transformers that are a part of AudioLM, you must use the `MuLaNEmbedQuantizer` as so

# MusicLM

In [19]:
# 重启也要运行
from musiclm_pytorch import MuLaNEmbedQuantizer

# setup the quantizer with the namespaced conditioning embeddings, unique per quantizer as well as namespace (per transformer)

quantizer = MuLaNEmbedQuantizer(
    mulan = mulan,                          # pass in trained mulan from above
    conditioning_dims = (1024, 1024, 1024), # say all three transformers have model dimensions of 1024
    namespaces = ('semantic', 'coarse', 'fine')
)

# now say you want the conditioning embeddings for semantic transformer

wavs = torch.randn(2, 1024)
conds = quantizer(wavs = wavs, namespace = 'semantic') # (2, 8, 1024) - 8 is number of quantizers

# 5.train Hubert模型，audioLM transformer training, train AudioLM,
To train (or finetune) the three transformers that are a part of `AudioLM`, you simply follow the instructions over at `audiolm-pytorch` for training, but pass in the `MulanEmbedQuantizer` instance to the training classes under the keyword `audio_conditioner`

ex. `SemanticTransformerTrainer`

验证路径是否正确

In [20]:
import os
print("Current working directory:", os.getcwd())
print("Files in './downloaded_audios':", os.listdir('./downloaded_audios'))


Current working directory: /content
Files in './downloaded_audios': ['-5f6hjZf9Yw.wav', '-7B9tPuIP-w.wav', '-0xzrMun0Rs_temp.*.wav', '-1UWSisR2zo_temp.*.wav', '-88me9bBzrk_temp.*.wav', '-5f6hjZf9Yw_temp.*.wav', '-5FoeegAgvU_temp.*.wav', '-0vPFx-wRRI_temp.*.wav', '-1LrH01Ei1w.wav', '-0SdAVK79lg.wav', '-7wUQP6G5EQ.wav', '-6HBGg1cAI0_temp.*.wav', '-4NLarMj4xU.wav', '-5xOcMJpTUk.wav', '-6pcgdLfb_A.wav', '-0SdAVK79lg_temp.*.wav', '-7B9tPuIP-w_temp.*.wav', '-6QGvxvaTkI_temp.*.wav', '-7wUQP6G5EQ_temp.*.wav', '-0xzrMun0Rs.wav', '-6pcgdLfb_A_temp.*.wav', '-0Gj8-vB1q4_temp.*.wav', '-5xOcMJpTUk_temp.*.wav', '-8C-gydUbR8.wav', '-5FoeegAgvU.wav', '-4SYC2YgzL8_temp.*.wav', '-6HBGg1cAI0.wav', '-1OlgJWehn8.wav', '-1UWSisR2zo.wav', '-1LrH01Ei1w_temp.*.wav', '-88me9bBzrk.wav', '-3Kv4fdm7Uk_temp.*.wav', '-4SYC2YgzL8.wav', '-6QGvxvaTkI.wav', '-4NLarMj4xU_temp.*.wav', '-0vPFx-wRRI.wav', '-1OlgJWehn8_temp.*.wav', '-8C-gydUbR8_temp.*.wav', '-3Kv4fdm7Uk.wav', '-0Gj8-vB1q4.wav']



函数封装：将创建和使用训练器的代码封装在一个函数中。这有助于在函数结束时自动销毁局部变量，包括训练器实例。但这只在训练器没有修改全局状态或在内部维持静态/全局变量的情况下有效。
或者每次注释掉一个transformer，然后重启运行

防止触发

AssertionError: only one Trainer can be instantiated at a time for training



# 多线程解决one trainer问题

生成`train_semantic_transformer.py` 防止trainer冲突

In [21]:
# # train_semantic_transformer.py
# import torch
# from audiolm_pytorch import HubertWithKmeans
# from audiolm_pytorch import SemanticTransformer, SemanticTransformerTrainer
# from audiolm_pytorch import CoarseTransformer, CoarseTransformerTrainer
# from audiolm_pytorch import FineTransformer, FineTransformerTrainer
# from audiolm_pytorch import AudioLMSoundStream, AudioLM
# import gc  # 导入垃圾回收模块

# # 公共变量
# checkpoint_path = 'hubert_base_ls960.pt'
# kmeans_path = 'hubert_base_ls960_L9_km500.bin'

# audio_output_dir = './downloaded_audios'
# batch_size = 1
# data_max_length = 320 * 32
# num_train_steps = 1

# # 函数：训练 SemanticTransformer
# def train_semantic_transformer():
#     wav2vec = HubertWithKmeans(checkpoint_path=checkpoint_path, kmeans_path=kmeans_path)   # 每个函数中重新创建 wav2vec，后面会删掉
#     soundstream = AudioLMSoundStream()
#     semantic_transformer = SemanticTransformer(num_semantic_tokens=wav2vec.codebook_size, dim=1024, depth=6, audio_text_condition=True).cuda()
#     trainer = SemanticTransformerTrainer(transformer=semantic_transformer, wav2vec=wav2vec, audio_conditioner=quantizer, folder=audio_output_dir, batch_size=batch_size, data_max_length=data_max_length, num_train_steps=num_train_steps)
#     trainer.train()
#     torch.save(semantic_transformer.state_dict(), 'semantic_transformer.pth')
#     del semantic_transformer, trainer, wav2vec
#     gc.collect()  # 执行垃圾回收



# # 依次训练每个模型
# train_semantic_transformer()


`train_coarse_transformer.py`

In [22]:
# # train_coarse_transformer.py
# import torch
# from audiolm_pytorch import HubertWithKmeans
# from audiolm_pytorch import SemanticTransformer, SemanticTransformerTrainer
# from audiolm_pytorch import CoarseTransformer, CoarseTransformerTrainer
# from audiolm_pytorch import FineTransformer, FineTransformerTrainer
# from audiolm_pytorch import AudioLMSoundStream, AudioLM
# import gc  # 导入垃圾回收模块

# # 公共变量
# checkpoint_path = 'hubert_base_ls960.pt'
# kmeans_path = 'hubert_base_ls960_L9_km500.bin'

# audio_output_dir = './downloaded_audios'
# batch_size = 1
# data_max_length = 320 * 32
# num_train_steps = 1

# # 函数：训练 CoarseTransformer
# def train_coarse_transformer():
#     wav2vec = HubertWithKmeans(checkpoint_path=checkpoint_path, kmeans_path=kmeans_path)   # 每个函数中重新创建 wav2vec，后面会删掉
#     soundstream = AudioLMSoundStream()

#     coarse_transformer = CoarseTransformer(num_semantic_tokens=wav2vec.codebook_size, codebook_size=1024, num_coarse_quantizers=4, dim=1024, depth=6, audio_text_condition=True).cuda()
#     trainer = CoarseTransformerTrainer(transformer=coarse_transformer, codec=soundstream, wav2vec=wav2vec, audio_conditioner=quantizer, folder=audio_output_dir, batch_size=batch_size, data_max_length=data_max_length, num_train_steps=num_train_steps)
#     trainer.train()
#     torch.save(coarse_transformer.state_dict(), 'coarse_transformer.pth')
#     del coarse_transformer, trainer, wav2vec, soundstream
#     gc.collect()

# train_coarse_transformer()





`train_fine_transformer.py`

In [23]:
# # train_fine_transformer.py

# import torch
# from audiolm_pytorch import HubertWithKmeans
# from audiolm_pytorch import SemanticTransformer, SemanticTransformerTrainer
# from audiolm_pytorch import CoarseTransformer, CoarseTransformerTrainer
# from audiolm_pytorch import FineTransformer, FineTransformerTrainer
# from audiolm_pytorch import AudioLMSoundStream, AudioLM
# import gc  # 导入垃圾回收模块

# # 公共变量
# checkpoint_path = 'hubert_base_ls960.pt'
# kmeans_path = 'hubert_base_ls960_L9_km500.bin'

# audio_output_dir = './downloaded_audios'
# batch_size = 1
# data_max_length = 320 * 32
# num_train_steps = 1


# # 函数：训练 FineTransformer
# def train_fine_transformer():
#     soundstream = AudioLMSoundStream()

#     fine_transformer = FineTransformer(num_coarse_quantizers=4, num_fine_quantizers=8, codebook_size=1024, dim=1024, depth=6, audio_text_condition=True).cuda()
#     trainer = FineTransformerTrainer(transformer=fine_transformer, codec=soundstream, folder=audio_output_dir, batch_size=batch_size, data_max_length=data_max_length, num_train_steps=num_train_steps)
#     trainer.train()
#     torch.save(fine_transformer.state_dict(), 'fine_transformer.pth')
#     del fine_transformer, trainer, soundstream
#     gc.collect()

# train_fine_transformer()

强制覆盖文件

In [24]:
!wget -O train_semantic_transformer.py https://raw.githubusercontent.com/Huang-Yongzhi/musiclm-pytorch/main/train_semantic_transformer.py
!wget -O train_coarse_transformer.py https://raw.githubusercontent.com/Huang-Yongzhi/musiclm-pytorch/main/train_coarse_transformer.py
!wget -O train_fine_transformer.py https://raw.githubusercontent.com/Huang-Yongzhi/musiclm-pytorch/main/train_fine_transformer.py


!pip install tensorboardX

%run train_semantic_transformer.py


# !python train_semantic_transformer.py
# !python train_coarse_transformer.py
# !python train_fine_transformer.py

--2023-11-13 14:49:51--  https://raw.githubusercontent.com/Huang-Yongzhi/musiclm-pytorch/main/train_semantic_transformer.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2623 (2.6K) [text/plain]
Saving to: ‘train_semantic_transformer.py’


2023-11-13 14:49:51 (42.3 MB/s) - ‘train_semantic_transformer.py’ saved [2623/2623]

--2023-11-13 14:49:51--  https://raw.githubusercontent.com/Huang-Yongzhi/musiclm-pytorch/main/train_coarse_transformer.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.110.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2631 (2.6K) [text/plain]
Saving to: ‘tr

In [25]:
%run train_coarse_transformer.py


AssertionError: ignored

In [26]:
%run train_fine_transformer.py


AssertionError: ignored

In [27]:
!ls

downloaded_audios		results			     train_fine_transformer.py
hubert_base_ls960_L9_km500.bin	sample_data		     train_semantic_transformer.py
hubert_base_ls960.pt		semantic_transformer.pth
musiccaps-public.csv		train_coarse_transformer.py


In [None]:
import torch
from audiolm_pytorch import HubertWithKmeans
from audiolm_pytorch import SemanticTransformer, SemanticTransformerTrainer
from audiolm_pytorch import CoarseTransformer, CoarseTransformerTrainer
from audiolm_pytorch import FineTransformer, FineTransformerTrainer
from audiolm_pytorch import AudioLMSoundStream, AudioLM

# 创建并加载 AudioLM 实例
wav2vec = HubertWithKmeans(checkpoint_path=checkpoint_path, kmeans_path=kmeans_path) # 前面被删掉了
soundstream = AudioLMSoundStream()

semantic_transformer = SemanticTransformer(num_semantic_tokens=wav2vec.codebook_size, dim=1024, depth=6, audio_text_condition=True).cuda()
coarse_transformer = CoarseTransformer(num_semantic_tokens=wav2vec.codebook_size, codebook_size=1024, num_coarse_quantizers=4, dim=1024, depth=6, audio_text_condition=True).cuda()
fine_transformer = FineTransformer(num_coarse_quantizers=4, num_fine_quantizers=8, codebook_size=1024, dim=1024, depth=6, audio_text_condition=True).cuda()

# 加载模型状态
semantic_transformer.load_state_dict(torch.load('semantic_transformer.pth'))
coarse_transformer.load_state_dict(torch.load('coarse_transformer.pth'))
fine_transformer.load_state_dict(torch.load('fine_transformer.pth'))

audiolm = AudioLM(wav2vec=wav2vec, codec=soundstream, semantic_transformer=semantic_transformer, coarse_transformer=coarse_transformer, fine_transformer=fine_transformer)


In [None]:
# # 重启也要运行
# import torch
# from audiolm_pytorch import HubertWithKmeans
# from audiolm_pytorch import SemanticTransformer, SemanticTransformerTrainer, SoundStream
# from audiolm_pytorch import CoarseTransformer, CoarseTransformerTrainer
# from audiolm_pytorch import FineTransformer, FineTransformerTrainer
# from audiolm_pytorch import AudioLMSoundStream, MusicLMSoundStream
# from unittest import mock
# import multiprocessing as mp
# import gc  # 导入垃圾回收模块


# # 设置多进程启动方式为 'spawn'
# mp.set_start_method('spawn', force=True)

# def train_semantic_transformer_process(checkpoint_path, kmeans_path):
#     wav2vec = HubertWithKmeans(checkpoint_path=checkpoint_path, kmeans_path=kmeans_path)
#     semantic_transformer = SemanticTransformer(
#         num_semantic_tokens = wav2vec.codebook_size,
#         dim = 1024,
#         depth = 6,
#         audio_text_condition = True      # this must be set to True (same for CoarseTransformer and FineTransformers)
#     ).cuda()

#     trainer = SemanticTransformerTrainer(
#         transformer = semantic_transformer,
#         wav2vec = wav2vec,
#         audio_conditioner = quantizer,   # pass in the MulanEmbedQuantizer instance above
#         folder ='./downloaded_audios',
#         batch_size = 1,
#         data_max_length = 320 * 32,
#         num_train_steps = 1
#     )

#     trainer.train()
#     torch.save(semantic_transformer.state_dict(), '/content/semantic_transformer.pth') # 保存模型
# #     del trainer  # 显式删除实例，删除好像没有效果，需要存下来重启


# def train_coarse_transformer_process(checkpoint_path, kmeans_path):
#     wav2vec = HubertWithKmeans(checkpoint_path=checkpoint_path, kmeans_path=kmeans_path)
#     soundstream = MusicLMSoundStream()

#     coarse_transformer = CoarseTransformer(
#         num_semantic_tokens = wav2vec.codebook_size,
#         codebook_size = 1024,
#         num_coarse_quantizers = 4,
#         dim = 1024,
#         depth = 6,
#         audio_text_condition = True
#     ).cuda()

#     with mock.patch('builtins.input', return_value='n'):
#         trainer = CoarseTransformerTrainer(
#             transformer = coarse_transformer,
#             codec = soundstream,
#             wav2vec = wav2vec,
#             audio_conditioner = quantizer,
#             folder = './downloaded_audios/',
#             batch_size = 1,
#             data_max_length = 320 * 32,
#             num_train_steps = 1
#         )
#         trainer.train()
#         # 保存模型状态字典
#         torch.save(coarse_transformer.state_dict(), '/content/coarse_transformer.pth') # 保存模型


# def train_fine_transformer_process():
#     soundstream = MusicLMSoundStream()
#     fine_transformer = FineTransformer(
#         num_coarse_quantizers = 4,
#         num_fine_quantizers = 8,
#         codebook_size = 1024,
#         dim = 1024,
#         depth = 6,
#         audio_text_condition = True
#     ).cuda()

#     with mock.patch('builtins.input', return_value='n'):
#         trainer = FineTransformerTrainer(
#             transformer = fine_transformer,
#             codec = soundstream,
#             folder = './downloaded_audios/',
#             batch_size = 1,
#             data_max_length = 320 * 32,
#             num_train_steps = 1,
#             audio_conditioner = quantizer
#         )

#         trainer.train()
#         torch.save(fine_transformer.state_dict(), '/content/fine_transformer.pth') # 保存模型



# def run_training_process(train_function, *args):
#     process = mp.Process(target=train_function, args=args)
#     process.start()
#     process.join()  # 等待进程结束



# def train_all_transformers():

#     # 加载 wav2vec 的参数
#     checkpoint_path = 'hubert_base_ls960.pt'
#     kmeans_path = 'hubert_base_ls960_L9_km500.bin'

#     # 多线程运行
#     run_training_process(train_semantic_transformer_process, checkpoint_path, kmeans_path)
#     run_training_process(train_coarse_transformer_process, checkpoint_path, kmeans_path)
#     run_training_process(train_fine_transformer_process)

# train_all_transformers()


下面是单线程运行，但是需要每次调用一个网络训练。

In [None]:
# # 重启也要运行
# import torch
# from audiolm_pytorch import HubertWithKmeans
# from audiolm_pytorch import SemanticTransformer, SemanticTransformerTrainer, SoundStream
# from audiolm_pytorch import CoarseTransformer, CoarseTransformerTrainer
# from audiolm_pytorch import FineTransformer, FineTransformerTrainer
# from audiolm_pytorch import AudioLMSoundStream, MusicLMSoundStream
# from unittest import mock
# import multiprocessing


# def train_semantic_transformer(wav2vec):

#     semantic_transformer = SemanticTransformer(
#         num_semantic_tokens = wav2vec.codebook_size,
#         dim = 1024,
#         depth = 6,
#         audio_text_condition = True      # this must be set to True (same for CoarseTransformer and FineTransformers)
#     ).cuda()

#     trainer = SemanticTransformerTrainer(
#         transformer = semantic_transformer,
#         wav2vec = wav2vec,
#         audio_conditioner = quantizer,   # pass in the MulanEmbedQuantizer instance above
#         folder ='./downloaded_audios',
#         batch_size = 1,
#         data_max_length = 320 * 32,
#         num_train_steps = 1
#     )

#     trainer.train()
#     torch.save(semantic_transformer.state_dict(), 'semantic_transformer.pth') # 保存模型
#     del trainer  # 显式删除实例，删除好像没有效果，需要存下来重启


# def train_coarse_transformer(wav2vec, soundstream):

#     coarse_transformer = CoarseTransformer(
#         num_semantic_tokens = wav2vec.codebook_size,
#         codebook_size = 1024,
#         num_coarse_quantizers = 4,
#         dim = 1024,
#         depth = 6,
#         audio_text_condition = True
#     ).cuda()

#     with mock.patch('builtins.input', return_value='n'):
#         trainer = CoarseTransformerTrainer(
#             transformer = coarse_transformer,
#             codec = soundstream,
#             wav2vec = wav2vec,
#             audio_conditioner = quantizer,
#             folder = './downloaded_audios/',
#             batch_size = 1,
#             data_max_length = 320 * 32,
#             num_train_steps = 1
#         )
#         trainer.train()
#         # 保存模型状态字典
#         torch.save(coarse_transformer.state_dict(), 'coarse_transformer.pth') # 保存模型

# def train_fine_transformer(soundstream):
#     fine_transformer = FineTransformer(
#         num_coarse_quantizers = 4,
#         num_fine_quantizers = 8,
#         codebook_size = 1024,
#         dim = 1024,
#         depth = 6,
#         audio_text_condition = True
#     ).cuda()

#     with mock.patch('builtins.input', return_value='n'):
#         trainer = FineTransformerTrainer(
#             transformer = fine_transformer,
#             codec = soundstream,
#             folder = './downloaded_audios/',
#             batch_size = 1,
#             data_max_length = 320 * 32,
#             num_train_steps = 1,
#             audio_conditioner = quantizer
#         )

#         trainer.train()
#         torch.save(fine_transformer.state_dict(), 'fine_transformer.pth') # 保存模型

# def train_all_transformers():
#      # soundstream = SoundStream.init_and_load_from('/path/to/trained/soundstream.pt')
#     soundstream = MusicLMSoundStream()
#     # 重新加载 wav2vec 和 soundstream
#     wav2vec = HubertWithKmeans(
#         checkpoint_path='hubert_base_ls960.pt',
#         kmeans_path='hubert_base_ls960_L9_km500.bin'
#     )

#    # 训练 SemanticTransformer
#     train_semantic_transformer(wav2vec)
#     # 训练 CoarseTransformer
#     train_coarse_transformer(wav2vec, soundstream)

#     # 训练 FineTransformer
#     train_fine_transformer(soundstream)

# train_all_transformers()


调用训练模型

In [None]:
# train_semantic_transformer()


# **重启后**，重新创建 semantic_transformer 实例

In [None]:
# # 重启也要运行
# # 重启后导入必要的库
# import torch
# from audiolm_pytorch import HubertWithKmeans, SemanticTransformer, AudioLM

# # 重新创建和初始化模型
# wav2vec = HubertWithKmeans(
#     checkpoint_path='hubert_base_ls960.pt',
#     kmeans_path='hubert_base_ls960_L9_km500.bin'
# )
# semantic_transformer = SemanticTransformer(
#     num_semantic_tokens=wav2vec.codebook_size,
#     dim=1024, depth=6,
#     audio_text_condition=True).cuda()

# # 加载之前保存的模型状态
# semantic_transformer.load_state_dict(torch.load('semantic_transformer.pth'))

In [None]:
# train_coarse_transformer()


# **重启后**，重新创建 coarse_transformer 实例

In [None]:
# # 重启也要运行
# import torch
# from audiolm_pytorch import CoarseTransformer, AudioLMSoundStream

# # 重新创建模型架构
# coarse_transformer = CoarseTransformer(
#     num_semantic_tokens = wav2vec.codebook_size,
#     codebook_size = 1024,
#     num_coarse_quantizers = 4,
#     dim = 1024,
#     depth = 6,
#     audio_text_condition = True
# ).cuda()

# # 加载之前保存的状态字典
# coarse_transformer.load_state_dict(torch.load('coarse_transformer.pth'))


In [None]:
# train_fine_transformer()

# **重启后**，重新创建 fine_transformer 实例

In [None]:
# # 重启也要运行
# import torch
# from audiolm_pytorch import CoarseTransformer, AudioLMSoundStream

# # 重新创建模型架构
# fine_transformer = FineTransformer(
#     num_coarse_quantizers = 4,
#     num_fine_quantizers = 8,
#     codebook_size = 1024,
#     dim = 1024,
#     depth = 6,
#     audio_text_condition = True
# ).cuda()

# # 加载之前保存的状态字典
# fine_transformer.load_state_dict(torch.load('fine_transformer.pth'))


In [None]:
# import torch
# from audiolm_pytorch import HubertWithKmeans, SemanticTransformer, CoarseTransformer, FineTransformer, AudioLM, MusicLMSoundStream

# # 重新创建 wav2vec 实例
# wav2vec = HubertWithKmeans(
#     checkpoint_path='hubert_base_ls960.pt',
#     kmeans_path='hubert_base_ls960_L9_km500.bin'
# )

# # 重新创建 soundstream 实例
# soundstream = MusicLMSoundStream()

# # 重新创建 SemanticTransformer 实例
# semantic_transformer = SemanticTransformer(
#     num_semantic_tokens=wav2vec.codebook_size,
#     dim=1024, depth=6,
#     audio_text_condition=True
# ).cuda()

# # 加载之前保存的 SemanticTransformer 状态
# semantic_transformer.load_state_dict(torch.load('semantic_transformer.pth'))

# # 重新创建 CoarseTransformer 实例
# coarse_transformer = CoarseTransformer(
#     num_semantic_tokens=wav2vec.codebook_size,
#     codebook_size=1024,
#     num_coarse_quantizers=4,
#     dim=1024,
#     depth=6,
#     audio_text_condition=True
# ).cuda()

# # 加载之前保存的 CoarseTransformer 状态
# coarse_transformer.load_state_dict(torch.load('coarse_transformer.pth'))

# # 重新创建 FineTransformer 实例
# fine_transformer = FineTransformer(
#     num_coarse_quantizers=4,
#     num_fine_quantizers=8,
#     codebook_size=1024,
#     dim=1024,
#     depth=6,
#     audio_text_condition=True
# ).cuda()

# # 加载之前保存的 FineTransformer 状态
# fine_transformer.load_state_dict(torch.load('fine_transformer.pth'))

# # 创建 AudioLM 实例
# audiolm = AudioLM(
#     wav2vec=wav2vec,
#     codec=soundstream,
#     semantic_transformer=semantic_transformer,
#     coarse_transformer=coarse_transformer,
#     fine_transformer=fine_transformer
# )


# 7. musiclm

In [None]:
# you need the trained AudioLM (audio_lm) from above
# with the MulanEmbedQuantizer (mulan_embed_quantizer)

from musiclm_pytorch import MusicLM

musiclm = MusicLM(
    audio_lm = audiolm,
    mulan_embed_quantizer = quantizer
).cuda()



In [None]:
music = musiclm('the crystalline sounds of the piano in a ballroom', num_samples = 1) # sample 4 and pick the top match with mulan

In [None]:
torch.save(music, 'generated_music.pt')

In [None]:
output_path = "out.wav"
sample_rate = 44100
torchaudio.save(output_path, music.cpu() , sample_rate)