# 0. load model

In [3]:
import json
import pytorch_lightning as pl
from einops import rearrange
import torch
import torchaudio
from torch.utils.data import DataLoader
from safetensors.torch import load_file
import shutil
from collections import OrderedDict
import safetensors
from datetime import datetime
import os

from stable_audio_tools import create_model_from_config, replace_audio, save_audio
from stable_audio_tools.data.dataset import VideoFeatDataset, collation_fn
from stable_audio_tools.training.training_wrapper import DiffusionCondTrainingWrapper
from stable_audio_tools.inference.generation import generate_diffusion_cond, generate_diffusion_cond_from_path


model_config_file = './stable_audio_tools/configs/model_config.json'
with open(model_config_file) as f:
    model_config = json.load(f)

sample_rate = model_config["sample_rate"]
model = create_model_from_config(model_config)
sample_rate

44100

In [4]:
model_dir = './weight/StableAudio/2024-07-06 10:28:13'
model_name = 'epoch=30-step=58'  
# ./weight/StableAudio/lightning_logs/version_3/checkpoints         epoch=9-step=14620    
# ./weight/StableAudio/2024-07-02 19:49:04                  epoch=2-step=427
# ./weight/StableAudio/2024-07-03 15:16:01               可训练权重attn_bias之后，epoch=5 epoch=9 出现了一些time align的现象
# ./weight/StableAudio/2024-07-03 21:46:54               可训练权重attn_bias之后，epoch=18 多物体表现较好
# ./weight/StableAudio/2024-07-04 11:41:24               不可训练的从0.5-1的attn_bias权重，epoch=39-step=59  epoch=3-step=55  epoch=21-step=57
# ./weight/StableAudio/2024-07-06 10:28:13               加入pos_emb, epoch=30-step=58   time_align和生成效果表现较好(best)
# ./weight/StableAudio/2024-07-06 21:34:56               在上一行的基础上加入VGG进行训练 epoch=3-step=1023




try:
    state_dict = load_file(f'{model_dir}/{model_name}.safetensors')
except:
    state_dict = torch.load(f'{model_dir}/{model_name}.ckpt')['state_dict']
    state_dict = OrderedDict([(".".join(key.split('.')[1:]), value)  for key, value in state_dict.items()])
    safetensors.torch.save_file(state_dict, f'{model_dir}/{model_name}.safetensors')

model.load_state_dict(state_dict, strict=False)



<All keys matched successfully>

# 2. Train

In [4]:
info_dirs = ['./dataset/feature/train/AudioSet/10']
audio_dirs = ['/home/chengxin/chengxin/AudioSet/generated_audios/train/10']
# info_dirs = ['./dataset/feature/train/AudioSet/10', './dataset/feature/train/VGGSound/10']
# audio_dirs = ['/home/chengxin/chengxin/AudioSet/generated_audios/train/10', '/home/chengxin/chengxin/VGGSound/generated_audios/train/10']

ds_config = {
    'info_dirs' : info_dirs,
    'audio_dirs' : audio_dirs,
    'exts':'wav',
    'sample_rate':sample_rate, 
    'force_channels':"stereo"
}
dl_config = {
    'batch_size':25, 
    'shuffle':False,
    'num_workers':4, 
    'persistent_workers':True, 
    'pin_memory':True, 
    'drop_last':False, 
}


dataset = VideoFeatDataset(**ds_config)
dataloader = DataLoader(dataset=dataset,  collate_fn=collation_fn, **dl_config)

Found 22050 files


In [5]:
training_config = model_config.get('training', None)
training_wrapper = DiffusionCondTrainingWrapper(
            model=model, 
            lr=training_config.get("learning_rate", None),
            optimizer_configs=training_config.get("optimizer_configs", None),
            pre_encoded=training_config.get("pre_encoded", False),
            cfg_dropout_prob = training_config.get("cfg_dropout_prob", 0.1),
            timestep_sampler = training_config.get("timestep_sampler", "uniform"),
        )

