# 0. load model

In [1]:
import json
import pytorch_lightning as pl
from einops import rearrange
import torch
import torchaudio
from torch.utils.data import DataLoader
from safetensors.torch import load_file
import shutil
from collections import OrderedDict
import safetensors

from stable_audio_tools import create_model_from_config, replace_audio, save_audio
from stable_audio_tools.data.dataset import VideoFeatDataset, collation_fn
from stable_audio_tools.training.training_wrapper import DiffusionCondTrainingWrapper
from stable_audio_tools.inference.generation import generate_diffusion_cond, generate_diffusion_cond_from_path


model_config_file = './stable_audio_tools/configs/model_config.json'
with open(model_config_file) as f:
    model_config = json.load(f)
    sample_rate = model_config["sample_rate"]
sample_rate

  from .autonotebook import tqdm as notebook_tqdm


44100

In [2]:
# state_dict = load_file('./weight/StableAudio/lightning_logs/version_2/checkpoints/epoch=9-step=12800.safetensors')
state_dict = load_file('./weight/StableAudio/lightning_logs/version_3/checkpoints/epoch=9-step=14620.safetensors')

# state_dict = torch.load('./weight/StableAudio/lightning_logs/version_3/checkpoints/epoch=9-step=14620.ckpt')['state_dict']
# state_dict = OrderedDict([(".".join(key.split('.')[1:]), value)  for key, value in state_dict.items()])
# safetensors.torch.save_file(state_dict, './weight/StableAudio/lightning_logs/version_3/checkpoints/epoch=9-step=14620.safetensors')

model = create_model_from_config(model_config)
model.load_state_dict(state_dict, strict=True)

<All keys matched successfully>

# 2. Train

In [None]:
info_dirs = ['./dataset/feature/train/AudioSet/10']
audio_dirs = ['/home/chengxin/chengxin/AudioSet/generated_audios/train/10']
# info_dirs = ['./dataset/feature/train/AudioSet/10', './dataset/feature/train/VGGSound/10']
# audio_dirs = ['/home/chengxin/chengxin/AudioSet/generated_audios/train/10', '/home/chengxin/chengxin/VGGSound/generated_audios/train/10']

ds_config = {
    'info_dirs' : info_dirs,
    'audio_dirs' : audio_dirs,
    'exts':'wav',
    'sample_rate':sample_rate, 
    'force_channels':"stereo"
}
dl_config = {
    'batch_size':1, 
    'shuffle':False,
    'num_workers':4, 
    'persistent_workers':True, 
    'pin_memory':True, 
    'drop_last':False, 
}


dataset = VideoFeatDataset(**ds_config)
dataloader = DataLoader(dataset=dataset,  collate_fn=collation_fn, **dl_config)

Found 10 files


10

In [7]:
training_config = model_config.get('training', None)
training_wrapper = DiffusionCondTrainingWrapper(
            model, 
            lr=training_config.get("learning_rate", None),
            mask_padding=training_config.get("mask_padding", False),
            mask_padding_dropout=training_config.get("mask_padding_dropout", 0.0),
            use_ema = training_config.get("use_ema", True),
            log_loss_info=training_config.get("log_loss_info", False),
            optimizer_configs=training_config.get("optimizer_configs", None),
            pre_encoded=training_config.get("pre_encoded", False),
            cfg_dropout_prob = training_config.get("cfg_dropout_prob", 0.1),
            timestep_sampler = training_config.get("timestep_sampler", "uniform")
        )


trainer = pl.Trainer(
    devices=[1],
    accelerator="gpu",
    num_nodes = 1,
    max_epochs=2,
)
trainer.fit(training_wrapper, dataloader)

# 3.Test

In [4]:

info_dirs = ['./dataset/feature/test/AudioSet/10']
audio_dirs = ['/home/chengxin/chengxin/AudioSet/generated_audios/test/10']

