# 0. load model

In [1]:
import json
import pytorch_lightning as pl
from einops import rearrange
import torch
import torchaudio
from torch.utils.data import DataLoader
from safetensors.torch import load_file

import shutil
from collections import OrderedDict
import safetensors
from datetime import datetime
import os
import torch.nn.functional as F

from stable_audio_tools import create_model_from_config, replace_audio, save_audio
from stable_audio_tools.data.dataset import VideoFeatDataset, collation_fn
from stable_audio_tools.training.training_wrapper import DiffusionCondTrainingWrapper
from stable_audio_tools.inference.generation import generate_diffusion_cond, generate_diffusion_cond_from_path


model_config_file = './stable_audio_tools/configs/model_config.json'
# config_t2a: t2a      config:v2a     config1:new v2a
with open(model_config_file) as f:
    model_config = json.load(f)

sample_rate = model_config["sample_rate"]
sample_size = model_config["sample_size"]
fps = model_config["fps"]
sample_size = None

model = create_model_from_config(model_config)
sample_rate, fps, sample_size

  from .autonotebook import tqdm as notebook_tqdm


(44100, 22, None)

In [2]:
model_dir = './weight/StableAudio/2024-08-04 02:52:24'
model_name = 'epoch=27-step=2818'  
# ./weight/StableAudio/lightning_logs/version_3/checkpoints  epoch=9-step=14620                                     FPS=22  config
# ./weight/StableAudio/2024-07-02 19:49:04                   epoch=2-step=427                                       FPS=22 basemodel config
# ./weight/StableAudio/2024-07-06 10:28:13               在basemodel的基础上 加入pos_emb,利用AS进行训练<epoch=30-step=58>   time_align和生成效果表现较好(best)   FPS=22 config
# ./weight/StableAudio/2024-07-06 21:34:56               在bestmodel的基础上    加入pos_emb 利用VGG进行训练 epoch=3-step=1023                                          FPS=22 config
# ./weight/StableAudio/2024-07-19 18:04:16               在basemodel的基础上 加入pos_emb 利用AS重新训练  epoch=13-step=220   <epoch=88-step=220>    FPS=22  config
# ./weight/StableAudio/2024-07-20 13:34:51               在basemodel的基础上 加入pos_emb 利用VGG重新训练 epoch=3-step=1817                          FPS=22  config
# ./weight/StableAudio/2024-07-22 19:08:45               在不load t2a-crosscond&conditioner的基础上 加入rotary_cond_emb 利用AS重新训练              FPS=8 sr=16000  config_rotebd
# ./weight/StableAudio/2024-07-24 23:06:33               在bestmodel的基础上    利用AS继续进行训练     <epoch=70-step=304> <epoch=150-step=304>            FPS=22  config
# ./weight/StableAudio/2024-08-01 09:36:20               在lastmodel<epoch=70-step=304>的基础上    利用ASVGG继续进行训练   <epoch=45-step=2818>  <epoch=29-step=2818>    FPS=22  config
# ./weight/StableAudio/2024-08-01 09:36:20               在lastmodel<epoch=29-step=2818>的基础上    利用ASVGG继续进行训练
#                                                        在不load t2a-crosscond&conditioner的基础上 加入rotary_cond_emb 利用VGG重新训练             FPS=8 sr=16000  config_rotebd
#                                                        在xxxxxxxxx的基础上 加入global_cond_ids ["time_cond", "seconds_total"]，采用prepend
#                                                        在xxxxxxxxx的基础上 加入global_cond_ids ["time_cond", "seconds_total"]，采用concat

try:
    state_dict = load_file(f'{model_dir}/{model_name}.safetensors')
except:
    state_dict = torch.load(f'{model_dir}/{model_name}.ckpt')['state_dict']
    state_dict = OrderedDict([(".".join(key.split('.')[1:]), value)  for key, value in state_dict.items()])
    safetensors.torch.save_file(state_dict, f'{model_dir}/{model_name}.safetensors')

print(f'{model_dir}/{model_name}.safetensors')
# state_dict = load_file(f'./weight/StableAudio/model.safetensors')
model.load_state_dict(state_dict, strict=False)

./weight/StableAudio/2024-08-04 02:52:24/epoch=27-step=2818.safetensors


<All keys matched successfully>

## 0.1 T2A

In [5]:
device = 6
state_dict = load_file(f'./weight/StableAudio/model.safetensors')
model.load_state_dict(state_dict, strict=False)
model = model.to(device)

# Set up text and timing conditioning
conditioning = {
    "prompt": ["laugh"], ### 128 BPM tech house drum loop
    # 'feature': ['/home/chengxin/chengxin/FoleyCrafter/examples/sora/0.mp4'],
    "seconds_start": torch.tensor([0]), 
    "seconds_total": torch.tensor([10])
}

