In [1]:
# from moviepy.editor import VideoFileClip
# from PIL import Image
# mp4_file = '/home/chengxin/chengxin/stable-v2a/stable_audio_tools/data/Pv6BhKDXpHE_000026.mp4'
# video = VideoFileClip(mp4_file)


# video.write_videofile("Pv6BhKDXpHE_000026.mp4", codec='libx264') 

In [1]:
import json
import pytorch_lightning as pl
from einops import rearrange
import torch
import torchaudio
from torch.utils.data import DataLoader
from safetensors.torch import load_file
import subprocess
import shutil

from stable_audio_tools import create_model_from_config
from stable_audio_tools.data.dataset import VideoFeatDataset, collation_fn
from stable_audio_tools.training.training_wrapper import DiffusionCondTrainingWrapper
from stable_audio_tools.inference.generation import generate_diffusion_cond


model_config_file = './stable_audio_tools/configs/model_config.json'
with open(model_config_file) as f:
    model_config = json.load(f)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = create_model_from_config(model_config)
# model.load_state_dict(load_file('./weight/StableAudio/model.safetensors'), strict=False)
# model.load_state_dict(torch.load('./lightning_logs/version_0/checkpoints/epoch=99-step=13800.ckpt')['state_dict'], strict=False)




In [17]:
info_dirs = ['./dataset/feature/train/AudioSet/10']
audio_dirs = ['/home/chengxin/chengxin/AudioSet/generated_audios/train/10']
info_dirs = ['./dataset/feature/train/AudioSet/10', './dataset/feature/train/VGGSound/10']
audio_dirs = ['/home/chengxin/chengxin/AudioSet/generated_audios/train/10', '/home/chengxin/chengxin/VGGSound/generated_audios/train/10']

ds_config = {
    'info_dirs' : info_dirs,
    'audio_dirs' : audio_dirs,
    'exts':'wav',
    'sample_rate':44100, 
    'force_channels':"stereo"
}
dl_config = {
    'batch_size':10, 
    'shuffle':True,
    'num_workers':4, 
    'persistent_workers':True, 
    'pin_memory':True, 
    'drop_last':False, 
}


dataset = VideoFeatDataset(**ds_config)
dataloader = DataLoader(dataset=dataset,  collate_fn=collation_fn, **dl_config)
len(dataset)

Found 204665 files


204665

# 2. Train

In [6]:
training_config = model_config.get('training', None)
training_wrapper = DiffusionCondTrainingWrapper(
            model, 
            lr=training_config.get("learning_rate", None),
            mask_padding=training_config.get("mask_padding", False),
            mask_padding_dropout=training_config.get("mask_padding_dropout", 0.0),
            use_ema = training_config.get("use_ema", True),
            log_loss_info=training_config.get("log_loss_info", False),
            optimizer_configs=training_config.get("optimizer_configs", None),
            pre_encoded=training_config.get("pre_encoded", False),
            cfg_dropout_prob = training_config.get("cfg_dropout_prob", 0.1),
            timestep_sampler = training_config.get("timestep_sampler", "uniform")
        )

training_wrapper.load_state_dict(torch.load('./lightning_logs/version_0/checkpoints/epoch=99-step=13800.ckpt')['state_dict'], strict=False)

<All keys matched successfully>

In [None]:
trainer = pl.Trainer(
    devices=[1],
    accelerator="gpu",
    num_nodes = 1,
    max_epochs=2,
)


trainer.fit(training_wrapper, dataloader)

# 3.Sample

In [16]:
def replace_audio(input_video, input_audio, output_video):
    ffmpeg_cmd = [
        'ffmpeg',
        '-i', input_video,        # 输入视频文件
        '-i', input_audio,        # 输入音频文件
        '-c:v', 'copy',           # 复制视频流，不重新编码视频
        '-c:a', 'aac',            # 指定音频编码为aac
        '-map', '0:v:0',          # 从第一个输入（视频）映射视频流
        '-map', '1:a:0',          # 从第二个输入（音频）映射音频流
        '-strict', 'experimental', # 允许使用实验性编码
        '-loglevel', 'error',
        output_video,              # 输出视频文件
        '-y'
    ]

    try:
        subprocess.run(ffmpeg_cmd, check=True)
    except subprocess.CalledProcessError as e:
        print(f"命令执行失败: {e}")