trainer = pl.Trainer(
    devices=[1],
    accelerator="gpu",
    num_nodes = 1,
    max_epochs=2,
)
trainer.fit(training_wrapper, dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA H100 PCIe') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name      | Type                             | Params
---------------------------------------------------------------
0 | diffusion | ConditionedDiffusionModelWrapper | 1.2 B 
1 | losses    | MultiLoss                        | 0     
---------------------------------------------------------------
1.1 B     Trainable params
156 M     Non-trainable params
1.2 B     Total params
4,880.205 Total estimated model params size (MB)


Epoch 0:   3%|▎         | 24/882 [00:57<34:29,  0.41it/s, v_num=4] 

# 3.Test

In [9]:

info_dirs = ['./dataset/feature/test/VGGSound/10']
audio_dirs = ['/home/chengxin/chengxin/VGGSound/generated_audios/test/10']

ds_config = {
    'info_dirs' : info_dirs,
    'audio_dirs' : audio_dirs,
    'exts':'wav',
    'sample_rate':sample_rate, 
    'force_channels':"mono",
    'limit_num':50
}
dl_config = {
    'batch_size':1, 
    'shuffle':False,
    'num_workers':4, 
    'persistent_workers':True, 
    'pin_memory':True, 
    'drop_last':False, 
}


dataset = VideoFeatDataset(**ds_config)
dataloader = DataLoader(dataset=dataset,  collate_fn=collation_fn, **dl_config)

Found 50 files


In [10]:
device = 0
duration = 10  # generate 10 secs
output_dir = f"./demo/{model_name}"
count = 0
max_iter = 50
os.makedirs(output_dir, exist_ok=True)
print(output_dir)

for audio, conditioning in dataloader:
    output = generate_diffusion_cond(
        model = model.to(device),
        steps=150,
        cfg_scale=7,
        conditioning=conditioning,
        sample_size=sample_rate*duration,
        batch_size=len(audio),
        sigma_min=0.3,
        sigma_max=500,
        sampler_type="dpmpp-3m-sde",
        device=device
    )

    
    for idx in range(len(audio)):
        # Save generated audio
        video_path = conditioning['video_path'][idx].replace('../../', './')
        audio_path = f"{output_dir}/{video_path.split('/')[-1].replace('.mp4', '.wav')}"
        save_audio(output[idx:1+idx], audio_path, sample_rate)
        
        # Replace the audio of original video to generated one
        moved_video_path = f"{output_dir}/{video_path.split('/')[-1]}"
        shutil.copy(video_path, moved_video_path)
        generated_video_path = moved_video_path.replace(".mp4","_GEN.mp4")
        replace_audio(moved_video_path, audio_path, generated_video_path)
        
    count += 1
    if count >= max_iter:
        break
    

./demo/epoch=30-step=58
537874431


100%|██████████| 150/150 [00:02<00:00, 53.62it/s]


3383632475


100%|██████████| 150/150 [00:02<00:00, 53.37it/s]


525742486


100%|██████████| 150/150 [00:02<00:00, 53.34it/s]


411212823


100%|██████████| 150/150 [00:02<00:00, 53.70it/s]


2964436039


100%|██████████| 150/150 [00:02<00:00, 53.57it/s]


607246033


100%|██████████| 150/150 [00:03<00:00, 47.38it/s]


1097967394


100%|██████████| 150/150 [00:02<00:00, 54.16it/s]


443532012


100%|██████████| 150/150 [00:02<00:00, 54.10it/s]


3178057732


100%|██████████| 150/150 [00:02<00:00, 54.09it/s]


3384241858


100%|██████████| 150/150 [00:02<00:00, 54.01it/s]


2731777235


100%|██████████| 150/150 [00:02<00:00, 54.33it/s]


836307171


100%|██████████| 150/150 [00:02<00:00, 52.40it/s]


2367592603


100%|██████████| 150/150 [00:02<00:00, 53.66it/s]


2778303660


100%|██████████| 150/150 [00:02<00:00, 53.48it/s]


866891932


100%|██████████| 150/150 [00:02<00:00, 53.30it/s]


200200129


100%|██████████| 150/150 [00:02<00:00, 53.42it/s]


2004619552


100%|██████████| 150/150 [00:02<00:00, 53.67it/s]


4031522473


100%|██████████| 150/150 [00:02<00:00, 53.46it/s]


1067329041


100%|██████████| 150/150 [00:02<00:00, 53.46it/s]


2395199896


100%|██████████| 150/150 [00:02<00:00, 53.56it/s]


1800603755


100%|██████████| 150/150 [00:02<00:00, 53.61it/s]


2085814752


100%|██████████| 150/150 [00:02<00:00, 53.52it/s]


725892454


100%|██████████| 150/150 [00:02<00:00, 51.82it/s]


2447158760


100%|██████████| 150/150 [00:02<00:00, 53.07it/s]


2646659075


100%|██████████| 150/150 [00:03<00:00, 49.72it/s]


4261204720


100%|██████████| 150/150 [00:02<00:00, 53.79it/s]


31987609


100%|██████████| 150/150 [00:02<00:00, 53.43it/s]


3168171855


100%|██████████| 150/150 [00:02<00:00, 53.39it/s]


3671267215


100%|██████████| 150/150 [00:02<00:00, 53.06it/s]


1456427302


100%|██████████| 150/150 [00:02<00:00, 53.67it/s]


53995403


100%|██████████| 150/150 [00:03<00:00, 48.85it/s]


254090288


100%|██████████| 150/150 [00:02<00:00, 53.50it/s]


3982211108


100%|██████████| 150/150 [00:03<00:00, 49.63it/s]


3257052610


100%|██████████| 150/150 [00:02<00:00, 53.60it/s]


952849560


100%|██████████| 150/150 [00:02<00:00, 53.32it/s]


704193615


100%|██████████| 150/150 [00:02<00:00, 53.69it/s]


1072010845


100%|██████████| 150/150 [00:02<00:00, 53.91it/s]


2425010540


100%|██████████| 150/150 [00:02<00:00, 53.80it/s]


190320245


100%|██████████| 150/150 [00:02<00:00, 53.79it/s]


1515516189


100%|██████████| 150/150 [00:02<00:00, 53.65it/s]


4230306987


100%|██████████| 150/150 [00:02<00:00, 52.92it/s]


705722223


100%|██████████| 150/150 [00:03<00:00, 48.00it/s]


453880280


100%|██████████| 150/150 [00:02<00:00, 53.57it/s]


1659621928


100%|██████████| 150/150 [00:02<00:00, 53.72it/s]


3577936975


100%|██████████| 150/150 [00:02<00:00, 53.75it/s]


379966509


100%|██████████| 150/150 [00:02<00:00, 53.65it/s]


2322729482


100%|██████████| 150/150 [00:02<00:00, 53.91it/s]


3354515557


100%|██████████| 150/150 [00:02<00:00, 54.01it/s]


2568749535


100%|██████████| 150/150 [00:02<00:00, 53.95it/s]


2814994138


100%|██████████| 150/150 [00:02<00:00, 53.84it/s]


In [14]:
# device = 0
# duration = 10  # generate 10 secs
# output_dir = "./demo"


# for audio, conditioning in dataloader:
#     for idx in range(dl_config['batch_size']):
#         # Save generated audio
#         video_path = conditioning['video_path'][idx].replace('../../', './')
#         audio_path = f"/home/chengxin/chengxin/AudioSet/generated_audios/specvqgan/10/{video_path.split('/')[-1].replace('.mp4', '.wav')}"
            
#         # Replace the audio of original video to generated one
#         moved_video_path = f"{output_dir}/{video_path.split('/')[-1]}"
#         shutil.copy(video_path, moved_video_path)
#         generated_video_path = moved_video_path.replace(".mp4","_GEN.mp4")
#         replace_audio(moved_video_path, audio_path, generated_video_path)

#     break

# 4. Sample

In [3]:
device = 0
duration = 10
video_paths = ['/home/chengxin/chengxin/FoleyCrafter/examples/sora/0.mp4', '/home/chengxin/chengxin/FoleyCrafter/examples/sora/1.mp4']
output_dir = "./demo"

output = generate_diffusion_cond_from_path(
    model=model.to(device),
    video_paths=video_paths,
    steps=100,
    cfg_scale=7,
    sample_rate=sample_rate,
    sample_size=sample_rate*duration,
    batch_size=len(video_paths),
    sigma_min=0.3,
    sigma_max=500,
    sampler_type="dpmpp-3m-sde",
    device=device
)

for idx in range(len(video_paths)):
    # Save generated audio
    video_path = video_paths[idx]
    audio_path = f"{output_dir}/{video_path.split('/')[-1].replace('.mp4', '.wav')}"
    save_audio(output[idx:1+idx], audio_path, sample_rate)
        
    # Replace the audio of original video to generated one
    moved_video_path = f"{output_dir}/{video_path.split('/')[-1]}"
    shutil.copy(video_path, moved_video_path)
    generated_video_path = moved_video_path.replace(".mp4","_GEN.mp4")
    replace_audio(moved_video_path, audio_path, generated_video_path)

2720942727
Extracting features from video:/home/chengxin/chengxin/FoleyCrafter/examples/sora/0.mp4
Extracting features from video:/home/chengxin/chengxin/FoleyCrafter/examples/sora/1.mp4


100%|██████████| 100/100 [00:04<00:00, 23.83it/s]