# Generate stereo audio
output = generate_diffusion_cond(
    model,
    steps=50,
    cfg_scale=7,
    conditioning=conditioning,
    sample_size=44100*47,
    sigma_min=0.3,
    sigma_max=500,
    sampler_type="k-dpm-fast", # dpmpp-3m-sde
    device=device,
    seed=3044415654
)
# Rearrange audio batch to a single sequence
output = rearrange(output, "b d n -> d (b n)")

# Peak normalize, clip, convert to int16, and save to file
output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
torchaudio.save("output.wav", output, sample_rate)

KeyboardInterrupt: 

# 2. Train

In [None]:
info_dirs = [
    './dataset/feature/train/AudioSet/10', 
    # './dataset/feature/train/VGGSound/10',
    ]
audio_dirs = [
    '/home/chengxin/chengxin/AudioSet/generated_audios/train/10', 
    # '/home/chengxin/chengxin/VGGSound/generated_audios/train/10'
    ]
# info_dirs = ['./dataset/feature/train/AudioSet/10', './dataset/feature/train/VGGSound/10']
# audio_dirs = ['/home/chengxin/chengxin/AudioSet/generated_audios/train/10', '/home/chengxin/chengxin/VGGSound/generated_audios/train/10']

ds_config = {
    'info_dirs' : info_dirs,
    'audio_dirs' : audio_dirs,
    'exts':'wav',
    'sample_rate':sample_rate, 
    'sample_size':sample_size,
    'fps':fps,
    'force_channels':"mono",
    'limit_num':300
}
dl_config = {
    'batch_size':20, 
    'shuffle':False,
    'num_workers':4, 
    'persistent_workers':True, 
    'pin_memory':True, 
    'drop_last':False, 
}


dataset = VideoFeatDataset(**ds_config)
dataloader = DataLoader(dataset=dataset,  collate_fn=collation_fn, **dl_config)

Found 300 files


In [None]:
training_config = model_config.get('training', None)
training_wrapper = DiffusionCondTrainingWrapper(
            model=model, 
            lr=training_config.get("learning_rate", None),
            optimizer_configs=training_config.get("optimizer_configs", None),
            pre_encoded=training_config.get("pre_encoded", False),
            cfg_dropout_prob = training_config.get("cfg_dropout_prob", 0.1),
            timestep_sampler = training_config.get("timestep_sampler", "uniform"),
        )