info_dirs = ['./dataset/feature/test/AudioSet/10']
audio_dirs = ['/home/chengxin/chengxin/AudioSet/generated_audios/test/10']

ds_config = {
    'info_dirs' : info_dirs,
    'audio_dirs' : audio_dirs,
    'exts':'wav',
    'sample_rate':44100, 
    'force_channels':"stereo"
}
dl_config = {
    'batch_size':10, 
    'shuffle':True,
    'num_workers':4, 
    'persistent_workers':True, 
    'pin_memory':True, 
    'drop_last':False, 
}


dataset = VideoFeatDataset(**ds_config)
dataloader = DataLoader(dataset=dataset,  collate_fn=collation_fn, **dl_config)

Found 35621 files


In [15]:
device = "cuda:0"
model = training_wrapper.diffusion
model.to(device)


sample_size = model_config["sample_size"]
sample_size = 441000
sample_rate = model_config["sample_rate"]
for audio, conditioning in dataloader:
    output = generate_diffusion_cond(
        model,
        steps=100,
        cfg_scale=7,
        conditioning=conditioning,
        sample_size=sample_size,
        batch_size=dl_config['batch_size'],
        sigma_min=0.3,
        sigma_max=500,
        sampler_type="dpmpp-3m-sde",
        device=device
    )

    # Rearrange audio batch to a single sequence
    idx = 0

    for idx in range(dl_config['batch_size']):
        video_file = conditioning['video_path'][idx].replace('../../', './')

        generated_audio = output[idx:1+idx]
        generated_audio = rearrange(generated_audio, "b d n -> d (b n)")
        generated_audio = generated_audio.to(torch.float32).div(torch.max(torch.abs(generated_audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
        audio_file = f"./demo/{video_file.split('/')[-1].replace('.mp4', '.wav')}"
        torchaudio.save(audio_file, generated_audio, sample_rate)

        video_file_ = f"./demo/{video_file.split('/')[-1]}"
        generated_video_file = video_file_.replace(".mp4","_GEN.mp4")
        shutil.copy(video_file, video_file_)
        replace_audio(video_file_, audio_file, generated_video_file)
        print(conditioning['video_path'][idx], video_file_, audio_file, generated_video_file)
    
    break


1269009644


100%|██████████| 100/100 [00:05<00:00, 18.43it/s]


../../dataset/video/test/AudioSet/10/xUqfBPhdLOM_p.mp4 ./demo/xUqfBPhdLOM_p.mp4 ./demo/xUqfBPhdLOM_p.wav ./demo/xUqfBPhdLOM_p_GEN.mp4
../../dataset/video/test/AudioSet/10/fcLK-nPc7-w_p.mp4 ./demo/fcLK-nPc7-w_p.mp4 ./demo/fcLK-nPc7-w_p.wav ./demo/fcLK-nPc7-w_p_GEN.mp4
../../dataset/video/test/AudioSet/10/YpjJNavdrmc_p.mp4 ./demo/YpjJNavdrmc_p.mp4 ./demo/YpjJNavdrmc_p.wav ./demo/YpjJNavdrmc_p_GEN.mp4
../../dataset/video/test/AudioSet/10/COoJtPFTkhU_p.mp4 ./demo/COoJtPFTkhU_p.mp4 ./demo/COoJtPFTkhU_p.wav ./demo/COoJtPFTkhU_p_GEN.mp4
../../dataset/video/test/AudioSet/10/SM10Hrl5d7U_p.mp4 ./demo/SM10Hrl5d7U_p.mp4 ./demo/SM10Hrl5d7U_p.wav ./demo/SM10Hrl5d7U_p_GEN.mp4
../../dataset/video/test/AudioSet/10/DHk6BiokRC4_p.mp4 ./demo/DHk6BiokRC4_p.mp4 ./demo/DHk6BiokRC4_p.wav ./demo/DHk6BiokRC4_p_GEN.mp4
../../dataset/video/test/AudioSet/10/UrUJ9tUJynA_p.mp4 ./demo/UrUJ9tUJynA_p.mp4 ./demo/UrUJ9tUJynA_p.wav ./demo/UrUJ9tUJynA_p_GEN.mp4
../../dataset/video/test/AudioSet/10/wRnPuL0kuXY_p.mp4 ./demo/

In [8]:
conditioning['video_path'][idx]

'./dataset/video/test/AudioSet/10/--4gqARaEJE_p.mp4'

# 4. Split Steps

In [None]:
from stable_audio_tools.models.diffusion import DiTWrapper, ConditionedDiffusionModelWrapper
from stable_audio_tools.models.conditioners import create_multi_conditioner_from_conditioning_config
from stable_audio_tools.models.autoencoders import create_autoencoder_from_config
from stable_audio_tools.models.pretransforms import AutoencoderPretransform

io_channels = model_config["model"].get('io_channels', None)
sample_rate = model_config.get('sample_rate', None)
assert io_channels is not None, "Must specify io_channels in model config"
assert sample_rate is not None, "Must specify sample_rate in config"



diffusion_config = {
    'cross_attention_cond_ids': ['feature'],
    'type': 'dit',
    'config': {
        'io_channels': 64,
        'embed_dim': 1536,
        'depth': 24,
        'num_heads': 24,
        'cond_token_dim': 768,
        'global_cond_dim': 1536,
        'project_cond_tokens': False,
        'transformer_type': 'continuous_transformer'
        }
}
diffusion_objective = diffusion_config.get('diffusion_objective', 'v')
cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', [])
global_cond_ids = diffusion_config.get('global_cond_ids', [])
input_concat_ids = diffusion_config.get('input_concat_ids', [])
prepend_cond_ids = diffusion_config.get('prepend_cond_ids', [])

diffusion_model_config = diffusion_config['config']
diffusion_model = DiTWrapper(**diffusion_model_config)


In [None]:
conditioning_config = {
    'configs': [
        {
            'id': 'duration',
            'type': 'number',
            'config': {'min_val': 0, 'max_val': 512}
         },
         {
             'id': 'feature', 
             'type': 'video_feature', 
             'config': {}
        }
    ],
    'cond_dim': 768,
    'default_keys': {'feature': 'video_path'}
}
conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config)

In [None]:

pretransform_config = {
    'type': 'autoencoder',
    'iterate_batch': True,
    'config': {
        'encoder': {
            'type': 'oobleck',
            'requires_grad': False,
            'config': {
                'in_channels': 2,
                'channels': 128,
                'c_mults': [1, 2, 4, 8, 16],
                'strides': [2, 4, 4, 8, 8],
                'latent_dim': 128,
                'use_snake': True
            }
        },
        'decoder': {
            'type': 'oobleck',
            'config': {
                'out_channels': 2,
                'channels': 128,
                'c_mults': [1, 2, 4, 8, 16],
                'strides': [2, 4, 4, 8, 8],
                'latent_dim': 64,
                'use_snake': True,
                'final_tanh': False
            }
        },
        'bottleneck': {'type': 'vae'},
        'latent_dim': 64,
        'downsampling_ratio': 2048,
        'io_channels': 2
        }
    }

autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]}
autoencoder = create_autoencoder_from_config(autoencoder_config)