ds_config = {
    'info_dirs' : info_dirs,
    'audio_dirs' : audio_dirs,
    'exts':'wav',
    'sample_rate':sample_rate, 
    'force_channels':"stereo"
}
dl_config = {
    'batch_size':20, 
    'shuffle':False,
    'num_workers':4, 
    'persistent_workers':True, 
    'pin_memory':True, 
    'drop_last':False, 
}


dataset = VideoFeatDataset(**ds_config)
dataloader = DataLoader(dataset=dataset,  collate_fn=collation_fn, **dl_config)

Found 20280 files


In [5]:
device = 0
duration = 10  # generate 10 secs
output_dir = "./demo"


for audio, conditioning in dataloader:
    output = generate_diffusion_cond(
        model = model.to(device),
        steps=150,
        cfg_scale=7,
        conditioning=conditioning,
        sample_size=sample_rate*duration,
        batch_size=len(audio),
        sigma_min=0.3,
        sigma_max=500,
        sampler_type="dpmpp-3m-sde",
        device=device
    )

    
    for idx in range(dl_config['batch_size']):
        # Save generated audio
        video_path = conditioning['video_path'][idx].replace('../../', './')
        audio_path = f"{output_dir}/{video_path.split('/')[-1].replace('.mp4', '.wav')}"
        save_audio(output[idx:1+idx], audio_path, sample_rate)
        
        # Replace the audio of original video to generated one
        moved_video_path = f"{output_dir}/{video_path.split('/')[-1]}"
        shutil.copy(video_path, moved_video_path)
        generated_video_path = moved_video_path.replace(".mp4","_GEN.mp4")
        replace_audio(moved_video_path, audio_path, generated_video_path)

    break


3733028847


100%|██████████| 150/150 [00:16<00:00,  9.22it/s]


In [14]:
# device = 0
# duration = 10  # generate 10 secs
# output_dir = "./demo"


# for audio, conditioning in dataloader:
#     for idx in range(dl_config['batch_size']):
#         # Save generated audio
#         video_path = conditioning['video_path'][idx].replace('../../', './')
#         audio_path = f"/home/chengxin/chengxin/AudioSet/generated_audios/specvqgan/10/{video_path.split('/')[-1].replace('.mp4', '.wav')}"
            
#         # Replace the audio of original video to generated one
#         moved_video_path = f"{output_dir}/{video_path.split('/')[-1]}"
#         shutil.copy(video_path, moved_video_path)
#         generated_video_path = moved_video_path.replace(".mp4","_GEN.mp4")
#         replace_audio(moved_video_path, audio_path, generated_video_path)

#     break

# 4. Sample

In [7]:
device = 0
duration = 10
video_paths = ['./dataset/video/test/AudioSet/10/-0YUDn-1yII_p.mp4', './dataset/video/test/AudioSet/10/-0jeONf82dE_p.mp4']
output_dir = "./demo"

output = generate_diffusion_cond_from_path(
    model=model.to(device),
    video_paths=video_paths,
    steps=100,
    cfg_scale=7,
    sample_rate=sample_rate,
    sample_size=sample_rate*duration,
    batch_size=len(video_paths),
    sigma_min=0.3,
    sigma_max=500,
    sampler_type="dpmpp-3m-sde",
    device=device
)

for idx in range(len(video_paths)):
    # Save generated audio
    video_path = video_paths[idx]
    audio_path = f"{output_dir}/{video_path.split('/')[-1].replace('.mp4', '.wav')}"
    save_audio(output[idx:1+idx], audio_path, sample_rate)
        
    # Replace the audio of original video to generated one
    moved_video_path = f"{output_dir}/{video_path.split('/')[-1]}"
    shutil.copy(video_path, moved_video_path)
    generated_video_path = moved_video_path.replace(".mp4","_GEN.mp4")
    replace_audio(moved_video_path, audio_path, generated_video_path)

2862089307
Extracting features from video:./dataset/video/test/AudioSet/10/-0YUDn-1yII_p.mp4
Extracting features from video:./dataset/video/test/AudioSet/10/-0jeONf82dE_p.mp4


100%|██████████| 100/100 [00:01<00:00, 52.09it/s]