trainer = pl.Trainer(
    devices=[1],
    accelerator="gpu",
    num_nodes = 1,
    max_epochs=2,
)
trainer.fit(training_wrapper, dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name      | Type                             | Params
---------------------------------------------------------------
0 | diffusion | ConditionedDiffusionModelWrapper | 1.2 B 
1 | losses    | MultiLoss                        | 0     
---------------------------------------------------------------
1.1 B     Trainable params
156 M     Non-trainable params
1.2 B     Total params
4,855.713 Total estimated model params size (MB)


Epoch 1: 100%|██████████| 15/15 [01:00<00:00,  0.25it/s, v_num=11]

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|██████████| 15/15 [01:30<00:00,  0.17it/s, v_num=11]


# 3.Test

In [6]:
from stable_audio_tools.data.dataset import VideoFeatDataset, collation_fn
from torch.utils.data import DataLoader

info_dirs = [
    # './dataset/feature/test/VGGSound/10', 
    './dataset/feature/test/unav100/10'
    ]
audio_dirs = [
    # '/home/chengxin/chengxin/VGGSound/generated_audios/test/10', 
    '/home/chengxin/chengxin/unav100/generated_audios/test'
    ]

ds_config = {
    'info_dirs' : info_dirs,
    # 'audio_dirs' : audio_dirs,
    'exts':'wav',
    'sample_rate':sample_rate, 
    'fps':fps,
    'force_channels':"mono",
    'limit_num':4
}
dl_config = {
    'batch_size':32, 
    'shuffle':False,
    'num_workers':4, 
    'persistent_workers':True, 
    'pin_memory':True, 
    'drop_last':False, 
}


dataset = VideoFeatDataset(**ds_config)
dataloader = DataLoader(dataset=dataset,  collate_fn=collation_fn, **dl_config)

Found 4 files


In [7]:
# model_dir = './weight/StableAudio/2024-07-06 10:28:13'
# model_name = 'epoch=30-step=58'  

device = 0
output_dir = f"./demo/{model_name}"
os.makedirs(output_dir, exist_ok=True)
print(output_dir)

for conditioning in dataloader:
    seconds_total = max(conditioning['seconds_total'])
    
    output = generate_diffusion_cond(
        model = model.to(device),
        steps=150,
        cfg_scale=7,
        conditioning=conditioning,
        sample_size=int(sample_rate*seconds_total),
        batch_size=len(conditioning['feature']),
        sigma_min=0.3,
        sigma_max=500,
        sampler_type="dpmpp-3m-sde", # k-dpm-fast"
        device=device
    )
    
    for idx in range(len(conditioning['feature'])):
        # Save generated audio
        if 'AuidoSet' in conditioning['video_path'][idx] or 'AudioSet' in conditioning['video_path'][idx]:
            l = conditioning['video_path'][idx].split('/')
            video_path = os.path.join('/home/chengxin/chengxin/AudioSet/dataset/', l[-4], l[-2], l[-1])
        else:
            video_path = conditioning['video_path'][idx].replace('../../../', '/home/chengxin/chengxin/')

        # video_path = conditioning['video_path'][idx].replace('../../', './')
        audio_path = f"{output_dir}/{video_path.split('/')[-1].replace('.mp4', '.wav')}"
        waveform = output[idx:1+idx,...,:int(conditioning['seconds_total'][idx]*sample_rate)]
        # print(output.shape, output[idx:idx+1].shape, waveform.shape)
        save_audio(waveform, audio_path, sample_rate)
        
        # Replace the audio of original video to generated one
        moved_video_path = f"{output_dir}/{video_path.split('/')[-1]}"
        shutil.copy(video_path, moved_video_path)
        generated_video_path = moved_video_path.replace(".mp4","_GEN.mp4")
        replace_audio(moved_video_path, audio_path, generated_video_path)
    


./demo/epoch=80-step=254
tensor([60.0100])
1270040263


100%|██████████| 150/150 [00:21<00:00,  6.93it/s]


# 4. Sample

In [None]:
conditioning = {
    'seconds_start': [0, 0],
    'seconds_total': [10, 10],
    'feature': ['/home/chengxin/chengxin/FoleyCrafter/examples/sora/0.mp4', '/home/chengxin/chengxin/FoleyCrafter/examples/sora/1.mp4']
}

output = generate_diffusion_cond(
    model = model.to(device),
    steps=50,
    cfg_scale=7,
    conditioning=conditioning,
    sample_size=int(sample_rate*seconds_total),
    batch_size=len(conditioning['feature']),
    sigma_min=0.3,
    sigma_max=500,
    sampler_type="k-dpm-fast",
    device=device
)

190309694
Extracting features from video:/home/chengxin/chengxin/FoleyCrafter/examples/sora/0.mp4
Extracting features from video:/home/chengxin/chengxin/FoleyCrafter/examples/sora/1.mp4


100%|██████████| 50/50 [00:01<00:00, 30.41it/s]


In [None]:
# device = 0
# seconds_total = 10
# video_paths = ['/home/chengxin/chengxin/FoleyCrafter/examples/sora/0.mp4', '/home/chengxin/chengxin/FoleyCrafter/examples/sora/1.mp4']
# output_dir = "./demo"

# conditioning = {'seconds_start':0,'seconds_total':10}
# output = generate_diffusion_cond_from_path(
#     model=model.to(device),
#     video_paths=video_paths,
#     conditioning=conditioning,
#     steps=100,
#     cfg_scale=7,
#     sample_rate=sample_rate,
#     sample_size=sample_rate*seconds_total,
#     batch_size=len(video_paths),
#     sigma_min=0.3,
#     sigma_max=500,
#     sampler_type="dpmpp-3m-sde",
#     device=device
# )

# for idx in range(len(video_paths)):
#     # Save generated audio
#     video_path = video_paths[idx]
#     audio_path = f"{output_dir}/{video_path.split('/')[-1].replace('.mp4', '.wav')}"
#     save_audio(output[idx:1+idx], audio_path, sample_rate)
        
#     # Replace the audio of original video to generated one
#     moved_video_path = f"{output_dir}/{video_path.split('/')[-1]}"
#     shutil.copy(video_path, moved_video_path)
#     generated_video_path = moved_video_path.replace(".mp4","_GEN.mp4")
#     replace_audio(moved_video_path, audio_path, generated_video_path)

In [None]:
import torch
from torch import nn
from torch.cuda.amp import autocast

class RotaryEmbedding(nn.Module):
    def __init__(
        self,
        dim,
        use_xpos = False,
        scale_base = 512,
        interpolation_factor = 1.,
        base = 10000,
        base_rescale_factor = 1.
    ):
        super().__init__()
        # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
        # has some connection to NTK literature
        # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
        base *= base_rescale_factor ** (dim / (dim - 2))

        inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)

        assert interpolation_factor >= 1.
        self.interpolation_factor = interpolation_factor

        if not use_xpos:
            self.register_buffer('scale', None)
            return

        scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)

        self.scale_base = scale_base
        self.register_buffer('scale', scale)

    def forward_from_seq_len(self, seq_len):
        device = self.inv_freq.device

        t = torch.arange(seq_len, device = device)
        return self.forward(t)

    @autocast(enabled = False)
    def forward(self, t):
        device = self.inv_freq.device

        t = t.to(torch.float32)

        t = t / self.interpolation_factor

        freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
        freqs = torch.cat((freqs, freqs), dim = -1)

        if self.scale is None:
            return freqs, 1.

        power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
        scale = self.scale ** rearrange(power, 'n -> n 1')
        scale = torch.cat((scale, scale), dim = -1)

        return freqs, scale