pretransform = AutoencoderPretransform(
    autoencoder, 
    scale=pretransform_config.get("scale", 1.0),
    model_half=pretransform_config.get("model_half", False), 
    iterate_batch=pretransform_config.get("iterate_batch", False), 
    chunked=pretransform_config.get("chunked", False)
)
pretransform.enable_grad = pretransform_config.get('enable_grad', False)
pretransform.eval().requires_grad_(pretransform.enable_grad)


min_input_length = pretransform.downsampling_ratio
min_input_length *= diffusion_model.model.patch_size

In [None]:
extra_kwargs = {"diffusion_objective":diffusion_objective}
model = ConditionedDiffusionModelWrapper(
        diffusion_model,
        conditioner,
        min_input_length=min_input_length,
        sample_rate=sample_rate,
        cross_attn_cond_ids=cross_attention_ids,
        global_cond_ids=global_cond_ids,
        input_concat_ids=input_concat_ids,
        prepend_cond_ids=prepend_cond_ids,
        pretransform=pretransform,
        io_channels=io_channels,
        **extra_kwargs
    )

In [None]:
trainer = pl.Trainer(
    devices=1,
    accelerator="gpu",
    num_nodes = 1,
    max_epochs=2,
)


trainer.fit(training_wrapper, dataloader)