rotary_pos_emb = RotaryEmbedding(32)

In [None]:
from functools import reduce, partial
def rotate_half(x):
    x = rearrange(x, '... (j d) -> ... j d', j = 2)
    x1, x2 = x.unbind(dim = -2)
    return torch.cat((-x2, x1), dim = -1)

def apply_rotary_pos_emb(t, freqs, scale = 1):
    out_dtype = t.dtype

    # cast to float32 if necessary for numerical stability
    dtype = reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32))
    rot_dim, seq_len = freqs.shape[-1], t.shape[-2]

    freqs, t = freqs.to(dtype), t.to(dtype)
    freqs = freqs[-seq_len:, :]

    if t.ndim == 4 and freqs.ndim == 3:
        freqs = rearrange(freqs, 'b n d -> b 1 n d')


    # partial rotary embeddings, Wang et al. GPT-J
    t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
    t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)

    t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype)

    return torch.cat((t, t_unrotated), dim = -1)

In [None]:
freqs, _ = rotary_pos_emb.forward_from_seq_len(100)
freqs

tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 5.6234e-01, 3.1623e-01,  ..., 5.6234e-04, 3.1623e-04,
         1.7783e-04],
        [2.0000e+00, 1.1247e+00, 6.3246e-01,  ..., 1.1247e-03, 6.3246e-04,
         3.5566e-04],
        ...,
        [9.7000e+01, 5.4547e+01, 3.0674e+01,  ..., 5.4547e-02, 3.0674e-02,
         1.7249e-02],
        [9.8000e+01, 5.5109e+01, 3.0990e+01,  ..., 5.5109e-02, 3.0990e-02,
         1.7427e-02],
        [9.9000e+01, 5.5672e+01, 3.1307e+01,  ..., 5.5672e-02, 3.1307e-02,
         1.7605e-02]])

# 5. Time Cond

In [2]:
from stable_audio_tools.data.dataset import VideoFeatDataset, collation_fn
from torch.utils.data import DataLoader

info_dirs = [
    './dataset/feature_t/test/AudioSet/10', 
    # './dataset/feature/test/unav100/10'
    ]
audio_dirs = [
    '/home/chengxin/chengxin/AudioSet/generated_audios/test/10', 
    # '/home/chengxin/chengxin/unav100/generated_audios/test'
    ]

ds_config = {
    'info_dirs' : info_dirs,
    # 'audio_dirs' : audio_dirs,
    'exts':'wav',
    'sample_rate':sample_rate, 
    'fps':fps,
    'force_channels':"mono",
    'limit_num':32
}
dl_config = {
    'batch_size':32, 
    'shuffle':False,
    'num_workers':4, 
    'persistent_workers':True, 
    'pin_memory':True, 
    'drop_last':False, 
}


dataset = VideoFeatDataset(**ds_config)
dataloader = DataLoader(dataset=dataset,  collate_fn=collation_fn, **dl_config)

Found 32 files


In [3]:
model_dir = './weight/StableAudio/2024-07-06 10:28:13'
model_name = 'epoch=30-step=58'  

device = 7
output_dir = f"./demo/{model_name}"
os.makedirs(output_dir, exist_ok=True)
print(output_dir)

for conditioning in dataloader:
    del conditioning['seconds_start']
    seconds_total = max(conditioning['seconds_total'])
    output = generate_diffusion_cond(
        model = model.to(device),
        steps=50,
        cfg_scale=7,
        conditioning=conditioning,
        sample_size=int(sample_rate*seconds_total),
        batch_size=len(conditioning['feature']),
        sigma_min=0.3,
        sigma_max=500,
        sampler_type="k-dpm-fast",
        device=device
    )
    break



./demo/epoch=30-step=58
3642395041


100%|██████████| 50/50 [00:08<00:00,  5.92it/s]
